#lang racket
(require "coord.rkt")
(provide
matrix?
matrix-vals
m-lines
m-cols
m-frame
m-cs
m-n
m-rotation
m-scaling
m-translation
m-line
m-column
m-translation-c
m-neg
+m
-m
*m
m*p
m-transform
m-zero
m-identity
list<-matrix
vector-lines<-matrix
vector-cols<-matrix)
(define (v*v v1 v2)
(foldl + 0 (vector->list (vector-map * v1 v2))))
(define (matrix-write m port mode)
(define (print-line v)
(write-string
(string-join
(vector->list
(vector-map (λ (e) (format "~a" e)) v))
" ")
port))
(write-string "m(" port)
(print-line (vector-copy (matrix-vals m) 0 4))
(newline port)
(write-string " " port)
(print-line (vector-copy (matrix-vals m) 4 8))
(newline port)
(write-string " " port)
(print-line (vector-copy (matrix-vals m) 8 12))
(newline port)
(write-string " " port)
(print-line (vector-copy (matrix-vals m) 12 16))
(write-string ")" port))
(struct matrix (vals)
#:property prop:custom-write matrix-write)
(define (m-lines v1 v2 v3)
(let ((x1 (vector-ref v1 0))
(y1 (vector-ref v1 1))
(z1 (vector-ref v1 2))
(w1 (vector-ref v1 3)))
(let ((x2 (vector-ref v2 0))
(y2 (vector-ref v2 1))
(z2 (vector-ref v2 2))
(w2 (vector-ref v2 3)))
(let ((x3 (vector-ref v3 0))
(y3 (vector-ref v3 1))
(z3 (vector-ref v3 2))
(w3 (vector-ref v3 3)))
(matrix
(vector x1 y1 z1 w1
x2 y2 z2 w2
x3 y3 z3 w3
0 0 0 1))))))
(define m-cols
(case-lambda
((v1 v2 v3 v4)
(let ((x1 (vector-ref v1 0))
(y1 (vector-ref v1 1))
(z1 (vector-ref v1 2)))
(let ((x2 (vector-ref v2 0))
(y2 (vector-ref v2 1))
(z2 (vector-ref v2 2)))
(let ((x3 (vector-ref v3 0))
(y3 (vector-ref v3 1))
(z3 (vector-ref v3 2)))
(let ((x4 (vector-ref v4 0))
(y4 (vector-ref v4 1))
(z4 (vector-ref v4 2)))
(matrix
(vector x1 x2 x3 x4
y1 y2 y3 y4
z1 z2 z3 z4
0 0 0 1)))))))
((v)
(let ((x1 (vector-ref v 0))
(y1 (vector-ref v 1))
(z1 (vector-ref v 2)))
(let ((x2 (vector-ref v 4))
(y2 (vector-ref v 5))
(z2 (vector-ref v 6)))
(let ((x3 (vector-ref v 8))
(y3 (vector-ref v 9))
(z3 (vector-ref v 10)))
(let ((x4 (vector-ref v 12))
(y4 (vector-ref v 13))
(z4 (vector-ref v 14)))
(matrix
(vector x1 x2 x3 x4
y1 y2 y3 y4
z1 z2 z3 z4
0 0 0 1)))))))))
(define (m-frame c x y)
(let ((xc (norm x))
(zc (norm (cross-c x y))))
(let ((yc (norm (cross-c zc xc))))
(m-cols
(vector-of-coord xc)
(vector-of-coord yc)
(vector-of-coord zc)
(vector-of-coord c)))))
(define (m-cs c x y)
(m-frame c (-c x c) (-c y c)))
(define (m-n c n)
(let ((z (norm n)))
(let ((x (norm (collinear-cross-c z))))
(let ((y (norm (cross-c z x))))
(m-cols
(vector-of-coord x)
(vector-of-coord y)
(vector-of-coord z)
(vector-of-coord c))))))
(define (m-rotation a n)
(define (m-rotate-aux n1 n2 n3)
(let ((c (cos a))
(s (sin a))
(t (- 1 (cos a)))
(p (norm (xyz n1 n2 n3))))
(let ((x (xyz-x p))
(y (xyz-y p))
(z (xyz-z p)))
(let ((r00 (+ (* t (^2 x)) c))
(r01 (- (* t x y) (* s z)))
(r02 (+ (* t x z) (* s y)))
(r10 (+ (* t x y) (* s z)))
(r11 (+ (* t (^2 y)) c))
(r12 (- (* t y z) (* s x)))
(r20 (- (* t x z) (* s y))) (r21 (+ (* t y z) (* s x)))
(r22 (+ (* t (^2 z)) c)))
(matrix
(vector r00 r01 r02 0
r10 r11 r12 0
r20 r21 r22 0
0 0 0 1))))))
(if (or (= a 0) (eq-c n u0))
m-identity
(apply m-rotate-aux (list-of-coord n))))
(define m-scaling
(case-lambda
((x y z)
(matrix
(vector x 0 0 0
0 y 0 0
0 0 z 0
0 0 0 1)))
((c)
(m-scaling (xyz-x c) (xyz-y c) (xyz-z c)))))
(define m-translation
(case-lambda
((x y z)
(matrix
(vector 1 0 0 x
0 1 0 y
0 0 1 z
0 0 0 1)))
((c)
(m-translation (xyz-x c) (xyz-y c) (xyz-z c)))))
(define (m-line m l)
(vector-copy (matrix-vals m) (* l 4) (+ (* l 4) 4)))
(define (m-column m c)
(let ((vals (matrix-vals m)))
(vector
(vector-ref vals c)
(vector-ref vals (+ c 4))
(vector-ref vals (+ c 8))
(vector-ref vals (+ c 12)))))
(define (m-translation-c m)
(let ((vals (matrix-vals m))
(c 3))
(xyz
(vector-ref vals c)
(vector-ref vals (+ c 4))
(vector-ref vals (+ c 8)))))
(define (m-neg m)
(matrix
(vector-map - (matrix-vals m))))
(define (+m m1 m2)
(matrix
(vector-map + (matrix-vals m1) (matrix-vals m2))))
(define (-m m1 m2)
(+m m1 (m-neg m2)))
(define (*m m1 m2)
(matrix
(list->vector
(flatten
(for/list ((i 4))
(for/list ((j 4))
(v*v (m-line m1 i) (m-column m2 j))))))))
(define (m*p m p (vector? #f))
(let* ((w (if vector? 0 1))
(v (vector (xyz-x p) (xyz-y p) (xyz-z p) w)))
(xyz
(v*v (m-line m 0) v)
(v*v (m-line m 1) v)
(v*v (m-line m 2) v))))
(define (m-transform t a n s)
(*m
(*m
(m-scaling (xyz-x s) (xyz-y s) (xyz-z s))
(m-rotation a n))
(m-translation (xyz-x t) (xyz-y t) (xyz-z t))))
(define m-zero
(matrix
(vector 0 0 0 0
0 0 0 0
0 0 0 0
0 0 0 0)))
(define m-identity
(matrix
(vector 1 0 0 0
0 1 0 0
0 0 1 0
0 0 0 1)))
(define (list<-matrix m)
(vector->list (matrix-vals m)))
(define (vector-lines<-matrix m)
(vector-copy (matrix-vals m)))
(define (vector-cols<-matrix m)
(let ((v (matrix-vals m)))
(let ((x1 (vector-ref v 0))
(y1 (vector-ref v 1))
(z1 (vector-ref v 2))
(w1 (vector-ref v 3)))
(let ((x2 (vector-ref v 4))
(y2 (vector-ref v 5))
(z2 (vector-ref v 6))
(w2 (vector-ref v 7)))
(let ((x3 (vector-ref v 8))
(y3 (vector-ref v 9))
(z3 (vector-ref v 10))
(w3 (vector-ref v 11)))
(vector x1 x2 x3 0
y1 y2 y3 0
z1 z2 z3 0
w1 w2 w3 1))))))