(module discrete mzscheme
(define-values (struct:discrete
discrete-constructor
discrete?
discrete-field-ref
set-discrete-field!)
(make-struct-type 'discrete #f 5 0))
(require (lib "contract.ss"))
(provide/contract
(discrete?
(-> any/c boolean?))
(make-discrete
(-> (vectorof (>=/c 0.0)) discrete?))
(random-discrete
(case-> (-> random-source? discrete? natural-number/c)
(-> discrete? natural-number/c)))
(discrete-pdf
(-> discrete? integer? (real-in 0.0 1.0)))
(discrete-cdf
(-> discrete? integer? (real-in 0.0 1.0))))
(require "../random-source.ss")
(define discrete-n
(make-struct-field-accessor discrete-field-ref 0 'n))
(define discrete-a
(make-struct-field-accessor discrete-field-ref 1 'a))
(define discrete-f
(make-struct-field-accessor discrete-field-ref 2 'f))
(define discrete-p
(make-struct-field-accessor discrete-field-ref 3 'p))
(define discrete-c
(make-struct-field-accessor discrete-field-ref 4 'c))
(define (make-discrete w)
(let* ((n (vector-length w))
(a (make-vector n))
(f (make-vector n))
(sum 0.0)
(cumm 0.0)
(mean (/ 1.0 n))
(smalls '())
(bigs '())
(e (make-vector n))
(p (make-vector n))
(c (make-vector n)))
(do ((i 0 (+ i 1)))
((= i n) (void))
(let ((wi (vector-ref w i)))
(set! sum (+ sum wi))))
(do ((i 0 (+ i 1)))
((= i n) (void))
(let* ((wi (vector-ref w i))
(q (/ wi sum)))
(vector-set! e i q)
(vector-set! p i q)
(set! cumm (+ cumm q))
(vector-set! c i cumm)
(if (< q mean)
(set! smalls (cons i smalls))
(set! bigs (cons i bigs)))))
(let loop ()
(when (not (null? smalls))
(let ((s (car smalls)))
(set! smalls (cdr smalls))
(if (null? bigs)
(begin
(vector-set! a s s)
(vector-set! f s 1.0)
(loop))
(let ((b (car bigs)))
(set! bigs (cdr bigs))
(vector-set! a s b)
(vector-set! f s (* n (vector-ref e s)))
(let ((d (- mean (vector-ref e s))))
(vector-set! e s (+ (vector-ref e s) d))
(vector-set! e b (- (vector-ref e b) d)))
(cond ((< (vector-ref e b) mean)
(set! smalls (cons b smalls)))
((> (vector-ref e b) mean)
(set! bigs (cons b bigs)))
(else
(vector-set! a b b)
(vector-set! f b 1.0))))))
(loop)))
(let loop ()
(when (not (null? bigs))
(let ((b (car bigs)))
(set! bigs (cdr bigs))
(vector-set! a b b)
(vector-set! f b 1.0))
(loop)))
(do ((i 0 (+ i 1)))
((= i n) (void))
(vector-set! f i (/ (+ (vector-ref f i) i) n)))
(discrete-constructor n a f p c)))
(define random-discrete
(case-lambda
((r d)
(let* ((u (random-uniform r))
(c (inexact->exact (floor (* u (discrete-n d)))))
(f (vector-ref (discrete-f d) c)))
(if (= f 1.0)
c
(if (< u f)
c
(vector-ref (discrete-a d) c)))))
((d)
(random-discrete (current-random-source) d))))
(define (discrete-pdf d k)
(let* ((p (discrete-p d))
(n (vector-length p)))
(if (or (< k 0)
(>= k n))
0.0
(vector-ref p k))))
(define (discrete-cdf d k)
(let* ((c (discrete-c d))
(n (vector-length c)))
(cond ((< k 0)
0.0)
((> k n)
1.0)
(else
(vector-ref c k)))))
)