Source code for zss.compare

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#Authors: Tim Henderson and Steve Johnson
#Email: tim.tadh@gmail.com, steve@steveasleep.com
#For licensing see the LICENSE file in the top level directory.

from __future__ import absolute_import
from six.moves import range

import collections

try:
    import numpy as np
    zeros = np.zeros
except ImportError:
    def py_zeros(dim, pytype):
        assert len(dim) == 2
        return [[pytype() for y in range(dim[1])]
                for x in range(dim[0])]
    zeros = py_zeros

try:
    from editdist import distance as strdist
except ImportError:
    try:
        from editdistance import eval as strdist
    except ImportError:
        def strdist(a, b):
            if a == b:
                return 0
            else:
                return 1

from zss.simple_tree import Node


class AnnotatedTree(object):

    def __init__(self, root, get_children):
        self.get_children = get_children

        self.root = root
        self.nodes = list()  # a post-order enumeration of the nodes in the tree
        self.ids = list()    # a matching list of ids
        self.lmds = list()   # left most descendents
        self.keyroots = None
            # k and k' are nodes specified in the post-order enumeration.
            # keyroots = {k | there exists no k'>k such that lmd(k) == lmd(k')}
            # see paper for more on keyroots

        stack = list()
        pstack = list()
        stack.append((root, collections.deque()))
        j = 0
        while len(stack) > 0:
            n, anc = stack.pop()
            nid = j
            for c in self.get_children(n):
                a = collections.deque(anc)
                a.appendleft(nid)
                stack.append((c, a))
            pstack.append(((n, nid), anc))
            j += 1
        lmds = dict()
        keyroots = dict()
        i = 0
        while len(pstack) > 0:
            (n, nid), anc = pstack.pop()
            self.nodes.append(n)
            self.ids.append(nid)
            if not self.get_children(n):
                lmd = i
                for a in anc:
                    if a not in lmds: lmds[a] = i
                    else: break
            else:
                try: lmd = lmds[nid]
                except:
                    import pdb
                    pdb.set_trace()
            self.lmds.append(lmd)
            keyroots[lmd] = i
            i += 1
        self.keyroots = sorted(keyroots.values())


class Operation(object):
    """
    Dummy class for storing edit operations
    """
    remove = 0
    insert = 1
    update = 2
    match = 3

    def __init__(self, op, arg1=None, arg2=None):
        self.type = op
        self.arg1 = arg1
        self.arg2 = arg2

    def __repr__(self):
        if self.type == self.remove:
            return '<Operation Remove>'
        elif self.type == self.insert:
            return '<Operation Insert>'
        elif self.type == self.update:
            return '<Operation Update>'
        else:
            return '<Operation Match>'

    def __eq__(self, other):
        if other is None: return False
        if not isinstance(other, Operation):
            raise TypeError("Must compare against type Operation")
        return self.type == other.type and self.arg1 == other.arg1 and \
            self.arg2 == other.arg2


REMOVE = Operation.remove
INSERT = Operation.insert
UPDATE = Operation.update
MATCH = Operation.match


