#lang scheme/base
(require scheme/match
srfi/13/string
srfi/26/cut
(planet untyped/unlib:3/syntax)
"../base.ss"
(prefix-in sql: "sql-lang.ss")
"sql-struct.ss"
"sql-syntax-util.ss"
(for-template scheme/base
(prefix-in sql: "sql-lang.ss")))
(define (or-expand message [backtrace null])
(match-lambda*
[(list)
(lambda (stx)
(raise-exn exn:fail:snooze
(format "Expanding ~s: no match for rule ~s~n~a"
(syntax->datum stx)
message
(string-join (filter (lambda (item)
(not (regexp-match #rx": bad syntax" item)))
backtrace)
"\n"))))]
[(list-rest expand rest)
(lambda (stx)
(with-handlers
([exn? (lambda (exn)
(define backtrace*
(cons (exn-message exn) backtrace))
(define next-expand
(or-expand message backtrace*))
((apply next-expand rest) stx))])
(expand stx)))]))
(define (expand-unquote stx)
(syntax-case* stx (unquote) symbolic-identifier=?
[(unquote expr) #'expr]))
(define (expand-identifier stx)
(syntax-case stx ()
[id (identifier? #'id)
(if (sql-identifier? #'id)
#'id
(raise-syntax-error 'expand-identifier
(format "not an SQL identifier: ~a" (syntax->datum #'id))
stx
#'id))]))
(define (expand-literal stx)
(syntax-case* stx (quote quasiquote) symbolic-identifier=?
[(quote literal) (self-quoting-literal? #'literal) #'(sql:literal (quote literal))]
[(quasiquote literal) (self-quoting-literal? #'literal) #'(sql:literal (quasiquote literal))]
[literal (self-quoting-literal? #'literal) #'(sql:literal literal)]))
(define (expand-true+false stx)
(syntax-case* stx () symbolic-identifier=?
[#t #'#t]
[#f #'#f]))
(define expand-literal+unquote
((or-expand "literal+unquote")
expand-literal
expand-unquote))
(define (expand-integer stx)
(syntax-case* stx (quote quasiquote) symbolic-identifier=?
[(quote literal) (integer? (syntax->datum #'literal)) #'literal]
[(quasiquote literal) (integer? (syntax->datum #'literal)) #'literal]
[literal (integer? (syntax->datum #'literal)) #'literal]))
(define expand-integer+unquote
((or-expand "integer+unquote")
expand-integer
expand-unquote))
(define (expand-join stx)
(syntax-case* stx (outer inner left right) symbolic-identifier=?
[(outer a b) #`(sql:outer #,(expand-source+unquote #'a) #,(expand-source+unquote #'b))]
[(inner a b on) #`(sql:inner #,(expand-source+unquote #'a) #,(expand-source+unquote #'b) #,(expand-expression+unquote #'on))]
[(left a b on) #`(sql:left #,(expand-source+unquote #'a) #,(expand-source+unquote #'b) #,(expand-expression+unquote #'on))]
[(right a b on) #`(sql:right #,(expand-source+unquote #'a) #,(expand-source+unquote #'b) #,(expand-expression+unquote #'on))]))
(define expand-source
((or-expand "source")
expand-join
expand-identifier))
(define expand-source+unquote
((or-expand "source+unquote")
expand-source
expand-unquote))
(define (expand-function stx)
(syntax-case* stx (and or not + - * / abs floor ceiling round = <> < <= > >=
like regexp-match regexp-match-ci regexp-replace regexp-replace* regexp-replace-ci regexp-replace*-ci
string-append string-replace null? coalesce ->string ->symbol in if) symbolic-identifier=?
[(and arg ...) #`(sql:and #,@(map expand-expression+unquote (syntax->list #'(arg ...))))]
[(or arg ...) #`(sql:or #,@(map expand-expression+unquote (syntax->list #'(arg ...))))]
[(not arg) #`(sql:not #,@(map expand-expression+unquote (syntax->list #'(arg))))]
[(+ arg ...) #`(sql:+ #,@(map expand-expression+unquote (syntax->list #'(arg ...))))]
[(- arg ...) #`(sql:- #,@(map expand-expression+unquote (syntax->list #'(arg ...))))]
[(* arg ...) #`(sql:* #,@(map expand-expression+unquote (syntax->list #'(arg ...))))]
[(/ arg ...) #`(sql:/ #,@(map expand-expression+unquote (syntax->list #'(arg ...))))]
[(abs arg) #`(sql:abs #,@(map expand-expression+unquote (syntax->list #'(arg))))]
[(floor arg) #`(sql:floor #,@(map expand-expression+unquote (syntax->list #'(arg))))]
[(ceiling arg) #`(sql:ceiling #,@(map expand-expression+unquote (syntax->list #'(arg))))]
[(round arg) #`(sql:round #,@(map expand-expression+unquote (syntax->list #'(arg))))]
[(= arg1 arg2) #`(sql:= #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2))))]
[(<> arg1 arg2) #`(sql:<> #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2))))]
[(< arg1 arg2) #`(sql:< #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2))))]
[(<= arg1 arg2) #`(sql:<= #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2))))]
[(> arg1 arg2) #`(sql:> #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2))))]
[(>= arg1 arg2) #`(sql:>= #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2))))]
[(like arg1 arg2) #`(sql:like #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2))))]
[(regexp-match arg1 arg2) #`(sql:regexp-match #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2))))]
[(regexp-match-ci arg1 arg2) #`(sql:regexp-match-ci #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2))))]
[(regexp-replace arg1 arg2 arg3) #`(sql:regexp-replace #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2 arg3))))]
[(regexp-replace* arg1 arg2 arg3) #`(sql:regexp-replace* #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2 arg3))))]
[(regexp-replace-ci arg1 arg2 arg3) #`(sql:regexp-replace-ci #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2 arg3))))]
[(regexp-replace*-ci arg1 arg2 arg3) #`(sql:regexp-replace*-ci #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2 arg3))))]
[(string-append arg1 arg2) #`(sql:string-append #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2))))]
[(string-replace arg1 arg2 arg3) #`(sql:string-replace #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2 arg3))))]
[(null? arg) #`(sql:null? #,@(map expand-expression+unquote (syntax->list #'(arg))))]
[(coalesce arg ...) #`(sql:coalesce #,@(map expand-expression+unquote (syntax->list #'(arg ...))))]
[(->string arg1 arg2) #`(sql:->string #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2))))]
[(->symbol arg1 arg2) #`(sql:->symbol #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2))))]
[(in arg1 arg2) #`(sql:in #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2))))]
[(if arg1 arg2 arg3) #`(sql:if #,@(map expand-expression+unquote (syntax->list #'(arg1 arg2 arg3))))]))
(define (expand-aggregate stx)
(syntax-case* stx (count* count min max average) symbolic-identifier=?
[(count*) #`(sql:count*)]
[(count* alias) #`(sql:count* #,(expand-identifier #'alias))]
[(count alias) #`(sql:count #,(expand-identifier #'alias))]
[(min alias) #`(sql:min #,(expand-identifier #'alias))]
[(max alias) #`(sql:max #,(expand-identifier #'alias))]
[(average alias) #`(sql:average #,(expand-identifier #'alias))]))
(define expand-expression
((or-expand "expression")
expand-function
expand-aggregate
expand-identifier
expand-literal))
(define expand-expression+unquote
((or-expand "expression+unquote")
expand-expression
expand-unquote))
(define expand-distinct
((or-expand "expression")
expand-true+false
expand-function
expand-aggregate
expand-identifier
expand-literal))
(define expand-distinct+unquote
((or-expand "distinct+unquote")
expand-distinct
expand-unquote))
(define expand-column
((or-expand "column")
expand-expression
expand-identifier))
(define expand-column+unquote
((or-expand "column+unquote")
expand-column
expand-unquote))
(define (expand-column-list stx)
#`(list #,@(map expand-column+unquote (syntax->list stx))))
(define expand-column-list+unquote
((or-expand "column-list+unquote")
expand-column-list
expand-unquote))
(define expand-column+column-list
((or-expand "column+column-list")
expand-column+unquote
expand-column-list))
(define expand-column+column-list+unquote
((or-expand "column+column-list")
expand-column+column-list
expand-unquote))
(define (expand-direction stx)
(syntax-case* stx (quote quasiquote asc desc) symbolic-identifier=?
[(quote asc) #'(quote asc)]
[(quote desc) #'(quote desc)]
[(quasiquote asc) #'(quasiquote asc)]
[(quasiquote desc) #'(quasiquote desc)]))
(define expand-direction+unquote
((or-expand "direction+unquote")
expand-direction
expand-unquote))
(define (expand-order stx)
(syntax-case* stx (asc desc ord order) symbolic-identifier=?
[(asc expr) #`(sql:asc #,(expand-expression+unquote #'expr))]
[(desc expr) #`(sql:desc #,(expand-expression+unquote #'expr))]
[(ord expr dir) #`(sql:order #,(expand-expression+unquote #'expr) #,(expand-direction+unquote #'dir))]
[(order expr dir) #`(sql:order #,(expand-expression+unquote #'expr) #,(expand-direction+unquote #'dir))]))
(define expand-order+unquote
((or-expand "order+unquote")
expand-order
expand-unquote))
(define (expand-order-list stx)
#`(list #,@(map expand-order+unquote (syntax->list stx))))
(define expand-order-list+unquote
((or-expand "order-list+unquote")
expand-order-list
expand-unquote))
(define (expand-select-arguments select-stx args)
(match args
[(list) null]
[(list (and key-stx (app syntax->datum key)))
(if (keyword? key)
(raise-syntax-error #f (format "Missing value for ~s" key) select-stx key-stx)
(raise-syntax-error #f (format "Expected keyword, received ~s" key) select-stx key-stx))]
[(list-rest (and key-stx (app syntax->datum key))
(and arg-stx (app syntax->datum arg))
rest)
(cond [(not (keyword? key)) (raise-syntax-error #f (format "Expected keyword, received ~s" key) select-stx key-stx)]
[(keyword? arg) (raise-syntax-error #f (format "Missing value for ~s" arg) select-stx arg-stx)]
[(eq? key '#:what) (list* key-stx
(expand-column+column-list+unquote arg-stx)
(expand-select-arguments select-stx rest))]
[(eq? key '#:from) (list* key-stx
(expand-source+unquote arg-stx)
(expand-select-arguments select-stx rest))]
[(eq? key '#:where) (list* key-stx
(expand-expression+unquote arg-stx)
(expand-select-arguments select-stx rest))]
[(eq? key '#:group) (list* key-stx
(expand-column-list+unquote arg-stx)
(expand-select-arguments select-stx rest))]
[(eq? key '#:order) (list* key-stx
(expand-order-list+unquote arg-stx)
(expand-select-arguments select-stx rest))]
[(eq? key '#:having) (list* key-stx
(expand-expression+unquote arg-stx)
(expand-select-arguments select-stx rest))]
[(eq? key '#:limit) (list* key-stx
(expand-integer+unquote arg-stx)
(expand-select-arguments select-stx rest))]
[(eq? key '#:offset) (list* key-stx
(expand-integer+unquote arg-stx)
(expand-select-arguments select-stx rest))]
[(eq? key '#:distinct) (list* key-stx
(expand-distinct+unquote arg-stx)
(expand-select-arguments select-stx rest))]
[else (raise-syntax-error #f (format "Unknown select keyword: ~s" key) select-stx key-stx)])]))
(define (expand-select stx)
(syntax-case* stx (select) symbolic-identifier=?
[(select kw+expr ...)
#`(sql:select #,@(expand-select-arguments stx (syntax->list #'(kw+expr ...))))]))
(define expand-top-level
((or-expand "sql")
expand-select
expand-column
expand-source
expand-expression
expand-order))
(provide expand-top-level)