(module union mzscheme
(require "type-rep.ss" "subtype.ss" "tc-utils.ss"
"type-effect-printer.ss" "rep-utils.ss"
"type-comparison.ss")
(require (lib "plt-match.ss") (lib "list.ss") (lib "trace.ss"))
(provide Un (rename *Un Un))
(define (make-union* set)
(match set
[(list t) t]
[_ (make-Union set)]))
(define (Un . args)
(define (flat t)
(match t
[(Union: es) es]
[_ (list t)]))
(define (Values-types t) (match t [(Values: ts) ts]))
(define (remove-subtypes ts)
(let loop ([ts* ts] [result '()])
(cond [(null? ts*) (reverse result)]
[(ormap (lambda (t) (subtype (car ts*) t)) result) (loop (cdr ts*) result)]
[else (loop (cdr ts*) (cons (car ts*) result))])))
(define (union2 a b)
(define b* (make-union* b))
(cond
[(subtype a b*) (list b*)]
[(subtype b* a) (list a)]
[else (cons a b)]))
(let ([types (remove-dups (sort (apply append (map flat args)) type<?))])
(cond
[(null? types) (make-union* null)]
[(andmap Values? types)
(make-Values (apply map Un (map Values-types types)))]
[(ormap Values? types)
(int-err "Un: should not take the union of multiple values with some other type: ~a" types)]
[else (make-union* (remove-subtypes types) (foldr union2 null types))])))
(defintern (Un-intern args) (lambda (_ args) (apply Un args)) args)
(define (*Un . args) (Un-intern args))
(define (u-maker args) (apply Un args))
(set-union-maker! u-maker)
)