[docs]def simple_distance(A, B, get_children=Node.get_children, get_label=Node.get_label, label_dist=strdist, return_operations=False): """Computes the exact tree edit distance between trees A and B. Use this function if both of these things are true: * The cost to insert a node is equivalent to ``label_dist('', new_label)`` * The cost to remove a node is equivalent to ``label_dist(new_label, '')`` Otherwise, use :py:func:`zss.distance` instead. :param A: The root of a tree. :param B: The root of a tree. :param get_children: A function ``get_children(node) == [node children]``. Defaults to :py:func:`zss.Node.get_children`. :param get_label: A function ``get_label(node) == 'node label'``.All labels are assumed to be strings at this time. Defaults to :py:func:`zss.Node.get_label`. :param label_dist: A function ``label_distance((get_label(node1), get_label(node2)) >= 0``. This function should take the output of ``get_label(node)`` and return an integer greater or equal to 0 representing how many edits to transform the label of ``node1`` into the label of ``node2``. By default, this is string edit distance (if available). 0 indicates that the labels are the same. A number N represent it takes N changes to transform one label into the other. :param return_operations: if True, return a tuple (cost, operations) where operations is a list of the operations to transform A into B. :return: An integer distance [0, inf+) """ return distance( A, B, get_children, insert_cost=lambda node: label_dist('', get_label(node)), remove_cost=lambda node: label_dist(get_label(node), ''), update_cost=lambda a, b: label_dist(get_label(a), get_label(b)), return_operations=return_operations )
[docs]def distance(A, B, get_children, insert_cost, remove_cost, update_cost, return_operations=False): '''Computes the exact tree edit distance between trees A and B with a richer API than :py:func:`zss.simple_distance`. Use this function if either of these things are true: * The cost to insert a node is **not** equivalent to the cost of changing an empty node to have the new node's label * The cost to remove a node is **not** equivalent to the cost of changing it to a node with an empty label Otherwise, use :py:func:`zss.simple_distance`. :param A: The root of a tree. :param B: The root of a tree. :param get_children: A function ``get_children(node) == [node children]``. Defaults to :py:func:`zss.Node.get_children`. :param insert_cost: A function ``insert_cost(node) == cost to insert node >= 0``. :param remove_cost: A function ``remove_cost(node) == cost to remove node >= 0``. :param update_cost: A function ``update_cost(a, b) == cost to change a into b >= 0``. :param return_operations: if True, return a tuple (cost, operations) where operations is a list of the operations to transform A into B. :return: An integer distance [0, inf+) ''' A, B = AnnotatedTree(A, get_children), AnnotatedTree(B, get_children) size_a = len(A.nodes) size_b = len(B.nodes) treedists = zeros((size_a, size_b), float) operations = [[[] for _ in range(size_b)] for _ in range(size_a)] def treedist(i, j): Al = A.lmds Bl = B.lmds An = A.nodes Bn = B.nodes m = i - Al[i] + 2 n = j - Bl[j] + 2 fd = zeros((m,n), float) partial_ops = [[[] for _ in range(n)] for _ in range(m)] ioff = Al[i] - 1 joff = Bl[j] - 1 for x in range(1, m): # δ(l(i1)..i, θ) = δ(l(1i)..1-1, θ) + γ(v → λ) node = An[x+ioff] fd[x][0] = fd[x-1][0] + remove_cost(node) partial_ops[x][0].append(Operation(REMOVE, node)) for y in range(1, n): # δ(θ, l(j1)..j) = δ(θ, l(j1)..j-1) + γ(λ → w) node = Bn[y+joff] fd[0][y] = fd[0][y-1] + insert_cost(node) partial_ops[0][y].append(Operation(INSERT, arg2=node)) for x in range(1, m): # the plus one is for the xrange impl for y in range(1, n): # x+ioff in the fd table corresponds to the same node as x in # the treedists table (same for y and y+joff) node1 = An[x+ioff] node2 = Bn[y+joff] # only need to check if x is an ancestor of i # and y is an ancestor of j if Al[i] == Al[x+ioff] and Bl[j] == Bl[y+joff]: # +- # | δ(l(i1)..i-1, l(j1)..j) + γ(v → λ) # δ(F1 , F2 ) = min-+ δ(l(i1)..i , l(j1)..j-1) + γ(λ → w) # | δ(l(i1)..i-1, l(j1)..j-1) + γ(v → w) # +- costs = [fd[x-1][y] + remove_cost(node1), fd[x][y-1] + insert_cost(node2), fd[x-1][y-1] + update_cost(node1, node2)] fd[x][y] = min(costs) min_index = costs.index(fd[x][y]) if min_index == 0: op = Operation(REMOVE, node1) partial_ops[x][y] = partial_ops[x-1][y] + [op] elif min_index == 1: op = Operation(INSERT, arg2=node2) partial_ops[x][y] = partial_ops[x][y - 1] + [op] else: op_type = MATCH if fd[x][y] == fd[x-1][y-1] else UPDATE op = Operation(op_type, node1, node2) partial_ops[x][y] = partial_ops[x - 1][y - 1] + [op] operations[x + ioff][y + joff] = partial_ops[x][y] treedists[x+ioff][y+joff] = fd[x][y] else: # +- # | δ(l(i1)..i-1, l(j1)..j) + γ(v → λ) # δ(F1 , F2 ) = min-+ δ(l(i1)..i , l(j1)..j-1) + γ(λ → w) # | δ(l(i1)..l(i)-1, l(j1)..l(j)-1) # | + treedist(i1,j1) # +- p = Al[x+ioff]-1-ioff q = Bl[y+joff]-1-joff costs = [fd[x-1][y] + remove_cost(node1), fd[x][y-1] + insert_cost(node2), fd[p][q] + treedists[x+ioff][y+joff]] fd[x][y] = min(costs) min_index = costs.index(fd[x][y]) if min_index == 0: op = Operation(REMOVE, node1) partial_ops[x][y] = partial_ops[x-1][y] + [op] elif min_index == 1: op = Operation(INSERT, arg2=node2) partial_ops[x][y] = partial_ops[x][y-1] + [op] else: partial_ops[x][y] = partial_ops[p][q] + \ operations[x+ioff][y+joff] for i in A.keyroots: for j in B.keyroots: treedist(i, j) if return_operations: return treedists[-1][-1], operations[-1][-1] else: return treedists[-1][-1]