Source code for tree

""" This module contains the data representation for the structure of trees.
Note that we separate the structure of the branched tree from its
realization, that would contain things such as charges, positions, etc.
that evolve even when the structure is fixed.
"""
from random import random as uniform
from scipy.sparse import lil_matrix, csr_matrix
from numpy import *

[docs]class Tree(object): """ Instances of the Tree class contain the topological information of a tree discharge: i.e. they encapsulate the relations between different segments in a tree but not about locations, conductivities etc..""" def __init__(self): # We must carry a global (tree-level) index to access the parameter # data arrays self.n = 0 self.segments = [] self.root = None
[docs] def add_segment(self, segment): """ Adds a *segment* to this tree. Returns the index of the segment inside the tree. """ if self.n == 0: self.root = segment self.segments.append(segment) index = self.n self.n += 1 return index
[docs] def parents(self, root_index=0): """ Builds an array with the indices to each segment's parent. The root segment gets an index *root_index*. """ p = zeros((self.n,), dtype='i') for i, segment in enumerate(self.segments): try: p[i] = segment.parent.index except AttributeError: p[i] = root_index return p
[docs] def make_root(self): """ Creates a segment node to be root of this tree. """ root = Segment() root.set_tree(self) self.root = root return root
[docs] def terminals(self): """ Finds all segments contained in the tree that do not have any children. Returns an array with segment indices. """ l = [] for i, segment in enumerate(self.segments): if not segment.children: l.append(i) return array(l)
[docs] def branches(self): """ Finds all indices of segments that branch in the tree""" l = [] for i, segment in enumerate(self.segments): if len(segment.children) > 1: l.append(i) return array(l)
[docs] def extend(self, indices): """ Extends the tree adding one children to each segment indexed by *indices*, in that order. This is used to extend a propagating tree. """ for i in indices: new_segment = Segment() self.segments[i].add_child(new_segment)
[docs] def zeros(self, dim=None): """ Returns an array that can hold all the data needed for a variable in this tree's segments. For multi-dimension data, use *dim*. """ if dim is None: return zeros((self.n,)) else: return zeros((self.n, dim))
[docs] def lengths(self, endpoints): """ Returns an array with the segment lengths of the tree, given an array with the *endpoints*. """ parents = self.parents() l = sqrt(sum((endpoints - endpoints[parents, :])**2, axis=1)) return l
[docs] def midpoints(self, endpoints): """ Returns an array with the segment midpoints of the tree, given an array with the *endpoints*. """ parents = self.parents() return 0.5 * (endpoints + endpoints[parents, :])
[docs] def ohm_matrix(self, endpoints, fix=[]): """ Builds a matrix M that will provide the evolution of charges in every segment of the tree as dq/dt = M . phi, where phi is the potential at the center of each segment and '.' is the dot product. This function builds the matrix from scratch. Usually it is much better to keep updating the matrix as the tree grows. * *endpoints* must contain an array with the endpoints. * *fix* contains an array with indices of nodes with a fixed charge. usually that means the root node. """ l = self.lengths(endpoints) linv = 1.0 / l # We build the matrix in LIL format first, later we convert to a # format more efficient for matrix-vector multiplications M = lil_matrix((self.n, self.n)) for segment in self: i = segment.index m = 0.0 for other in segment.children: j = other.index M[i, j] = linv[j] m -= linv[j] if segment.parent is not None: j = segment.parent.index M[i, j] = linv[i] m -= linv[i] M[i, i] = m for f in fix: M[f, :] = 0 return csr_matrix(M)
[docs] def branch_label(self, labels=None, label=1, segment=None): """ Returns an array with an integer for each node that is unique for the branch where it sits. """ if labels is None: labels = zeros((self.n,), dtype='i') if segment is None: segment = self.root while True: labels[segment.index] = label if len(segment.children) != 1: break segment = segment.children[0] for i, c in enumerate(segment.children): self.branch_label(labels, label=2*label + i, segment=c) return labels
[docs] def branch_distance(self, endpoints, dist=None, segment=None, lengths=None): """ Returns an array with the distance of each node from the branching immediately above it. The distance is calculated along the branch. """ if dist is None: dist = zeros((self.n,), dtype='d') if segment is None: segment = self.root if lengths is None: lengths = self.lengths(endpoints) l = 0 while True: dist[segment.index] = l l += lengths[segment.index] if len(segment.children) != 1: break segment = segment.children[0] for i, c in enumerate(segment.children): self.branch_distance(endpoints, dist, segment=c, lengths=lengths) return dist
[docs] def reconnects(self, endpoints, rmin=5e-4, dmin=1e-3): """ Finds reconnections in a tree. """ term = array(self.terminals()) rterm = endpoints[term, :] labels = self.branch_label() lterm = labels[term] dist = self.branch_distance(endpoints) dterm = dist[term] # We look only at node pairs where one of the node is a terminal. r2 = sum((rterm[newaxis, :, :] - endpoints[:, newaxis, :])**2, axis=2) dlabel = lterm[newaxis, :] - labels[:, newaxis] # These still include branching events, which are very close but # close to the branching points s = logical_and(dlabel != 0, r2 <= rmin**2) t = logical_and(dterm[newaxis, :] > dmin, dist[:, newaxis] > dmin) u = logical_and(s, t) i, j = nonzero(u) return len(i) > 0
[docs] def save(self, fname): """ Saves the tree structure into file fname. """ parents = self.parents() i = arange(self.n) savetxt(fname, c_[i, parents])
@staticmethod
[docs] def loadtxt(fname): """ Loads a tree structure from a txt file [DEBUG]. """ indices, parents = loadtxt(fname, unpack=True) return Tree.from_parents(parents)
@staticmethod
[docs] def from_parents(parents): """ Builds a tree from a list of the parent indices. """ t = Tree() indices = arange(parents.shape[0]) for i in indices: if i == 0: t.make_root() else: seg = Segment() t.segments[parents[i]].add_child(seg) return t
def __iter__(self): return iter(self.segments)
[docs]class Segment(object): """ This is class of the segments composing a :class:`Tree`. """ def __init__(self): self.children = [] self.parent = None self.tree = None def set_tree(self, tree): self.tree = tree self.index = tree.add_segment(self) def set_parent(self, parent): self.parent = parent self.set_tree(parent.tree)
[docs] def get(self, a): """ Gets the value in array a corresponding to this segment. """ return a[self.index]
[docs] def set(self, a, value): """ Sets the value in array a corresponding to this segment. """ a[self.index] = value
[docs] def iter_adjacent(self): """ Iterates over all adjacent segments, including parent and children (if any). """ if self.parent is not None: yield self.parent for child in self.children: yield child
[docs] def add_child(self, other): """ Adds the :class:`Segment` *other* as a child of this segment. """ other.set_parent(self) self.children.append(other)
[docs]def random_branching_tree(n, p): """ Builds a branched tree of n segments where every segment has a probability p of having two descendants. This produces nice pictures and can be useful for testing. """ tree = Tree() root = tree.make_root() leafs = [root] for i in xrange(n): l = leafs.pop(0) # Every leaf has at least one descendant s = Segment() l.add_child(s) leafs.append(s) # With probability p it has two children if random.uniform() < p: s = Segment() l.add_child(s) leafs.append(s) return tree
[docs]def sample_endpoints(tree): """ Gives endpoints to a tree structure. Useful for plotting sample trees [DEBUG]. """ r = tree.zeros(dim=3) deltav = {1: array([[0, 0, -1]]), 2: array([[-1, 0, -1], [1, 0, -1]])} def recurse(leaf, v): if leaf.parent is None: leaf.set(r, (0, 0, 0)) else: leaf.set(r, leaf.parent.get(r) + v) n = len(leaf.children) lr = leaf.get(r) for i, child in enumerate(leaf.children): vnew = (v * array([0.9, 1.0, 0.95]) + (deltav[n][i] + random.uniform(-0.1, 0.1, size=3)) * exp(lr[1] / 100)) recurse(child, vnew) recurse(tree.root, array([0, 0, 0])) return r
def test(): import pylab tree = random_branching_tree(1000, 0.05) r = sample_endpoints(tree) for segment in tree: ep = segment.get(r) try: ip = segment.parent.get(r) except AttributeError: ip = array([0, 0, 0]) pylab.plot([ip[0], ep[0]], [ip[2], ep[2]], lw=0.8, c='k') pylab.show() if __name__ == '__main__': test()