Source code for treem.commands.modify

"""Implementation of CLI modify command."""

import math

from itertools import chain

import numpy as np

from treem.morph import Morph
from treem.io import SWC
from treem.utils.geom import rotation


[docs]def modify(args): """Modifies selected parts of morphology reconstruction.""" # pylint: disable=too-many-locals # pylint: disable=too-many-branches # pylint: disable=too-many-statements morph = Morph(args.file) if args.ids: nodes = filter(lambda x: x.ident() in args.ids, morph.root.walk()) else: sections = chain.from_iterable(x.sections() for x in morph.stems()) nodes = chain(x[0] for x in sections) types = args.type if args.type else SWC.TYPES nodes = filter(lambda x: x.type() in types, nodes) if args.order: nodes = filter(lambda x: x.order() in args.order, nodes) if args.breadth: nodes = filter(lambda x: x.breadth() in args.breadth, nodes) nodes = list(nodes) if args.scale_radius: scale = np.abs(args.scale_radius) for node in nodes: sec = list(node.section()) radii = morph.radii(sec) radii *= scale if args.scale: scale = np.abs(args.scale) for node in nodes: sec = list(node.section()) head = sec[0].coord().copy() tail = sec[-1].coord().copy() coords = morph.coords(sec) coords *= np.array(args.scale) shift = head - sec[0].coord() coords += shift for child in sec[-1].siblings: shift = sec[-1].coord() - tail morph.translate(shift, child) if args.seed: np.random.seed(args.seed) if args.jitter: for node in nodes: sec = list(node.section()) if len(sec) > 1: head = sec[0].coord().copy() tail = sec[-1].coord().copy() length = morph.length(sec[1:]) coords = morph.coords(sec) if not args.sec: rnd = np.random.uniform(-1, 1, np.shape(coords)) coords += args.jitter * rnd else: xlen = 0 rnd = np.random.uniform(-1, 1, 3) for node in sec: xlen += node.length() vec = args.jitter * rnd * xlen / length morph.move(vec, node) scale = length / morph.length(sec[1:]) coords *= scale shift = head - sec[0].coord() coords += shift for child in sec[-1].siblings: shift = sec[-1].coord() - tail morph.translate(shift, child) if args.twist: for node in nodes: axis = node.coord() - node.parent.coord() angle = args.twist * np.random.uniform(-1, 1) * math.pi / 180 morph.rotate(axis, angle, node) if args.stretch: for node in nodes: sec = list(node.section()) if len(sec) > 1: head = sec[0].coord().copy() tail = sec[-1].coord().copy() length = morph.length(sec[1:]) coords = morph.coords(sec) vdir = coords.mean(axis=0) - head vdir /= np.linalg.norm(vdir) for secnode in sec[1:]: coord = secnode.coord() coord += vdir * np.linalg.norm(coord - head) * args.stretch scale = length / morph.length(sec[1:]) coords *= scale shift = head - sec[0].coord() coords += shift for child in sec[-1].siblings: shift = sec[-1].coord() - tail morph.translate(shift, child) if args.smooth: for node in nodes: sec = list(node.section()) if len(sec) > 2: head = sec[0].coord().copy() tail = sec[-1].coord().copy() length = morph.length(sec[1:]) coords = morph.coords(sec) for i in range(args.smooth): # pylint: disable=unused-variable for secnode in sec[-1::-1]: coord = secnode.coord() coord += secnode.parent.coord() scale = length / morph.length(sec[1:]) coords *= scale shift = head - sec[0].coord() coords += shift for child in sec[-1].siblings: shift = sec[-1].coord() - tail morph.translate(shift, child) if args.swap: np.random.shuffle(nodes) for node1, node2 in zip(nodes[:-1:2], nodes[1::2]): parent1 = node1.parent parent2 = node2.parent dir1 = node1.coord() - parent1.coord() dir2 = node2.coord() - parent2.coord() tree1 = morph.copy(node1) tree2 = morph.copy(node2) coord1 = tree1.root.coord().copy() coord2 = tree2.root.coord().copy() tree1.translate(coord2 - coord1) axis, angle = rotation(dir1, dir2) tree1.rotate(axis, angle) tree2.translate(coord1 - coord2) axis, angle = rotation(dir2, dir1) tree2.rotate(axis, angle) morph.graft(tree1, parent2) morph.graft(tree2, parent1) morph.prune(node1) morph.prune(node2) if args.prune and not args.swap: for node in nodes: morph.prune(node) if args.prune or args.swap: morph = Morph(data=morph.data) morph.save(args.out)