#lang scheme/base
(require (only-in srfi/1/list lset-difference)
scheme/match
scheme/list
"../syntax/ast-core.ss"
"../syntax/ast-utils.ss"
"../syntax/token.ss"
"hoist-monad.ss")
(define (unique-vars funs vars)
(lset-difference Identifier=?
(remove-duplicates vars Identifier=?)
(map FunctionDeclaration-name funs)))
(define (to-location x)
(cond
[(not x) #f]
[(position? x) x]
[else (ast-location x)]))
(define (optional-statement->statement stmt loc)
(or stmt (make-EmptyStatement (to-location loc))))
(define (wrap-in-implicit-block loc stmts let-funs let-vars)
(list (make-BlockStatement/hoisted loc stmts let-funs let-vars)))
(define (hoist-function-expression expr)
(match expr
[(struct FunctionExpression (loc name args body))
(let-values ([(body let-vars let-funs vars funs imports exports)
(execute (hoist-source-elements body))])
(make-FunctionExpression/hoisted loc
name
args
(wrap-in-implicit-block loc body let-funs let-vars)
funs
(unique-vars funs vars)
imports
exports))]))
(define (hoist-function-declaration decl)
(match decl
[(struct FunctionDeclaration (location name args body))
(let-values ([(stmts let-vars let-funs vars funs imports exports)
(execute (hoist-source-elements body))])
(make-FunctionDeclaration/hoisted location
name
args
(wrap-in-implicit-block location stmts let-funs let-vars)
funs
(unique-vars funs vars)
imports
exports))]))
(define (hoist-program-unit elts)
(let-values ([(stmts let-vars let-funs vars funs imports exports)
(execute (hoist-source-elements elts))])
(values funs
(unique-vars funs vars)
imports
exports
(wrap-in-implicit-block (@ (first elts) (last elts)) stmts let-funs let-vars))))
(define (hoist-source-elements elts)
(map/m hoist-source-element elts))
(define (hoist-optional-expression expr)
(if expr (hoist-expression expr) (return #f)))
(define (hoist-substatements stmts)
(hoist-source-elements stmts))
(define (hoist-var-initializers decls loc)
(begin-hoist
(exprs <- (filter-map/m hoist-var-initializer decls))
(return (if (and (pair? exprs)
(null? (cdr exprs)))
(car exprs)
(make-ListExpression loc exprs)))))
(define (hoist-let-var-initializers decls loc)
(begin-hoist
(exprs <- (filter-map/m hoist-let-var-initializer decls))
(return (if (and (pair? exprs)
(null? (cdr exprs)))
(car exprs)
(make-ListExpression loc exprs)))))
(define (hoist-case-clauses cases)
(map/m hoist-case-clause cases))
(define (hoist-catch-clauses catches)
(map/m hoist-catch-clause catches))
(define (hoist-source-element src)
(if (Declaration? src)
(hoist-declaration src)
(hoist-statement src)))
(define (hoist-declaration decl0)
(cond
[(FunctionDeclaration? decl0)
(begin-hoist
(hoist 'function (hoist-function-declaration decl0))
(return (make-EmptyStatement (Term-location decl0))))]
[(VariableDeclaration? decl0)
(begin-hoist
(expr <- (hoist-variable-declaration decl0))
(return (make-ExpressionStatement (Term-location decl0) expr)))]
[(LetDeclaration? decl0)
(begin-hoist
(expr <- (hoist-let-declaration decl0))
(return (make-ExpressionStatement (Term-location decl0) expr)))]
[(ImportDeclaration? decl0)
(begin-hoist
(hoist 'top decl0)
(return (make-EmptyStatement (Term-location decl0))))]
[(ExportDeclaration? decl0)
(begin-hoist
(hoist 'top decl0)
(return (make-EmptyStatement (Term-location decl0))))]
[else (error 'hoist-declaration (format "unrecognized declaration: ~v" decl0))]))
(define (hoist-variable-declaration decl0)
(match decl0
[(struct VariableDeclaration (loc bindings))
(hoist-var-initializers bindings loc)]))
(define (hoist-let-declaration decl0)
(match decl0
[(struct LetDeclaration (loc bindings))
(hoist-let-var-initializers bindings loc)]))
(define (hoist-substatement src0)
(hoist-source-element src0))
(define (hoist-var-initializer decl)
(match decl
[(struct VariableInitializer (loc id #f))
(begin-hoist
(hoist 'function id)
(return #f))]
[(struct VariableInitializer (loc id init))
(begin-hoist
(hoist 'function id)
(init <- (hoist-expression init))
(return (make-AssignmentExpression loc (make-VarReference (Term-location id) id) '= init)))]))
(define (hoist-let-var-initializer decl)
(match decl
[(struct VariableInitializer (loc id #f))
(begin-hoist
(hoist 'block id)
(return #f))]
[(struct VariableInitializer (loc id init))
(begin-hoist
(hoist 'block id)
(init <- (hoist-expression init))
(return (make-AssignmentExpression loc (make-VarReference (Term-location id) id) '= init)))]))
(define (hoist-case-clause case)
(match case
[(struct CaseClause (loc #f answer))
(begin-hoist
(answer <- (hoist-substatements answer))
(return (make-CaseClause loc #f answer)))]
[(struct CaseClause (loc question answer))
(begin-hoist
(question <- (hoist-expression question))
(answer <- (hoist-substatements answer))
(return (make-CaseClause loc question answer)))]))
(define (hoist-catch-clause catch)
(match catch
[(struct CatchClause (loc id body0))
(begin-hoist
(body0 <- (hoist-statement body0))
(return (make-CatchClause loc id body0)))]))
(define (hoist-statement stmt)
(match stmt
[(struct BlockStatement (loc stmts))
(begin-hoist
(result <- (capture 'block (hoist-substatements stmts)))
(match-let ([(list hs stmts) result])
(let-values ([(vars funs) (partition Identifier? (map hoisted-element hs))])
(return (make-BlockStatement/hoisted loc stmts funs vars)))))]
[(struct ExpressionStatement (loc expr))
(begin-hoist
(expr <- (hoist-expression expr))
(return (make-ExpressionStatement loc expr)))]
[(struct IfStatement (loc test consequent0 #f))
(begin-hoist
(test <- (hoist-expression test))
(consequent <- (hoist-substatement consequent0))
(return (make-IfStatement loc test (optional-statement->statement consequent consequent0) #f)))]
[(struct IfStatement (loc test consequent0 alternate))
(begin-hoist
(test <- (hoist-expression test))
(consequent <- (hoist-substatement consequent0))
(alternate <- (hoist-substatement alternate))
(return (make-IfStatement loc test (optional-statement->statement consequent consequent0) alternate)))]
[(struct DoWhileStatement (loc body0 test))
(begin-hoist
(body <- (hoist-substatement body0))
(test <- (hoist-expression test))
(return (make-DoWhileStatement loc (optional-statement->statement body body0) test)))]
[(struct WhileStatement (loc test body0))
(begin-hoist
(test <- (hoist-expression test))
(body <- (hoist-substatement body0))
(return (make-WhileStatement loc test (optional-statement->statement body body0))))]
[(struct ForStatement (loc (? VariableDeclaration? init) test incr body0))
(begin-hoist
(init <- (hoist-variable-declaration init))
(test <- (hoist-optional-expression test))
(incr <- (hoist-optional-expression incr))
(body <- (hoist-substatement body0))
(return (make-ForStatement loc init test incr (optional-statement->statement body body0))))]
[(struct ForStatement (loc (struct LetDeclaration (loc* bindings)) test incr body0))
(let ([bindings* (for/list ([binding bindings])
(struct-copy VariableInitializer binding [init #f]))]
[inits* (make-ListExpression loc*
(for/list ([binding bindings])
(match binding
[(struct VariableInitializer (loc id init))
(let ([var (make-VarReference (Term-location id) id)])
(make-AssignmentExpression loc var '= init))])))])
(hoist-statement (make-LetStatement loc bindings* (make-ForStatement loc inits* test incr body0))))]
[(struct ForStatement (loc init test incr body0))
(begin-hoist
(init <- (hoist-optional-expression init))
(test <- (hoist-optional-expression test))
(incr <- (hoist-optional-expression incr))
(body <- (hoist-substatement body0))
(return (make-ForStatement loc init test incr (optional-statement->statement body body0))))]
[(struct ForInStatement (loc (? Expression? lhs) container body0))
(begin-hoist
(lhs <- (hoist-expression lhs))
(container <- (hoist-expression container))
(body <- (hoist-substatement body0))
(return (make-ForInStatement loc lhs container (optional-statement->statement body body0))))]
[(struct ForInStatement (loc (struct VariableDeclaration (_ (list (struct VariableInitializer (v-loc id #f))))) container body0))
(begin-hoist
(hoist 'function id)
(container <- (hoist-expression container))
(body <- (hoist-substatement body0))
(return (make-ForInStatement loc
(make-VarReference v-loc id)
container
(optional-statement->statement body body0))))]
[(struct ForInStatement (loc (struct LetDeclaration (_ (list (struct VariableInitializer (v-loc id #f))))) container body0))
(hoist-statement (make-LetStatement loc
(list (make-VariableInitializer v-loc id #f))
(make-ForInStatement loc
(make-VarReference (Term-location id) id)
container
body0)))]
[(struct ReturnStatement (loc expr))
(begin-hoist
(expr <- (hoist-optional-expression expr))
(return (make-ReturnStatement loc expr)))]
[(struct LetStatement (loc (list (struct VariableInitializer (binding-locs binding-names binding-inits)) ...) body0))
(begin-hoist
(binding-inits <- (map/m hoist-optional-expression binding-inits))
(body <- (hoist-substatement body0))
(return (make-LetStatement loc
(map make-VariableInitializer binding-locs binding-names binding-inits)
(optional-statement->statement body body0))))]
[(struct WithStatement (loc obj body0))
(begin-hoist
(obj <- (hoist-expression obj))
(body <- (hoist-substatement body0))
(return (make-WithStatement loc obj (optional-statement->statement body body0))))]
[(struct SwitchStatement (loc expr cases))
(begin-hoist
(expr <- (hoist-expression expr))
(cases <- (hoist-case-clauses cases))
(return (make-SwitchStatement loc expr cases)))]
[(struct LabelledStatement (loc label stmt0))
(begin-hoist
(stmt <- (hoist-substatement stmt0))
(return (make-LabelledStatement loc label (optional-statement->statement stmt stmt0))))]
[(struct ThrowStatement (loc expr))
(begin-hoist
(expr <- (hoist-expression expr))
(return (make-ThrowStatement loc expr)))]
[(struct TryStatement (loc body0 catch0 finally0))
(begin-hoist
(body <- (hoist-statement body0))
(catch <- (hoist-catch-clauses catch0))
(if finally0
(begin-hoist
(finally <- (hoist-statement finally0))
(return (make-TryStatement loc body catch finally)))
(return (make-TryStatement loc body catch #f))))]
[_ (return stmt)]))
(define (hoist-expression-block block)
(match block
[(struct ExpressionBlock (loc body tail))
(begin-hoist
(body <- (hoist-statement body))
(tail <- (hoist-optional-expression tail))
(return (make-ExpressionBlock loc body tail)))]))
(define (hoist-expression expr)
(match expr
[(struct ArrayLiteral (loc elts))
(begin-hoist
(elts <- (map/m hoist-optional-expression elts))
(return (make-ArrayLiteral loc elts)))]
[(struct ObjectLiteral (loc (list (cons props vals) ...)))
(begin-hoist
(vals <- (map/m hoist-expression vals))
(return (make-ObjectLiteral loc (map cons props vals))))]
[(struct BracketReference (loc container key))
(begin-hoist
(container <- (hoist-expression container))
(key <- (hoist-expression key))
(return (make-BracketReference loc container key)))]
[(struct DotReference (loc container id))
(begin-hoist
(container <- (hoist-expression container))
(return (make-DotReference loc container id)))]
[(struct NewExpression (loc ctor args))
(begin-hoist
(ctor <- (hoist-expression ctor))
(args <- (map/m hoist-expression args))
(return (make-NewExpression loc ctor args)))]
[(struct PostfixExpression (loc expr op))
(begin-hoist
(expr <- (hoist-expression expr))
(return (make-PostfixExpression loc expr op)))]
[(struct PrefixExpression (loc op expr))
(begin-hoist
(expr <- (hoist-expression expr))
(return (make-PrefixExpression loc op expr)))]
[(struct InfixExpression (loc left op right))
(begin-hoist
(left <- (hoist-expression left))
(right <- (hoist-expression right))
(return (make-InfixExpression loc left op right)))]
[(struct ConditionalExpression (loc test consequent alternate))
(begin-hoist
(test <- (hoist-expression test))
(consequent <- (hoist-expression consequent))
(alternate <- (hoist-expression alternate))
(return (make-ConditionalExpression loc test consequent alternate)))]
[(struct AssignmentExpression (loc left op right))
(begin-hoist
(left <- (hoist-expression left))
(right <- (hoist-expression right))
(return (make-AssignmentExpression loc left op right)))]
[(struct BlockLiteral (loc args body))
(begin-hoist
(body <- (hoist-expression-block body))
(return (make-BlockLiteral loc args body)))]
[(? FunctionExpression?)
(return (hoist-function-expression expr))]
[(struct LetExpression (loc (list (struct VariableInitializer (binding-locs binding-names binding-inits)) ...) body))
(begin-hoist
(binding-inits <- (map/m hoist-optional-expression binding-inits))
(body <- (hoist-expression body))
(return (make-LetExpression loc (map make-VariableInitializer binding-locs binding-names binding-inits) body)))]
[(struct CallExpression (loc method args))
(begin-hoist
(method <- (hoist-expression method))
(args <- (map/m hoist-expression args))
(return (make-CallExpression loc method args)))]
[(struct ParenExpression (loc expr))
(begin-hoist
(expr <- (hoist-expression expr))
(return (make-ParenExpression loc expr)))]
[(struct ListExpression (loc exprs))
(begin-hoist
(exprs <- (map/m hoist-expression exprs))
(return (make-ListExpression loc exprs)))]
[(struct DoExpression (loc block))
(begin-hoist
(block <- (hoist-expression-block block))
(return (make-DoExpression loc block)))]
[_ (return expr)]))
(provide hoist-program-unit
hoist-function-expression
(struct-out FunctionDeclaration/hoisted)
(struct-out FunctionExpression/hoisted)
(struct-out BlockStatement/hoisted))