Source code for treem.commands.repair

"""Implementation of CLI repair command."""

import math

from itertools import chain

import numpy as np

from treem import Morph, SWC
from treem.utils.geom import repair_branch, sample, norm, rotation


[docs]def repair(args): """Corrects morphology reconstruction at the given nodes.""" # pylint: disable=invalid-name # pylint: disable=too-many-locals # pylint: disable=too-many-branches # pylint: disable=too-many-statements # pylint: disable=cell-var-from-loop # pylint: disable=expression-not-assigned vprint = print if args.verbose else lambda *a, **k: None morph = Morph(args.file) morig = morph.copy() if args.rotate and args.cut else morph err = 0 if args.translate: morph.data[:, SWC.XYZ] += np.array(args.translate) if args.rotate: morph.rotate([1, 0, 0], args.rotate[0] / 180 * math.pi) morph.rotate([0, 1, 0], args.rotate[1] / 180 * math.pi) morph.rotate([0, 0, 1], args.rotate[2] / 180 * math.pi) if args.shrink_xy: scale = np.array([args.shrink_xy, args.shrink_xy, 1]) origin = morph.root.coord().copy() for node in morph.root.walk(): coord = node.coord() coord *= scale shift = origin - morph.root.coord() morph.translate(shift) if args.shrink: if args.bottom_up: bottom = max(x.coord()[2] for x in morph.root.walk()) else: bottom = min(x.coord()[2] for x in morph.root.walk()) for node in morph.root.walk(): z = node.coord()[2] z = bottom + args.shrink * (z - bottom) node.v[SWC.Z] = z if args.zjump: nodes = [x for x in morph.root.walk() if x.ident() in args.zjump] for node in nodes: jump = node.parent.coord()[2] - node.coord()[2] if args.zjump_mode == 'align': shift = [0, 0, jump] morph.translate(shift, node) elif args.zjump_mode == 'split': shift = [0, 0, jump / 2] jump_sec = list(node.section()) for split_node in jump_sec[:int(len(jump_sec) / 2)]: morph.move(shift, split_node) elif args.zjump_mode == 'tilt': parent = node.parent dist = max(norm(parent.coord() - jump_node.coord()) for jump_node in node.leaves()) leng = max(norm(node.coord() - jump_node.coord()) for jump_node in node.leaves()) leaf = [jump_node for jump_node in node.leaves() if norm(node.coord() - jump_node.coord()) == leng][0] vdir = leaf.coord() - node.parent.coord() shift = [0, 0, jump * leng / dist] morph.translate(shift, node) udir = leaf.coord() - node.coord() axis, angle = rotation(udir, vdir) morph.rotate(axis, angle, node) elif args.zjump_mode == 'join': parent = node.parent dist = max(norm(parent.coord() - jump_node.coord()) for jump_node in node.leaves()) leng = max(norm(node.coord() - jump_node.coord()) for jump_node in node.leaves()) leaf = [jump_node for jump_node in node.leaves() if norm(node.coord() - jump_node.coord()) == leng][0] vdir = leaf.coord() - node.parent.coord() shift = [0, 0, jump * leng / dist] morph.translate(shift, node) udir = leaf.coord() - node.coord() axis, angle = rotation(udir, vdir) morph.rotate(axis, angle, node) start = list(node.section(reverse=True))[-1].parent dist = max(norm(start.coord() - jump_node.coord()) for jump_node in node.leaves()) leng = morph.length(node.parent.section(reverse=True)) vdir = leaf.coord() - start.coord() shift = [0, 0, -jump * leng / dist] morph.translate(shift, node.parent) udir = leaf.coord() - node.coord() axis, angle = rotation(udir, vdir) morph.rotate(axis, angle, node) if args.pool: pool = [Morph(f) for f in args.pool] if args.diam: nodes = [x for x in morph.root.walk() if x.ident() in args.diam] if args.pool: types = {x.type() for x in nodes} if args.diam_mode == 'joint': for node in nodes: r = [] if node.parent.type() != SWC.SOMA: r.append(node.parent.radius()) if not node.is_fork() and not node.is_leaf(): r.append(node.siblings[0].radius()) if r: node.v[SWC.R] = np.mean(r) else: vprint(f'diam in node {node.ident()} not repaired') err += 1 elif args.diam_mode == 'sec': for node in nodes: sec = list(node.section(reverse=True)) sec = list(sec[-1].section()) r = morph.radii(sec).mean() node.v[SWC.R] = r elif args.diam_mode == 'order': for node in nodes: point_type = node.type() order = node.order() if args.pool: radii = [m.radii(sec).mean() for m in pool for sec in m.root.sections() if sec[0].type() in types and sec[0].order() == order] if radii: r = np.mean(radii) node.v[SWC.R] = r else: vprint(f'diam in node {node.ident()} (order {order}) ' f'not repaired') err += 1 else: r = np.array([morph.radii(sec).mean() for sec in morph.root.sections() if sec[0].type() == point_type and sec[0].order() == order]).mean() node.v[SWC.R] = r elif args.diam_mode == 'breadth': for node in nodes: point_type = node.type() breadth = node.breadth() if args.pool: radii = [m.radii(sec).mean() for m in pool for sec in m.root.sections() if sec[0].type() in types and sec[0].breadth() == breadth] if radii: r = np.mean(radii) node.v[SWC.R] = r else: vprint(f'diam in node {node.ident()} ' f'(breadth {breadth}) not repaired') err += 1 else: r = np.array([morph.radii(sec).mean() for sec in morph.root.sections() if sec[0].type() == point_type and sec[0].breadth() == breadth]).mean() node.v[SWC.R] = r if args.diam_mode == 'value': for node in nodes: node.v[SWC.R] = args.diam_value / 2 if args.seed: np.random.seed(args.seed) def is_intact(tree, cuts): leaves = [x.ident() for x in tree.leaves()] return set(leaves).isdisjoint(cuts) if args.cut: # pylint: disable=too-many-nested-blocks cuts = set(x for x in args.cut if morph.node(x).type() != SWC.SOMA) keep_radii = args.keep_radii if args.del_branch: stems = [] for cut in cuts: stems.extend(x for x in filter(lambda x: x.is_stem() and x.type() != SWC.SOMA, morph.node(cut).walk(reverse=True)) if x not in stems) for node in stems: for child in node.siblings: morph.prune(child) vprint('renumbering nodes, old node ids are lost') morph = Morph(data=morph.data) cuts = [x.ident() for x in morph.root.siblings if x.is_leaf()] vprint(f'reassigning cut points to {cuts}') graft_points = set() keep_radii = True else: graft_points = set(args.cut).difference(cuts) types = {x.type() for x in morph.root.walk() if x.ident() in cuts} for point_type in types: intact_branches = {} if args.pool: for rec in pool: sections = filter(lambda x: x[0].type() == point_type, rec.root.sections()) nodes = chain(x[0] for x in sections) for node in nodes: order = node.order() if order not in intact_branches: intact_branches[order] = [] intact_branches[order].append((rec, node)) else: sections = filter(lambda x: x[0].type() == point_type, morig.root.sections()) nodes = chain(x[0] for x in sections) nodes = filter(lambda x: is_intact(x, args.cut), nodes) for node in nodes: order = node.order() if order not in intact_branches: intact_branches[order] = [] intact_branches[order].append((morig, node)) nodes = [x for x in morph.root.walk() if x.type() == point_type and x.ident() in cuts] for node in nodes: order = node.order() vprint(f'repairing node {node.ident()} (order {order})', end=' ') if order in intact_branches: idx = np.random.choice(len(intact_branches[order])) rec, rep = intact_branches[order][idx] vprint(f'using {rep.ident()} (order {order}) ...', end=' ') done = repair_branch(morph, node, rec, rep, force=args.force_repair, keep_radii=keep_radii) err += 1 if not done else 0 vprint('done') if done else vprint('not repaired') elif order - 1 in intact_branches: idx = np.random.choice(len(intact_branches[order - 1])) rec, rep = intact_branches[order - 1][idx] vprint(f'using {rep.ident()} (order {order-1}) ...', end=' ') done = repair_branch(morph, node, rec, rep, force=args.force_repair, keep_radii=keep_radii) err += 1 if not done else 0 vprint('done') if done else vprint('not repaired') elif args.force_repair: if intact_branches: order = np.random.choice(list(intact_branches.keys())) idx = np.random.choice(len(intact_branches[order])) rec, rep = intact_branches[order][idx] vprint(f'using {rep.ident()} (order {order}) ...', end=' ') done = repair_branch(morph, node, rec, rep, force=True) err += 1 if not done else 0 vprint('done') if done else vprint('not repaired') else: err += 1 vprint('... no intact branches, not repaired') else: err += 1 vprint('... not repaired') if args.cut and graft_points: point_type = args.graft_point_type intact_branches = [] if args.pool: for rec in pool: sections = filter(lambda x: x[0].type() == point_type and x[0].order() == 1, rec.root.sections()) nodes = chain(x[0] for x in sections) for node in nodes: intact_branches.append((rec, node)) else: sections = filter(lambda x: x[0].type() == point_type and x[0].order() == 1, morig.root.sections()) nodes = chain(x[0] for x in sections) nodes = filter(lambda x: is_intact(x, args.cut), nodes) for node in nodes: intact_branches.append((morig, node)) vprint('grafting branch on to a soma node', end=' ') nodes = [x for x in morph.root.walk() if x.ident() in graft_points] for node in nodes: vprint(f'{node.ident()}', end=' ') if intact_branches: idx = np.random.choice(len(intact_branches)) rec, rep = intact_branches[idx] morph.graft(rec.copy(rep), node) vprint('done') else: err += 1 vprint('... no intact branches, not grafted') if args.delete and not args.cut: nodes = [x for x in morph.root.walk() if x.ident() in args.delete] for node in nodes: morph.delete(node) if args.delete or args.cut: morph = Morph(data=morph.data) if args.res: ident = 1 data = [] idmap = {-1: -1} for sec in filter(lambda x: x[0].type() == SWC.SOMA, morph.root.sections()): for node in sec: v = node.v.copy() i, p = v[SWC.I].astype(int), v[SWC.P].astype(int) v[SWC.I], v[SWC.P] = ident, idmap[p] idmap[i] = ident data.append(v) ident += 1 for sec in filter(lambda x: x[0].type() in set(SWC.TYPES).difference((SWC.SOMA,)), morph.root.sections()): length = morph.length(sec) points = morph.points(sec) # pylint: disable=unsubscriptable-object points = np.insert(points, 0, sec[0].parent.v[SWC.XYZR], axis=0) points = sample(points, np.ceil(length / args.res).astype(int)) points = points[1:] start = True for ident, point in enumerate(points, ident): x, y, z, r = point pid = idmap[sec[0].parent_ident()] if start else ident - 1 v = np.array([ident, sec[0].type(), x, y, z, r, pid]) start = False if start else start data.append(v) idmap[sec[-1].v[SWC.I]] = ident ident += 1 morph = Morph(data=np.array(data)) if args.flip: center = morph.root.coord().copy() if 'x' in args.flip: morph.data[:, SWC.X] *= -1 if 'y' in args.flip: morph.data[:, SWC.Y] *= -1 if 'z' in args.flip: morph.data[:, SWC.Z] *= -1 shift = morph.root.coord() - center morph.data[:, SWC.XYZ] -= shift if args.center: morph.data[:, SWC.XYZ] -= morph.root.coord() morph.save(args.out) return err