#lang typed/racket
(require "util.rkt")
(provide: [whirlpool (Bytes -> Exact-Nonnegative-Integer)])
(define-type Matrix Exact-Nonnegative-Integer)
(define rounds 10)
(define: (matrix-index [i : Byte] [j : Byte]) : Integer
(- 512
(+ (* 64 i) (* 8 j))
8))
(define: (matrix-ref [m : Matrix]
[i : Byte]
[j : Byte])
: Byte
(bitwise-and (arithmetic-shift m (- (matrix-index i j)))
#xff))
(define: (matrix-set [m : Matrix]
[i : Byte]
[j : Byte]
[v : Byte])
: Matrix
(bitwise-ior (bitwise-and m
(bitwise-not
(arithmetic-shift #xff (matrix-index i j))))
(arithmetic-shift v (matrix-index i j))))
(define-syntax first-2value
(syntax-rules ()
[(_ e) (let-values ([(x y) e]) x)]))
(define: (matrix-map [f : (Byte -> Byte)] [m : Matrix])
: Matrix
(make-matrix
(λ (i j)
(f (matrix-ref m i j)))))
(define: (make-matrix [proc : (Byte Byte -> Byte)]) : Matrix
(for*/fold: ([m : Matrix 0])
([i : Index (in-range 0 8)]
[j : Index (in-range 0 8)])
(bitwise-ior (proc (assert i byte?) (assert j byte?))
(arithmetic-shift m 8))))
(define: (print-matrix [m : Matrix]) : Void
(printf "~a" (format-matrix m)))
(define: (format-matrix [m : Matrix]) : String
(with-output-to-string
(λ ()
(for ([i (in-range 0 8)])
(printf "| ")
(for ([j (in-range 0 8)])
(let ([b (matrix-ref m (assert i byte?) (assert j byte?))])
(if (> b #xf)
(printf "~x " b)
(printf " ~x " b))))
(printf "|~%"))
(printf "~%"))))
(define: (reverse-minibox [box : (Vectorof Byte)])
: (Vectorof Byte)
(let: ([antibox : (Vectorof Byte) (make-vector 16 0)])
(for: ([i : Index (in-range 0 16)])
(vector-set! antibox (vector-ref box i) (assert i byte?)))
(vector->immutable-vector antibox)))
(define: C-box : (Vector Byte Byte Byte Byte Byte Byte Byte Byte)
#(#x1 #x1 #x4 #x1 #x8 #x5 #x2 #x9))
(define: S-box : (Vectorof Byte)
#(#x18 #x23 #xc6 #xE8 #x87 #xB8 #x01 #x4F #x36 #xA6 #xd2 #xF5 #x79 #x6F #x91 #x52 #x60 #xBc #x9B #x8E #xA3 #x0c #x7B #x35 #x1d #xE0 #xd7 #xc2 #x2E #x4B #xFE #x57 #x15 #x77 #x37 #xE5 #x9F #xF0 #x4A #xdA #x58 #xc9 #x29 #x0A #xB1 #xA0 #x6B #x85 #xBd #x5d #x10 #xF4 #xcB #x3E #x05 #x67 #xE4 #x27 #x41 #x8B #xA7 #x7d #x95 #xd8 #xFB #xEE #x7c #x66 #xdd #x17 #x47 #x9E #xcA #x2d #xBF #x07 #xAd #x5A #x83 #x33 #x63 #x02 #xAA #x71 #xc8 #x19 #x49 #xd9 #xF2 #xE3 #x5B #x88 #x9A #x26 #x32 #xB0 #xE9 #x0F #xd5 #x80 #xBE #xcd #x34 #x48 #xFF #x7A #x90 #x5F #x20 #x68 #x1A #xAE #xB4 #x54 #x93 #x22 #x64 #xF1 #x73 #x12 #x40 #x08 #xc3 #xEc #xdB #xA1 #x8d #x3d #x97 #x00 #xcF #x2B #x76 #x82 #xd6 #x1B #xB5 #xAF #x6A #x50 #x45 #xF3 #x30 #xEF #x3F #x55 #xA2 #xEA #x65 #xBA #x2F #xc0 #xdE #x1c #xFd #x4d #x92 #x75 #x06 #x8A #xB2 #xE6 #x0E #x1F #x62 #xd4 #xA8 #x96 #xF9 #xc5 #x25 #x59 #x84 #x72 #x39 #x4c #x5E #x78 #x38 #x8c #xd1 #xA5 #xE2 #x61 #xB3 #x21 #x9c #x1E #x43 #xc7 #xFc #x04 #x51 #x99 #x6d #x0d #xFA #xdF #x7E #x24 #x3B #xAB #xcE #x11 #x8F #x4E #xB7 #xEB #x3c #x81 #x94 #xF7 #xB9 #x13 #x2c #xd3 #xE7 #x6E #xc4 #x03 #x56 #x44 #x7F #xA9 #x2A #xBB #xc1 #x53 #xdc #x0B #x9d #x6c #x31 #x74 #xF6 #x46 #xAc #x89 #x14 #xE1 #x16 #x3A #x69 #x09 #x70 #xB6 #xd0 #xEd #xcc #x42 #x98 #xA4 #x28 #x5c #xF8 #x86))
(define: E-box : (Vectorof Byte)
#(#x1 #xB #x9 #xC #xD #x6 #xF #x3 #xE #x8 #x7 #x4 #xA #x2 #x5 #x0))
(define: R-box : (Vectorof Byte)
#(#x7 #xC #xB #xD #xE #x4 #x9 #xF #x6 #x3 #x8 #xA #x2 #x5 #x1 #x0))
(define: E-antibox : (Vectorof Byte)
(reverse-minibox E-box))
(define: R-antibox : (Vectorof Byte)
(reverse-minibox R-box))
(define: (C-ref [i : Byte] [j : Byte]) : Byte
(vector-ref C-box (modulo (+ j (- i)) 8)))
(define: C : Matrix
(make-matrix (λ (i j) (C-ref i j))))
(define: (γ [m : Matrix]) : Matrix
(matrix-map S m))
(define: (S [b : Byte]) : Byte
(vector-ref S-box b))
(define: gf2^8+ : (Byte Byte -> Byte)
bitwise-xor)
(define: (gf2^8* [a : Byte] [b : Byte]) : Byte
(let: loop : Byte ([a : Byte a]
[b : Byte b]
[p : Byte 0])
(if (or (zero? a) (zero? b))
p
(let ([a-shift (bitwise-and (arithmetic-shift a 1) #xff)])
(loop (if (bitwise-bit-set? a 7)
(bitwise-xor a-shift #b00011101)
a-shift)
(assert (arithmetic-shift b -1) byte?)
(if (bitwise-bit-set? b 0)
(bitwise-xor p a)
p))))))
(define: (θ [m : Matrix]) : Matrix
(make-matrix
(λ (i j)
(for/fold: ([sum : Byte 0])
([k : Index (in-range 0 8)])
(let ([k (assert k byte?)])
(gf2^8+ sum
(gf2^8* (matrix-ref m i k)
(matrix-ref C k j))))))))
(define: cr : (Vectorof Matrix)
(vector->immutable-vector
(ann
(list->vector
(ann (for/list: ([r : Integer (in-range 1 (add1 rounds))])
(make-matrix
(λ: ([i : Byte] [j : Byte])
(if (zero? i)
(S (assert (+ (* 8 (sub1 r)) j)
byte?))
0))))
(Listof Matrix)))
(Vectorof Matrix))))
(define: σ : (Matrix -> Matrix -> Matrix)
(curry bitwise-xor))
(define: (π [m : Matrix]) : Matrix
(make-matrix
(λ (i j)
(matrix-ref m
(modulo (- i j) 8)
j))))
(define: (ρ [k : Matrix]) : (Matrix -> Matrix)
(compose (σ k)
(compose θ
(compose π
γ))))
(define: (K [m : Matrix] [r : Exact-Nonnegative-Integer]) : Matrix
((ρ (vector-ref cr (sub1 r))) m))
(define: (W [m : Matrix]) : (Matrix -> Matrix)
((inst compose Matrix Matrix Matrix)
(let: loop : (Matrix -> Matrix)
([acc : (Matrix -> Matrix) identity]
[Kr : Matrix m]
[r : Exact-Nonnegative-Integer 0])
(if (>= r rounds)
acc
(let* ([next-r (add1 r)]
[next-Kr (K Kr next-r)])
(loop (compose (ρ next-Kr) acc)
next-Kr
next-r))))
(σ m)))
(define: (bytes->matrix [b : Bytes]) : Matrix
(for/fold: ([acc : Matrix 0])
([byte : Byte (in-bytes b)])
(+ (arithmetic-shift acc 8) byte)))
(define: (length->bytes [n : Exact-Nonnegative-Integer]) : Bytes
(let ([b (integer->bytes n 'big-endian)])
(bytes-append (make-bytes (- 32 (bytes-length b)) 0) b)))
(define: (pad-whirlpool-bytes [b : Bytes]) : Bytes
(let* ([missingno (modulo (- 32 (remainder (bytes-length b) 64))
64)]
[padding (cons #x80 (make-list (sub1 missingno) 0))]
[len (length->bytes (* 8 (bytes-length b)))])
(bytes-append b (list->bytes padding) len)))
(define: (bytes->message [b : Bytes]) : (Listof Matrix)
(let: ([pb : Bytes (pad-whirlpool-bytes b)])
(reverse
(for/fold: ([acc : (Listof Matrix) (list)])
([i : Exact-Nonnegative-Integer
(in-range 0 (quotient (bytes-length pb) 64))])
(cons (bytes->matrix (subbytes pb (* i 64) (* (add1 i) 64)))
acc)))))
(define: IV : Matrix
0)
(define: (H [message : (Listof Matrix)]) : Matrix
(for/fold: ([acc : Matrix IV])
([η : Matrix (in-list message)])
(bitwise-xor ((W acc) η)
acc
η)))
(define (whirlpool msg)
(H (bytes->message msg)))