#lang racket/base
(require racket/class
racket/list
racket/math
ffi/unsafe
"../generic/interfaces.rkt"
"../generic/prepared.rkt"
"../generic/sql-data.rkt"
"../generic/sql-convert.rkt"
"ffi.rkt"
"ffi-constants.rkt"
"dbsystem.rkt")
(provide connection%
handle-status*
dbsystem)
(define connection%
(class* transactions% (connection<%>)
(init-private db
env
notice-handler
char-mode)
(init strict-parameter-types?)
(define statement-table (make-weak-hasheq))
(define lock (make-semaphore 1))
(define use-describe-param?
(and strict-parameter-types?
(let-values ([(status supported?) (SQLGetFunctions db SQL_API_SQLDESCRIBEPARAM)])
(handle-status 'odbc-connect status db)
supported?)))
(inherit call-with-lock
call-with-lock*
add-delayed-call!
check-valid-tx-status)
(inherit-field tx-status)
(define/public (get-db fsym)
(unless db
(error/not-connected fsym))
db)
(define/public (get-dbsystem) dbsystem)
(define/override (connected?) (and db #t))
(define/public (query fsym stmt)
(let-values ([(stmt* dvecs rows)
(call-with-lock fsym
(lambda ()
(check-valid-tx-status fsym)
(query1 fsym stmt)))])
(statement:after-exec stmt*)
(cond [(pair? dvecs) (recordset (map field-dvec->field-info dvecs) rows)]
[else (simple-result '())])))
(define/private (query1 fsym stmt)
(let* ([stmt (cond [(string? stmt)
(let* ([pst (prepare1 fsym stmt #t)])
(send pst bind fsym null))]
[(statement-binding? stmt)
stmt])]
[pst (statement-binding-pst stmt)]
[params (statement-binding-params stmt)])
(send pst check-owner fsym this stmt)
(let ([result-dvecs (send pst get-result-dvecs)])
(for ([dvec (in-list result-dvecs)])
(let ([typeid (field-dvec->typeid dvec)])
(unless (supported-typeid? typeid)
(error/unsupported-type fsym typeid)))))
(let-values ([(dvecs rows) (query1:inner fsym pst params)])
(values stmt dvecs rows))))
(define/private (query1:inner fsym pst params)
(let* ([db (get-db fsym)]
[stmt (send pst get-handle)])
(let* ([param-bufs
(for/list ([i (in-naturals 1)]
[param (in-list params)]
[param-typeid (in-list (send pst get-param-typeids))])
(load-param fsym db stmt i param param-typeid))])
(handle-status fsym (SQLExecute stmt) stmt)
(strong-void param-bufs))
(let* ([result-dvecs (send pst get-result-dvecs)]
[rows
(and (pair? result-dvecs)
(fetch* fsym stmt (map field-dvec->typeid result-dvecs)))])
(handle-status fsym (SQLFreeStmt stmt SQL_CLOSE) stmt)
(handle-status fsym (SQLFreeStmt stmt SQL_RESET_PARAMS) stmt)
(values result-dvecs rows))))
(define/private (load-param fsym db stmt i param typeid)
(define (bind ctype sqltype buf)
(let* ([lenbuf
(int->buffer (if buf (bytes-length buf) SQL_NULL_DATA))]
[status
(SQLBindParameter stmt i SQL_PARAM_INPUT ctype sqltype 0 0 buf lenbuf)])
(handle-status fsym status stmt)
(if buf (cons buf lenbuf) lenbuf)))
(define unknown-type? (= typeid SQL_UNKNOWN_TYPE))
(cond [(string? param)
(case char-mode
((wchar)
(bind SQL_C_WCHAR (if unknown-type? SQL_WVARCHAR typeid)
(case WCHAR-SIZE
((2) (cpstr2 param))
((4) (cpstr4 param)))))
((utf-8)
(bind SQL_C_CHAR (if unknown-type? SQL_VARCHAR typeid)
(copy-buffer (string->bytes/utf-8 param))))
((latin-1)
(bind SQL_C_CHAR (if unknown-type? SQL_VARCHAR typeid)
(copy-buffer (string->bytes/latin-1 param (char->integer #\?))))))]
[(bytes? param)
(bind SQL_C_BINARY (if unknown-type? SQL_BINARY typeid)
(copy-buffer param))]
[(pair? param) (bind SQL_C_NUMERIC typeid
(copy-buffer
(let ([ma (car param)]
[ex (cdr param)])
(apply bytes-append
(bytes (if (zero? ma) 1 (+ 1 (order-of-magnitude (abs ma))))
ex
(if (negative? ma) 0 1))
(let loop ([i 0] [ma (abs ma)])
(if (< i 4)
(let-values ([(q r) (quotient/remainder ma (expt 2 32))])
(cons (integer->integer-bytes r 4 #f #f)
(loop (add1 i) q)))
null))))))]
[(real? param)
(cond [(or (= typeid SQL_NUMERIC) (= typeid SQL_DECIMAL))
(bind SQL_C_CHAR typeid
(copy-buffer (marshal-decimal fsym i param)))]
[(or (and unknown-type? (int32? param))
(= typeid SQL_INTEGER)
(= typeid SQL_SMALLINT)
(= typeid SQL_BIGINT)
(= typeid SQL_TINYINT))
(if (= typeid SQL_BIGINT)
(bind SQL_C_SBIGINT SQL_BIGINT
(copy-buffer (integer->integer-bytes param 8 #t)))
(bind SQL_C_LONG (if unknown-type? SQL_INTEGER typeid)
(copy-buffer (integer->integer-bytes param 4 #t))))]
[else
(bind SQL_C_DOUBLE (if unknown-type? SQL_DOUBLE typeid)
(copy-buffer
(real->floating-point-bytes (exact->inexact param) 8)))])]
[(boolean? param)
(bind SQL_C_LONG SQL_BIT
(copy-buffer (int->buffer (if param 1 0))))]
[(sql-date? param)
(bind SQL_C_TYPE_DATE SQL_TYPE_DATE
(copy-buffer
(let* ([x param]
[y (sql-date-year x)]
[m (sql-date-month x)]
[d (sql-date-day x)])
(bytes-append (integer->integer-bytes y 2 #t)
(integer->integer-bytes m 2 #f)
(integer->integer-bytes d 2 #f)))))]
[(sql-time? param)
(bind SQL_C_TYPE_TIME SQL_TYPE_TIME
(copy-buffer
(let* ([x param]
[h (sql-time-hour x)]
[m (sql-time-minute x)]
[s (sql-time-second x)])
(bytes-append (integer->integer-bytes h 2 #f)
(integer->integer-bytes m 2 #f)
(integer->integer-bytes s 2 #f)))))]
[(sql-timestamp? param)
(bind SQL_C_TYPE_TIMESTAMP
(if unknown-type? SQL_TYPE_TIMESTAMP typeid)
(copy-buffer
(let ([x param])
(bytes-append
(integer->integer-bytes (sql-timestamp-year x) 2 #f)
(integer->integer-bytes (sql-timestamp-month x) 2 #f)
(integer->integer-bytes (sql-timestamp-day x) 2 #f)
(integer->integer-bytes (sql-timestamp-hour x) 2 #f)
(integer->integer-bytes (sql-timestamp-minute x) 2 #f)
(integer->integer-bytes (sql-timestamp-second x) 2 #f)
(integer->integer-bytes (sql-timestamp-nanosecond x) 4 #f)))))]
[(sql-null? param)
(bind SQL_C_CHAR SQL_VARCHAR #f)]
[else (error/internal fsym "cannot convert to typeid ~a: ~e" typeid param)]))
(define/private (fetch* fsym stmt result-typeids)
(let ([scratchbuf (make-bytes 50)])
(let loop ()
(let ([c (fetch fsym stmt result-typeids scratchbuf)])
(if c (cons c (loop)) null)))))
(define/private (fetch fsym stmt result-typeids scratchbuf)
(let ([s (SQLFetch stmt)])
(cond [(= s SQL_NO_DATA) #f]
[(= s SQL_SUCCESS)
(let* ([column-count (length result-typeids)]
[vec (make-vector column-count)])
(for ([i (in-range column-count)]
[typeid (in-list result-typeids)])
(vector-set! vec i (get-column fsym stmt (add1 i) typeid scratchbuf)))
vec)]
[else (handle-status fsym s stmt)])))
(define/private (get-column fsym stmt i typeid scratchbuf)
(define-syntax-rule (get-num size ctype convert convert-arg ...)
(let-values ([(status ind) (SQLGetData stmt i ctype scratchbuf 0)])
(handle-status fsym status stmt)
(cond [(= ind SQL_NULL_DATA) sql-null]
[else (convert scratchbuf convert-arg ... 0 size)])))
(define (get-int size ctype)
(get-num size ctype integer-bytes->integer #t (system-big-endian?)))
(define (get-real ctype)
(get-num 8 ctype floating-point-bytes->real (system-big-endian?)))
(define (get-int-list sizes ctype)
(let* ([buflen (apply + sizes)]
[buf (if (<= buflen (bytes-length scratchbuf)) scratchbuf (make-bytes buflen))])
(let-values ([(status ind) (SQLGetData stmt i ctype buf 0)])
(handle-status fsym status stmt)
(cond [(= ind SQL_NULL_DATA) sql-null]
[else (let ([in (open-input-bytes buf)])
(for/list ([size (in-list sizes)])
(case size
((1) (read-byte in))
((2) (integer-bytes->integer (read-bytes 2 in) #f))
((4) (integer-bytes->integer (read-bytes 4 in) #f))
(else (error/internal
'get-int-list "bad size: ~e" size)))))]))))
(define (get-varbuf ctype ntlen convert)
(define (loop buf start rchunks)
(let-values ([(status len-or-ind) (SQLGetData stmt i ctype buf start)])
(handle-status fsym status stmt #:ignore-ok/info? #t)
(cond [(= len-or-ind SQL_NULL_DATA) sql-null]
[(= len-or-ind SQL_NO_TOTAL)
(let* ([data-end (- (bytes-length buf) ntlen)])
(loop buf 0 (cons (subbytes buf 0 data-end) rchunks)))]
[else
(let ([len (+ start len-or-ind)])
(cond [(<= 0 len (- (bytes-length buf) ntlen))
(cond [(pair? rchunks)
(let* ([chunk (subbytes buf 0 (+ len ntlen))]
[chunks (append (reverse rchunks) (list chunk))]
[complete (apply bytes-append chunks)])
(convert complete (- (bytes-length complete) ntlen) #t))]
[else
(convert buf len #f)])]
[else
(let* ([len-got (- (bytes-length buf) ntlen)]
[newbuf (make-bytes (+ len ntlen))])
(bytes-copy! newbuf 0 buf start len-got)
(loop newbuf len-got rchunks))]))])))
(loop scratchbuf 0 null))
(define (get-string/latin-1)
(get-varbuf SQL_C_CHAR 1
(lambda (buf len _fresh?)
(bytes->string/latin-1 buf #f 0 len))))
(define (get-string/utf-8)
(get-varbuf SQL_C_CHAR 1
(lambda (buf len _fresh?)
(bytes->string/utf-8 buf #f 0 len))))
(define (get-string)
(case char-mode
((wchar)
(get-varbuf SQL_C_WCHAR WCHAR-SIZE (case WCHAR-SIZE ((2) mkstr2) ((4) mkstr4))))
((utf-8)
(get-string/utf-8))
((latin-1)
(get-string/latin-1))))
(define (get-bytes)
(get-varbuf SQL_C_BINARY 0
(lambda (buf len fresh?)
(if (and fresh? (= len (bytes-length buf)))
buf
(subbytes buf 0 len)))))
(cond [(or (= typeid SQL_CHAR)
(= typeid SQL_VARCHAR)
(= typeid SQL_LONGVARCHAR)
(= typeid SQL_WCHAR)
(= typeid SQL_WVARCHAR)
(= typeid SQL_WLONGVARCHAR))
(get-string)]
[(or (= typeid SQL_DECIMAL)
(= typeid SQL_NUMERIC))
(let ([fields (get-int-list '(1 1 1 4 4 4 4) SQL_C_NUMERIC)])
(cond [(list? fields)
(let* ([precision (first fields)]
[scale (second fields)]
[sign (case (third fields) ((0) -1) ((1) 1))]
[ma (let loop ([lst (cdddr fields)])
(if (pair? lst)
(+ (* (loop (cdr lst)) (expt 2 32))
(car lst))
0))])
(* sign ma (expt 10 (- scale))))]
[(sql-null? fields) sql-null]))]
[(or (= typeid SQL_SMALLINT)
(= typeid SQL_INTEGER)
(= typeid SQL_TINYINT))
(get-int 4 SQL_C_LONG)]
[(or (= typeid SQL_BIGINT))
(get-int 8 SQL_C_SBIGINT)]
[(or (= typeid SQL_REAL)
(= typeid SQL_FLOAT)
(= typeid SQL_DOUBLE))
(get-real SQL_C_DOUBLE)]
[(or (= typeid SQL_BIT))
(case (get-int 4 SQL_C_LONG)
((0) #f)
((1) #t)
(else 'get-column "internal error: SQL_BIT"))]
[(or (= typeid SQL_BINARY)
(= typeid SQL_VARBINARY))
(get-bytes)]
[(= typeid SQL_TYPE_DATE)
(let ([fields (get-int-list '(2 2 2) SQL_C_TYPE_DATE)])
(cond [(list? fields) (apply sql-date fields)]
[(sql-null? fields) sql-null]))]
[(= typeid SQL_TYPE_TIME)
(let ([fields (get-int-list '(2 2 2) SQL_C_TYPE_TIME)])
(cond [(list? fields) (apply sql-time (append fields (list 0 #f)))]
[(sql-null? fields) sql-null]))]
[(= typeid SQL_TYPE_TIMESTAMP)
(let ([fields (get-int-list '(2 2 2 2 2 2 4) SQL_C_TYPE_TIMESTAMP)])
(cond [(list? fields) (apply sql-timestamp (append fields (list #f)))]
[(sql-null? fields) sql-null]))]
[else (get-string)]))
(define/public (prepare fsym stmt close-on-exec?)
(call-with-lock fsym
(lambda ()
(check-valid-tx-status fsym)
(prepare1 fsym stmt close-on-exec?))))
(define/private (prepare1 fsym sql close-on-exec?)
(let* ([stmt
(let*-values ([(db) (get-db fsym)]
[(status stmt) (SQLAllocHandle SQL_HANDLE_STMT db)])
(handle-status fsym status db)
(with-handlers ([(lambda (e) #t)
(lambda (e)
(SQLFreeHandle SQL_HANDLE_STMT stmt)
(raise e))])
(let ([status (SQLPrepare stmt sql)])
(handle-status fsym status stmt)
stmt)))]
[param-typeids (describe-params fsym stmt)]
[result-dvecs (describe-result-columns fsym stmt)])
(let ([pst (new prepared-statement%
(handle stmt)
(close-on-exec? close-on-exec?)
(owner this)
(param-typeids param-typeids)
(result-dvecs result-dvecs))])
(hash-set! statement-table pst #t)
pst)))
(define/private (describe-params fsym stmt)
(let-values ([(status param-count) (SQLNumParams stmt)])
(handle-status fsym status stmt)
(for/list ([i (in-range 1 (add1 param-count))])
(cond [use-describe-param?
(let-values ([(status type size digits nullable)
(SQLDescribeParam stmt i)])
(handle-status fsym status stmt)
type)]
[else SQL_UNKNOWN_TYPE]))))
(define/private (describe-result-columns fsym stmt)
(let-values ([(status result-count) (SQLNumResultCols stmt)]
[(scratchbuf) (make-bytes 200)])
(handle-status fsym status stmt)
(for/list ([i (in-range 1 (add1 result-count))])
(let-values ([(status name type size digits nullable)
(SQLDescribeCol stmt i scratchbuf)])
(handle-status fsym status stmt)
(vector name type size digits)))))
(define/public (disconnect)
(define (go)
(let ([db* db]
[env* env])
(when db*
(let ([statements (hash-map statement-table (lambda (k v) k))])
(set! db #f)
(set! env #f)
(set! statement-table #f)
(for ([pst (in-list statements)])
(free-statement* 'disconnect pst))
(handle-status 'disconnect (SQLDisconnect db*) db*)
(handle-status 'disconnect (SQLFreeHandle SQL_HANDLE_DBC db*))
(handle-status 'disconnect (SQLFreeHandle SQL_HANDLE_ENV env*))
(void)))))
(call-with-lock* 'disconnect go go #f))
(define/public (free-statement pst)
(define (go) (free-statement* 'free-statement pst))
(call-with-lock* 'free-statement go go #f))
(define/private (free-statement* fsym pst)
(let ([stmt (send pst get-handle)])
(when stmt
(send pst set-handle #f)
(handle-status 'free-statement (SQLFreeStmt stmt SQL_CLOSE) stmt)
(handle-status 'free-statement (SQLFreeHandle SQL_HANDLE_STMT stmt) stmt)
(void))))
(define/public (transaction-status fsym)
(call-with-lock fsym
(lambda () (let ([db (get-db fsym)]) tx-status))))
(define/public (start-transaction fsym isolation)
(call-with-lock fsym
(lambda ()
(let* ([db (get-db fsym)])
(when tx-status
(error/already-in-tx fsym))
(let* ([ok-levels
(let-values ([(status value) (SQLGetInfo db SQL_TXN_ISOLATION_OPTION)])
(begin0 value (handle-status fsym status db)))]
[default-level
(let-values ([(status value) (SQLGetInfo db SQL_DEFAULT_TXN_ISOLATION)])
(begin0 value (handle-status fsym status db)))]
[requested-level
(case isolation
((serializable) SQL_TXN_SERIALIZABLE)
((repeatable-read) SQL_TXN_REPEATABLE_READ)
((read-committed) SQL_TXN_READ_COMMITTED)
((read-uncommitted) SQL_TXN_READ_UNCOMMITTED)
(else
(if (zero? default-level) SQL_TXN_SERIALIZABLE default-level)))])
(when (zero? (bitwise-and requested-level ok-levels))
(uerror fsym "requested isolation level ~a is not available" isolation))
(let ([status (SQLSetConnectAttr db SQL_ATTR_TXN_ISOLATION requested-level)])
(handle-status fsym status db)))
(let ([status (SQLSetConnectAttr db SQL_ATTR_AUTOCOMMIT SQL_AUTOCOMMIT_OFF)])
(handle-status fsym status db)
(set! tx-status #t)
(void))))))
(define/public (end-transaction fsym mode)
(call-with-lock fsym
(lambda ()
(unless (eq? mode 'rollback)
(check-valid-tx-status fsym))
(let ([db (get-db fsym)]
[completion-type
(case mode
((commit) SQL_COMMIT)
((rollback) SQL_ROLLBACK))])
(handle-status fsym (SQLEndTran db completion-type) db)
(let ([status (SQLSetConnectAttr db SQL_ATTR_AUTOCOMMIT SQL_AUTOCOMMIT_ON)])
(handle-status fsym status db)
(set! tx-status #f)
(void))))))
(define/public (get-tables fsym catalog schema table)
(define-values (dvecs rows)
(call-with-lock fsym
(lambda ()
(let* ([db (get-db fsym)]
[stmt (let-values ([(status stmt) (SQLAllocHandle SQL_HANDLE_STMT db)])
(handle-status fsym status db)
stmt)]
[_ (handle-status fsym (SQLTables stmt catalog schema table))]
[result-dvecs (describe-result-columns fsym stmt)]
[rows (fetch* fsym stmt (map field-dvec->typeid result-dvecs))])
(handle-status fsym (SQLFreeStmt stmt SQL_CLOSE) stmt)
(handle-status fsym (SQLFreeHandle SQL_HANDLE_STMT stmt) stmt)
(values result-dvecs rows)))))
(recordset (map field-dvec->field-info dvecs)
rows))
(define add-notice! (lambda (sqlstate message)
(add-delayed-call! (lambda () (notice-handler sqlstate message)))))
(define/private (handle-status who s [handle #f]
#:ignore-ok/info? [ignore-ok/info? #f])
(define (handle-error e)
(let ([db db])
(when db
(when tx-status
(set! tx-status 'invalid))))
(raise e))
(with-handlers ([exn:fail? handle-error])
(handle-status* who s handle
#:ignore-ok/info? ignore-ok/info?
#:on-notice add-notice!)))
(super-new)
(register-finalizer this (lambda (obj) (send obj disconnect)))))
(define (handle-status* who s [handle #f]
#:ignore-ok/info? [ignore-ok/info? #f]
#:on-notice [on-notice void])
(cond [(= s SQL_SUCCESS_WITH_INFO)
(when (and handle (not ignore-ok/info?))
(diag-info who handle 'notice on-notice))
s]
[(= s SQL_ERROR)
(when handle (diag-info who handle 'error #f))
(uerror who "unknown error (no diagnostic returned)")]
[else s]))
(define (diag-info who handle mode on-notice)
(let ([handle-type
(cond [(sqlhenv? handle) SQL_HANDLE_ENV]
[(sqlhdbc? handle) SQL_HANDLE_DBC]
[(sqlhstmt? handle) SQL_HANDLE_STMT]
[else
(error/internal 'diag-info "unknown handle type: ~e" handle)])])
(let-values ([(status sqlstate native-errcode message)
(SQLGetDiagRec handle-type handle 1)])
(case mode
((error)
(raise-sql-error who sqlstate message
`((code . ,sqlstate)
(message . ,message)
(native-errcode . ,native-errcode))))
((notice)
(on-notice sqlstate message))))))
(define (field-dvec->field-info dvec)
`((name . ,(vector-ref dvec 0))
(typeid . ,(vector-ref dvec 1))
(size . ,(vector-ref dvec 2))
(digits . ,(vector-ref dvec 3))))
(define (field-dvec->typeid dvec)
(vector-ref dvec 1))