datatype.ss
#lang scheme
(require mzlib/pconvert-prop
         (for-syntax scheme/list))

(provide define-type type-case)

; Like ormap, but returns the original value that satisfied pred?
(define-for-syntax (ormap/orig pred? lst)
  (if (null? lst)
      #f
      (if (pred? (first lst))
          (first lst)
          (ormap/orig pred? (rest lst)))))

(define-for-syntax (plai-syntax-error id stx-loc format-string . args)
  (raise-syntax-error 
   id (apply format (cons format-string args)) stx-loc))

(define bug:fallthru-no-else
  (string-append 
   "You have encountered a bug in the PLAI code.  (Error: type-case "
   "fallthru on cond without an else clause.)"))
(define-for-syntax bound-id
  (string-append
   "identifier is already bound in this scope (If you didn't define it, "
   "it was defined by the PLAI language.)"))
(define-for-syntax define-type:generic
  (string-append
   "expected the name of a type, followed by a sequence of variants; search "
   "the Help Desk for `define-type' for assistance."))
(define-for-syntax type-case:generic
  (string-append
   "syntax error in type-case; search the Help Desk for `type-case' for "
   "assistance."))
(define-for-syntax define-type:duplicate-variant
  "this identifier has already been used")
(define-for-syntax define-type:bound-id
  "this identifier is already bound")
(define-for-syntax assign-to-variant
  "This identifier references a variant of a define-type type")
(define-for-syntax type-case:not-a-type
  "this must be a type defined with define-type")
(define-for-syntax type-case:not-a-variant
  "this is not a variant of the specified type")
(define-for-syntax type-case:argument-count
  "this variant has ~a fields, but you provided bindings for ~a fields")
(define-for-syntax type-case:missing-variant
  "syntax error; probable cause: you did not include a case for the ~a variant, or no else-branch was present")
(define-for-syntax type-case:expected-datatype-id
  (string-append 
   "this must be an identifier that names the datatype you want to define"))
(define-for-syntax type-case:unreachable-else
  "the else branch of this type-case is unreachable; you have matched all variants")
(define-for-syntax define-type:zero-variants
  "you must specify a sequence of variants after the type, ~a")


(define-for-syntax ((make-variant-name src-stx) name-stx)
  (datum->syntax src-stx (syntax->datum name-stx)))

; Constructs the name name-stx?
(define-for-syntax ((make-variant? src-stx) name-stx)
  (datum->syntax 
   src-stx
   (string->symbol
    (string-append (symbol->string (syntax->datum name-stx)) "?"))))

(define-for-syntax ((make-field-accessor-name src-stx variant-stx) field-stx)
  (datum->syntax
   src-stx
   (string->symbol
    (string-append (symbol->string (syntax-e variant-stx)) "-"
                   (symbol->string (syntax-e field-stx))))))

(define-for-syntax ((make-field-mutator-name src-stx variant-stx) field-stx)
  (datum->syntax
   src-stx
   (string->symbol
    (string-append 
     "set-"
     (symbol->string (syntax-e variant-stx)) "-"
     (symbol->string (syntax-e field-stx))
     "!"))))

(define-for-syntax ((assert-unbound stx-symbol) id-stx)
  (when (identifier-binding id-stx)
    (plai-syntax-error stx-symbol id-stx bound-id)))

(define-for-syntax (assert-unique variant-stx)
  (let ([dup-id (check-duplicate-identifier (syntax->list variant-stx))])
    (when dup-id
      (plai-syntax-error 'define-type dup-id 
                         define-type:duplicate-variant))))

; Given a syntax-list of n elements, returns the list (0 1 ... (sub1 n))
(define-for-syntax (count-stx list-stx)
  (let loop ([lst (syntax->list list-stx)] [n 0])
    (if (empty? lst)
        empty
        (cons n (loop (rest lst) (+ n 1))))))

(define-for-syntax ((const x) y) x)

(define-for-syntax type-symbol (gensym))

