(module random mzscheme
(require (lib "contract.ss")
(lib "etc.ss")
(lib "plt-match.ss")
(planet "combinators.ss" ("cce" "combinators.plt" 1 4))
(prefix schematics:
(planet "random.ss" ("schematics" "random.plt" 1 0))))
(require-for-syntax
(planet "syntax-utils.ss" ("cce" "syntax-utils.plt" 1 1)))
(define-struct generator (proc))
(define (gen/c T) (or/c generator? T))
(define nat/c natural-number/c)
(define pos-int/c (and/c integer? positive?))
(define prob/c (real-in 0 1))
(define weight/c (>/c 0))
(define fun/c (unconstrained-domain-> any/c))
(define gen-fun/c (unconstrained-domain-> (gen/c any/c)))
(define weighted-gens/c
(flat-rec-contract weighted-gens/c
null?
(cons/c weight/c (cons/c (gen/c any/c) weighted-gens/c))))
(define (integer-above/c lo)
(flat-named-contract (format "an integer >= ~s" lo)
(lambda (hi) (and (integer? hi) (>= hi lo)))))
(provide/contract
[default-generate-attempts (parameter/c pos-int/c)]
[generator? (-> any/c boolean?)]
[generate (opt-> [(gen/c any/c)] [(-> any/c any/c) pos-int/c] any/c)]
[nonrandom (-> any/c generator?)]
[choose-int-between (->r ([lo integer?]
[hi (integer-above/c lo)])
integer?)]
[random-int-between (->r ([lo integer?]
[hi (integer-above/c lo)])
generator?)]
[choose-size (opt-> [] [nat/c pos-int/c] nat/c)]
[random-size (opt-> [] [nat/c pos-int/c] generator?)]
[choose-boolean (opt-> [] [prob/c] boolean?)]
[random-boolean (opt-> [] [prob/c] generator?)]
[choose-char (opt-> [] [(gen/c nat/c)] char?)]
[random-char (opt-> [] [(gen/c nat/c)] generator?)]
[choose-group-of (opt-> [fun/c (gen/c any/c)] [(gen/c nat/c)] any/c)]
[random-group-of (opt-> [fun/c (gen/c any/c)] [(gen/c nat/c)] generator?)]
[choose-list-of (opt-> [(gen/c any/c)] [(gen/c nat/c)] list?)]
[random-list-of (opt-> [(gen/c any/c)] [(gen/c nat/c)] generator?)]
[choose-vector-of (opt-> [(gen/c any/c)] [(gen/c nat/c)] vector?)]
[random-vector-of (opt-> [(gen/c any/c)] [(gen/c nat/c)] generator?)]
[choose-string (opt-> [] [(gen/c char?) (gen/c nat/c)] string?)]
[random-string (opt-> [] [(gen/c char?) (gen/c nat/c)] generator?)]
[choose-bytes (opt-> [] [(gen/c byte?) (gen/c nat/c)] bytes?)]
[random-bytes (opt-> [] [(gen/c byte?) (gen/c nat/c)] generator?)]
[choose-apply (->* [fun/c] (listof (gen/c any/c)) [any/c])]
[random-apply (->* [fun/c] (listof (gen/c any/c)) [generator?])]
[choose-list (->* [] (listof (gen/c any/c)) [list?])]
[random-list (->* [] (listof (gen/c any/c)) [generator?])]
[choose-vector (->* [] (listof (gen/c any/c)) [vector?])]
[random-vector (->* [] (listof (gen/c any/c)) [generator?])]
[choose-symbol (opt-> [] [(gen/c string?)] symbol?)]
[random-symbol (opt-> [] [(gen/c string?)] generator?)]
[choose-uniform (->* [] (listof (gen/c any/c)) [any/c])]
[random-uniform (->* [] (listof (gen/c any/c)) [generator?])]
[choose-weighted (->* [] weighted-gens/c [any/c])]
[random-weighted (->* [] weighted-gens/c [generator?])]
[choose-weighted* (-> (listof (cons/c weight/c (gen/c any/c))) any/c)]
[random-weighted* (-> (listof (cons/c weight/c (gen/c any/c))) generator?)]
[choose-function (-> gen-fun/c fun/c)]
[random-function (-> gen-fun/c generator?)]
)
(provide
define-generator
let*-random
choose-recursive
random-recursive)
(define (make-distribution f)
(lambda args (make-generator (lambda () (apply f args)))))
(define default-generate-attempts (make-parameter 100))
(define (generate-one gen)
(if (generator? gen)
((generator-proc gen))
gen))
(define generate
(opt-lambda (gen [pred (constant #t)]
[max (default-generate-attempts)])
(let loop ([count max]
[tried null])
(if (< count 1)
(error 'generate
"could not find a satisfactory random value; tried:\n~s"
tried)
(let* ([v (generate-one gen)])
(if (pred v) v (loop (- count 1) (cons v tried))))))))
(define (nonrandom v)
(make-generator (constant v)))
(define (choose-int-between i j)
(let* ([lo (min i j)]
[hi (max i j)]
[diff (- hi lo)])
(+ (schematics:random-integer (+ diff 1)) lo)))
(define random-int-between (make-distribution choose-int-between))
(define choose-size
(opt-lambda ([minimum 0] [average 4])
(let* ([prob (/ 1 (+ average 1))])
(let loop ([base minimum])
(if (choose-boolean prob)
base
(loop (+ base 1)))))))
(define random-size (make-distribution choose-size))
(define choose-boolean
(opt-lambda ([prob 1/2])
(if (< (schematics:random-real) prob) #t #f)))
(define random-boolean (make-distribution choose-boolean))
(define choose-char
(opt-lambda ([code-gen (random-int-between
(char->integer #\A)
(char->integer #\Z))])
(integer->char (generate code-gen))))
(define random-char (make-distribution choose-char))
(define choose-group-of
(opt-lambda (make elem-gen [len-gen (random-size)])
(apply make (build-list (generate len-gen)
(lambda (i) (generate elem-gen))))))
(define random-group-of (make-distribution choose-group-of))
(define choose-list-of
(curry choose-group-of list))
(define random-list-of (make-distribution choose-list-of))
(define choose-vector-of
(curry choose-group-of vector))
(define random-vector-of (make-distribution choose-vector-of))
(define choose-string
(opt-lambda ([char-gen (random-char)] [len-gen (random-size)])
(choose-group-of string char-gen len-gen)))
(define random-string (make-distribution choose-string))
(define choose-bytes
(opt-lambda ([byte-gen (random-int-between 0 255)] [len-gen (random-size)])
(choose-group-of bytes byte-gen len-gen)))
(define random-bytes (make-distribution choose-bytes))
(define (choose-apply f . gens)
(apply f (map generate gens)))
(define random-apply (make-distribution choose-apply))
(define (choose-list . gens)
(map generate gens))
(define random-list (make-distribution choose-list))
(define (choose-vector . gens)
(apply vector (map generate gens)))
(define random-vector (make-distribution choose-vector))
(define choose-symbol
(opt-lambda ([name-gen (random-string (random-char) (random-size 1))])
(string->symbol (generate name-gen))))
(define random-symbol (make-distribution choose-symbol))
(define (choose-uniform . gens)
(let* ([v (apply vector gens)]
[c (vector-length v)])
(generate (vector-ref v (schematics:random-integer c)))))
(define random-uniform (make-distribution choose-uniform))
(define (cons-weights-and-gens alternation)
(match alternation
[(list) (list)]
[(list-rest weight gen rest)
(cons (cons weight gen) (cons-weights-and-gens rest))]))
(define (choose-weighted . weights-and-gens)
(choose-weighted* (cons-weights-and-gens weights-and-gens)))
(define random-weighted (make-distribution choose-weighted))
(define (choose-weighted* pairs)
(let* ([total (apply + (map car pairs))]
[choice (schematics:random-real)])
(let loop ([base 0]
[pairs pairs])
(match pairs
[(list (cons weight gen)) (generate gen)]
[(list-rest (cons weight gen) rest)
(let* ([sum (+ base weight)]
[prob (/ sum total)])
(if (<= choice prob)
(generate gen)
(loop sum rest)))]))))
(define random-weighted* (make-distribution choose-weighted*))
(define (choose-function f)
(let* ([table (make-hash-table 'equal)])
(lambda args
(hash-table-get
table args
(lambda ()
(let* ([result (generate (apply f args))])
(hash-table-put! table args result)
result))))))
(define random-function (make-distribution choose-function))
(define-syntax (random-recursive stx)
(syntax-case stx ()
[(r-r name [weight gen] ...)
(syntax/loc stx
(let* ([pairs null]
[name (make-generator (lambda () (choose-weighted* pairs)))])
(set! pairs (list (cons weight gen) ...))
name))]))
(define-syntax (choose-recursive stx)
(syntax-case stx ()
[(c-r name [weight gen] ...)
(syntax/loc stx
(generate (random-recursive name [weight gen] ...)))]))
(define-syntax (let*-random stx)
(syntax-case stx ()
[(lr ([var gen] . rest) . body)
(syntax/loc stx (lr ([var gen #t] . rest) . body))]
[(lr ([var gen pred] . rest) . body)
(syntax/loc stx
(lr ([var gen pred (default-generate-attempts)] . rest) . body))]
[(lr ([var gen pred count] . rest) . body)
(syntax/loc stx
(let* ([var (generate gen (lambda (var) pred) count)])
(lr rest . body)))]
[(lr () . body)
(syntax/loc stx (let* () . body))]))
(define-syntax (define-generator stx)
(syntax-case stx ()
[(dg (name arg ...) (weight gen) ...)
(syntax/loc stx
(define (name arg ...)
(make-generator
(lambda ()
(choose-weighted* (list (cons weight gen) ...))))))]))
)