(module cipher mzscheme
(require-for-syntax "stx-util.ss")
(require (lib "foreign.ss")
(lib "kw.ss")
(lib "plt-match.ss"))
(require "libcrypto.ss" "error.ss" "rand.ss" "util.ss")
(provide (all-defined))
(define-values (EVP_CIPHER_CTX_new EVP_CIPHER_CTX_free)
(if (ffi-available? 'EVP_CIPHER_CTX_new)
(values
(lambda/ffi (EVP_CIPHER_CTX_new) -> _pointer : pointer/error)
(lambda/ffi (EVP_CIPHER_CTX_free _pointer)))
(values
(lambda () (make-bytes 192))
(lambda/ffi (EVP_CIPHER_CTX_cleanup _pointer) -> _int : void))))
(define/ffi (EVP_CipherInit_ex _pointer _pointer (_pointer = #f)
_pointer _pointer _bool)
-> _int : check-error)
(define/ffi (EVP_CipherUpdate _pointer _pointer
(olen : (_ptr o _int)) _pointer _int)
-> _int : (lambda (f r) (check-error f r) olen))
(define/ffi (EVP_CipherFinal_ex _pointer _pointer (olen : (_ptr o _int)))
-> _int : (lambda (f r) (check-error f r) olen))
(define/ffi (EVP_CIPHER_CTX_set_padding _pointer _bool)
-> _int : check-error)
(define-struct cipher:algo (evp size keylen ivlen))
(define-struct cipher (algo ctx olen encrypt?))
(define (generate-cipher-key algo)
(let ((klen (cipher:algo-keylen algo))
(ivlen (cipher:algo-ivlen algo)))
(values (random-bytes klen) (and ivlen (pseudo-random-bytes ivlen)))))
(define (cipher-init algo key iv enc? pad?)
(let/error-fini ((ctx (EVP_CIPHER_CTX_new) EVP_CIPHER_CTX_free))
(EVP_CipherInit_ex ctx (cipher:algo-evp algo) key iv enc?)
(EVP_CIPHER_CTX_set_padding ctx pad?)
(let ((c (make-cipher algo ctx (cipher:algo-size algo) enc?)))
(register-finalizer c
(lambda (o) (cond ((cipher-ctx o) => EVP_CIPHER_CTX_free))))
c)))
(define (cipher-update c obs ibs ilen)
(cond
((cipher-ctx c) =>
(lambda (ctx) (EVP_CipherUpdate ctx obs ibs ilen)))
(else (error 'cipher-update "finalized context"))))
(define (cipher-final c obs)
(cond
((cipher-ctx c) =>
(lambda (ctx)
(let ((olen (EVP_CipherFinal_ex ctx obs)))
(EVP_CIPHER_CTX_free ctx)
(set-cipher-ctx! c #f)
olen)))
(else (error 'cipher-final "finalized context"))))
(define (cipher-new algo key iv enc? pad?)
(unless (>= (bytes-length key) (cipher:algo-keylen algo))
(error 'cipher-new "bad key"))
(when (cipher:algo-ivlen algo)
(unless (and iv (>= (bytes-length iv) (cipher:algo-ivlen algo)))
(error 'cipher-new "bad iv")))
(cipher-init algo key (if (cipher:algo-ivlen algo) iv #f) enc? pad?))
(define (cipher-maxlen c ilen)
(+ ilen (cipher-olen c)))
(define/kw (cipher-encrypt algo key iv #:key (pad? #:padding #t))
(cipher-new algo key iv #t pad?))
(define/kw (cipher-decrypt algo key iv #:key (pad? #:padding #t))
(cipher-new algo key iv #f pad?))
(define cipher-update!
(case-lambda
((c ibs)
(cipher-update! c ibs
(make-bytes (cipher-maxlen c (bytes-length ibs)))))
((c ibs obs)
(check-output-range 'cipher-update obs
(cipher-maxlen c (bytes-length ibs)))
(values obs (cipher-update c obs ibs (bytes-length ibs))))
((c ibs obs istart iend ostart oend)
(check-input-range 'cipher-update ibs istart iend)
(check-output-range 'cipher-update obs ostart oend
(cipher-maxlen c (- iend istart)))
(values obs (cipher-update c
(ptr-add obs ostart)
(ptr-add ibs istart) (- iend istart))))))
(define cipher-final!
(case-lambda
((c)
(cipher-final! c (make-bytes (cipher-olen c))))
((c obs)
(check-output-range 'cipher-final obs (cipher-olen c))
(values obs (cipher-final c obs)))
((c obs ostart oend)
(check-output-range 'cipher-final obs ostart oend (cipher-olen c))
(values obs (cipher-final c (ptr-add obs ostart))))))
(define-syntax define-cipher-prop
(syntax-rules ()
((_ prop op)
(define (prop c)
(cond
((cipher:algo? c) (op c))
((cipher? c) (op (cipher-algo c)))
(else (raise-type-error 'prop "cipher or cipher algorithm" c)))))))
(define-cipher-prop cipher-block-size cipher:algo-size)
(define-cipher-prop cipher-key-length cipher:algo-keylen)
(define-cipher-prop cipher-iv-length cipher:algo-ivlen)
(define (cipher-port cipher inp outp)
(let* ((1b (cipher-block-size cipher))
(2bs (* 2 1b))
(ibuf (make-bytes 1b))
(obuf (make-bytes 2bs)))
(let lp ((icount (read-bytes-avail! ibuf inp)))
(if (eof-object? icount)
(let-values (((bs ocount)
(cipher-final! cipher obuf)))
(void (write-bytes obuf outp 0 ocount)))
(let-values (((bs ocount)
(cipher-update! cipher ibuf obuf 0 icount 0 2bs)))
(write-bytes obuf outp 0 ocount)
(lp (read-bytes-avail! ibuf inp)))))))
(define-syntax define/cipher-port
(syntax-rules ()
((_ op init)
(define op
(case-lambda
((algo key iv)
(let-values (((cipher) (init algo key iv))
((rd1 wr1) (make-pipe))
((rd2 wr2) (make-pipe)))
(thread (lambda ()
(cipher-port cipher rd1 wr2)
(close-input-port rd1)
(close-output-port wr2)))
(values rd2 wr1)))
((algo key iv inp)
(cond
((bytes? inp)
(let ((outp (open-output-bytes)))
(cipher-port (init algo key iv) (open-input-bytes inp) outp)
(get-output-bytes outp)))
((input-port? inp)
(let-values (((cipher) (init algo key iv))
((rd wr) (make-pipe)))
(thread (lambda ()
(cipher-port cipher inp wr)
(close-output-port wr)))
rd))
(else (raise-type-error 'op "bytes or input-port" inp))))
((algo key iv inp outp)
(unless (output-port? outp)
(raise-type-error 'op "output-port" outp))
(cond
((bytes? inp)
(cipher-port (init algo key iv) (open-input-bytes inp) outp))
((input-port? inp)
(cipher-port (init algo key iv) inp outp))
(else (raise-type-error 'op "bytes or input-port" inp)))))))))
(define/cipher-port encrypt cipher-encrypt)
(define/cipher-port decrypt cipher-decrypt)
(define (c->props evp)
(match (ptr-ref evp (_list-struct _int _int _int _int))
((list _ size keylen ivlen)
(values size keylen (and (> ivlen 0) ivlen)))))
(define *ciphers* null)
(define (available-ciphers) *ciphers*)
(define-for-syntax cipher-modes '(ecb cbc cfb ofb))
(define-for-syntax default-mode 'cbc)
(define-syntax (define-cipher stx)
(define (unhyphen what) (regexp-replace* "-" what "_"))
(define (make-cipher mode)
(with-syntax
((evp (->stx stx (make-symbol "EVP_" (unhyphen mode))))
(cipher (->stx stx (make-symbol "cipher:" mode))))
#`(define cipher
(if (ffi-available? 'evp)
(let ((evpp ((lambda/ffi (evp) -> _pointer))))
(call/values
(lambda () (c->props evpp))
(lambda (size keylen ivlen)
(make-cipher:algo evpp size keylen ivlen))))
#f))))
(define (make name)
(with-syntax
((cipher (->stx stx (make-symbol "cipher:" name)))
(alias (->stx stx (make-symbol "cipher:" name "-" default-mode)))
(provider (->stx stx (make-symbol "provide:cipher:" name))))
(let ((modes (map (lambda (mode) (format "~a-~a" name mode))
cipher-modes)))
#`(begin
#,@(map make-cipher modes)
(define cipher
(begin (when alias (push! *ciphers* '#,(make-symbol name)))
alias))
(define-syntax provider
(syntax-rules ()
((_)
(provide cipher
#,@(map (lambda (mode)
(->stx stx (make-symbol "cipher:" mode)))
modes)))))))))
(define (make-meta-provider name ks)
(let* ((base (make-symbol "provide:cipher:" name))
(provs (map (lambda (k) (make-symbol base "-" k)) ks)))
(with-syntax ((provider (->stx stx base)))
#`(define-syntax provider
(syntax-rules ()
((_)
(begin #,@(map (lambda (p) (->stx stx (list p))) provs))))))))
(syntax-case stx ()
((_ c)
(make (->datum #'c)))
((_ c (klen ...))
(let ((name (->string (->datum #'c)))
(ks (map (lambda (k) (->datum k)) (syntax->list #'(klen ...)))))
#`(begin
#,@(map (lambda (k) (make (format "~a-~a" name k))) ks)
#,(make-meta-provider name ks))))))
(define-cipher des)
(define-cipher des-ede)
(define-cipher des-ede3)
(define-cipher idea)
(define-cipher bf)
(define-cipher cast5)
(define-cipher aes (128 192 256))
(define-cipher camellia (128 192 256))
(define-syntax provide:cipher
(syntax-rules ()
((_)
(begin
(provide available-ciphers cipher? cipher-encrypt?
cipher-block-size cipher-key-length cipher-iv-length
cipher-encrypt cipher-decrypt cipher-update! cipher-final!
encrypt decrypt)
(provide:cipher:des)
(provide:cipher:des-ede)
(provide:cipher:des-ede3)
(provide:cipher:idea)
(provide:cipher:bf)
(provide:cipher:cast5)
(provide:cipher:aes)
(provide:cipher:camellia)))))
)