(module subst mzscheme
(require (lib "match.ss")
(prefix plt: (lib "plt-match.ss"))
(lib "list.ss"))
(provide plt-subst subst
all-vars variable subterm subterms constant build
subst/proc alpha-rename free-vars/memoize)
(define-syntax (all-vars stx) (raise-syntax-error 'subst "all-vars out of context" stx))
(define-syntax (variable stx) (raise-syntax-error 'subst "variable out of context" stx))
(define-syntax (subterm stx) (raise-syntax-error 'subst "subterm out of context" stx))
(define-syntax (subterms stx) (raise-syntax-error 'subst "subterms out of context" stx))
(define-syntax (constant stx) (raise-syntax-error 'subst "constant out of context" stx))
(define-syntax (build stx) (raise-syntax-error 'subst "build out of context" stx))
(define-syntax (make-subst stx)
(syntax-case stx ()
[(_ subst match)
(syntax
(define-syntax (subst stx)
(syntax-case stx ()
[(_ (pat rhs (... ...)) (... ...))
(with-syntax ([term/arg #'term/arg]
[constant/arg #'constant/arg]
[variable/arg #'variable/arg]
[combine/arg #'combine/arg]
[sub-piece/arg #'subpiece/arg])
(define (handle-rhs rhs-stx)
(syntax-case rhs-stx (all-vars build subterm subterms variable constant)
[((all-vars all-vars-exp) (build build-exp) sub-pieces (... ...))
(with-syntax ([(sub-pieces (... ...))
(map (lambda (subterm-stx)
(syntax-case subterm-stx (subterm subterms)
[(subterm vars body) (syntax (list (sub-piece/arg vars body)))]
[(subterms vars terms)
(syntax
(let ([terms-var terms])
(unless (list? terms-var)
(error 'subst
"expected a list of terms for `subterms' subclause, got: ~e"
terms-var))
(map (lambda (x) (sub-piece/arg vars x))
terms-var)))]
[else (raise-syntax-error
'subst
"unknown all-vars subterm"
stx
subterm-stx)]))
(syntax->list (syntax (sub-pieces (... ...)))))])
(syntax
(apply combine/arg
build-exp
all-vars-exp
(append sub-pieces (... ...)))))]
[((all-vars) sub-pieces (... ...))
(raise-syntax-error 'subst "expected all-vars must have an argument" stx rhs-stx)]
[((all-vars all-vars-exp) not-build-clause anything (... ...))
(raise-syntax-error 'subst "expected build clause" (syntax not-build-clause))]
[((all-vars all-vars-exp))
(raise-syntax-error 'subst "missing build clause" (syntax (all-vars all-vars-exp)))]
[((constant))
(syntax (constant/arg term/arg))]
[((variable))
(syntax (variable/arg (lambda (x) x) term/arg))]
[(unk unk-more (... ...))
(raise-syntax-error 'subst "unknown clause" (syntax unk))]))
(with-syntax ([(expanded-rhs (... ...))
(map handle-rhs (syntax->list (syntax ((rhs (... ...)) (... ...)))))])
(syntax
(let ([separate
(lambda (term/arg constant/arg variable/arg combine/arg sub-piece/arg)
(match term/arg
[pat expanded-rhs] (... ...)
[else (error 'subst "no matching clauses for ~s\n" term/arg)]))])
(lambda (var val exp)
(subst/proc var val exp separate))))))])))]))
(make-subst subst match)
(make-subst plt-subst plt:match)
(define (subst/proc var val exp separate)
(let* ([free-vars-cache (make-hash-table 'equal)]
[fv-val (free-vars/memoize free-vars-cache val separate)])
(let loop ([exp exp])
(let ([fv-exp (free-vars/memoize free-vars-cache exp separate)]
[handle-constant
(lambda (x) x)]
[handle-variable
(lambda (rebuild var-name)
(if (equal? var-name var)
val
(rebuild var-name)))]
[handle-complex
(lambda (maker vars . subpieces)
(cond
[(ormap (lambda (var) (memq var fv-val)) vars)
=>
(lambda (to-be-renamed-l)
(let ([to-be-renamed (car to-be-renamed-l)])
(loop
(alpha-rename
to-be-renamed
(pick-new-name to-be-renamed (cons to-be-renamed fv-val))
exp
separate))))]
[else
(apply maker
vars
(map (lambda (subpiece)
(let ([sub-term-binders (subpiece-binders subpiece)]
[sub-term (subpiece-term subpiece)])
(if (memq var sub-term-binders)
sub-term
(loop sub-term))))
subpieces))]))])
(if (member var fv-exp)
(separate
exp
handle-constant
handle-variable
handle-complex
make-subpiece)
exp)))))
(define-struct subpiece (binders term) (make-inspector))
(define (alpha-rename to-be-renamed new-name exp separate)
(define (first exp)
(separate exp
first-handle-constant
first-handle-variable
first-handle-complex
first-handle-complex-subpiece))
(define (first-handle-constant x) x)
(define (first-handle-variable rebuild var) (rebuild var))
(define (first-handle-complex maker vars . subpieces)
(let ([replaced-vars
(map (lambda (x) (if (eq? x to-be-renamed) new-name x))
vars)])
(apply maker replaced-vars subpieces)))
(define (first-handle-complex-subpiece binders subterm)
(if (memq to-be-renamed binders)
(beyond-first subterm)
subterm))
(define (beyond-first exp)
(define (handle-constant x) x)
(define (handle-variable rebuild var)
(if (eq? var to-be-renamed)
(rebuild new-name)
(rebuild var)))
(define (handle-complex maker vars . subpieces)
(apply maker vars subpieces))
(define (handle-complex-subpiece binders subterm)
(if (memq to-be-renamed binders)
subterm
(beyond-first subterm)))
(separate
exp
handle-constant
handle-variable
handle-complex
handle-complex-subpiece))
(first exp))
(define (free-vars/memoize cache exp separate)
(hash-table-get
cache
exp
(lambda ()
(let ([res (free-vars/compute cache exp separate)])
(hash-table-put! cache exp res)
res))))
(define (free-vars/compute cache exp separate)
(let ([handle-constant (lambda (x) '())]
[handle-variable (lambda (rebuild var) (list var))]
[handle-complex
(lambda (maker vars . subpieces)
(apply append subpieces))]
[handle-complex-subpiece
(lambda (binders subterm)
(foldl remove-all
(free-vars/memoize cache subterm separate)
binders))])
(separate
exp
handle-constant
handle-variable
handle-complex
handle-complex-subpiece)))
(define (remove-all var lst)
(let loop ([lst lst]
[ans '()])
(cond
[(null? lst) ans]
[else (if (eq? (car lst) var)
(loop (cdr lst) ans)
(loop (cdr lst) (cons (car lst) ans)))])))
(define (lc-direct-subst var val exp)
(let ([fv-exp (lc-direct-free-vars exp)])
(if (memq var fv-exp)
(match exp
[`(lambda ,vars ,body)
(if (memq var vars)
exp
(let* ([fv-val (lc-direct-free-vars val)]
[vars1 (map (lambda (var) (pick-new-name var fv-val)) vars)])
`(lambda ,vars1 ,(lc-direct-subst
var
val
(lc-direct-subst/l vars
vars1
body)))))]
[`(let (,l-var ,exp) ,body)
(if (eq? l-var var)
`(let (,l-var ,(lc-direct-subst var val exp)) ,body)
(let* ([fv-val (lc-direct-free-vars val)]
[l-var1 (pick-new-name l-var fv-val)])
`(let (,l-var1 ,(lc-direct-subst var val exp))
,(lc-direct-subst
var
val
(lc-direct-subst
l-var
l-var1
body)))))]
[(? number?) exp]
[(and var1 (? symbol?))
(if (eq? var1 var)
val
var1)]
[`(,@(args ...))
`(,@(map (lambda (arg) (lc-direct-subst var val arg)) args))])
exp)))
(define (lc-direct-subst/l vars vals exp)
(foldr (lambda (var val exp) (lc-direct-subst var val exp))
exp
vars
vals))
(define (lc-direct-free-vars exp)
(let ([ht (make-hash-table)])
(let loop ([exp exp]
[binding-vars null])
(match exp
[(? symbol?)
(unless (memq exp binding-vars)
(hash-table-put! ht exp #t))]
[(? number?)
(void)]
[`(lambda ,vars ,body)
(loop body (append vars binding-vars))]
[`(let (,var ,exp) ,body)
(loop exp binding-vars)
(loop body (cons var binding-vars))]
[`(,@(args ...))
(for-each (lambda (arg) (loop arg binding-vars)) args)]))
(hash-table-map ht (lambda (x y) x))))
(define (pick-new-name var vars)
(if (member var vars)
(pick-new-name (prime var) vars)
var))
(define (prime var)
(string->symbol
(string-append
(symbol->string var)
"@"))))