#lang scheme/base
(require (for-syntax scheme/base)
mzlib/kw
scheme/contract
scheme/match
srfi/19/time
srfi/26/cut
(planet untyped/unlib:3/list)
(planet untyped/unlib:3/symbol)
"../base.ss"
"../era/era.ss"
"sql-struct.ss"
"sql-util.ss")
(define sql:alias
(match-lambda*
[(list (? symbol? id) (? entity? item))
(make-entity-alias id item)]
[(list (? symbol? id) (? query? item))
(make-query-alias id item)]
[(list (? symbol? id) (? non-alias-expression? item))
(make-expression-alias id item)]
[(list (and alias (struct entity-alias (name entity))) (? attribute? attr))
(define entity (source-alias-value alias))
(if (memq (attribute-entity attr) (list entity entity:persistent-struct))
(make-attribute-alias alias attr)
(raise-exn exn:fail:contract
(format "Entity does not contain that attribute: ~a ~a" entity attr)))]
[other (raise-exn exn:fail:contract (format "Bad arguments to sql:alias: ~s" other))]))
(define sql:entity
(case-lambda
[(entity) (sql:entity (gensym (entity-name entity)) entity)]
[(id entity) (make-entity-alias id entity)]))
(define (sql:attr alias attr+name)
(make-attribute-alias alias (entity-attribute (source-alias-value alias) attr+name)))
(define (non-alias-expression? item)
(and (expression? item)
(not (expression-alias? item))
(not (attribute-alias? item))))
(define (sql:select #:what [what #f]
#:distinct [distinct #f]
#:from from
#:where [where #f]
#:group [group null]
#:order [order null]
#:having [having #f]
#:limit [limit #f]
#:offset [offset #f])
(sql:select/internal what distinct from where group order having limit offset))
(define (sql:select/internal what* distinct* from* where group* order having limit offset)
(define from
(quote-argument from*))
(define-values (what expand-info)
(expand-what-argument (if what* what* (make-default-what-argument from))))
(define distinct
(cond [(expression? distinct*) (list distinct*)]
[(eq? distinct* #t) (list)]
[(eq? distinct* #f) #f]
[else (raise-select-exn #:distinct "(U expression (listof expression) #t #f)" distinct*)]))
(define group
(expand-group-argument group*))
(unless (source? from)
(raise-select-exn #:from "source" from))
(unless (or (expression? where) (not where))
(raise-select-exn #:where "(U expression #f)" where))
(unless (and (list? group) (andmap expression? group))
(raise-select-exn #:group "(listof expression)" group))
(unless (and (list? order) (andmap order? order))
(raise-select-exn #:order "(listof order)" order))
(unless (or (expression? having) (not having))
(raise-select-exn #:having "(U expression #f)" having))
(unless (or (integer? limit) (not limit))
(raise-select-exn #:limit "(U integer #f)" limit))
(unless (or (integer? offset) (not offset))
(raise-select-exn #:offset "(U integer #f)" offset))
(let*-values ( [(sources) (source->sources from)]
[(local-columns imported-columns) (source->columns from)]
[(columns) (append local-columns imported-columns)]
[(columns*) (append what columns)])
(check-what-clause what sources columns)
(check-distinct-clause distinct sources columns)
(check-where-clause where sources columns)
(check-group-clause group sources columns*)
(check-order-clause order sources columns*)
(check-having-clause having sources columns*)
(make-query what distinct from where group order having limit offset local-columns imported-columns expand-info)))
(define (raise-select-exn kw expected received)
(raise-exn exn:fail:contract
(format "~a argument to select: expected ~a, received ~s" kw expected received)))
(define (sql:inner left right op)
(create-join 'inner left right (quote-argument op)))
(define (sql:left left right op)
(create-join 'left left right (quote-argument op)))
(define (sql:right left right op)
(create-join 'right left right (quote-argument op)))
(define (sql:outer left right)
(create-join 'outer left right #f))
(define (create-join op left right on)
(define ans (make-join op (quote-argument left) (quote-argument right) on))
(check-join ans)
ans)
(define (sql:count arg)
(let ([arg (quote-argument arg)])
(make-aggregate type:integer 'count (list arg))))
(define sql:count*
(case-lambda
[() (make-aggregate type:integer 'count* null)]
[(alias) (make-aggregate type:integer 'count* (list alias))]))
(define (sql:min arg)
(let ([arg (quote-argument arg)])
(make-aggregate (expression-type arg) 'min (list arg))))
(define (sql:max arg)
(let ([arg (quote-argument arg)])
(make-aggregate (expression-type arg) 'max (list arg))))
(define (sql:average alias)
(let ([arg (quote-argument alias)])
(if (numeric-expression? arg)
(make-aggregate type:real 'average (list arg))
(raise-type-error 'average (list arg)))))
(define-function (sql:and . args)
[(andmap boolean-expression? args) type:boolean])
(define-function (sql:or . args)
[(andmap boolean-expression? args) type:boolean])
(define-function (sql:not arg)
[(boolean-expression? arg) type:boolean])
(define-function (sql:+ . args)
[(andmap integer-expression? args) type:integer]
[(andmap numeric-expression? args) type:real]
[(andmap time-tai-expression? args) type:time-tai]
[(andmap temporal-expression? args) type:time-utc])
(define-function (sql:- . args)
[(andmap integer-expression? args) type:integer]
[(andmap numeric-expression? args) type:real]
[(andmap time-tai-expression? args) type:time-tai]
[(andmap temporal-expression? args) type:time-utc])
(define-function (sql:* . args)
[(andmap integer-expression? args) type:integer]
[(andmap numeric-expression? args) type:real])
(define-function (sql:/ arg1 arg2)
[(and (numeric-expression? arg1) (numeric-expression? arg2)) type:real])
(define-function (sql:abs arg)
[(integer-expression? arg) type:integer]
[(numeric-expression? arg) type:real])
(define-function (sql:floor arg)
[(numeric-expression? arg) type:integer])
(define-function (sql:ceiling arg)
[(numeric-expression? arg) type:integer])
(define-function (sql:round arg)
[(numeric-expression? arg) type:integer])
(define-function (sql:like arg1 arg2)
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean])
(define-function (sql:regexp-match arg1 arg2)
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean])
(define-function (sql:regexp-match-ci arg1 arg2)
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean])
(define-function (sql:string-append . args)
[(andmap symbol-expression? args) type:symbol]
[(andmap character-expression? args) type:string])
(define-function (sql:string-replace arg1 arg2 arg3)
[(andmap symbol-expression? (list arg1 arg2 arg3)) type:symbol]
[(andmap character-expression? (list arg1 arg2 arg3)) type:string])
(define-function (sql:regexp-replace arg1 arg2 arg3)
[(andmap symbol-expression? (list arg1 arg2 arg3)) type:symbol]
[(andmap character-expression? (list arg1 arg2 arg3)) type:string])
(define-function (sql:regexp-replace-ci arg1 arg2 arg3)
[(andmap symbol-expression? (list arg1 arg2 arg3)) type:symbol]
[(andmap character-expression? (list arg1 arg2 arg3)) type:string])
(define-function (sql:regexp-replace* arg1 arg2 arg3)
[(andmap symbol-expression? (list arg1 arg2 arg3)) type:symbol]
[(andmap character-expression? (list arg1 arg2 arg3)) type:string])
(define-function (sql:regexp-replace*-ci arg1 arg2 arg3)
[(andmap symbol-expression? (list arg1 arg2 arg3)) type:symbol]
[(andmap character-expression? (list arg1 arg2 arg3)) type:string])
(define-function (sql:= arg1 arg2)
[#t type:boolean])
(define-function (sql:<> arg1 arg2)
[#t type:boolean])
(define-function (sql:< arg1 arg2)
[(and (boolean-expression? arg1) (boolean-expression? arg2)) type:boolean]
[(and (numeric-expression? arg1) (numeric-expression? arg2)) type:boolean]
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean]
[(and (temporal-expression? arg1) (temporal-expression? arg2)) type:boolean])
(define-function (sql:> arg1 arg2)
[(and (boolean-expression? arg1) (boolean-expression? arg2)) type:boolean]
[(and (numeric-expression? arg1) (numeric-expression? arg2)) type:boolean]
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean]
[(and (temporal-expression? arg1) (temporal-expression? arg2)) type:boolean])
(define-function (sql:>= arg1 arg2)
[(and (boolean-expression? arg1) (boolean-expression? arg2)) type:boolean]
[(and (numeric-expression? arg1) (numeric-expression? arg2)) type:boolean]
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean]
[(and (temporal-expression? arg1) (temporal-expression? arg2)) type:boolean])
(define-function (sql:<= arg1 arg2)
[(and (boolean-expression? arg1) (boolean-expression? arg2)) type:boolean]
[(and (numeric-expression? arg1) (numeric-expression? arg2)) type:boolean]
[(and (character-expression? arg1) (character-expression? arg2)) type:boolean]
[(and (temporal-expression? arg1) (temporal-expression? arg2)) type:boolean])
(define-function (sql:null? arg)
[#t type:boolean])
(define-function (sql:coalesce . args)
[(andmap boolean-expression? args) type:boolean]
[(andmap integer-expression? args) type:integer]
[(andmap numeric-expression? args) type:real]
[(andmap symbol-expression? args) type:symbol]
[(andmap string-expression? args) type:string]
[(andmap time-tai-expression? args) type:time-tai]
[(andmap temporal-expression? args) type:time-utc])
(define-function (sql:->string arg1 arg2)
[(string-expression? arg2) type:string])
(define-function (sql:->symbol arg1 arg2)
[(string-expression? arg2) type:symbol])
(define (sql:in arg1 arg2)
(let* ([arg1 (quote-argument arg1)]
[arg2 (if (list? arg2)
(map quote-argument arg2)
arg2)]
[type1 (expression-type arg1)]
[type2 (cond [(null? arg2) type1] [(pair? arg2) (let ([type2 (expression-type (car arg2))])
(unless (andmap (cut type-compatible? type2 <>)
(map expression-type arg2))
(raise-exn exn:fail:contract
(format "sql:in: list elements must all be of the same type: ~a" arg2)))
type2)]
[(query? arg2) (let ([columns (query-what arg2)])
(unless (= (length columns) 1)
(raise-exn exn:fail:contract
(format "sql:in: subquery must have exactly one column: ~a" arg2)))
(expression-type (car columns)))])])
(unless (type-compatible? type1 type2)
(raise-exn exn:fail:contract
(format "sql:in: type mismatch: argument types do not match: ~a ~a" (type-name type1) (type-name type2))))
(make-function type:boolean 'in (list arg1 arg2))))
(define sql:if
(case-lambda
[(test pos)
(define pos* (quote-argument pos))
(sql:if test pos* (sql:null (expression-type pos*)))]
[(test pos neg)
(let* ([test (quote-argument test)]
[pos (quote-argument pos)]
[neg (quote-argument neg)]
[type (cond [(and (boolean-expression? test) (andmap boolean-expression? (list pos neg))) type:boolean]
[(and (boolean-expression? test) (andmap integer-expression? (list pos neg))) type:integer]
[(and (boolean-expression? test) (andmap numeric-expression? (list pos neg))) type:real]
[(and (boolean-expression? test) (andmap symbol-expression? (list pos neg))) type:symbol]
[(and (boolean-expression? test) (andmap character-expression? (list pos neg))) type:string]
[(and (boolean-expression? test) (andmap time-tai-expression? (list pos neg))) type:time-tai]
[(and (boolean-expression? test) (andmap temporal-expression? (list pos neg))) type:time-utc]
[else (raise-exn exn:fail:snooze
(format "Function not defined for the supplied argument types: ~a"
(cons 'id (map type-name (map expression-type (list test pos neg))))))])])
(make-function type 'if (list test pos neg)))]))
(define-syntax sql:cond
(syntax-rules (else)
[(_ [test expr1])
(sql:if test expr1)]
[(_ [test expr1] [else expr2])
(sql:if test expr1 expr2)]
[(_ [test1 expr1] [test2 expr2] ...)
(sql:if test1 expr1 (sql:cond [test2 expr2] ...))]))
(define sql:literal make-literal)
(define (sql:null type)
(make-null type))
(define (sql:order expr dir)
(make-order (quote-argument expr) dir))
(define sql:asc (cut sql:order <> 'asc))
(define sql:desc (cut sql:order <> 'desc))
(provide (rename-out [sql:alias alias]
[sql:cond cond]))
(provide/contract
[rename sql:select select (->* (#:from (or/c source? query?))
(#:what (or/c expression? source-alias? (listof (or/c expression? source-alias?)) false/c)
#:distinct (or/c expression? (listof expression?) boolean?)
#:where (or/c expression? false/c)
#:group (listof (or/c column? source-alias?))
#:order (listof order?)
#:having (or/c expression? false/c)
#:limit (or/c integer? false/c)
#:offset (or/c integer? false/c))
query?)]
[rename sql:select/internal select/internal (-> (or/c expression? source-alias? (listof (or/c expression? source-alias?)) false/c)
(or/c expression? (listof expression?) boolean?)
(or/c source? query?)
(or/c expression? false/c)
(listof (or/c column? source-alias?))
(listof order?)
(or/c expression? false/c)
(or/c integer? false/c)
(or/c integer? false/c)
query?)]
[rename sql:entity entity (case-> (-> entity? entity-alias?)
(-> symbol? entity? entity-alias?))]
[rename sql:attr attr (-> entity-alias? (or/c attribute? symbol?) attribute-alias?)]
[rename sql:count count (-> attribute-alias? aggregate?)]
[rename sql:count* count* (->* () ((or/c entity-alias? query-alias?)) aggregate?)]
[rename sql:min min (-> quotable? aggregate?)]
[rename sql:max max (-> quotable? aggregate?)]
[rename sql:average average (-> quotable? aggregate?)]
[rename sql:inner inner (-> source+query? source+query? quotable? join?)]
[rename sql:left left (-> source+query? source+query? quotable? join?)]
[rename sql:right right (-> source+query? source+query? quotable? join?)]
[rename sql:outer outer (-> source+query? source+query? join?)]
[rename sql:in in (-> quotable? (or/c query? (listof quotable?)) function?)]
[rename sql:if if (->* (quotable? quotable?) (quotable?) function?)]
[rename sql:literal literal (-> quotable? literal?)]
[rename sql:null null (-> type? literal?)]
[rename sql:order order (-> quotable? (symbols 'asc 'desc) order?)]
[rename sql:asc asc (-> quotable? order?)]
[rename sql:desc desc (-> quotable? order?)])