#lang racket
(require racket/async-channel
         "util.rkt"
         )
(provide (struct-out pool)
         make-worker-pool
         delete-worker-pool
         with-worker-pool
         add-job
         get-results
         )
(struct pool
        (threads           todo              done              ))
(define/contract (make-worker-pool num-threads)
  (exact-positive-integer? . -> . pool?)
  (define todo (make-async-channel))
  (define done (make-async-channel))
  (define (worker-thread)
    (let loop ()
      (let ([p (async-channel-get todo)])
        (cond [(procedure? p)
               (async-channel-put done (p))
               (loop)]
              [else
               (error 'worker-thread "expected procedure, got" p)]))))
  (pool (for/list ([n (in-range num-threads)])
          (thread worker-thread))
        todo
        done))
(define/contract (delete-worker-pool p)
  (pool? . -> . any)
  (for ([t (in-list (pool-threads p))])
    (kill-thread t)))
(define/contract (with-worker-pool num-threads proc)
  (exact-positive-integer? (pool? . -> . any) . -> . any)
  (define p (make-worker-pool num-threads))
  (dynamic-wind (lambda () (void))
                (lambda () (proc p))
                (lambda () (delete-worker-pool p))))
(define/contract (add-job pool proc)
  (pool? (-> any/c) . -> . any)
  (async-channel-put (pool-todo pool) proc))
(define/contract (get-results p n)
  (pool? exact-nonnegative-integer? . -> . any/c)
  (let loop ([xs '()]
             [n n])
        (cond [(zero? n) xs]
          [else (loop (cons (async-channel-get (pool-done p)) xs)
                      (sub1 n))])))
(module+ test
  (require "run-suite.rkt")
  (def/run-test-suite
    (test-case
     "worker pool"
     (define (try num-threads num-jobs)
       (define results
         (with-worker-pool
          num-threads
          (lambda (pool)
            (for ([i (in-range num-jobs)])
              (add-job pool (lambda () (sleep (random)) i)))
            (get-results pool num-jobs))))
       (check-equal? (sort results <)
                     (for/list ([i (in-range num-jobs)])
                       i)))
          (try 4 20)
     (try 1 10)
     (try 10 1)
     )))