(module unit-utils mzscheme
(require (lib "unit.ss"))
(require-for-syntax (lib "unit-exptime.ss")
(lib "list.ss" "srfi" "1")
(lib "match.ss"))
(provide define-values/link-units/infer)
(define-syntax (define-values/link-units/infer stx)
(define (datum->sig-elem d)
(if (car d)
(quasisyntax/loc (cdr d) (tag . #,(cdr d)))
(cdr d)))
(define (get-sigs id)
(define-values (imps exps) (unit-static-signatures id id))
(list imps exps))
(define (flatten l) (apply append l))
(define (get-all-sigs ids)
(define imps/exps (map get-sigs ids))
(define-values (imps exps) (unzip2 imps/exps))
(values (flatten imps) (flatten exps)))
(define (mk imports exports units stx)
(quasisyntax/loc stx
(begin (define-compound-unit/infer new-unit@
(import #,@imports)
(export #,@exports)
(link #,@units))
(define-values/invoke-unit/infer new-unit@))))
(define (sig=? sig1 sig2)
(and (eq? (car sig1) (car sig2))
(or (symbol? (car sig1)) (not (car sig1)))
(bound-identifier=? (cdr sig1) (cdr sig2))))
(define (imp-in-exps? imp exps)
(s:member imp exps sig=?))
(define (imps/exps-from-units units)
(let-values ([(imps exps) (get-all-sigs units)])
(let* ([exps* (map datum->sig-elem exps)]
[imps* (map datum->sig-elem (filter (lambda (imp) (not (imp-in-exps? imp exps))) imps))])
(values imps* exps*))))
(syntax-case stx (import export)
[(_ (export . sigs) . units)
(let*-values ([(units) (syntax->list #'units)]
[(imps exps) (imps/exps-from-units units)])
(mk imps (syntax->list #'sigs) units stx))]
[(_ . units)
(andmap identifier? (syntax->list #'units))
(let*-values ([(units) (syntax->list #'units)]
[(imps exps) (imps/exps-from-units units)])
(mk imps exps units stx))]))
(define-signature x^ (x))
(define-signature y^ (y))
(define-signature z^ (z))
(define-unit y@
(import z^)
(export y^)
(define y (* 2 z)))
(define-unit x@
(import y^)
(export x^)
(define (x) (+ y 1)))
(define z 45)
(define-values/link-units/infer (export x^) x@ y@)
)