(module subst-unit mzscheme
(require "planet-requires.ss" (lib "trace.ss") (lib "unit.ss")
"signatures.ss"
"type-def-structs.ss")
(require-libs)
(provide (all-defined))
(define-unit subst@
(import type-structs^)
(export subst^)
(define (subst var replacement ty)
(let ([sb (lambda (e) (subst var replacement e))])
(match ty
[($ tvar v) (if (eq? v var) replacement ty)]
[($ funty ((? arr? elems) ...)) (make-funty (map (match-lambda
[($ arr ins out rest thn-eff els-eff)
(make-arr (map sb ins)
(sb out)
rest
thn-eff
els-eff)])
elems))]
[($ vec t) (make-vec (sb t))]
[($ pred-ty t) (make-pred-ty (sb t))]
[($ poly vs t) (if (memq var vs)
ty
(let* ([fresh-vs (map gensym vs)]
[fresh-t (subst-all (map list vs (map make-tvar fresh-vs)) t)])
(make-poly fresh-vs (sb fresh-t))))]
[($ mu v t) (if (eq? v var) ty
(let* ([v* (gensym v)]
[t* (subst v (make-tvar v*) t)])
(make-mu v* (sb t*))))]
[($ struct-ty nm par elems)
(make-struct-ty nm par (map sb elems))]
[($ union elems)
(make-union (set:map sb elems))]
[($ values-ty vs) (make-values-ty (map sb vs))]
[($ pair-ty a d) (make-pair-ty (sb a) (sb d))]
[(or ($ base-type _)
($ value _)
($ dynamic)
($ univ))
ty]
[_ (car ty)])))
(define (subst-all s t)
(foldr (lambda (e acc) (subst (car e) (cadr e) acc)) t s))
)
(trace subst)
)