ml-package.ss
#lang scheme/base
(require (for-syntax scheme/base
                     syntax/kerncase
                     syntax/boundmap
                     syntax/define))

(provide define-package
         package-begin
         
         open-package
         open*-package
         
         define*
         define*-values
         define*-syntax
         define*-syntaxes
         
         (for-syntax package?
                     package-exported-identifiers
                     package-original-identifiers))

(define-for-syntax (do-define-* stx define-values-id)
  (syntax-case stx ()
    [(_ (id ...) rhs)
     (let ([ids (syntax->list #'(id ...))])
       (for-each (lambda (id)
                   (unless (identifier? id)
                     (raise-syntax-error
                      #f
                      "expected an identifier for definition"
                      stx
                      id)))
                 ids)
       (with-syntax ([define-values define-values-id])
         (syntax/loc stx
           (define-values (id ...) rhs))))]))
(define-syntax (-define*-values stx)
  (do-define-* stx #'define-values))
(define-syntax (-define*-syntaxes stx)
  (do-define-* stx #'define-syntaxes))
(define-syntax (define*-values stx)
  (syntax-case stx ()
    [(_ (id ...) rhs)
     (syntax-property
      (syntax/loc stx (-define*-values (id ...) rhs))
      'certify-mode
      'transparent-binding)]))
(define-syntax (define*-syntaxes stx)
  (syntax-case stx ()
    [(_ (id ...) rhs)
     (syntax-property
      (syntax/loc stx (-define*-syntaxes (id ...) rhs))
      'certify-mode
      'transparent-binding)]))

(define-syntax (define* stx)
  (let-values ([(id rhs) (normalize-definition stx #'lambda)])
    (quasisyntax/loc stx
      (define*-values (#,id) #,rhs))))
(define-syntax (define*-syntax stx)
  (let-values ([(id rhs) (normalize-definition stx #'lambda)])
    (quasisyntax/loc stx
      (define*-syntaxes (#,id) #,rhs))))

(begin-for-syntax
  (define-struct package (exports)
    #:omit-define-syntaxes
    #:property prop:procedure (lambda (r stx)
                                (raise-syntax-error
                                 #f
                                 "misuse of a package name"
                                 stx))))

(define-for-syntax (ml-name? id)
  (not (char=? (string-ref (symbol->string (syntax-e id)) 0) #\:)))

(define-for-syntax (do-define-package stx exp-stx)
  (syntax-case exp-stx ()
    [(_ pack-id mode exports . form)
     (let ([id #'pack-id]
           [exports #'exports]
           [mode (syntax-e #'mode)])
       (unless (eq? mode '#:begin)
         (unless (identifier? id)
           (raise-syntax-error #f
                               "expected an identifier"
                               stx
                               id)))
       (let ([exports
              (cond
                [(syntax->list exports)
                 => (lambda (l)
                      (for-each (lambda (i)
                                  (unless (identifier? i)
                                    (raise-syntax-error #f
                                                        "expected identifier to export"
                                                        stx
                                                        i)))
                                l)
                      (let ([dup-id (check-duplicate-identifier l)])
                        (when dup-id
                          (raise-syntax-error
                           #f
                           "duplicate export"
                           stx
                           dup-id)))
                      l)]
                [else (raise-syntax-error #f
                                          (format "expected a parenthesized sequence of identifiers ~a"
                                                  (case mode
                                                    [(#:only) "to export"]
                                                    [(#:all-defined-except) "to exclude from export"]
                                                    [else (format "for ~a" mode)]))
                                          stx
                                          exports)])])
         (let* ([def-ctx (syntax-local-make-definition-context)]
                [ctx (cons (gensym 'intdef)
                           (let ([orig-ctx (syntax-local-context)])
                             (if (pair? orig-ctx)
                                 orig-ctx
                                 null)))]
                [pre-package-id (lambda (id def-ctxes)
                                  (identifier-remove-from-definition-context 
                                   id 
                                   def-ctxes))]
                [kernel-forms (list*
                               #'-define*-values
                               #'-define*-syntaxes
                               (kernel-form-identifier-list))]
                [init-exprs (syntax->list #'form)]
                [new-bindings (make-hasheq)]
                [complement (lambda (bindings ids)
                              (let ([tmp (make-bound-identifier-mapping)])
                                (hash-for-each bindings
                                               (lambda (k v)
                                                 (bound-identifier-mapping-put! tmp v #t)))
                                (for-each (lambda (id)
                                            (bound-identifier-mapping-put! tmp id #f))
                                          ids)
                                (filter
                                 values
                                 (bound-identifier-mapping-map tmp (lambda (k v) (and v k))))))])
           (let ([register-bindings!
                  (lambda (ids)
                    (for-each (lambda (id)
                                (when (let ((nid (hash-ref new-bindings (syntax-e id) #f)))
                                        (and nid (bound-identifier=? id nid)))
                                  (raise-syntax-error #f 
                                                      "duplicate binding"
                                                      stx
                                                      id))
                                (hash-set! new-bindings (syntax-e id) id))
                              ids))]
                 [add-package-context (lambda (def-ctxes)
                                        (lambda (stx)
                                          (let ([q (local-expand #`(quote #,stx)
                                                                 ctx
                                                                 (list #'quote)
                                                                 def-ctxes)])
                                            (syntax-case q ()
                                              [(_ stx) #'stx]))))])
             (let loop ([exprs init-exprs]
                        [rev-forms null]
                        [def-ctxes (list def-ctx)]
                        [register-bindings? #t])
               (cond
                 [(and register-bindings?
                       (null? exprs))
                  (for-each (lambda (def-ctx)
                              (internal-definition-context-seal def-ctx))
                            def-ctxes)
                  (let* ([exports (map (lambda (id)
                                         (hash-ref new-bindings
                                                   (syntax-e id)
                                                   (lambda ()
                                                     (raise-syntax-error #f
                                                                         (format "no definition for ~a identifier"
                                                                                 (case mode
                                                                                   [(#:only) "exported"]
                                                                                   [(#:all-defined-except) "excluded"]))
                                                                         stx
                                                                         id))))
                                       exports)]
                         [exports-renamed (map (add-package-context def-ctxes) exports)]
                         [defined-renamed (hash-map new-bindings (lambda (k v) v))])
                    (let-values ([(exports exports-renamed)
                                  (if (memq mode '(#:only #:begin))
                                      (values exports exports-renamed)
                                      (let ([all-exports-renamed (complement new-bindings exports-renamed)])
                                        ;; In case of define*, get only the last definition:
                                        (let ([tmp (make-bound-identifier-mapping)])
                                          (for-each (lambda (id)
                                                      (bound-identifier-mapping-put!
                                                       tmp
                                                       ((add-package-context def-ctxes) 
                                                        (pre-package-id id def-ctxes))
                                                       #t))
                                                    all-exports-renamed)
                                          (let* ([exports-renamed (bound-identifier-mapping-map tmp (lambda (k v) k))]
                                                 [exports (map (lambda (id) (pre-package-id id def-ctxes))
                                                               exports-renamed)])
                                            (values exports exports-renamed)))))])
                      (with-syntax ([(export ...) exports]
                                    [(renamed ...) (map identifier-prune-lexical-context exports-renamed)])
                        (let ([body (reverse rev-forms)])
                          (if (eq? mode '#:begin)
                              (if (eq? 'expression (syntax-local-context))
                                  (quasisyntax/loc stx (let () #,@body))
                                  (quasisyntax/loc stx (begin #,@body)))
                              (quasisyntax/loc stx
                                (begin
                                  #,@(if (eq? 'top-level (syntax-local-context))
                                         ;; delcare all bindings before they are used:
                                         #`((define-syntaxes #,defined-renamed (values)))
                                         null)
                                  #,@body
                                  (define-syntax pack-id
                                    (make-package
                                     (lambda ()
                                       (list (cons (quote export)
                                                   (quote-syntax renamed))
                                             ...)))))))))))]
                 [(null? exprs)
                  (values rev-forms def-ctxes)]
                 [else
                  (let ([expr (local-expand (car exprs)
                                            ctx
                                            (cons #'define-package kernel-forms)
                                            def-ctxes)])
                    (syntax-case expr (begin define-package)
                      [(begin . rest)
                       (loop (append (syntax->list #'rest) (cdr exprs))
                             rev-forms
                             def-ctxes
                             register-bindings?)]
                      [(def (id ...) rhs)
                       (and (or (free-identifier=? #'def #'define-syntaxes)
                                (free-identifier=? #'def #'-define*-syntaxes))
                            (andmap identifier? (syntax->list #'(id ...))))
                       (with-syntax ([rhs (local-transformer-expand
                                           #'rhs
                                           'expression
                                           null)])
                         (let ([star? (free-identifier=? #'def #'-define*-syntaxes)]
                               [ids (syntax->list #'(id ...))])
                           (let* ([def-ctx (if star?
                                               (syntax-local-make-definition-context (car def-ctxes))
                                               (car def-ctxes))]
                                  [ids (if star? 
                                           (map (add-package-context (list def-ctx)) ids)
                                           ids)])
                             (syntax-local-bind-syntaxes ids #'rhs def-ctx)
                             (when register-bindings?
                               (register-bindings! ids))
                             (loop (cdr exprs)
                                   (cons #`(define-syntaxes #,ids rhs)
                                         rev-forms)
                                   (if star? (cons def-ctx def-ctxes) def-ctxes)
                                   register-bindings?))))]
                      [(def (id ...) rhs)
                       (and (or (free-identifier=? #'def #'define-values)
                                (free-identifier=? #'def #'-define*-values))
                            (andmap identifier? (syntax->list #'(id ...))))
                       (let ([star? (free-identifier=? #'def #'-define*-values)]
                             [ids (syntax->list #'(id ...))])
                         (let* ([def-ctx (if star?
                                             (syntax-local-make-definition-context (car def-ctxes))
                                             (car def-ctxes))]
                                [ids (if star? 
                                         (map (add-package-context (list def-ctx)) ids)
                                         ids)])
                           (syntax-local-bind-syntaxes ids #f def-ctx)
                           (when register-bindings?
                             (register-bindings! ids))
                           (loop (cdr exprs)
                                 (cons #`(define-values #,ids rhs) rev-forms)
                                 (if star? (cons def-ctx def-ctxes) def-ctxes)
                                 register-bindings?)))]
                      [(define-package id . _)
                       (let-values (((rev-forms def-ctxes)
                                     (loop (list (local-expand expr
                                                               ctx
                                                               kernel-forms
                                                               def-ctxes))
                                           rev-forms
                                           def-ctxes
                                           #f)))
                         (when (and register-bindings?
                                    (ml-name? #'id))
                           (register-bindings! (list #'id)))
                         (loop (cdr exprs)
                               rev-forms
                               def-ctxes
                               register-bindings?))]
                      [else
                       (loop (cdr exprs) 
                             (cons (if (and (eq? mode '#:begin)
                                            (null? (cdr exprs)))
                                       expr
                                       #`(define-values () (begin #,expr (values))))
                                   rev-forms)
                             def-ctxes
                             register-bindings?)]))]))))))]))

(define-syntax (define-package stx)
  (syntax-case stx ()
    [(_ id #:all-defined form ...)
     (do-define-package stx #'(define-package id #:all-defined () form ...))]
    [(_ id #:all-defined-except ids form ...)
     (do-define-package stx stx)]
    [(_ id #:only ids form ...)
     (do-define-package stx stx)]
    [(_ id ids form ...)
     (do-define-package stx #'(define-package id #:only ids form ...))]))

(define-syntax (package-begin stx)
  (syntax-case stx ()
    [(_ form ...)
     (do-define-package stx #'(define-package #f #:begin () form ...))]))

(define-for-syntax (do-open stx define-syntaxes-id)
  (syntax-case stx ()
    [(_ pack-id)
     (do-open #'(_ pack-id pack-id) define-syntaxes-id)]
    [(_ pack-id ctx-id)
     (let ([id #'pack-id])
       (unless (identifier? id)
         (raise-syntax-error #f
                             "expected an identifier for a package"
                             stx
                             id))
       (let ([v (syntax-local-value id (lambda () #f))])
         (unless (package? v)
           (raise-syntax-error #f
                               "identifier is not bound to a package"
                               stx
                               id))
         (with-syntax ([(intro ...)
                        (map (lambda (i)
                               (datum->syntax #'ctx-id
                                              i
                                              #f))
                             (map car ((package-exports v))))]
                       [(defined ...)
                        (map (lambda (v) (syntax-local-introduce (cdr v)))
                             ((package-exports v)))])
           (syntax-property #`(#,define-syntaxes-id (intro ...)
                                                    (values (make-rename-transformer #'defined)
                                                            ...))
                            'disappeared-use
                            (syntax-local-introduce id)))))]))

(define-syntax (open-package stx)
  (do-open stx #'define-syntaxes))
(define-syntax (open*-package stx)
  (do-open stx #'define*-syntaxes))

(define-for-syntax (package-exported-identifiers id)
  (let ([v (and (identifier? id)
                (syntax-local-value id (lambda () #f)))])
    (unless (package? v)
      (raise-type-error 'package-exported-identifiers "identifier bound to a package" id))
    (map (lambda (i)
           (datum->syntax id
                          (car i)
                          #f))
         ((package-exports v)))))

(define-for-syntax (package-original-identifiers id)
  (let ([v (and (identifier? id)
                (syntax-local-value id (lambda () #f)))])
    (unless (package? v)
      (raise-type-error 'package-exported-identifiers "identifier bound to a package" id))
    (map cdr ((package-exports v)))))