"""Morphology reconstruction data structure."""
import math
from collections import deque
import numpy as np
from treem.tree import Tree
from treem.io import SWC, load_swc, save_swc
from treem.utils.geom import rotation_matrix, norm
[docs]class Node(Tree):
"""Morphology data storage."""
def __init__(self, value=None):
"""Inits Node with value."""
super().__init__()
self.v = value # pylint: disable=invalid-name
def __str__(self):
"""String representation of Node value."""
return str(self.v)
[docs] def walk(self, reverse=False):
"""Iterates through the tree nodes starting from the current node.
Iteration is terminated when root is reached if reverse is
True. If False, full tree is traversed downstream in pre-order.
Args:
reverse (bool): walk the tree in ascending order if True.
Returns:
sequence of tree nodes (generator object).
"""
iterator = Tree.preorder if not reverse else Tree.ascendorder
return iterator(self)
[docs] def is_stem(self):
"""Returns True if node is a stem node."""
return (not self.is_root() and self.parent.is_root() and self.type() is not SWC.SOMA)
[docs] def order(self):
"""Returns branch order (int). A primary neurite has order 1."""
return (sum(1 for node in self.forks(iterator=Tree.ascendorder)) + 1
if not self.is_root() else 0)
[docs] def ident(self):
"""Returns node ID (int)."""
return int(self.v[SWC.I])
[docs] def parent_ident(self):
"""Returns node's parent ID (int)."""
return int(self.v[SWC.P])
[docs] def type(self):
"""Returns point type of the node (int)."""
return int(self.v[SWC.T])
[docs] def point(self):
"""Returns point data of the node (x,y,z,r) (NumPy ndarray[4])."""
return self.v[SWC.XYZR]
[docs] def coord(self):
"""Returns coordinates of the node (x,y,z) (NumPy ndarray[3])."""
return self.v[SWC.XYZ]
[docs] def dist(self, origin=[0.0, 0.0, 0.0]): # pylint: disable=dangerous-default-value
"""Returns Euclidean distance of the node to origin (float)."""
return np.linalg.norm(self.v[SWC.XYZ] - origin)
[docs] def radius(self):
"""Returns radius of the node (float)."""
return self.v[SWC.R]
[docs] def diam(self):
"""Returns diameter of the node (float)."""
return 2 * self.radius()
[docs] def length(self):
"""Returns segment length at the node (float)."""
# pylint: disable=invalid-name
a = self.coord()
b = self.parent.coord() if not self.is_root() else a
return norm(a - b)
[docs] def area(self):
"""Returns segment area at the node (float)."""
# pylint: disable=invalid-name
h = self.length()
a = self.radius()
b = self.parent.radius() if not self.is_root() else a
return math.pi * (a + b) * math.sqrt((a - b) * (a - b) + h * h)
[docs] def volume(self):
"""Returns segment volume at the node (float)."""
# pylint: disable=invalid-name
h = self.length()
a = self.radius()
b = self.parent.radius() if not self.is_root() else a
return math.pi / 3.0 * (a * a + a * b + b * b) * h
[docs] def section(self, reverse=False):
"""
Iterates through the nodes of the section.
Iteration starts with the current node and procedes until the
end of the section.
Args:
reverse (bool): ascending order if True (defaults to False).
Yields:
sequence of nodes (generator object).
"""
iterator = Tree.preorder if not reverse else Tree.ascendorder
for node in iterator(self):
yield node
term = node if not reverse else node.parent
if term.is_fork() or term.is_leaf() or term.is_root():
break
[docs] def sections(self):
"""Iterates through the sections in descending order.
Iterations traverse entire branch starting with the current node.
Yields:
sequence of sections (generator object).
"""
queue = deque((self.section(),))
while queue:
sec = list(queue.pop())
node = sec[-1]
siblings = deque(child.section() for child in node.siblings)
queue.extend(reversed(siblings))
yield sec
[docs]class Morph():
"""Neuron morphology representation."""
def __init__(self, source=None, data=None):
"""Initializes Morph from source file or data.
Args:
source (str): source file (swc).
data (NumPy ndarray): morphology data (N, 7).
"""
self.data = None
self.root = None
self.nodes = []
if source:
self.load(source)
elif data is not None:
data[0][SWC.P] = -1
idmap = dict(enumerate([int(x[0]) for x in data], 1))
idmap = {v: k for k, v in idmap.items()}
idmap[-1] = -1
for rec in data:
rec[SWC.I] = idmap[int(rec[SWC.I])]
rec[SWC.P] = idmap[int(rec[SWC.P])]
self.load(data=data)
[docs] def load(self, source=None, data=None):
"""Fill-in Morph from source file or data.
Args:
source (str): source file (swc).
data (NumPy ndarray): morphology data (N, 7).
"""
self.data = load_swc(source) if source else data
for row in self.data:
self.nodes.append(Node(row))
self.root = self.nodes[0]
for node in self.nodes[1:]:
parent = node.parent_ident() - 1
child = node.ident() - 1
self.nodes[parent].add(self.nodes[child])
[docs] def save(self, target):
"""Writes morphology to file (str)."""
save_swc(target, self.data)
[docs] def node(self, ident):
"""Returns node by it's ID."""
return [node for node in self.root.walk() if node.ident() == ident][0]
[docs] def stems(self):
"""Iterates through stem nodes.
Returns:
sequence of stem nodes (generator object).
"""
return filter(lambda x: x.type() != SWC.SOMA, self.root.siblings)
[docs] def coords(self, sec):
"""Returns reference to section coordinates."""
first = sec[0].ident() - 1
last = sec[-1].ident()
block = slice(first, last)
return self.data[block, SWC.XYZ]
# return np.array([node.coord() for node in sec])
[docs] def radii(self, sec):
"""Returns reference to section radii."""
first = sec[0].ident() - 1
last = sec[-1].ident()
block = slice(first, last)
# possibly unsafe addressing, see repair_branch()
return self.data[block, SWC.RADII]
# return np.array([node.radius() for node in sec])
[docs] def points(self, sec):
"""Returns reference to section data."""
first = sec[0].ident() - 1
last = sec[-1].ident()
block = slice(first, last)
return self.data[block, SWC.XYZR]
# return np.array([node.v[XYZR] for node in sec])
[docs] def length(self, sec):
"""Returns section length (float)."""
return sum(node.length() for node in sec)
[docs] def area(self, sec):
"""Returns section area (float)."""
return sum(node.area() for node in sec)
[docs] def volume(self, sec):
"""Returns section volume (float)."""
return sum(node.volume() for node in sec)
[docs] def move(self, shift, node):
"""Shifts node coordinates by 3D vector (float[3])."""
node.v[SWC.XYZ] += shift
[docs] def translate(self, shift, node=None):
"""Shifts coordinates of the branch at the given node.
Branch is traversed downstream from the given node.
Args:
shift (float[3]): translation vector.
node (treem.Node): starting node (defaults to root).
"""
node = node if node else self.root
for sec in node.sections():
points = self.coords(sec)
points += shift
[docs] def rotate(self, axis, angle, node=None):
"""Rotates branch at the node.
Branch is traversed downstream from the given node.
Args:
axis (float[3]): rotation axis.
angle (float): rotation angle in degrees.
node (treem.Node): starting node (defaults to root).
"""
node = node if node else self.root
head = node.coord().copy()
for sec in node.sections():
points = self.coords(sec)
first = sec[0].ident() - 1
last = sec[-1].ident()
block = slice(first, last)
self.data[block, SWC.XYZ] = np.dot(rotation_matrix(axis, angle),
points.T).T
shift = head - node.coord()
self.translate(shift, node)
[docs] def copy(self, node=None):
"""Copies branch at the node (defaults to root)."""
node = node if node else self.root
data = np.array([x.v for x in node.walk()])
return Morph(data=data)
# Programming notes:
# 1) __renumber() changes internal container data;
# 2) delete(), insert(), prune() and graft() desynchronize
# the data and the linked list;
# 3) constructor Morph(data=new_data) updates the linked list.
def __renumber(self):
"""Renumbers morphology nodes in tree traversal order."""
data = np.array([x.v for x in self.root.walk()])
idmap = dict(enumerate([int(x[0]) for x in data], 1))
idmap = {v: k for k, v in idmap.items()}
idmap[-1] = -1
for rec in data:
rec[SWC.I], rec[SWC.P] = idmap[rec[SWC.I]], idmap[rec[SWC.P]]
self.data = data
[docs] def delete(self, node):
"""Delete node."""
siblings = node.parent.siblings
index = siblings.index(node)
siblings.pop(index)
for child in node.siblings:
node.parent.add(child)
child.v[SWC.P] = node.parent.ident()
self.__renumber()
[docs] def insert(self, new_node, node):
"""Inserts new node before the given node."""
siblings = node.parent.siblings
index = siblings.index(node)
siblings.pop(index)
maxid = np.max(self.data[:, slice(SWC.I, SWC.I + 1)]).astype(int)
new_node.v[SWC.I] = maxid + 1
new_node.v[SWC.P] = node.parent.ident()
node.v[SWC.P] = new_node.ident()
node.parent.add(new_node)
new_node.add(node)
self.__renumber()
[docs] def prune(self, node):
"""Prunes branch at the given node."""
siblings = node.parent.siblings
index = siblings.index(node)
siblings.pop(index)
self.__renumber()
[docs] def graft(self, tree, node=None):
"""Grafts tree at the given node (defaults to root)."""
node = node if node else self.root
maxid = np.max(self.data[:, slice(SWC.I, SWC.I + 1)]).astype(int)
tree.data[:, slice(SWC.I, SWC.P + 1, SWC.P)] += maxid
tree.data[0][SWC.P] = node.ident()
self.data = np.append(self.data, tree.data, axis=0)
node.add(tree.root)
self.__renumber()
def get_segdata(morph):
"""Collects extended segment data."""
# pylint: disable=invalid-name
# pylint: disable=too-many-locals
m = morph
d = {}
for i, t, x, y, z, r, p in m.data:
i, t, x, y, z, r, p = int(i), int(t), float(x), float(y), float(z), float(r), int(p)
d[i] = {'t': t, 'x': x, 'y': y, 'z': z, 'r': r, 'p': p}
center = m.root.coord()
for node in m.root.walk():
if node.type() == SWC.SOMA:
ident = node.ident()
d[ident]['length'] = 0.0
d[ident]['path'] = 0.0
d[ident]['xsec'] = 0.0
d[ident]['xsec_rel'] = 0.0
d[ident]['dist'] = 0.0
d[ident]['degree'] = 0
d[ident]['order'] = 0
d[ident]['breadth'] = 0
d[ident]['totlen'] = 0.0
for stem in m.stems():
for sec in stem.sections():
order = 1
xsec = 0.0
seclen = m.length(sec)
for node in sec:
ident = node.ident()
length = node.length()
xsec += length
if node.parent.is_fork() and node.parent != m.root:
order += 1
dist = np.linalg.norm(center - node.coord())
path = d[node.parent.ident()]['path']
path += length
d[ident]['length'] = length
d[ident]['path'] = path
d[ident]['xsec'] = xsec
d[ident]['xsec_rel'] = xsec / seclen
d[ident]['dist'] = dist
d[ident]['degree'] = node.degree()
d[ident]['order'] = order
d[ident]['breadth'] = 1
d[ident]['totlen'] = 0.0
for term in m.root.leaves():
for node in term.walk(reverse=True):
if not node.is_leaf():
ident = node.ident()
descent_ident = [x.ident() for x in node.siblings]
descent_length = [x.length() for x in node.siblings]
descent_breadth = [d[i]['breadth'] for i in descent_ident]
descent_totlen = [d[i]['totlen'] for i in descent_ident]
breadth = sum(descent_breadth)
totlen = sum(descent_totlen) + sum(descent_length)
d[ident]['breadth'] = breadth
d[ident]['totlen'] = totlen
return np.array([[i, d[i]['t'], d[i]['x'], d[i]['y'], d[i]['z'],
d[i]['r'], d[i]['p'],
d[i]['length'], d[i]['path'], d[i]['xsec'], d[i]['xsec_rel'],
d[i]['dist'], d[i]['degree'], d[i]['order'], d[i]['breadth'],
d[i]['totlen']]
for i in sorted(d)])
class SEG(): # pylint: disable=too-few-public-methods
"""Definitions of the extended segment data format."""
(I, T, X, Y, Z, R, P, LENGTH, PATH, XSEC, XSEC_REL,
DIST, DEGREE, ORDER, BREADTH, TOTLEN) = range(16)
class DGram(Morph):
"""Neuron dendrogram representation."""
def __init__(self, morph=None, source=None, data=None, types=SWC.TYPES,
zorder=0.0, ystep=0.0, zstep=0.0):
# pylint: disable=too-many-arguments
# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
if not morph:
morph = Morph(source=source, data=data)
else:
morph = Morph(data=morph.data)
if morph is None:
super().__init__()
else:
for stem in morph.stems():
if stem.type() not in types:
morph.prune(stem)
graph = Morph(data=morph.data)
segdata = get_segdata(graph)
for sec in graph.root.sections():
start_node = sec[0]
# seclink = start_node.length()
# secrad = graph.radii(sec).mean()
for node in sec:
ident = node.ident()
data = graph.data[ident - 1]
segd = segdata[ident - 1]
data[SWC.X] = segd[SEG.PATH]
# data[SWC.R] = secrad
if ystep == 0.0 or zstep == 0.0:
# maxdist = max([node.dist() for node in morph.root.leaves()])
# ntips = sum([1 for node in morph.root.leaves()])
maxdist = max(node.dist() for node in morph.root.leaves())
ntips = sum(1 for node in morph.root.leaves())
dgram_step = maxdist / ntips
ystep = ystep if ystep != 0.0 else dgram_step
zstep = zstep if zstep != 0.0 else dgram_step
graph.data[:, SWC.YZ] = [0.0, zorder * zstep]
for stem in graph.stems():
for sec in stem.sections():
start_node = sec[0]
parent = start_node.parent
shift = start_node.coord() - parent.coord()
graph.translate(-shift, start_node)
for index, term in enumerate(graph.root.leaves(), start=1):
pos = index * ystep
for node in term.walk(reverse=True):
ident = node.ident()
value = graph.data[ident - 1]
if node.is_fork() or node.is_root():
pos = np.mean([x.coord()[1] for x in node.siblings])
value[SWC.Y] = pos
super().__init__(data=graph.data)