(module remove-intersect mzscheme
(require "types.ss" "types-aux.ss" "infer.ss" "subtype.ss" "planet-requires.ss" "subst.ss" "infer2.ss")
(require-libs)
(define (overlap t1 t2)
(match (list t1 t2)
[(($ univ) _) #t]
[(_ ($ univ)) #t]
[((? mu?) _) (overlap (unfold t1) t2)]
[(_ (? mu?)) (overlap t1 (unfold t2))]
[(($ union e) t)
(ormap (lambda (t*) (overlap t* t)) (set:elements e))]
[(t ($ union e))
(ormap (lambda (t*) (overlap t t*)) (set:elements e))]
[((or ($ poly _ _))
(or ($ poly _ _))) #t] [(($ base-type s1) ($ base-type s2)) (eq? s1 s2)]
[(($ base-type _) ($ value _)) (subtype t2 t1)] [(($ value _) ($ base-type _)) (subtype t1 t2)] [(($ base-type _) _) #f]
[(_ ($ base-type _)) #f]
[else #t]))
(define (restrict t1 t2)
(define (unify/poly a b)
(if (poly? b)
(let* ([vs (poly-var b)]
[vs* (map gensym vs)]
[body* (subst-all (map (lambda (v v*) (list v (make-tvar v*))) vs vs*) (poly-type b))]
[subst (unify (list (list a body*)))])
(if subst a #f))
#f))
(define (union-map f l)
(apply Un (map f (set:elements (union-elems l)))))
(cond [(subtype t1 t2) t1] [(and (poly? t2)
(let* ([vars (poly-var t2)]
[t (poly-type t2)]
[subst (infer t t1 vars)])
(if subst
(restrict t1 (subst-all subst t))
#f)))]
[(unify/poly t1 t2)]
[(union? t1) (union-map (lambda (e) (restrict e t2)) t1)]
[(mu? t1)
(restrict (unfold t1) t2)]
[(mu? t2) (restrict t1 (unfold t2))]
[(subtype t2 t1) t2] [(not (overlap t1 t2)) (Un)] [else t2] ))
(define (intersect t1 t2)
(cond [(subtype t1 t2) t1]
[(subtype t2 t1) t2]
[else (match (list t1 t2)
[(($ mu v b) t) (intersect b t)]
[(($ union l1) ($ union l2))
(make-union* (set:filter (lambda (e) (set:member? e l2)) l1))]
[(($ union l) t)
(make-union* (set:filter (lambda (e) (subtype e t)) l))]
[_ t1])]
))
(define (remove old rem)
(define initial
(if (subtype old rem)
(Un) (match (list old rem)
[(($ union l) rem)
(apply Un (map (lambda (e) (remove e rem)) (set:elements l)))]
[(($ union l) t)
(make-union* (set:filter (lambda (e) (not (type-equal? e t))) l))]
[(t ($ union l2))
(set:fold remove t l2)]
[((? mu? old) t) (remove (unfold old) t)]
[(($ poly v b) t) (make-poly v (remove b rem))]
[_ old])))
(if (subtype old initial) old initial))
(provide restrict remove overlap)
)