#lang scheme/base
(require (for-syntax scheme/base
"../base.ss")
scheme/contract
scheme/match
(only-in srfi/1/list append-map)
srfi/26/cut
(planet untyped/unlib:3/symbol)
"../base.ss"
"../era/era.ss"
"sql-struct.ss")
(define source->sources
(match-lambda
[(? join? join) (append (source->sources (join-left join)) (source->sources (join-right join)))]
[(? query-alias? alias) (list alias)]
[(? entity-alias? alias) (list alias)]))
(define source->columns
(match-lambda
[(? join? join)
(define-values (left-local left-imported)
(source->columns (join-left join)))
(define-values (right-local right-imported)
(source->columns (join-right join)))
(values (append left-local right-local)
(append left-imported right-imported))]
[(? query-alias? alias)
(values null (source-alias-columns alias))]
[(? entity-alias? alias)
(values (source-alias-columns alias) null)]))
(define (make-default-what-argument from)
(if (join? from)
(source->sources from)
(car (source->sources from))))
(define (expand-what-argument argument)
(if (list? argument)
(expand-what-list argument)
(expand-what-item argument)))
(define (expand-what-list argument)
(for/fold ([what-accum null] [info-accum null])
([arg argument])
(define-values (what-term info-term)
(expand-what-item arg))
(values (append what-accum what-term)
(append info-accum (listify info-term)))))
(define expand-what-item
(match-lambda
[(and argument (struct attribute-alias (type _ _ attr)))
(values (list argument)
type)]
[(? expression-alias? argument)
(values (list argument)
(expression-type argument))]
[(? expression? expr)
(values (list (make-expression-alias (gensym 'expr) expr))
(expression-type expr))]
[(? entity-alias? alias)
(values (source-alias-columns alias)
(source-alias-value alias))]
[(? query-alias? alias)
(values (source-alias-columns alias)
(query-extract-info (source-alias-value alias)))]))
(define (expand-group-argument group)
(append-map expand-group-item group))
(define expand-group-item
(match-lambda
[(? column? column) (list column)]
[(? source-alias? alias) (source-alias-columns alias)]))
(define (check-what-clause what sources columns)
(define (check-item item)
(cond [(attribute-alias? item)
(check-attribute-in-scope 'what-clause item columns)]
[(expression-alias? item)
(with-handlers ([exn? (lambda _ (check-expression 'what-clause (expression-alias-value item) sources columns))])
(check-expression-in-scope 'what-clause item columns))]
[else (raise-type-error 'what-clause "(U attribute-alias expression-alias)" item)]))
(if (list? what)
(for-each check-item what)
(check-item what)))
(define (check-distinct-clause distinct sources columns)
(when distinct
(for-each (cut check-expression 'distinct-clause <> sources columns)
distinct)))
(define (check-where-clause where sources columns)
(when where
(check-expression 'where-clause where sources columns)))
(define (check-group-clause group sources columns)
(for-each (cut check-expression 'group-clause <> sources columns) group))
(define (check-order-clause order sources columns)
(for-each (cut check-expression 'order-clause <> sources columns)
(map order-expression order)))
(define (check-having-clause having sources columns)
(when having
(check-expression 'having-clause having sources columns)))
(define (make-expression-predicate . types)
(match-lambda
[(struct expression (type))
(ormap (cut type-compatible? type <>) types)]
[_ #f]))
(define boolean-expression?
(make-expression-predicate type:boolean))
(define integer-expression?
(make-expression-predicate type:integer))
(define real-expression?
(make-expression-predicate type:real))
(define numeric-expression?
(make-expression-predicate type:integer type:real))
(define string-expression?
(make-expression-predicate type:string))
(define symbol-expression?
(make-expression-predicate type:symbol))
(define character-expression?
(make-expression-predicate type:string type:symbol))
(define time-utc-expression?
(make-expression-predicate type:time-utc))
(define time-tai-expression?
(make-expression-predicate type:time-tai))
(define temporal-expression?
(make-expression-predicate type:time-utc type:time-tai))
(define-syntax (define-function stx)
(define (remove-prefix sym)
(let ([match (regexp-match #rx"^sql:(.*)$" (symbol->string sym))])
(if match
(string->symbol (cadr match))
(raise-exn exn:fail:snooze
(format "define-function identifier does not have 'sql:' prefix: ~a" sym)))))
(syntax-case stx (else)
[(_ (id arg ...) [rule type] ...)
(identifier? #'id)
(with-syntax ([plain-id (remove-prefix (syntax->datum #'id))]
[(arg-contract ...) (map (lambda _ #'quotable?) (syntax->list #'(arg ...)))])
#'(begin (define (id arg ...)
(let ([arg (quote-argument arg)] ...)
(make-function (cond [rule type] ...
[else (raise-exn exn:fail:snooze
(format "Function not defined for the supplied argument types: ~a"
(cons 'id (map type-name (map expression-type (list arg ...))))))])
'plain-id
(list arg ...))))
(provide/contract [rename id plain-id (-> arg-contract ... function?)])))]
[(_ (id . args) [rule type] ...)
(identifier? #'id)
(with-syntax ([plain-id (remove-prefix (syntax->datum #'id))])
#'(begin (define (id . args)
(let ([args (map quote-argument args)])
(make-function (cond [rule type] ...
[else (raise-exn exn:fail:snooze
(format "Function not defined for the supplied argument types: ~a"
(cons 'id (map type-name (map expression-type args)))))])
'plain-id
args)))
(provide/contract [rename id plain-id (->* () () #:rest (listof quotable?) function?)])))]))
(define (listify item)
(if (or (pair? item) (null? item))
item
(list item)))
(define (check-repeated-sources sources)
(let loop ([sources sources] [names (map source-alias-name sources)])
(match sources
[(list) (void)]
[(list-rest curr rest)
(when (memq (car names) (cdr names))
(raise-exn exn:fail:contract
(format "~a: source selected more than once: ~a ~s" 'from-clause (car names) (car sources))))
(loop (cdr sources) (cdr names))])))
(define (check-repeated-columns columns)
(let loop ([columns columns] [names (map column-name columns)])
(match columns
[(list) (void)]
[(list-rest curr rest)
(when (memq (car names) (cdr names))
(raise-exn exn:fail:contract
(format "~a: column selected more than once: ~a ~s" 'from-clause (car names) (car columns))))
(loop (cdr columns) (cdr names))])))
(define (check-attribute-in-scope name attr columns)
(unless (member attr columns)
(raise-exn exn:fail:contract
(format "~a: attribute not in scope: ~s" name attr))))
(define (check-expression-in-scope name expr columns)
(unless (memq expr columns)
(raise-exn exn:fail:contract
(format "~a: expression not in scope: ~s" name expr))))
(define (check-source-in-scope name source sources)
(unless (memq source sources)
(raise-exn exn:fail:contract
(format "~a: source not in scope: ~s" name source))))
(define (check-join j)
(define sources (source->sources j))
(define columns (call-with-values (cut source->columns j) append))
(match j
[(struct join (op left right on))
(check-repeated-sources sources)
(check-repeated-columns columns)
(when on
(check-expression op on sources columns)
(check-no-aggregates op on))]))
(define (check-expression name expr sources columns)
(cond [(attribute-alias? expr) (check-attribute-in-scope name expr columns)]
[(expression-alias? expr) (check-expression-in-scope name expr columns)]
[(function? expr) (for-each (cut check-expression name <> sources columns)
(if (and (eq? (function-op expr) 'in)
(query? (cadr (function-args expr))))
(list (car (function-args expr)))
(function-args expr)))]
[(literal? expr) (void)]
[(source-alias? expr) (check-source-in-scope name expr sources)]
[(list? expr) (for-each (cut check-expression name <> sources columns) expr)]))
(define (check-no-aggregates name expr)
(cond [(aggregate? expr) (raise-exn exn:fail:snooze
(format "~a: aggregates not allowed: ~s" name expr))]
[(function? expr) (for-each (cut check-no-aggregates name <>) (function-args expr))]
[(list? expr) (for-each (cut check-no-aggregates name <>) expr)]
[else (void)]))
(define (opt-listof item/c)
(or/c item/c (listof item/c)))
(provide check-join
check-what-clause
check-distinct-clause
check-where-clause
check-group-clause
check-order-clause
check-having-clause
check-no-aggregates
boolean-expression?
integer-expression?
real-expression?
string-expression?
symbol-expression?
time-utc-expression?
time-tai-expression?
numeric-expression?
character-expression?
temporal-expression?
define-function)
(provide/contract
[source->sources (-> source? (listof source/c))]
[source->columns (-> source? (values (listof column?) (listof column?)))]
[make-default-what-argument (-> source? (opt-listof source/c))]
[expand-what-argument (-> (opt-listof (or/c expression? source/c))
(values (listof column?)
(opt-listof (or/c entity? type?))))]
[expand-group-argument (-> (opt-listof (or/c expression? source/c))
(listof column?))])