Source code for zss.simple_tree

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#Author: Tim Henderson
#For licensing see the LICENSE file in the top level directory.

from __future__ import absolute_import

import collections

[docs]class Node(object): """ A simple node object that can be used to construct trees to be used with :py:func:`zss.distance`. Example: :: Node("f") .addkid(Node("a") .addkid(Node("h")) .addkid(Node("c") .addkid(Node("l")))) .addkid(Node("e")) """ def __init__(self, label, children=None): self.label = label self.children = children or list() @staticmethod
[docs] def get_children(node): """ Default value of ``get_children`` argument of :py:func:`zss.distance`. :returns: ``self.children``. """ return node.children
[docs] def get_label(node): """ Default value of ``get_label`` argument of :py:func:`zss.distance`. :returns: ``self.label``. """ return node.label
[docs] def addkid(self, node, before=False): """ Add the given node as a child of this node. """ if before: self.children.insert(0, node) else: self.children.append(node) return self
[docs] def get(self, label): """:returns: Child with the given label.""" if self.label == label: return self for c in self.children: if label in c: return c.get(label)
[docs] def iter(self): """Iterate over this node and its children in a preorder traversal.""" queue = collections.deque() queue.append(self) while len(queue) > 0: n = queue.popleft() for c in n.children: queue.append(c) yield n
def __contains__(self, b): if isinstance(b, str) and self.label == b: return 1 elif not isinstance(b, str) and self.label == b.label: return 1 elif (isinstance(b, str) and self.label != b) or self.label != b.label: return sum(b in c for c in self.children) raise TypeError("Object %s is not of type str or Node" % repr(b)) def __eq__(self, b): if b is None: return False if not isinstance(b, Node): raise TypeError("Must compare against type Node") return self.label == b.label def __ne__(self, b): return not self.__eq__(b) def __repr__(self): return super(Node, self).__repr__()[:-1] + " %s>" % self.label def __str__(self): s = "%d:%s" % (len(self.children), self.label) s = '\n'.join([s]+[str(c) for c in self.children]) return s