(module find-optimal-join mzscheme
(require (lib "etc.ss")
(lib "contract.ss")
(planet "comprehensions.ss" ("dyoo" "srfi-alias.plt" 1))
(planet "contract-utils.ss" ("cobbe" "contract-utils.plt" 3 0)))
(define (make-table m n)
(build-vector
m
(lambda (i)
(build-vector
n
(lambda (j)
+inf.0)))))
(define (table-ref a-table i j)
(vector-ref (vector-ref a-table i) j))
(define (table-set! a-table i j val)
(vector-set! (vector-ref a-table i) j val))
(define (along-diagonals N f)
(do-ec (:range m 0 N)
(:range n 0 (- N m))
(let ([i n]
[j (+ m n)])
(f i j))))
(define (compute-recurrence forest-vec initial-cost-f cost+)
(local ((define N (vector-length forest-vec))
(define F (make-table N N))
(define K (make-table N N)))
(along-diagonals
N
(lambda (i j)
(cond [(= i j)
(table-set! F i i
(initial-cost-f (vector-ref forest-vec i)))]
[else
(do-ec (:range k i j)
(let ([cost-at-k
(cost+ (table-ref F i k)
(table-ref F (add1 k) j))])
(when (< cost-at-k (table-ref F i j))
(table-set! F i j cost-at-k)
(table-set! K i j k))))])))
K))
(define (concatenate-with-K forest-vec join-f K)
(let loop ([i 0]
[j (sub1 (vector-length forest-vec))])
(cond
[(= i j)
(vector-ref forest-vec i)]
[else
(let ([k (table-ref K i j)])
(join-f (loop i k)
(loop (add1 k) j)))])))
(define (join-forest forest join-f depth-f)
(local [(define (cost+ c1 c2)
(add1 (max c1 c2)))]
(join-forest/cost+ forest join-f depth-f cost+)))
(define (join-forest/cost+ forest join-f initial-cost-f cost+)
(local [(define forest-vec
(list->vector forest))
(define K
(compute-recurrence forest-vec initial-cost-f cost+))]
(concatenate-with-K forest-vec join-f K)))
(provide/contract [join-forest
((nelistof/c any/c)
(any/c any/c . -> . any)
(any/c . -> . natural-number/c)
. -> . any)]
[join-forest/cost+
((nelistof/c any/c)
(any/c any/c . -> . any)
(any/c . -> . natural-number/c)
(natural-number/c natural-number/c
. -> . natural-number/c)
. -> . any)]))