#lang racket
(require racket/async-channel
"util.rkt"
)
(provide (struct-out pool)
make-worker-pool
delete-worker-pool
with-worker-pool
add-job
get-results
)
(struct pool
(threads todo done ))
(define/contract (make-worker-pool num-threads)
(exact-positive-integer? . -> . pool?)
(define todo (make-async-channel))
(define done (make-async-channel))
(define (worker-thread)
(let loop ()
(let ([p (async-channel-get todo)])
(cond [(procedure? p)
(async-channel-put done (p))
(loop)]
[else
(error 'worker-thread "expected procedure, got" p)]))))
(pool (for/list ([n (in-range num-threads)])
(thread worker-thread))
todo
done))
(define/contract (delete-worker-pool p)
(pool? . -> . any)
(for ([t (in-list (pool-threads p))])
(kill-thread t)))
(define/contract (with-worker-pool num-threads proc)
(exact-positive-integer? (pool? . -> . any) . -> . any)
(define p (make-worker-pool num-threads))
(dynamic-wind (lambda () (void))
(lambda () (proc p))
(lambda () (delete-worker-pool p))))
(define/contract (add-job pool proc)
(pool? (-> any/c) . -> . any)
(async-channel-put (pool-todo pool) proc))
(define/contract (get-results p n)
(pool? exact-nonnegative-integer? . -> . any/c)
(let loop ([xs '()]
[n n])
(cond [(zero? n) xs]
[else (loop (cons (async-channel-get (pool-done p)) xs)
(sub1 n))])))
(module+ test
(require "run-suite.rkt")
(def/run-test-suite
(test-case
"worker pool"
(define (try num-threads num-jobs)
(define results
(with-worker-pool
num-threads
(lambda (pool)
(for ([i (in-range num-jobs)])
(add-job pool (lambda () (sleep (random)) i)))
(get-results pool num-jobs))))
(check-equal? (sort results <)
(for/list ([i (in-range num-jobs)])
i)))
(try 4 20)
(try 1 10)
(try 10 1)
)))