(define-for-syntax (validate-and-remove-type-symbol stx-loc lst)
  (if (and (list? lst) (eq? type-symbol (first lst)))
      (rest lst)
      (plai-syntax-error 'type-case stx-loc type-case:not-a-type)))

(define-for-syntax (validate-field-spec stx)
  (syntax-case stx ()
    [(field contract)
     (unless (identifier? #'field)
       (plai-syntax-error
        'define-type #'field
        "must be an identifier that names the field of this variant"))]
    [_
     (plai-syntax-error
      'define-type stx
      "must be of the form (field-id contract-expr)")]))

(define-for-syntax (validate-variant variant-stx)
  (syntax-case variant-stx ()
    [(variant  . rest)
     (not (identifier? #'variant))
     (plai-syntax-error 
      'define-type #'variant
      "this must be an identifier that names a variant of the type you are defining")]
    [(variant  field-spec ...)
     (for-each validate-field-spec (syntax->list #'(field-spec ...)))]
    [_
     (plai-syntax-error
      'define-type variant-stx
      "this variant is malformed")]))

(define-syntax (define-type stx)
  (syntax-case stx ()
    [(_ datatype-id [variant-name (field-name field/c) ...] ...)
     ; Ensures that everything that should be an identifier is an identifier.
     (and (identifier? #'datatype-id)
          (andmap identifier? (syntax->list #'(variant-name ...)))
          (andmap (λ (stx) (andmap identifier? (syntax->list stx)))
                  (syntax->list #`((field-name ...) ...))))
     (begin
       ; Ensure we have at least one variant.
       (when (empty? (syntax->list #'(variant-name ...)))
         (plai-syntax-error 'define-type stx define-type:zero-variants
                            (syntax-e #'datatype-id)))
       
       ; Ensure variant names are unique.
       (assert-unique #'(variant-name ...))
       ; Ensure each set of fields have unique names.
       (map assert-unique (syntax->list #'((field-name ...) ...)))
       
       ; Ensure type and variant names are unbound
       (map (assert-unbound 'define-type)
            (cons #'datatype-id (syntax->list #'(variant-name ...))))
       
       ; At this point, only contracts may be malformed.
       
       ; Introduce names
       (with-syntax ([datatype-id? ((make-variant? stx) #'datatype-id)]
                     [(variant ...) 
                      (map (make-variant-name stx) 
                           (syntax->list #'(variant-name ...)))]
                     [(variant? ...) 
                      (map (make-variant? stx) 
                           (syntax->list #'(variant-name ...)))]
                     [(make-write-variant ...)
                      (generate-temporaries #'(variant-name ...))]
                     [(struct:variant ...)
                      (generate-temporaries #'(variant-name ...))]
                     [(variant-proj ...)
                      (map syntax-local-introduce (generate-temporaries #'(variant-name ...)))]
                     [(num-fields ...) (generate-temporaries #'(variant-name ...))]
                     [(variant-set! ...)
                      (generate-temporaries #'(variant-name ...))]
                     ; Creates names for the field accessor functions.
                     [((field ...) ...)
                      (map (λ (variant/field-stx)
                             (map (make-field-accessor-name 
                                   stx (first (syntax->list variant/field-stx)))
                                  (syntax->list (second (syntax->list variant/field-stx)))))
                           (syntax->list #'((variant-name (field-name ...)) ...)))]
                     [((field-set! ...) ...)
                      (map (λ (variant/field-stx)
                             (map (make-field-mutator-name 
                                   stx (first (syntax->list variant/field-stx)))
                                  (syntax->list (second (syntax->list variant/field-stx)))))
                           (syntax->list #'((variant-name (field-name ...)) ...)))]
                     [((field-position ...) ...) 
                      (map count-stx (syntax->list #'((field-name ...) ...)))])
         (with-syntax ([(make-variant ...) (generate-temporaries #'(variant ...))]
                       [((variant-proj/f ...) ...)
                        (map (λ (proj fields) (map (const proj) (syntax->list fields)))
                             (syntax->list #'(variant-proj ...))
                             (syntax->list #'((field-name ...) ...)))]
                       [((variant?/f ...) ...)
                        (map (λ (v? fields) (map (const v?) (syntax->list fields)))
                             (syntax->list #'(variant? ...))
                             (syntax->list #'((field-name ...) ...)))]
                       [((variant-set!/f ...) ...)
                        (map (λ (mut fields) (map (const mut) (syntax->list fields)))
                             (syntax->list #'(variant-set! ...))
                             (syntax->list #'((field-name ...) ...)))])
           #`(begin
               ; For each variant, create constructor, predicate and
               ; projections.
               (begin
                 
                 (define num-fields (length (list 'field-name ...)))
                 
                 (define ((make-write-variant variant proj-count) val port write?)
                   (set! write? false)
                   (write-string (format (if write? "#<struct:~a" "(~a") variant) port)
                   (let ([field-vec (struct->vector val)])
                     (for ([index (in-range 1 (vector-length field-vec))])
                       (write-string " " port)
                       ((if write? write display) (vector-ref field-vec index) port)))
                   (write-string (if write? ">" ")") port))
                 
                 (define-values (struct:variant make-variant variant? 
                                                variant-proj variant-set!)
                   (make-struct-type 
                    'variant #f (length (list 'field-name ...)) 0 #f 
                    `((,prop:print-convert-constructor-name . ,'variant))
                    ;`((,prop:custom-write . ,(make-write-variant 'variant num-fields)))
                    (make-inspector)))
                 
                 )
               ...
               
               ; The type? predicate must come before the fields and
               ; constructors, to permit recursive data-structures.
               (define (datatype-id? x)
                 (or (variant? x) ...))
               
               (begin
                 ; Projections
                 (define field
                   ; We use field-name as the name symbol as
                   ; make-struct-field-accessor prepends the name of the
                   ; structure itself.
                   (make-struct-field-accessor variant-proj/f field-position
                                               'field-name))
                 ...
                 
                 ; Mutators.  If provide-datatype-field-mutators? is false
                 ; (which it is by default, the field-set! names are not
                 ; visible outside.
                 (define (field-set! var val)
                   ((contract (variant?/f field/c . -> . any)
                              (make-struct-field-mutator variant-set!/f field-position
                                                         'field-name)
                              'field-name 'use) var val))
                 ...
                 
                 ; Constructor
                 (define-syntax variant
                   (make-set!-transformer
                    (λ (stx)
                      (syntax-case stx (set!)
                        [(set! id v) (plai-syntax-error 'set! #'id assign-to-variant)]
                        [(f . args)
                         (with-syntax ([num-fields (length (syntax->list #'(field/c ...)))]
                                       [num-args (length (syntax->list #'args))])
                           #`(if (= num-fields num-args)
                                 ((contract (field/c ... . -> . variant?)
                                            make-variant 'variant-name 'use (quote-syntax f))
                                  . args)
                                 (raise (make-exn:fail:contract:arity
                                         (format "use broke the contract on ~a; expected ~a arguments, given ~a"
                                                 'variant num-fields num-args)
                                         (current-continuation-marks)
                                         ))))]
                        [id #`(contract (field/c ... . -> . variant?)
                                        make-variant 'variant-name 'use (quote-syntax id))]))))
                 
                 )
               ...
               
               (define-syntax datatype-id
                 (list type-symbol
                       (list (list (quote-syntax variant-name) (list #'field ...) #'variant?)
                             ...)
                       #'datatype-id?))
               ))))]
    ;;; The rest of these attempt to give decent error messages for malformed
    ;;; cases.  If we made it this far, either the datatype-id is not an
    ;;; identifer or a clause is malformed
    [(_ datatype-id . rest)
     (not (identifier? #'datatype-id))
     (plai-syntax-error 'define-type #'datatype-id 
                        type-case:expected-datatype-id)]
    [(_ datatype-id clause ...)
     (for-each validate-variant (syntax->list #'(clause ...)))]
    [_ (plai-syntax-error 'define-type stx define-type:generic)]))

;;; Asserts that variant-id-stx is a variant of the type described by
;;; type-stx.
(define-for-syntax ((assert-variant type-info) variant-id-stx)
  (unless (ormap (λ (stx) (free-identifier=? variant-id-stx stx)) 
                 (map first type-info))
    (plai-syntax-error 'type-case variant-id-stx type-case:not-a-variant)))

;;; Asserts that the number of fields is appropriate.
(define-for-syntax ((assert-field-count type-info) variant-id-stx field-stx)
  (let ([field-count
         (ormap (λ (type) ; assert-variant first and this ormap will not fail
                  (and (free-identifier=? (first type) variant-id-stx)
                       (length (second type))))
                type-info)])
    (unless (= field-count (length (syntax->list field-stx)))
      (plai-syntax-error 'type-case variant-id-stx type-case:argument-count 
                         field-count (length (syntax->list field-stx))))))

(define-for-syntax ((ensure-variant-present stx-loc variants) variant)
  (unless (ormap (λ (id-stx) (free-identifier=? variant id-stx))
                 (syntax->list variants))
    (plai-syntax-error 'type-case stx-loc type-case:missing-variant 
                       (syntax->datum variant))))

(define-for-syntax ((variant-missing? stx-loc variants) variant)
  (not (ormap (λ (id-stx) (free-identifier=? variant id-stx))
              (syntax->list variants))))


(define-syntax (lookup-variant stx)
  (syntax-case stx ()
    [(_ variant-id ((id (field ...) id?) . rest))
     (free-identifier=? #'variant-id #'id)
     #'(list (list field ...) id?)]
    [(_ variant-id (__ . rest)) #'(lookup-variant variant-id rest)]
    [(_ variant-id ()) (error 'lookup-variant "variant ~a not found (bug in PLAI code)"
                              (syntax-e #'variant-id))]))

(define-for-syntax (validate-clause clause-stx)
  (syntax-case clause-stx ()
    [(variant (field ...) body ...)
     (cond
       [(not (identifier? #'variant))
        (plai-syntax-error 'type-case #'variant
                           "this must be the name of a variant")]
       [(ormap (λ (stx) 
                 (and (not (identifier? stx)) stx)) (syntax->list #'(field ...)))
        => (λ (malformed-field)
             (plai-syntax-error
              'type-case malformed-field
              "this must be an identifier that names the value of a field"))]
       [(not (= (length (syntax->list #'(body ...))) 1))
        (plai-syntax-error 
         'type-case clause-stx
         (string-append
          "there must be just one body expression in a clause, but you "
          "provided ~a body expressions.")
         (length (syntax->list #'(body ...))))]
       [else #t])]
    [(variant (field ...))
     (plai-syntax-error
      'type-case clause-stx
      "this case is missing a body expression")]
    [_ 
     (plai-syntax-error
      'type-case clause-stx
      "this case is missing a field list (possibly an empty field list)")]))

(define-syntax (bind-fields-in stx)
  (syntax-case stx ()
    [(_ (binding-name ...) case-variant-id ((variant-id (selector-id ...) ___) . rest) value-id body-expr)
     (if (free-identifier=? #'case-variant-id #'variant-id)
         #'(let ([binding-name (selector-id value-id)]
                 ...)
             body-expr)
         #'(bind-fields-in (binding-name ...) case-variant-id rest value-id body-expr))]))

(define-syntax (type-case stx)
  (syntax-case stx (else)
    [(_ type-id test-expr [variant (field ...) case-expr] ... [else else-expr])
     ; Ensure that everything that should be an identifier is an identifier.
     (and (identifier? #'type-id)
          (andmap identifier? (syntax->list #'(variant ...)))
          (andmap (λ (stx) (andmap identifier? (syntax->list stx)))
                  (syntax->list #'((field ...) ...))))
     (let* ([info (validate-and-remove-type-symbol
                   #'type-id (syntax-local-value #'type-id (λ () #f)))]
            [type-info (first info)]
            [type? (second info)])
       
       ; Ensure all names are unique
       (assert-unique #'(variant ...))
       (map assert-unique (syntax->list #'((field ...) ...)))
       
       ; Ensure variants are valid.
       (map (assert-variant type-info) (syntax->list #'(variant ...)))
       
       ; Ensure field counts match.
       (map (assert-field-count type-info) 
            (syntax->list #'(variant ...))
            (syntax->list #'((field ...) ...)))
       
       ; Ensure some variant is missing.
       (unless (ormap (variant-missing? stx #'(variant ...)) 
                      (map first type-info))
         (plai-syntax-error 'type-case stx type-case:unreachable-else))
       
       
       #`(let ([expr test-expr])
           (if (not (#,type? expr))
               #,(syntax/loc #'test-expr 
                   (error 'type-case "the value ~a is not of the specified type"
                          expr))
               (cond
                 [(let ([variant-info (lookup-variant variant #,type-info)])
                    ((second variant-info) expr))
                  (bind-fields-in (field ...) variant #,type-info expr case-expr)]
                 ...
                 [else else-expr]))))]
    [(_ type-id test-expr [variant (field ...) case-expr] ...)
     ; Ensure that everything that should be an identifier is an identifier.
     (and (identifier? #'type-id)
          (andmap identifier? (syntax->list #'(variant ...)))
          (andmap (λ (stx) (andmap identifier? (syntax->list stx)))
                  (syntax->list #'((field ...) ...))))
     (let* ([info (validate-and-remove-type-symbol 
                   #'type-id (syntax-local-value #'type-id (λ () #f)))]
            [type-info (first info)]
            [type? (second info)])
       
       ; Ensure all names are unique
       (assert-unique #'(variant ...))
       (map assert-unique (syntax->list #'((field ...) ...)))
       
       ; Ensure variants are valid.
       (map (assert-variant type-info) (syntax->list #'(variant ...)))
       
       ; Ensure field counts match.
       (map (assert-field-count type-info) 
            (syntax->list #'(variant ...))
            (syntax->list #'((field ...) ...)))
       
       ; Ensure all variants are covered
       (map (ensure-variant-present stx #'(variant ...))
            (map first type-info))
       
       #`(let ([expr test-expr])
           (if (not (#,type? expr))
               #,(syntax/loc #'test-expr 
                   (error 'type-case "this value (~a) is not of the specified type"
                          expr))
               (cond
                 [(let ([variant-info (lookup-variant variant #,type-info)])
                    ((second variant-info) expr))
                  (bind-fields-in (field ...) variant #,type-info expr case-expr)]
                 ...
                 [else (error 'type-case bug:fallthru-no-else)]))))]
    ;;; The remaining clauses are for error reporting only.  If we got this
    ;;; far, either the clauses are malformed or the error is completely
    ;;; unintelligible.
    [(_ type-id test-expr clauses ...)
     (map validate-clause (syntax->list #'(clauses ...)))]
    [_ (plai-syntax-error 'type-case stx type-case:generic)]))