(module infer-unit mzscheme
(require "planet-requires.ss" (lib "unit.ss") "signatures.ss"
"type-def-structs.ss")
(require-libs)
(define-unit infer@
(import subst^ type-structs^ type-equal^)
(export infer^)
(define (fv t)
(define (fv-of-list ts)
(foldl (lambda (e acc) (set:union (fv e) acc)) (set:make-eq) ts))
(match t
[($ tvar v) (set:make-equal v)]
[($ arr ts t rest _ _) (foldr set:union
(if rest
(set:union (fv rest) (fv t))
(fv t))
(map fv ts))]
[($ funty ts) (foldr set:union (set:make-equal) (map fv ts))]
[($ vec t) (fv t)]
[($ mu v t) (set:remove v (fv t))]
[($ poly vars t) (set:difference (fv t) (set:list->equal vars))]
[($ union elems) (fv-of-list (set:elements elems))]
[($ values-ty elems) (fv-of-list elems)]
[($ pair-ty a d) (set:union (fv a) (fv d))]
[($ struct-ty name parent flds)
(foldr set:union (set:make-equal) (map fv flds))]
[(or ($ base-type _)
($ value _)
($ dynamic)
($ univ))
(set:make-equal)]))
(define (fv/list t) (set:elements (fv t)))
(define (unfold t)
(match t
[($ mu v b) (subst v t b)]
[_ (error "unfold failed" (type? t) t)]))
(define (unify cl) (unify/acc cl '()))
(define (unify1 t1 t2) (unify (list (list t1 t2))))
(define (unify/acc constraint-list acc)
(parameterize ([match-equality-test type-equal?])
(match constraint-list
[() acc]
[((t t) rest ...) (unify/acc rest acc)]
[((($ tvar v) t) rest ...)
(unify/acc (map (lambda (p) (map (lambda (e) (subst v t e)) p)) rest)
(cons (list v t) acc))]
[((t ($ tvar v)) rest ...)
(unify/acc (map (lambda (p) (map (lambda (e) (subst v t e)) p)) rest)
(cons (list v t) acc))]
[((($ funty (($ arr ts t t-rest t-thn-eff t-els-eff) ...)) ($ funty (($ arr ss s s-rest s-thn-eff s-els-eff) ...))) rest ...)
(let ()
(define (compatible-rest t-rest s-rest)
(andmap (lambda (x y) (or (and x y) (and (not x) (not y)))) t-rest s-rest))
(define (flatten/zip x y) (map list (apply append x) (apply append y)))
(if (and (= (length ts) (length ss))
(compatible-rest t-rest s-rest)
(equiv? t-thn-eff s-thn-eff)
(equiv? t-els-eff s-els-eff))
(let ([ret-constraints (map list t s)]
[rest-constraints (map list (filter values t-rest) (filter values s-rest))]
[arg-constraints (flatten/zip ts ss)])
(printf "constraints ~a~n"(append ret-constraints rest-constraints arg-constraints))
(unify/acc (append arg-constraints rest-constraints ret-constraints rest) acc))
#f))]
[((($ vec t) ($ vec s)) rest ...) (unify/acc (cons (list t s) rest) acc)]
[((($ pair-ty t1 t2) ($ pair-ty s1 s2)) rest ...)
(unify/acc (list* (list t1 s1) (list t2 s2) rest) acc)]
[((or (_ ($ dynamic)) (($ dynamic) _)) rest ...) (unify/acc rest acc)]
[((($ struct-ty nm p elems) ($ struct-ty nm p elems*)) rest ...)
(unify/acc (append rest (map list elems elems*)) acc)]
[((($ union e1) ($ union e2)) rest ...)
(let ([l1 (set:elements e1)]
[l2 (set:elements e2)])
(and (= (length l1) (length l2))
(unify/acc (append (map list l1 l2) rest) acc)))]
[((or (($ union _) _) (_ ($ union _))) rest ...)
(printf "FIXME: union type ~n~a~n---------~n~a~n in unifier~n"
(caar constraint-list)
(cadar constraint-list))
#f]
[(((? mu? s) (? mu? t)) rest ...)
(unify/acc (cons (rename s t) rest) acc)]
[((t (? mu? s)) rest ...) (unify/acc (cons (list t (unfold s)) rest) acc)]
[(((? mu? s) t) rest ...) (unify/acc (cons (list (unfold s) t) rest) acc)]
[((or (($ mu _ _) _) (_ ($ mu _ _))) rest ...)
(printf "FIXME: mu types ~a in unifier~n" constraint-list)
#f]
[((or (($ poly a b) _) (_ ($ poly a b))) rest ...)
#f]
[else #f]
)))
)
(provide (all-defined))
)