(module matrix mzscheme
(require (lib "foreign.ss")
(lib "etc.ss")
(all-except (lib "contract.ss") ->)
(rename (lib "contract.ss") ->/c ->)
(all-except (lib "42.ss" "srfi") :)
(all-except (planet "srfi-4-comprehensions.ss" ("wmfarr" "srfi-4-comprehensions.plt")) :)
"blas-lapack.ss")
(define (list/length/c n)
(flat-named-contract
(format "list of length ~a" n)
(lambda (l) (= (length l) n))))
(define (matrix-multiplication-compatible/c m)
(flat-named-contract
(format "compatible for multiplication by a ~a by ~a matrix" (matrix-rows m) (matrix-cols m))
(lambda (m2)
(= (matrix-cols m) (matrix-rows m2)))))
(define (matrix-same-dimensions/c m)
(flat-named-contract
(format "~a by ~a matrix" (matrix-rows m) (matrix-cols m))
(lambda (m2)
(and (= (matrix-rows m) (matrix-rows m2))
(= (matrix-cols m) (matrix-cols m2))))))
(define (matrix-valid-row-index/c m)
(let ((r (matrix-rows m)))
(flat-named-contract
(format "valid row index for a ~a by ~a matrix" r (matrix-cols m))
(lambda (i) (and (>= i 0)
(< i r))))))
(define (matrix-valid-col-index/c m)
(let ((c (matrix-cols m)))
(flat-named-contract
(format "valid column index for ~a by ~a matrix" (matrix-rows m) c)
(lambda (j) (and (>= j 0)
(< j c))))))
(define (matrix-col-vector-compatible/c m)
(let ((c (matrix-cols m)))
(flat-named-contract
(format "column vector of length ~a" c)
(lambda (v) (= (f64vector-length v) c)))))
(define (matrix-row-vector-compatible/c m)
(let ((r (matrix-rows m)))
(flat-named-contract
(format "row vector of length ~a" r)
(lambda (v) (= (f64vector-length v) r)))))
(define matrix-square/c
(flat-named-contract
"square matrix"
(lambda (m) (= (matrix-rows m) (matrix-cols m)))))
(define-struct matrix
(ptr rows cols) #f)
(provide matrix? matrix-multiplication-compatible/c matrix-valid-row-index/c
matrix-valid-col-index/c matrix-square/c matrix-same-dimensions/c
matrix-col-vector-compatible/c matrix-row-vector-compatible/c
matrix-ec :matrix
_matrix
struct:matrix)
(provide/contract
(rename my-make-matrix make-matrix
(->/c natural-number/c natural-number/c number? matrix?))
(rename my-matrix matrix
(->r ((i natural-number/c)
(j natural-number/c))
elts (and/c (listof number?)
(list/length/c (* i j)))
matrix?))
(matrix-rows (->/c matrix? natural-number/c))
(matrix-cols (->/c matrix? natural-number/c))
(matrix-ref (->r ((m matrix?)
(i (and/c natural-number/c
(matrix-valid-row-index/c m)))
(j (and/c natural-number/c
(matrix-valid-col-index/c m))))
number?))
(matrix-set! (->r ((m matrix?)
(i (and/c natural-number/c
(matrix-valid-row-index/c m)))
(j (and/c natural-number/c
(matrix-valid-col-index/c m)))
(x number?))
any))
(matrix-add (->r ((m1 matrix?)
(m2 (and/c matrix?
(matrix-same-dimensions/c m1))))
matrix?))
(matrix-sub (->r ((m1 matrix?)
(m2 (and/c matrix?
(matrix-same-dimensions/c m1))))
matrix?))
(matrix-scale (->/c matrix? number? matrix?))
(matrix-mul (->r ((m1 matrix?)
(m2 (and/c matrix?
(matrix-multiplication-compatible/c m1))))
matrix?))
(matrix-f64vector-mul (->r ((m matrix?)
(v (and/c f64vector?
(matrix-col-vector-compatible/c m))))
f64vector?))
(f64vector-matrix-mul (->r ((v (and/c f64vector?
(matrix-row-vector-compatible/c m)))
(m matrix?))
f64vector?))
(matrix-inverse (->/c (and/c matrix? matrix-square/c) matrix?))
(matrix-norm (->/c matrix? number?))
(matrix-identity (->/c natural-number/c matrix?))
(matrix-transpose (->/c matrix? matrix?))
(matrix-solve (->r ((m matrix-square/c)
(v (and/c f64vector?
(let ((r (matrix-rows m)))
(flat-named-contract
(format "column vector of length ~a" r)
(lambda (v) (= (f64vector-length v) r)))))))
f64vector?))
(matrix-solve-many (->r ((m1 matrix-square/c)
(m2 (and/c matrix?
(matrix-multiplication-compatible/c m1))))
matrix?)))
(unsafe!)
(define my-make-matrix
(case-lambda
((rows cols)
(let* ((n (* rows cols))
(p (malloc n _double 'atomic)))
(memset p 0 n _double)
(make-matrix p rows cols)))
((rows cols elt)
(let* ((m (my-make-matrix rows cols))
(p (matrix-ptr m)))
(do-ec (:range i (* rows cols))
(ptr-set! p _double* i elt))
m))))
(define (my-matrix i j . elts)
(let* ((m (my-make-matrix i j))
(p (matrix-ptr m)))
(do-ec (:parallel (:range k (* i j))
(:list elt elts))
(ptr-set! p _double* k elt))
m))
(define _matrix*
(make-ctype _pointer matrix-ptr
(lambda (x)
(error '_matrix
"cannot convert C output to _matrix"))))
(define-fun-syntax _matrix
(syntax-id-rules (i o io)
((_matrix i)
_matrix*)
((_matrix o rows cols)
(type: _pointer
pre: (let* ((n (* rows cols))
(p (malloc n _double 'atomic)))
(memset p 0 n _double)
p)
post: (p => (make-matrix p rows cols))))
((_matrix io)
(type: _pointer
bind: m
pre: (m => (matrix-ptr m))
post: m))
(_matrix _matrix*)))
(define (matrix-ptr-index m i j)
(+ i (* j (matrix-rows m))))
(define (matrix-ref m i j)
(ptr-ref (matrix-ptr m) _double (matrix-ptr-index m i j)))
(define (matrix-set! m i j elt)
(ptr-set! (matrix-ptr m) _double* (matrix-ptr-index m i j) elt))
(define (matrix-length m)
(* (matrix-cols m)
(matrix-rows m)))
(define matrix-copy
(get-ffi-obj 'cblas_dcopy *blas*
(_fun (m) ::
(_int = (matrix-length m))
(m : _matrix) (_int = 1)
(m-out : (_matrix o (matrix-rows m) (matrix-cols m))) (_int = 1) ->
_void ->
m-out)))
(define matrix-add
(get-ffi-obj 'cblas_daxpy *blas*
(_fun (m1 m2) ::
(_int = (matrix-length m1))
(_double = 1.0)
(_matrix = m1) (_int = 1)
(m-out : _matrix = (matrix-copy m2)) (_int = 1) ->
_void ->
m-out)))
(define matrix-sub
(get-ffi-obj 'cblas_daxpy *blas*
(_fun (m1 m2) ::
(_int = (matrix-length m1))
(_double = -1.0)
(_matrix = m2) (_int = 1)
(m-out : _matrix = (matrix-copy m1)) (_int = 1) ->
_void ->
m-out)))
(define matrix-scale
(get-ffi-obj 'cblas_dscal *blas*
(_fun (m s) ::
(_int = (matrix-length m))
(_double* = s)
(m-out : _matrix = (matrix-copy m)) (_int = 1) ->
_void ->
m-out)))
(define matrix-mul
(get-ffi-obj 'cblas_dgemm *blas*
(_fun (m1 m2) ::
(_cblas-order = 'col-major)
(_cblas-transpose = 'no-trans)
(_cblas-transpose = 'no-trans)
(_int = (matrix-rows m1))
(_int = (matrix-cols m2))
(_int = (matrix-rows m2))
(_double = 1.0)
(_matrix = m1)
(_int = (matrix-rows m1))
(_matrix = m2)
(_int = (matrix-rows m2))
(_double = 0.0)
(m-out : (_matrix o (matrix-rows m1) (matrix-cols m2)))
(_int = (matrix-rows m1)) ->
_void ->
m-out)))
(define matrix-f64vector-mul
(get-ffi-obj 'cblas_dgemv *blas*
(_fun (m v) ::
(_cblas-order = 'col-major)
(_cblas-transpose = 'no-trans)
(_int = (matrix-rows m))
(_int = (matrix-cols m))
(_double = 1.0)
(_matrix = m)
(_int = (matrix-rows m))
(_f64vector = v)
(_int = 1)
(_double = 0.0)
(v-out : (_f64vector o (matrix-rows m)))
(_int = 1) ->
_void ->
v-out)))
(define f64vector-matrix-mul
(get-ffi-obj 'cblas_dgemv *blas*
(_fun (v m) ::
(_cblas-order = 'col-major)
(_cblas-transpose = 'trans)
(_int = (matrix-rows m))
(_int = (matrix-cols m))
(_double = 1.0)
(_matrix = m)
(_int = (matrix-rows m))
(_f64vector = v)
(_int = 1)
(_double = 0.0)
(v-out : (_f64vector o (matrix-cols m)))
(_int = 1) ->
_void ->
v-out)))
(define matrix-lu-decomp
(get-ffi-obj 'dgetrf_ *lapack*
(_fun (m) ::
((_ptr i _int) = (matrix-rows m))
((_ptr i _int) = (matrix-cols m))
(m-out : _matrix = (matrix-copy m))
((_ptr i _int) = (matrix-rows m))
(ipiv : (_u32vector o (matrix-rows m)))
(_ptr o _int) ->
_void ->
(values m-out ipiv))))
(define dgetri-lwork
(get-ffi-obj 'dgetri_ *lapack*
(_fun (m ipiv) ::
(n : (_ptr i _int) = (matrix-rows m))
(_matrix = m)
((_ptr i _int) = n)
(_u32vector = ipiv)
(lwork : (_ptr o _double))
((_ptr i _int) = -1)
(res : (_ptr o _int)) ->
_void ->
(values (inexact->exact (round lwork)) res))))
(define dgetri/lwork
(get-ffi-obj 'dgetri_ *lapack*
(_fun (m ipiv lwork) ::
(n : (_ptr i _int) = (matrix-rows m))
(m-out : _matrix = (matrix-copy m))
((_ptr i _int) = n)
(_u32vector = ipiv)
(_f64vector o lwork)
((_ptr i _int) = lwork)
(_ptr o _int) ->
_void ->
m-out)))
(define (matrix-inverse m)
(let-values (((m-lu ipiv)
(matrix-lu-decomp m)))
(let-values (((lwork res)
(dgetri-lwork m-lu ipiv)))
(dgetri/lwork m-lu ipiv lwork))))
(define matrix-norm
(get-ffi-obj 'cblas_dnrm2 *blas*
(_fun (m) ::
(_int = (matrix-length m))
(_matrix = m)
(_int = 1) ->
_double)))
(define (matrix-transpose m)
(let ((r (matrix-rows m))
(c (matrix-cols m)))
(matrix-ec c r (:range j r) (:range i c) (matrix-ref m j i))))
(define matrix-solve-many
(get-ffi-obj 'dgesv_ *lapack*
(_fun (m b) ::
((_ptr i _int) = (matrix-rows m))
((_ptr i _int) = (matrix-cols b))
(_matrix = (matrix-copy m))
((_ptr i _int) = (matrix-rows m))
(_u32vector o (matrix-rows m))
(x : _matrix = (matrix-copy b))
((_ptr i _int) = (matrix-rows b))
(_ptr o _int) ->
_void ->
x)))
(define matrix-solve
(get-ffi-obj 'dgesv_ *lapack*
(_fun (m v) ::
((_ptr i _int) = (matrix-rows m))
((_ptr i _int) = 1)
(_matrix = (matrix-copy m))
((_ptr i _int) = (matrix-rows m))
(_u32vector o (matrix-rows m))
(x : _f64vector = (f64vector-of-length-ec (f64vector-length v) (:f64vector x v) x))
((_ptr i _int) = (f64vector-length v))
(_ptr o _int) ->
_void ->
x)))
(define (matrix-identity n)
(let ((m (my-make-matrix n n 0.0)))
(do-ec (:range i n) (matrix-set! m i i 1.0))
m))
(define-syntax matrix-ec
(syntax-rules ()
((matrix-ec rrows ccols etc ...)
(let ((rows rrows)
(cols ccols))
(apply my-matrix rows cols (list-ec etc ...))))))
(define-syntax :matrix
(syntax-rules (index)
((:matrix cc var arg)
(:matrix cc var (index i j) arg))
((:matrix cc var (index i j) arg)
(:do cc
(let ((m arg)
(rows #f)
(cols #f))
(set! rows (matrix-rows m))
(set! cols (matrix-cols m)))
((i 0) (j 0))
(< i rows)
(let ((i+1 (+ i 1))
(j+1 (+ j 1))
(wrapping? #f)
(var (matrix-ref m i j)))
(set! wrapping? (>= i+1 rows)))
#t
((if wrapping? 0 i+1)
(if wrapping? j+1 j))))))
(define ptr->matrix make-matrix)
(provide* (unsafe ptr->matrix))
(define-unsafer matrix-unsafe!))