#lang scheme
(require mzlib/pconvert-prop
(for-syntax scheme/list))
(provide define-type type-case)
(define-for-syntax (ormap/orig pred? lst)
(if (null? lst)
#f
(if (pred? (first lst))
(first lst)
(ormap/orig pred? (rest lst)))))
(define-for-syntax (plai-syntax-error id stx-loc format-string . args)
(raise-syntax-error
id (apply format (cons format-string args)) stx-loc))
(define bug:fallthru-no-else
(string-append
"You have encountered a bug in the PLAI code. (Error: type-case "
"fallthru on cond without an else clause.)"))
(define-for-syntax bound-id
(string-append
"identifier is already bound in this scope (If you didn't define it, "
"it was defined by the PLAI language.)"))
(define-for-syntax define-type:generic
(string-append
"expected the name of a type, followed by a sequence of variants; search "
"the Help Desk for `define-type' for assistance."))
(define-for-syntax type-case:generic
(string-append
"syntax error in type-case; search the Help Desk for `type-case' for "
"assistance."))
(define-for-syntax define-type:duplicate-variant
"this identifier has already been used")
(define-for-syntax define-type:bound-id
"this identifier is already bound")
(define-for-syntax assign-to-variant
"This identifier references a variant of a define-type type")
(define-for-syntax type-case:not-a-type
"this must be a type defined with define-type")
(define-for-syntax type-case:not-a-variant
"this is not a variant of the specified type")
(define-for-syntax type-case:argument-count
"this variant has ~a fields, but you provided bindings for ~a fields")
(define-for-syntax type-case:missing-variant
"syntax error; probable cause: you did not include a case for the ~a variant, or no else-branch was present")
(define-for-syntax type-case:expected-datatype-id
(string-append
"this must be an identifier that names the datatype you want to define"))
(define-for-syntax type-case:unreachable-else
"the else branch of this type-case is unreachable; you have matched all variants")
(define-for-syntax define-type:zero-variants
"you must specify a sequence of variants after the type, ~a")
(define-for-syntax ((make-variant-name src-stx) name-stx)
(datum->syntax src-stx (syntax->datum name-stx)))
(define-for-syntax ((make-variant? src-stx) name-stx)
(datum->syntax
src-stx
(string->symbol
(string-append (symbol->string (syntax->datum name-stx)) "?"))))
(define-for-syntax ((make-field-accessor-name src-stx variant-stx) field-stx)
(datum->syntax
src-stx
(string->symbol
(string-append (symbol->string (syntax-e variant-stx)) "-"
(symbol->string (syntax-e field-stx))))))
(define-for-syntax ((make-field-mutator-name src-stx variant-stx) field-stx)
(datum->syntax
src-stx
(string->symbol
(string-append
"set-"
(symbol->string (syntax-e variant-stx)) "-"
(symbol->string (syntax-e field-stx))
"!"))))
(define-for-syntax ((assert-unbound stx-symbol) id-stx)
(when (identifier-binding id-stx)
(plai-syntax-error stx-symbol id-stx bound-id)))
(define-for-syntax (assert-unique variant-stx)
(let ([dup-id (check-duplicate-identifier (syntax->list variant-stx))])
(when dup-id
(plai-syntax-error 'define-type dup-id
define-type:duplicate-variant))))
(define-for-syntax (count-stx list-stx)
(let loop ([lst (syntax->list list-stx)] [n 0])
(if (empty? lst)
empty
(cons n (loop (rest lst) (+ n 1))))))
(define-for-syntax ((const x) y) x)
(define-for-syntax type-symbol (gensym))
(define-for-syntax (validate-and-remove-type-symbol stx-loc lst)
(if (and (list? lst) (eq? type-symbol (first lst)))
(rest lst)
(plai-syntax-error 'type-case stx-loc type-case:not-a-type)))
(define-for-syntax (validate-field-spec stx)
(syntax-case stx ()
[(field contract)
(unless (identifier? #'field)
(plai-syntax-error
'define-type #'field
"must be an identifier that names the field of this variant"))]
[_
(plai-syntax-error
'define-type stx
"must be of the form (field-id contract-expr)")]))
(define-for-syntax (validate-variant variant-stx)
(syntax-case variant-stx ()
[(variant . rest)
(not (identifier? #'variant))
(plai-syntax-error
'define-type #'variant
"this must be an identifier that names a variant of the type you are defining")]
[(variant field-spec ...)
(for-each validate-field-spec (syntax->list #'(field-spec ...)))]
[_
(plai-syntax-error
'define-type variant-stx
"this variant is malformed")]))
(define-syntax (define-type stx)
(syntax-case stx ()
[(_ datatype-id [variant-name (field-name field/c) ...] ...)
(and (identifier? #'datatype-id)
(andmap identifier? (syntax->list #'(variant-name ...)))
(andmap (λ (stx) (andmap identifier? (syntax->list stx)))
(syntax->list #`((field-name ...) ...))))
(begin
(when (empty? (syntax->list #'(variant-name ...)))
(plai-syntax-error 'define-type stx define-type:zero-variants
(syntax-e #'datatype-id)))
(assert-unique #'(variant-name ...))
(map assert-unique (syntax->list #'((field-name ...) ...)))
(map (assert-unbound 'define-type)
(cons #'datatype-id (syntax->list #'(variant-name ...))))
(with-syntax ([datatype-id? ((make-variant? stx) #'datatype-id)]
[(variant ...)
(map (make-variant-name stx)
(syntax->list #'(variant-name ...)))]
[(variant? ...)
(map (make-variant? stx)
(syntax->list #'(variant-name ...)))]
[(make-write-variant ...)
(generate-temporaries #'(variant-name ...))]
[(struct:variant ...)
(generate-temporaries #'(variant-name ...))]
[(variant-proj ...)
(map syntax-local-introduce (generate-temporaries #'(variant-name ...)))]
[(num-fields ...) (generate-temporaries #'(variant-name ...))]
[(variant-set! ...)
(generate-temporaries #'(variant-name ...))]
[((field ...) ...)
(map (λ (variant/field-stx)
(map (make-field-accessor-name
stx (first (syntax->list variant/field-stx)))
(syntax->list (second (syntax->list variant/field-stx)))))
(syntax->list #'((variant-name (field-name ...)) ...)))]
[((field-set! ...) ...)
(map (λ (variant/field-stx)
(map (make-field-mutator-name
stx (first (syntax->list variant/field-stx)))
(syntax->list (second (syntax->list variant/field-stx)))))
(syntax->list #'((variant-name (field-name ...)) ...)))]
[((field-position ...) ...)
(map count-stx (syntax->list #'((field-name ...) ...)))])
(with-syntax ([(make-variant ...) (generate-temporaries #'(variant ...))]
[((variant-proj/f ...) ...)
(map (λ (proj fields) (map (const proj) (syntax->list fields)))
(syntax->list #'(variant-proj ...))
(syntax->list #'((field-name ...) ...)))]
[((variant?/f ...) ...)
(map (λ (v? fields) (map (const v?) (syntax->list fields)))
(syntax->list #'(variant? ...))
(syntax->list #'((field-name ...) ...)))]
[((variant-set!/f ...) ...)
(map (λ (mut fields) (map (const mut) (syntax->list fields)))
(syntax->list #'(variant-set! ...))
(syntax->list #'((field-name ...) ...)))])
#`(begin
(begin
(define num-fields (length (list 'field-name ...)))
(define ((make-write-variant variant proj-count) val port write?)
(set! write? false)
(write-string (format (if write? "#<struct:~a" "(~a") variant) port)
(let ([field-vec (struct->vector val)])
(for ([index (in-range 1 (vector-length field-vec))])
(write-string " " port)
((if write? write display) (vector-ref field-vec index) port)))
(write-string (if write? ">" ")") port))
(define-values (struct:variant make-variant variant?
variant-proj variant-set!)
(make-struct-type
'variant #f (length (list 'field-name ...)) 0 #f
`((,prop:print-convert-constructor-name . ,'variant))
(make-inspector)))
)
...
(define (datatype-id? x)
(or (variant? x) ...))
(begin
(define field
(make-struct-field-accessor variant-proj/f field-position
'field-name))
...
(define (field-set! var val)
((contract (variant?/f field/c . -> . any)
(make-struct-field-mutator variant-set!/f field-position
'field-name)
'field-name 'use) var val))
...
(define-syntax variant
(make-set!-transformer
(λ (stx)
(syntax-case stx (set!)
[(set! id v) (plai-syntax-error 'set! #'id assign-to-variant)]
[(f . args)
(with-syntax ([num-fields (length (syntax->list #'(field/c ...)))]
[num-args (length (syntax->list #'args))])
#`(if (= num-fields num-args)
((contract (field/c ... . -> . variant?)
make-variant 'variant-name 'use (quote-syntax f))
. args)
(raise (make-exn:fail:contract:arity
(format "use broke the contract on ~a; expected ~a arguments, given ~a"
'variant num-fields num-args)
(current-continuation-marks)
))))]
[id #`(contract (field/c ... . -> . variant?)
make-variant 'variant-name 'use (quote-syntax id))]))))
)
...
(define-syntax datatype-id
(list type-symbol
(list (list (quote-syntax variant-name) (list #'field ...) #'variant?)
...)
#'datatype-id?))
))))]
[(_ datatype-id . rest)
(not (identifier? #'datatype-id))
(plai-syntax-error 'define-type #'datatype-id
type-case:expected-datatype-id)]
[(_ datatype-id clause ...)
(for-each validate-variant (syntax->list #'(clause ...)))]
[_ (plai-syntax-error 'define-type stx define-type:generic)]))
(define-for-syntax ((assert-variant type-info) variant-id-stx)
(unless (ormap (λ (stx) (free-identifier=? variant-id-stx stx))
(map first type-info))
(plai-syntax-error 'type-case variant-id-stx type-case:not-a-variant)))
(define-for-syntax ((assert-field-count type-info) variant-id-stx field-stx)
(let ([field-count
(ormap (λ (type) (and (free-identifier=? (first type) variant-id-stx)
(length (second type))))
type-info)])
(unless (= field-count (length (syntax->list field-stx)))
(plai-syntax-error 'type-case variant-id-stx type-case:argument-count
field-count (length (syntax->list field-stx))))))
(define-for-syntax ((ensure-variant-present stx-loc variants) variant)
(unless (ormap (λ (id-stx) (free-identifier=? variant id-stx))
(syntax->list variants))
(plai-syntax-error 'type-case stx-loc type-case:missing-variant
(syntax->datum variant))))
(define-for-syntax ((variant-missing? stx-loc variants) variant)
(not (ormap (λ (id-stx) (free-identifier=? variant id-stx))
(syntax->list variants))))
(define-syntax (lookup-variant stx)
(syntax-case stx ()
[(_ variant-id ((id (field ...) id?) . rest))
(free-identifier=? #'variant-id #'id)
#'(list (list field ...) id?)]
[(_ variant-id (__ . rest)) #'(lookup-variant variant-id rest)]
[(_ variant-id ()) (error 'lookup-variant "variant ~a not found (bug in PLAI code)"
(syntax-e #'variant-id))]))
(define-for-syntax (validate-clause clause-stx)
(syntax-case clause-stx ()
[(variant (field ...) body ...)
(cond
[(not (identifier? #'variant))
(plai-syntax-error 'type-case #'variant
"this must be the name of a variant")]
[(ormap (λ (stx)
(and (not (identifier? stx)) stx)) (syntax->list #'(field ...)))
=> (λ (malformed-field)
(plai-syntax-error
'type-case malformed-field
"this must be an identifier that names the value of a field"))]
[(not (= (length (syntax->list #'(body ...))) 1))
(plai-syntax-error
'type-case clause-stx
(string-append
"there must be just one body expression in a clause, but you "
"provided ~a body expressions.")
(length (syntax->list #'(body ...))))]
[else #t])]
[(variant (field ...))
(plai-syntax-error
'type-case clause-stx
"this case is missing a body expression")]
[_
(plai-syntax-error
'type-case clause-stx
"this case is missing a field list (possibly an empty field list)")]))
(define-syntax (bind-fields-in stx)
(syntax-case stx ()
[(_ (binding-name ...) case-variant-id ((variant-id (selector-id ...) ___) . rest) value-id body-expr)
(if (free-identifier=? #'case-variant-id #'variant-id)
#'(let ([binding-name (selector-id value-id)]
...)
body-expr)
#'(bind-fields-in (binding-name ...) case-variant-id rest value-id body-expr))]))
(define-syntax (type-case stx)
(syntax-case stx (else)
[(_ type-id test-expr [variant (field ...) case-expr] ... [else else-expr])
(and (identifier? #'type-id)
(andmap identifier? (syntax->list #'(variant ...)))
(andmap (λ (stx) (andmap identifier? (syntax->list stx)))
(syntax->list #'((field ...) ...))))
(let* ([info (validate-and-remove-type-symbol
#'type-id (syntax-local-value #'type-id (λ () #f)))]
[type-info (first info)]
[type? (second info)])
(assert-unique #'(variant ...))
(map assert-unique (syntax->list #'((field ...) ...)))
(map (assert-variant type-info) (syntax->list #'(variant ...)))
(map (assert-field-count type-info)
(syntax->list #'(variant ...))
(syntax->list #'((field ...) ...)))
(unless (ormap (variant-missing? stx #'(variant ...))
(map first type-info))
(plai-syntax-error 'type-case stx type-case:unreachable-else))
#`(let ([expr test-expr])
(if (not (#,type? expr))
#,(syntax/loc #'test-expr
(error 'type-case "the value ~a is not of the specified type"
expr))
(cond
[(let ([variant-info (lookup-variant variant #,type-info)])
((second variant-info) expr))
(bind-fields-in (field ...) variant #,type-info expr case-expr)]
...
[else else-expr]))))]
[(_ type-id test-expr [variant (field ...) case-expr] ...)
(and (identifier? #'type-id)
(andmap identifier? (syntax->list #'(variant ...)))
(andmap (λ (stx) (andmap identifier? (syntax->list stx)))
(syntax->list #'((field ...) ...))))
(let* ([info (validate-and-remove-type-symbol
#'type-id (syntax-local-value #'type-id (λ () #f)))]
[type-info (first info)]
[type? (second info)])
(assert-unique #'(variant ...))
(map assert-unique (syntax->list #'((field ...) ...)))
(map (assert-variant type-info) (syntax->list #'(variant ...)))
(map (assert-field-count type-info)
(syntax->list #'(variant ...))
(syntax->list #'((field ...) ...)))
(map (ensure-variant-present stx #'(variant ...))
(map first type-info))
#`(let ([expr test-expr])
(if (not (#,type? expr))
#,(syntax/loc #'test-expr
(error 'type-case "this value (~a) is not of the specified type"
expr))
(cond
[(let ([variant-info (lookup-variant variant #,type-info)])
((second variant-info) expr))
(bind-fields-in (field ...) variant #,type-info expr case-expr)]
...
[else (error 'type-case bug:fallthru-no-else)]))))]
[(_ type-id test-expr clauses ...)
(map validate-clause (syntax->list #'(clauses ...)))]
[_ (plai-syntax-error 'type-case stx type-case:generic)]))