#lang scheme/base
(provide add-free-variables
map-variables
bindings
binding-ref
unify
empty
variable?
)
(require scheme/match "fail.ss"
(for-syntax scheme/base))
(define-struct store (equivs bindings))
(define (empty) (make-store '() '()))
(define (bindings s)
(match s ((struct store (e bs))
(if (null? e) bs (fail)))))
(define make-set list)
(define set-union append)
(define (set-remove set . els)
(let loop ((set set)
(els els))
(if (null? els) set
(loop (remove (car els) set) (cdr els)))))
(define set-map map)
(define (set-member? set el) (member el set))
(define-syntax-rule (set-ormap . a) (ormap . a))
(define (variable? x)
(and (symbol? x)
(not (eq? x '?))
(eq? #\? (string-ref (symbol->string x) 0))))
(define-struct binding (var val))
(define bind make-binding)
(define (binding-ref set ref-var)
(if (null? set) #f
(match (car set)
((struct binding (store-var val))
(if (eq? ref-var store-var) val
(binding-ref (cdr set) ref-var))))))
(define (store-join-es s es1 es2)
(match s
((struct store (equivs bindings))
(make-store
(set-union
(make-set (set-union es1 es2))
(set-remove equivs es1 es2))
bindings))))
(define (store-bind s es1 data)
(match s
((struct store (equivs bindings))
(make-store
(set-remove equivs es1)
(set-union
(set-map (lambda (var)
(bind var data))
es1)
bindings)))))
(define (store-var-es s var)
(set-ormap (lambda (set)
(and (set-member? set var) set))
(store-equivs s)))
(define (store-var-ref s var)
(binding-ref (store-bindings s) var))
(define (store-defined? s var)
(or (store-var-es s var)
(store-var-ref s var)))
(define (store-declare s . vars)
(match s ((struct store (es det))
(make-store (set-union
(map (lambda (var)
(when (store-defined? s var)
(error 'defined "~s" var))
(make-set var))
vars)
es) det))))
(define (make-unify-error x1 x2)
(lambda (y1 y2)
(error 'contradiction "~a=~a -> ~a=~a"
x1 x2 y1 y2)))
(define (varref/value s var/val)
(if (variable? var/val)
(let ((es (store-var-es s var/val))
(det (store-var-ref s var/val)))
(unless (or es det) (error 'undefined "~a" var/val))
(values es det))
(values #f var/val)))
(define-syntax-rule (do/ state (fn arg ...) ...)
(let ((s state))
(let* ((s (fn s arg ...)) ...) s)))
(define (wildcard? x)
(eq? x '?))
(define (unify s x1 x2)
(define did (make-set))
(let recurse ((x1 x1) (x2 x2) (s s))
(let ((this (list x1 x2))
(same (lambda (x1 x2) (if (equal? x1 x2) s (fail x1 x2)))))
(if (set-member? did this) s
(begin
(set! did (set-union (make-set this) did))
(let-values
(((es1 d1) (varref/value s x1))
((es2 d2) (varref/value s x2)))
(cond
((or (wildcard? x1) (wildcard? x2)) s)
((and es1 (eq? es1 es2) s))
((and es1 es2) (store-join-es s es1 es2))
((and es1 d2) (store-bind s es1 d2))
((and es2 d1) (store-bind s es2 d1))
((and (list? d1) (list? d2))
(unless (= (length d1) (length d2)) (fail))
(foldl recurse s d1 d2))
(else (same d1 d2)))))))))
(define (swap fn) (lambda (x y) (fn y x)))
(define (add-free-variables s expr)
(cond
((null? expr) s)
((variable? expr) (store-declare s expr))
((pair? expr) (add-free-variables
(add-free-variables s (car expr))
(cdr expr)))
(else s)))
(define (map-variables fn expr)
(cond
((null? expr) expr)
((variable? expr) (fn expr))
((pair? expr) (cons (map-variables fn (car expr))
(map-variables fn (cdr expr))))
(else expr)))
(define (deref s lst)
(apply values (map
(lambda (v)
(binding-ref s v))
lst)))