Source code for treem.commands.modify

"""Implementation of CLI modify command."""

import math
from itertools import chain

import numpy as np

from treem.io import SWC
from treem.morph import Morph
from treem.utils.geom import rotation


def _scale_radii(morph, nodes, scale_radius):
    """Scales radii."""
    scale = np.abs(scale_radius)
    for node in nodes:
        sec = list(node.section())
        radii = morph.radii(sec)  # NOSONAR (S1481) "Necessary for in-place NumPy modification"
        radii *= scale


def _scale_coords(morph, nodes, scale_coords):
    """Scales X,Y,Z coordinates."""
    scale = np.abs(scale_coords)
    for node in nodes:
        sec = list(node.section())
        head = sec[0].coord().copy()
        tail = sec[-1].coord().copy()
        coords = morph.coords(sec)  # NOSONAR (S1481) "Necessary for in-place NumPy modification"
        coords *= np.array(scale)
        shift = head - sec[0].coord()
        coords += shift
        for child in sec[-1].siblings:
            shift = sec[-1].coord() - tail
            morph.translate(shift, child)


def _jitter_coords(morph, nodes, jitter, rng, args):
    """Adds random jitter to X,Y,Z coordinates."""
    for node in nodes:
        sec = list(node.section())
        if len(sec) > 1:
            head = sec[0].coord().copy()
            tail = sec[-1].coord().copy()
            length = morph.length(sec[1:])
            coords = morph.coords(sec)
            if not args.sec:
                rnd = rng.uniform(-1, 1, np.shape(coords))
                coords += jitter * rnd
            else:
                xlen = 0
                rnd = rng.uniform(-1, 1, 3)
                for node in sec:
                    xlen += node.length()
                    vec = jitter * rnd * xlen / length
                    morph.move(vec, node)
            scale = length / morph.length(sec[1:])
            coords *= scale
            shift = head - sec[0].coord()
            coords += shift
            for child in sec[-1].siblings:
                shift = sec[-1].coord() - tail
                morph.translate(shift, child)


def _twist_branches(morph, nodes, twist, rng):
    """Rotates branches by random angle."""
    for node in nodes:
        axis = node.coord() - node.parent.coord()
        angle = twist * rng.uniform(-1, 1) * math.pi / 180
        morph.rotate(axis, angle, node)


def _stretch_sections(morph, nodes, stretch):
    """Straighten sections by relative factor."""
    for node in nodes:
        sec = list(node.section())
        if len(sec) > 1:
            head = sec[0].coord().copy()
            tail = sec[-1].coord().copy()
            length = morph.length(sec[1:])
            coords = morph.coords(sec)
            vdir = coords.mean(axis=0) - head
            vdir /= np.linalg.norm(vdir)
            for secnode in sec[1:]:
                coord = secnode.coord()
                coord += vdir * np.linalg.norm(coord - head) * stretch
            scale = length / morph.length(sec[1:])
            coords *= scale
            shift = head - sec[0].coord()
            coords += shift
            for child in sec[-1].siblings:
                shift = sec[-1].coord() - tail
                morph.translate(shift, child)


def _smooth_sections(morph, nodes, smooth):
    """Smooth sections iteratively by low-pass filtering."""
    for node in nodes:
        sec = list(node.section())
        if len(sec) > 2:
            head = sec[0].coord().copy()
            tail = sec[-1].coord().copy()
            length = morph.length(sec[1:])
            coords = morph.coords(sec)  # NOSONAR (S1481) "Necessary for in-place NumPy modification"
            for _ in range(smooth):
                for secnode in sec[-1::-1]:
                    coord = secnode.coord()  # NOSONAR (S1481) "Necessary for in-place NumPy modification"
                    coord += secnode.parent.coord()
                scale = length / morph.length(sec[1:])
                coords *= scale
                shift = head - sec[0].coord()
                coords += shift
            for child in sec[-1].siblings:
                shift = sec[-1].coord() - tail
                morph.translate(shift, child)


def _swap_branches(morph, node1, node2):
    """Swap two branches."""
    parent1 = node1.parent
    parent2 = node2.parent
    dir1 = node1.coord() - parent1.coord()
    dir2 = node2.coord() - parent2.coord()
    tree1 = morph.copy(node1)
    tree2 = morph.copy(node2)
    coord1 = tree1.root.coord().copy()
    coord2 = tree2.root.coord().copy()
    tree1.translate(coord2 - coord1)
    axis, angle = rotation(dir1, dir2)
    tree1.rotate(axis, angle)
    tree2.translate(coord1 - coord2)
    axis, angle = rotation(dir2, dir1)
    tree2.rotate(axis, angle)
    morph.graft(tree1, parent2)
    morph.graft(tree2, parent1)
    morph.prune(node1)
    morph.prune(node2)



def _prune_branches(morph, nodes, args):
    """Prune branches if structure is not changed."""
    if not args.swap:
        for node in nodes:
            morph.prune(node)


def _collect_nodes(morph, args):
    """Collects nodes by given ids, default to section start nodes."""
    if args.ids:
        nodes = filter(lambda x: x.ident() in args.ids, morph.root.walk())
    else:
        sections = chain.from_iterable(x.sections() for x in morph.stems())
        nodes = chain(sec[0] for sec in sections)
    return nodes


def _filter_attr(nodes, args):
    """Selects nodes by given attributes."""
    types = args.type if args.type else SWC.TYPES
    nodes = filter(lambda x: x.type() in types, nodes)
    if args.order:
        nodes = filter(lambda x: x.order() in args.order, nodes)
    if args.breadth:
        nodes = filter(lambda x: x.breadth() in args.breadth, nodes)
    return list(nodes)


def _set_random_generator(args):
    """Initialize random generator."""
    if args.seed:
        rng = np.random.default_rng(seed=args.seed)
    else:
        rng = np.random.default_rng(0)
    return rng


[docs]def modify(args): """Modifies selected parts of morphology reconstruction.""" morph = Morph(args.file) rng = _set_random_generator(args) # collect nodes to operate on nodes = _collect_nodes(morph, args) nodes = _filter_attr(nodes, args) # manipulate morphology at specified nodes if args.scale_radius: _scale_radii(morph, nodes, args.scale_radius) if args.scale: _scale_coords(morph, nodes, args.scale) if args.jitter: _jitter_coords(morph, nodes, args.jitter, rng, args) if args.twist: _twist_branches(morph, nodes, args.twist, rng) if args.stretch: _stretch_sections(morph, nodes, args.stretch) if args.smooth: _smooth_sections(morph, nodes, args.smooth) if args.swap: rng.shuffle(nodes) node1, node2 = nodes[:2] _swap_branches(morph, node1, node2) if args.prune: _prune_branches(morph, nodes, args) # renumber nodes in restructured morphology if args.prune or args.swap: morph = Morph(data=morph.data) morph.save(args.out)