#lang scheme/unit
(require mzlib/etc
srfi/13/string
(planet untyped/unlib:3/symbol)
"../base.ss"
"../era/era.ss"
"../generic/sql-data-sig.ss"
"../generic/sql-name-sig.ss"
"../generic/sql-update-helpers-sig.ss"
"../generic/sql-update-sig.ss")
(import sql-data^
sql-name^)
(export sql-update^
sql-update-helpers^)
(define (create-sql entity)
(define table-name (entity-table-name entity))
(define sequence-name (symbol-append table-name '_seq))
(format "CREATE SEQUENCE ~a; CREATE TABLE ~a (~a);"
(escape-name sequence-name)
(escape-name table-name)
(create-fields-sql entity)))
(define (drop-sql table)
(define table-name
(cond [(entity? table) (entity-table-name table)]
[(symbol? table) table]
[else (raise-exn exn:fail:snooze
(format "Expceted (U entity symbol), received ~s" table))]))
(define sequence-name
(symbol-append table-name '_seq))
(format "DROP TABLE IF EXISTS ~a; DROP SEQUENCE IF EXISTS ~a;"
(escape-name table-name)
(escape-name sequence-name)))
(define (insert-sql struct [preserve-ids? #f])
(car (insert-multiple-sql (list struct) preserve-ids?)))
(define (insert-multiple-sql structs [preserve-ids? #f])
(cond [(not (list? structs))
(raise-exn exn:fail:contract
(format "Expected (listof persistent-struct), received ~s" structs))]
[(null? structs) null]
[else
(begin-with-definitions
(define entity (struct-entity (car structs)))
(check-insert-structs entity structs preserve-ids?)
(define table-name
(escape-name (entity-table-name entity)))
(define col-names
(string-join (if preserve-ids?
(column-names entity)
(cddr (column-names entity)))
", "))
(define col-values
(string-join (map (lambda (struct)
(string-append "(" (string-join (if preserve-ids?
(column-values struct)
(cddr (column-values struct)))
", ") ")"))
structs)
", "))
(list (string-append "INSERT INTO " table-name " (" col-names ") VALUES " col-values ";")))]))
(define (update-sql struct)
(define entity (struct-entity struct))
(define id (struct-id struct))
(if id
(string-append "UPDATE " (escape-name (entity-table-name entity)) " SET "
(string-join (map (lambda (attr value)
(define name (attribute-column-name attr))
(define type (attribute-type attr))
(string-append (escape-name name) " = " (escape-value type value)))
(cdr (entity-attributes entity))
(cdr (struct-attributes struct)))
", ")
" WHERE " (escape-name 'id) " = " (escape-value type:id id) ";")
(raise-exn exn:fail:snooze
(format "ID must be non-#f to perform an UPDATE: ~s" struct))))
(define (delete-sql guid)
(define name (entity-table-name (guid-entity guid)))
(define id (guid-id guid))
(string-append "DELETE FROM " (escape-name name) " WHERE " (escape-name 'id) " = " (escape-value type:id id) ";"))
(define (check-insert-structs entity structs [skip-id-check? #f])
(for-each (lambda (struct)
(unless (eq? entity (struct-entity struct))
(raise-exn exn:fail:snooze
(format "All structs must be of the same type: ~s" structs)))
(when (and (struct-id struct) (not skip-id-check?))
(raise-exn exn:fail:snooze
(format "ID must be #f to perform an INSERT: ~s" struct))))
structs))
(define (create-fields-sql entity)
(define seq-name (symbol-append (entity-table-name entity) '_seq))
(string-join (list* (string-append (escape-name 'id) " INTEGER PRIMARY KEY "
"DEFAULT nextval('" (escape-name seq-name) "')")
(string-append (escape-name 'revision) "INTEGER NOT NULL DEFAULT 0")
(map (lambda (attr)
(define name (attribute-column-name attr))
(define type (attribute-type attr))
(string-append (escape-name name) " " (type-definition-sql type)))
(cddr (entity-attributes entity))))
", "))
(define (column-names entity)
(map (lambda (attr)
(escape-name (attribute-column-name attr)))
(entity-attributes entity)))
(define (column-values struct)
(define entity (struct-entity struct))
(map (lambda (attr value)
(escape-value (attribute-type attr) value))
(entity-attributes entity)
(struct-attributes struct)))
(define (type-definition-sql type)
(string-join (append (list (cond [(guid-type? type) (format "INTEGER REFERENCES ~a.~a"
(escape-name (entity-name (guid-type-entity type)))
(escape-name 'id))]
[(boolean-type? type) "BOOLEAN"]
[(integer-type? type) "INTEGER"]
[(real-type? type) "REAL"]
[(string-type? type) (string-type-definition-sql (string-type-max-length type))]
[(symbol-type? type) (string-type-definition-sql (symbol-type-max-length type))]
[(time-tai-type? type) "TIMESTAMP WITHOUT TIME ZONE"]
[(time-utc-type? type) "TIMESTAMP WITHOUT TIME ZONE"]
[else (raise-exn exn:fail:snooze (format "Unrecognised type: ~a" type))]))
(if (type-allows-null? type)
null
(list "NOT NULL"))
(list (string-append "DEFAULT " (escape-value type (type-default type)))))
" "))
(define (string-type-definition-sql max-length)
(if max-length
(format "CHARACTER VARYING (~a)" max-length)
"TEXT"))