#lang scheme/unit
(require "vec-sig.ss"
"ring-sig.ss"
)
(import ring^)
(export vec^)
(define-struct mat_ (lst))
(define list->mat make-mat_)
(define mat->list mat_-lst)
(define mat? mat_?)
(define (transpose m) (list->mat (apply map list (mat->list m))))
(define (mat-map fn . ms)
(let ((ls (map mat->list ms)))
(list->mat
(apply map (lambda rs (apply map fn rs)) ls))))
(define (mat-add a b) (mat-map add a b))
(define (mat-mul a b)
(let ((a-rows (mat->list a))
(b-columns (mat->list (transpose b))))
(list->mat
(for/list ((r a-rows))
(for/list ((c b-columns))
(sum-list (map mul r c)))))))
(define (mat-gauss-jordan abs > m)
(define (ll-pivot ll)
(sort ll > #:key
(lambda (row) (abs (car row)))))
(define (ll-elim-column top-rows pivot-row bottom-rows)
(let ((inv-p (inv (car pivot-row))))
(define (/p x) (mul x inv-p)) (let ((1-row (map /p (cdr pivot-row))))
(define (->zero row)
(let ((head (car row)))
(let ((_h (neg head)))
(map (lambda (a b)
(add a (mul b _h)))
(cdr row)
1-row))))
(values
(map ->zero top-rows) 1-row
(map ->zero bottom-rows)))))
(define (ll-elim-all ll)
(let next ((r-top '())
(bottom ll))
(if (null? bottom) (reverse r-top)
(let ((bottom/p (ll-pivot bottom)))
(let-values (((t p b)
(ll-elim-column r-top
(car bottom/p)
(cdr bottom/p))))
(next (cons p t) b))))))
(list->mat (ll-elim-all (mat->list m))))
(define-syntax-rule (for/mat ((i j) (n m)) . body)
(list->mat
(for/list ((i (in-range 0 n)))
(for/list ((j (in-range 0 m)))
. body))))
(define (number->mat num)
(lambda (n)
(for/mat ((i j) (n n))
(if (= i j) num zero))))
(define (list->mat-diag lst)
(let ((n (length lst)))
(for/mat ((i j) (n n))
(if (= i j)
(list-ref lst i)
zero))))
(define mat-one (number->mat one))
(define mat-zero (number->mat zero))
(define (mat-cat-rows . ms)
(list->mat (apply append (map mat->list ms))))
(define (mat-cat-columns . ms)
(transpose (apply mat-cat-rows (map transpose ms))))
(define (mat-nb-rows m) (length (mat->list m)))
(define (mat-nb-columns m) (length (car (mat->list m))))
(define-struct vec_ (lst))
(define list->vec make-vec_)
(define (vec-dim vec) (length (vec_-lst vec)))
(define (vec-map fn . vs) (list->vec (apply map fn (map vec->list vs))))
(define vec? vec_?)
(define vec->list vec_-lst)
(define (vec->row vec) (list->mat (list (vec->list vec))))
(define (vec->column vec) (transpose (vec->row vec)))
(define (rows m) (map list->vector (mat->list m)))
(define (columns m) (rows (transpose m)))
(define (inner-product v1 v2)
(caar (mat->list (mat-mul (vec->row v1) (vec->column v2)))))