(module typechecker mzscheme
(require (rename (lib "1.ss" "srfi") assoc* assoc)
(lib "kerncase.ss" "syntax")
(lib "struct.ss" "syntax")
(lib "etc.ss")
"subst.ss" "utils.ss" "types.ss" "infer.ss" "types-aux.ss" "subtype.ss" "prims.ss" "remove-intersect.ss" "planet-requires.ss" "type-env.ss" "parse-type.ss" "tc-utils.ss" "type-environments.ss" "base-env.ss" "lexical-env.ss" "type-annotation.ss" "type-name-env.ss" (lib "match.ss"))
(require-for-template "prims.ss" (lib "match.ss"))
(require-libs)
(require (planet "environment.ss" ("cobbe" "environment.plt" 3 0)))
(provide (all-defined))
(require-for-template mzscheme)
(define (check-type stx e-type ty)
(parameterize ([current-orig-stx stx])
(unless (subtype e-type ty)
(tc-error "body had type ~a, variable had type ~a" e-type ty))))
(define (check-expr stx expr ty)
(check-type stx (tc-expr expr) ty))
(define (tc-literal v-stx)
(define (types-of-literals es)
(apply Un (map tc-literal es)))
(define v (syntax-e v-stx))
(cond
[(number? v) N]
[(boolean? v) (if v B (make-value #f))]
[(null? v) (make-value null) (make-poly '(a) (make-Listof (make-tvar 'a)))]
[(symbol? v) Sym]
[(string? v) String]
[(keyword? v) Keyword]
[(bytes? v) Bytes]
[(list? v) (make-Listof (types-of-literals v))]
[(vector? v) (make-vec (types-of-literals (vector->list v)))]
[else (begin (printf "checking literal : ~a~n" v) Univ)]))
(define (tc-id id)
(cond
[(syntax-property id 'defined-type-label) => (lambda (prop) (parse-type prop))]
[else (lookup-type/lexical id)]))
(define (tc-args arg-types dom-types rest-type arg-stxs)
(let loop ([args arg-types] [doms dom-types] [stxs arg-stxs])
(cond
[(and (null? args) (null? doms)) #t]
[(null? args) (tc-error "Insufficient arguments to function application, expected ~a, got ~a"
(length dom-types) (length arg-types))]
[(and (null? doms) rest-type)
(if (subtype (car args) rest-type)
(loop (cdr args) doms (cdr stxs))
(tc-error/stx (car stxs) "Rest argument had wrong type, expected: ~a and got: ~a" rest-type (car args)))]
[(null? doms)
(tc-error "Too many arguments to function, expected ~a, got ~a" (length dom-types) (length arg-types))]
[(subtype (car args) (car doms))
(loop (cdr args) (cdr doms) (cdr stxs))]
[else
(tc-error/stx (car stxs) "Wrong function argument type, expected ~a, got ~a" (car doms) (car args))])))
(define (tc/lambda-clause args body)
(syntax-case args ()
[(args ...)
(let* ([arg-list (syntax->list #'(args ...))]
[arg-types (map get-type arg-list)])
(with-lexical-env/extend
arg-list arg-types
(make-arr arg-types (tc-exprs (syntax->list body)))))]
[(args ... . rest)
(let* ([arg-list (syntax->list #'(args ...))]
[arg-types (map get-type arg-list)]
[rest-type (get-type #'rest)])
(with-lexical-env/extend
(cons #'rest arg-list) (cons (make-Listof rest-type) arg-types)
(make-arr arg-types (tc-exprs (syntax->list body)) rest-type)))]))
(define (tc/mono-lambda formals bodies)
(make-funty (map tc/lambda-clause (syntax->list formals) (syntax->list bodies))))
(define (tc/lambda form formals bodies)
(if (syntax-property form 'typechecker:plambda)
(tc/plambda form formals bodies)
(tc/mono-lambda formals bodies)))
(define (tc/plambda form formals bodies)
(with-syntax ([tvars (syntax-property form 'typechecker:plambda)])
(let* ([literal-tvars (map syntax-e (syntax->list #'tvars))]
[new-tvars (map make-tvar literal-tvars)]
[ty (parameterize ([current-tvars (extend-env literal-tvars new-tvars (current-tvars))])
(tc/mono-lambda formals bodies))])
(make-poly literal-tvars ty))))
(define (tc/if-onearm tst body) (tc/if-twoarm tst body (syntax/loc body (#%app void))))
(define (tc/if-twoarm tst thn els)
(define ((tc/update f) var t expr)
(let/ec return
(with-update-type/lexical (lambda (_ old)
(let ([new-t (f old t)])
(if (subtype new-t (Un))
(return (Un))
new-t)))
var
(tc-expr expr))))
(define tc/remove (tc/update remove))
(define tc/intersect (tc/update restrict))
(define tc/replace (tc/update (lambda (old t) t)))
(tc-expr tst)
(kernel-syntax-case tst #f
[(#%app p v)
(identifier? #'v)
(match (tc-expr #'p)
[($ pred-ty t)
(printf "using pred ~a~n" t)
(Un (tc/intersect #'v t thn) (tc/remove #'v t els))]
[_ (Un (tc-expr thn) (tc-expr els))])]
[i (identifier? #'i)
(Un (tc/remove #'i (make-value #f) thn)
(tc/replace #'i (make-value #f) els))]
[_ (Un (tc-expr thn) (tc-expr els))]))
(define (tc-expr form)
(define ty-ann (type-annotation form))
(define (internal-tc-expr form)
(kernel-syntax-case* form #f (values apply letrec-syntaxes+values) [(#%datum . val) (tc-literal #'val)]
[(quote val) (tc-literal #'val)]
[(lambda formals . body)
(tc/lambda form #'(formals) #'(body))]
[(case-lambda [formals . body] ...)
(tc/lambda form #'(formals ...) #'(body ...))]
[(let-values ([(name ...) expr ] ...) . body)
(let* ( [names (syntax->list #'((name ...) ...))]
[exprs (syntax->list #'(expr ...))]
[inferred-types (map tc-expr exprs)]
[types (map get-type/infer names inferred-types)]
[clauses (syntax-case form () [(lv cl . b) (syntax->list #'cl)])])
(with-lexical-env/extend
(map syntax->list names)
types
(for-each (lambda (stx e-type t) (check-type stx e-type t))
clauses
inferred-types
(map list->values-ty types))
(tc-exprs (syntax->list #'body))))]
[(letrec-values ([(name ...) expr] ...) . body)
(let* ([names (map syntax->list (syntax->list #'((name ...) ...)))]
[exprs (syntax->list #'(expr ...))]
[types (map (lambda (l) (map get-type l)) names)]
[clauses (syntax-case form () [(lv cl . b) (syntax->list #'cl)])])
(with-lexical-env/extend
names
types
(for-each (lambda (stx e t) (check-expr stx e t))
clauses
exprs
(map list->values-ty types))
(tc-exprs (syntax->list #'body))))]
[(set! id val)
(let* ([id-t (tc-id #'id)]
[val-t (tc-expr #'val)])
(unless (subtype val-t id-t)
(tc-error "Mutation only allowed with compatible types: ~a is not a subtype of ~a" val-t id-t))
Void)]
[(#%top . id) (lookup-type #'id)
(tc-error "#%top should never happen")]
[x (identifier? #'x) (tc-id #'x)]
[(#%variable-reference . _)
(tc-error "do not use #%variable-reference")]
[(#%app values . args)
(let ([tys (map tc-expr (syntax->list #'args))])
(list->values-ty tys))]
[(#%app apply f . args)
(let* ([f-ty (tc-expr #'f)]
[arg-tys0 (map tc-expr (syntax->list #'args))])
(define (split l)
(let loop ([l l] [acc '()])
(if (null? (cdr l))
(values (reverse acc) (car l))
(loop (cdr l) (cons (car l) acc)))))
(let-values ([(arg-tys tail-ty) (split arg-tys0)])
(match f-ty
[($ funty (($ arr doms rngs rests) ..1))
(let loop ([doms doms] [rngs rngs] [rests rests])
(cond [(null? doms) (tc-error "no function domain matched - domains were: ~a arguments were ~a" doms arg-tys0)]
[(and (subtypes arg-tys (car doms)) (subtype tail-ty (make-Listof (car rests))))
(car rngs)]
[else (loop (cdr doms) (cdr rngs) (cdr rests))]))]
[($ poly _ ($ funty _))
(tc-error "polymorphic functions not supported with apply")]
[else (tc-error "~a is not a function type" f-ty)])))]
[(#%app f args ...)
(let* ([ftype (tc-expr #'f)]
[argtypes (map tc-expr (syntax->list #'(args ...)))])
(match ftype
[($ pred-ty _)
(if (= 1 (length argtypes))
B
(tc-error "wrong number of arguments to predicate ~a" (syntax-e #'f)))]
[($ funty (($ arr doms rngs rests) ..1))
(printf "ftype is ~a~n" ftype)
(if (= 1 (length doms))
(begin
(tc-args argtypes (car doms) (car rests) (syntax->list #'(args ...)))
(car rngs))
(let loop ([doms* doms] [rngs rngs] [rests rests])
(cond [(null? doms*)
(tc-error "no function domain matched - domains were: ~a arguments were ~a" doms argtypes)]
[(subtypes/varargs argtypes (car doms*) (car rests)) (car rngs)]
[else (loop (cdr doms*) (cdr rngs) (cdr rests))])))]
[($ poly vars ($ funty (($ arr doms rngs #f) ...)))
(for-each (lambda (x) (unless (not (poly? x))
(tc-error "Polymorphic argument ~a to polymorphic function not allowed" x)))
argtypes)
(let* ([fresh-names (map gensym vars)]
[fresh-vars (map make-tvar fresh-names)]
[fresh-doms (map (lambda (argl) (map (lambda (t) (subst-all (map list vars fresh-vars) t)) argl)) doms)]
[fresh-rngs (map (lambda (l) (subst-all (map list vars fresh-vars) l)) rngs)])
(let loop ([doms* fresh-doms] [rngs* fresh-rngs])
(cond [(null? doms*)
(tc-error "no polymorphic function domain matched - domains were: ~a arguments were ~a" doms argtypes)]
[(and (= (length (car doms*))
(length argtypes))
(unify (map list (car doms*) argtypes)))
=> (lambda (substitution)
(subst-all (map list fresh-names (map make-tvar vars))
(subst-all substitution (car rngs*))))]
[else (loop (cdr doms*) (cdr rngs*))])))]
[else (tc-error "~a is not a function type" ftype)]))]
[(if tst body) (tc/if-onearm #'tst #'body)]
[(if tst thn els) (tc/if-twoarm #'tst #'thn #'els)]
[(letrec-syntaxes+values stxs vals . body)
(tc-expr (syntax/loc form (letrec-values vals . body)))]
[(begin e) (tc-expr #'e )]
[(begin e es ...)
(begin (tc-expr #'e)
(tc-expr #'(begin es ...)))]
[(begin) Void]
[(begin0 e) (tc-expr #'e )]
[(begin0 e es ...)
(begin0 (tc-expr #'e)
(tc-expr #'(begin es ...)))]
[(begin0) Void]
[_ (tc-error "cannot typecheck unknown form : ~a~n" (syntax-object->datum form))]))
(parameterize ([current-orig-stx form])
(unless (syntax? form)
(tc-error "bad form input to tc-expr"))
(let ([result (internal-tc-expr form)])
(cond [(and ty-ann (subtype result ty-ann)) ty-ann]
[ty-ann (tc-error "expression had type ~a, but was annotated with type ~a" result ty-ann)]
[else result])))
)
(define (tc-exprs exprs)
(cond [(null? exprs) Void]
[(null? (cdr exprs)) (tc-expr (car exprs))]
[else (tc-expr (car exprs))
(tc-exprs (cdr exprs))]))
(define (tc/poly-struct vars nm flds tys)
(with-syntax*
([(nm parent) (syntax-case nm ()
[nm (identifier? #'nm) #'(nm #f)]
[(nm par) #'(nm par)])]
[(_ maker pred . getters) (build-struct-names #'nm flds #f #t)])
(let* ([name (syntax-e #'nm)]
[tvars (map syntax-e vars)]
[new-tvars (map make-tvar tvars)]
[parent (if (syntax-e #'parent) (parse-type #'parent) #f)])
(parameterize ([current-tvars (extend-env tvars new-tvars (current-tvars))])
(let*
([name-tvar (make-tvar name)]
[types (parameterize
([current-tvars
(extend-env (list name) (list (make-poly tvars name-tvar)) (current-tvars))])
(map parse-type tys))]
[rec? (ormap (lambda (s) (set:member? name s)) (map fv types))]
[_ (if rec? (register-type-name #'nm (make-poly tvars name-tvar)))]
[types (map parse-type tys)]
[parent-field-types (if parent
(struct-ty-flds (if (poly? parent)
(subst-all (map list (poly-var parent) new-tvars)
(poly-type parent))
parent))
'())]
[fld-types (append parent-field-types types)]
[sty (make-struct-ty name parent fld-types tvars)]
[sty (if rec?
(make-mu name sty)
sty)]
[psty (make-poly* tvars sty)]
[external-fld-types/no-parent (map (lambda (t) (subst name sty t)) types)]
[external-fld-types (map (lambda (t) (subst name sty t)) fld-types)]
[result-tvars (filter (lambda (x) (not (member x (fv/list sty)))) tvars)]
[result-sty (if (null? result-tvars) sty
(make-poly result-tvars sty))])
(printf "added constructor of type ~a~n" (make-poly tvars (->* external-fld-types result-sty)))
(register-type-name #'nm psty)
(register-type #'maker (make-poly tvars (->* external-fld-types sty)))
(register-types (syntax->list #'getters)
(map (lambda (t) (make-poly tvars (->* (list sty) t))) external-fld-types/no-parent))
(register-type #'pred (make-pred-ty psty)))))))
(define (tc/struct nm flds tys)
(with-syntax*
([(nm parent) (syntax-case nm ()
[nm (identifier? #'nm) #'(nm #f)]
[(nm par) #'(nm par)])]
[(_ maker pred . getters) (build-struct-names #'nm flds #f #t)])
(let* ([name (syntax-e #'nm)]
[name-tvar (make-tvar name)]
[parent (if (syntax-e #'parent) (parse-type #'parent) #f)]
[types (parameterize
([current-tvars
(extend-env (list name) (list name-tvar) (current-tvars))])
(map parse-type tys))]
[rec? (ormap (lambda (s) (set:member? name-tvar s)) (map fv types))]
[parent-field-types (if parent (struct-ty-flds parent) '())]
[fld-types (append parent-field-types types)]
[sty (make-struct-ty name parent fld-types null)]
[sty (if rec?
(make-mu name sty)
sty)]
[external-fld-types/no-parent (map (lambda (t) (subst name sty t)) types)]
[external-fld-types (map (lambda (t) (subst name sty t)) fld-types)])
(printf "~a~n ~a~n" sty (->* external-fld-types sty))
(register-type-name #'nm sty)
(register-type #'maker (->* external-fld-types sty))
(register-types (syntax->list #'getters)
(map (lambda (t) (->* (list sty) t)) external-fld-types/no-parent))
(register-type #'pred (make-pred-ty sty)))))
(define (list->values-ty l)
(if (= 1 (length l)) (car l) (make-values-ty l)))
(define (tc-toplevel/pass1 form)
(parameterize ([current-orig-stx form])
(kernel-syntax-case* form #f (define-type-alias-internal define-typed-struct-internal require/typed-internal : tc/check)
[stx
(syntax-property form 'typechecker:ignore)
(void)]
[(begin0 (require/typed-internal nm ty))
(register-type #'nm (parse-type #'ty))]
[(begin0 (define-type-alias-internal (nm . vars) ty))
(register-type-name #'nm (parse-type #'(All vars ty)))]
[(begin0 (define-type-alias-internal nm ty))
(register-type-name #'nm (parse-type #'ty))]
[(begin0 (define-typed-struct-internal nm ([fld : ty] ...)))
(tc/struct #'nm (syntax->list #'(fld ...)) (syntax->list #'(ty ...)))]
[(begin0 (define-typed-struct-internal (vars ...) nm ([fld : ty] ...)))
(tc/poly-struct (syntax->list #'(vars ...)) #'nm (syntax->list #'(fld ...)) (syntax->list #'(ty ...)))]
[(begin0 (define-typed-internal nm [variant (fld ty) ...] ...))
(void)]
[(define-values (var ...) expr)
(andmap (lambda (s) (syntax-property s 'type-label)) (syntax->list #'(var ...)))
(let* ([vars (syntax->list #'(var ...))]
[ts (map get-type vars)])
(for-each register-type vars ts))]
[(define-values . _)
(tc-error "Untyped definition")]
[_ (void)]
[(require . _) (void)]
[(require-for-syntax . _) (void)]
[(require-for-template . _) (void)]
[(provide . _) (void)]
[(define-syntaxes . _) (void)]
[_ (begin
(tc-expr form)
cenv)])))
(define (tc-toplevel/pass2 form)
(parameterize ([current-orig-stx form])
(kernel-syntax-case* form #f (define-type-alias-internal define-typed-struct-internal require/typed-internal : tc/check)
[stx
(syntax-property form 'typechecker:ignore)
(void)]
[(require . _) (void)]
[(require-for-syntax . _) (void)]
[(require-for-template . _) (void)]
[(provide . _) (void)]
[(define-syntaxes . _) (void)]
[(begin0 (require/typed-internal . rest)) (void)]
[(begin0 (define-type-alias-internal . rest)) (void)]
[(begin0 (define-typed-struct-internal . rest)) (void)]
[(begin0 (define-type-internal . rest)) (void)]
[(define-values (var ...) expr)
(let* ([vars (syntax->list #'(var ...))]
[ts (map get-type vars)])
(check-expr form #'expr (list->values-ty ts)))]
[_ (tc-expr form)])))
(define (type-check forms)
(let ([forms (syntax->list forms)])
(for-each tc-toplevel/pass1 forms)
(for-each tc-toplevel/pass2 forms))
#`(begin
#,(env-init-code)
#,(tname-env-init-code)))
)