(module build-arity-table mzscheme
(require (prefix kernel: (lib "kerncase.ss" "syntax"))
(lib "contract.ss")
(lib "list.ss")
"arity-table.ss")
(provide/contract [build-arity-table (-> syntax? table?)])
(define (test desired p . args)
(let* ([result (apply p args)])
(unless (equal? desired result)
(fprintf (current-error-port) "test failed: desired: ~v\ngot: ~v\ntest: ~v\n" desired result (cons p args)))))
(define (build-arity-table stx)
(coalesce-table (top-level-expr-iterator stx)))
(define (top-level-expr-iterator stx)
(kernel:kernel-syntax-case stx #f
[(module identifier name (#%plain-module-begin . module-level-exprs))
(apply append (map module-level-expr-iterator (syntax->list #'module-level-exprs)))]
[else-stx
(general-top-level-expr-iterator stx)]))
(define (module-level-expr-iterator stx)
(kernel:kernel-syntax-case stx #f
[(provide . provide-specs)
null]
[else-stx
(general-top-level-expr-iterator stx)]))
(define (general-top-level-expr-iterator stx)
(kernel:kernel-syntax-case stx #f
[(define-values (var ...) expr)
(let ([var-list (syntax->list #'(var ...))])
(cond [(= (length var-list) 1) (expr-iterator #'expr (car var-list))]
[else (expr-iterator #'expr #f)]))]
[(define-syntaxes (var ...) expr)
null]
[(begin . top-level-exprs)
(apply append (map top-level-expr-iterator (syntax->list #'top-level-exprs)))]
[(require . require-specs)
null]
[(require-for-syntax . require-specs)
null]
[else
(expr-iterator stx #f)]))
(define (expr-iterator stx potential-name)
(let* ([recur-tail (lambda (expr) (expr-iterator expr potential-name))]
[recur-non-tail (lambda (expr) (expr-iterator expr #f))]
[recur-with-name (lambda (expr name) (expr-iterator expr name))]
[lambda-clause-abstraction
(lambda (clause)
(kernel:kernel-syntax-case clause #f
[(arglist . bodies)
(let ([rest (apply append (map recur-non-tail (syntax->list #'bodies)))])
(if potential-name
(cons
(list potential-name (list (arity-of-arglist #'arglist)))
rest)
rest))]
[else
(error 'expr-syntax-object-iterator
"unexpected (case-)lambda clause: ~a"
(syntax-object->datum stx))]))]
[let-values-abstraction
(lambda (stx)
(kernel:kernel-syntax-case stx #f
[(kwd (((variable ...) rhs) ...) . bodies)
(let* ([clause-fn
(lambda (vars rhs)
(let ([var-list (syntax->list vars)])
(cond [(= (length var-list) 1)
(recur-with-name rhs (car var-list))]
[else
(recur-non-tail rhs)])))])
(apply append
(append (map clause-fn (syntax->list #'((variable ...) ...)) (syntax->list #'(rhs ...)))
(map recur-non-tail (syntax->list #'bodies)))))]
[else
(error 'expr-syntax-object-iterator
"unexpected let(rec) expression: ~a"
stx
)]))])
(kernel:kernel-syntax-case stx #f
[var-stx
(identifier? (syntax var-stx))
null]
[(lambda . clause)
(lambda-clause-abstraction #'clause)]
[(case-lambda . clauses)
(apply append (map lambda-clause-abstraction (syntax->list #'clauses)))]
[(if test then)
(append
(recur-non-tail #'test)
(recur-tail #'then))]
[(if test then else)
(append
(recur-non-tail #'test)
(recur-non-tail #'then)
(recur-non-tail #'else))]
[(begin . bodies)
(let ([body-list (syntax->list #'bodies)])
(apply append
(recur-tail (car (reverse body-list)))
(map recur-non-tail (reverse (cdr (reverse body-list))))))]
[(begin0 . bodies)
(let ([body-list (syntax->list #'bodies)])
(apply append
(recur-tail (car body-list))
(map recur-non-tail (cdr body-list))))]
[(let-values . _)
(let-values-abstraction stx)]
[(letrec-values . _)
(let-values-abstraction stx)]
[(set! var val)
(cons (list #'var `(unknown))
(recur-non-tail #'val))]
[(quote _)
null]
[(quote-syntax _)
null]
[(with-continuation-mark key mark body)
(append
(recur-non-tail #'key)
(recur-non-tail #'mark)
(recur-tail #'body))]
[(#%app . exprs)
(apply append (map recur-non-tail (syntax->list #'exprs)))]
[(#%datum . _)
null]
[(#%top . var)
null]
[else
(error 'expr-iterator "unknown expr: ~a"
(syntax-object->datum stx))])))
(define (arity-of-arglist arglist-stx)
(syntax-case arglist-stx ()
[var
(identifier? arglist-stx)
(list 0 'inf)]
[(var ...)
(let ([args (length (syntax->list #'(var ...)))])
(list args args))]
[(var . others)
(let ([arity-of-rest (arity-of-arglist #'others)])
(list (incr-limit (car arity-of-rest))
(incr-limit (cadr arity-of-rest))))]))
(define incr-limit
(contract
(-> (union number? (symbols 'inf)) any)
(lambda (limit)
(cond [(number? limit) (+ 1 limit)]
[(eq? limit 'inf) 'inf]))
'incr-limit
'caller))
(define (build-arity-table-test expected stx)
(test expected
map (lambda (key-value-list)
(list (syntax-e (car key-value-list))
(cadr key-value-list)))
(build-arity-table
(expand stx))))
(build-arity-table-test `((a ((2 2)))) '(define a (lambda (b c) b)))
(build-arity-table-test `((a ((1 1)))) '(begin (define (a x) 3)))
(build-arity-table-test `((a ((3 3) (2 inf)))) '(define a (case-lambda ((a b c) 3) ((a b . c) 3))))
(build-arity-table-test `() '(define a (if #t (lambda (b c) 3) (lambda (c) 3))))
(build-arity-table-test `((a ((2 2))))
#'((let*-values ([(a) (lambda (b c) 3)]
[(b c) (values (lambda (b) 3) (lambda (x) 3))]
[(d) (begin (lambda (a b) 3) (lambda (a) 3))])
(set! d (lambda (a b c d e) 3)))))
(build-arity-table-test `((a ((1 1))) (a ((2 2))))
#'((let ([a (lambda (x) x)]) 3)
(let ([a (lambda (x y) x)]) 3)))
(build-arity-table-test `((a ((1 1))))
'(define a (begin (lambda () 3) (begin0 (lambda (x) 3) (lambda () 3)))))
(build-arity-table-test `((a ((1 1))) (b ((1 1))))
#'(+ (begin (let ([a (lambda (x) x)]) 3) 4)
(begin0 4 (let ([b (lambda (x) x)]) 3))))
(build-arity-table-test `()
'(define (a x) (set! a (lambda (x y) 3))))
)