#lang racket/base
(require "label.rkt"
racket/list)
(provide (all-defined-out))
(define-struct suffix-tree (root))
(define-struct node (up-label parent children suffix-link) #:mutable)
(define (new-suffix-tree)
(make-suffix-tree
(let ((root (make-node (make-label (make-vector 0)) #f (list) #f)))
root)))
(define (node-root? node)
(eq? #f (node-parent node)))
(define (node-add-leaf! node label)
(let ((leaf (make-node label node (list) #f)))
(node-add-child! node leaf)
leaf))
(define (node-add-child! node child)
(set-node-children! node (cons child (node-children node))))
(define (node-remove-child! node child)
(set-node-children! node (remq child (node-children node))))
(define children-list list)
(define (node-leaf? node)
(empty? (node-children node)))
(define (node-find-child node label-element)
(define (loop children)
(cond ((null? children) #f)
((label-element-equal? label-element (label-ref (node-up-label (first children)) 0))
(first children))
(else
(loop (rest children)))))
(loop (node-children node)))
(define (node-up-split! node offset)
(let* ((label (node-up-label node))
(pre-label (sublabel label 0 offset))
(post-label (sublabel label offset))
(parent (node-parent node))
(new-node (make-node pre-label parent (children-list node) #f)))
(set-node-up-label! node post-label)
(node-remove-child! parent node)
(set-node-parent! node new-node)
(node-add-child! parent new-node)
new-node))
(define (node-up-splice-leaf! node offset leaf-label)
(let* ((split-node (node-up-split! node offset))
(leaf (node-add-leaf! split-node leaf-label)))
(values split-node leaf)))
(define (tree-contains? tree label)
(node-follow/k (suffix-tree-root tree)
label
(lambda args #t)
(lambda args #t)
(lambda args #f)
(lambda args #f)))
(define node-follow/k
(lambda (node original-label
matched-at-node/k
matched-in-edge/k
mismatched-at-node/k
mismatched-in-edge/k)
(letrec
((EDGE/k
(lambda (node label label-offset)
(let ((up-label (node-up-label node)))
(let loop ((k 0))
(cond
((= k (label-length up-label))
(NODE/k node label (+ label-offset k)))
((= (+ label-offset k) (label-length label))
(matched-in-edge/k node k))
((label-element-equal? (label-ref up-label k)
(label-ref label (+ k label-offset)))
(loop (add1 k)))
(else
(mismatched-in-edge/k node k label
(+ k label-offset))))))))
(NODE/k
(lambda (node label label-offset)
(if (= (label-length label) label-offset)
(matched-at-node/k node)
(let ((child (node-find-child
node
(label-ref label label-offset))))
(if child
(EDGE/k child label label-offset)
(mismatched-at-node/k node label label-offset)))))))
(NODE/k node (label-copy original-label) 0))))
(define (node-position-at-end? node offset)
(label-ref-at-end? (node-up-label node) offset))