#lang racket
(require (for-syntax syntax/parse)
racket/stxparam)
(provide network
prev
signal?
signal-*s
signal-+s
(struct-out network/s)
(contract-out [network-init (-> network/c procedure?)])
fixed-inputs
loop-ctr
simple-ctr
signal-samples
signal-nth)
(struct network/s (ins outs maker))
(define network/c (or/c network/s? procedure?))
(define (network-init signal)
(cond [(network/s? signal) ((network/s-maker signal))]
[(procedure? signal) signal]
[else (raise-argument-error
'signal-init
"network or procedure" signal)]))
(define-syntax-parameter prev (lambda (stx)
#'(error "can't use prev outside of a network definition")))
(define-syntax (network/inr stx)
(define-syntax-class network-clause
#:description "network/inr clause"
(pattern ((out:id ...) (node:expr input:expr ...))))
(syntax-parse stx
[(_ (in:id ...)
clause:network-clause ...+)
(define num-ins (length (syntax->list #'(in ...))))
(define lhses (syntax->list #'((clause.out ...) ...)))
(define num-outs (length (syntax->list (car (reverse lhses)))))
(define lhses/flattened (syntax->list #'(clause.out ... ...)))
(with-syntax
([(saved-val ...) (generate-temporaries lhses/flattened)]
[(signal-proc ...) (generate-temporaries #'(clause ...))]
[(lhs ...) lhses/flattened]
[last-out
(syntax-parse (car (reverse lhses))
[(out:id) #'out]
[(out:id out2:id ...+) #'(values out out2 ...)])])
(with-syntax
( [init-prev
#`(lambda (stx)
(syntax-parse stx
#:literals (prev)
[(_ id init) #'init]))]
[lookup-prev
#'(lambda (stx)
(syntax-case stx (lhs ...)
[(_ lhs init) #'saved-val]
...))]
[fun-body
#`(let*-values ([(clause.out ...) (signal-proc
clause.input ...)]
...)
(begin
(set! saved-val lhs)
...)
last-out)])
(with-syntax
([maker
#`(lambda ()
(define saved-val #f)
...
(define signal-proc (network-init clause.node))
...
(define (later-times-fun in ...)
(syntax-parameterize
([prev lookup-prev])
fun-body))
(define (first-time-fun in ...)
(set! first-time-fun later-times-fun)
(syntax-parameterize
([prev init-prev])
fun-body))
(lambda (in ...)
(first-time-fun in ...)))])
#`(network/s (quote #,num-ins)
(quote #,num-outs)
maker))))]))
(define-syntax (network stx)
(define-syntax-class oneormoreids
#:description "id or (id ...)"
(pattern out:id)
(pattern (outs:id ...)))
(define-syntax-class rhs
#:description "network clause rhs"
#:literals (prev)
(pattern (prev named:id init:expr))
(pattern (node:expr input:expr ...))
(pattern input:expr))
(define-syntax-class network-clause
#:description "network clause"
#:literals (prev)
(pattern (outs:oneormoreids rhs:rhs)))
(define (ensure-parens ids)
(syntax-parse ids
[out:id #'(out)]
[(outs:id ...) #'(outs ...)]))
(define (maybe-wrap rhs)
(syntax-parse rhs
#:literals (prev)
[(prev named:id init:expr) #'((lambda (x) x) (prev named init))]
[(node:expr input:expr ...) #'(node input ...)]
[node:expr #'((lambda (x) x) node)]))
(define (rewrite clause)
(syntax-parse clause
[(out1a:id (node:expr input:expr ...))
#'((out1a) (node input ...))]
[(out1b:id node:expr)
#'((out1b) ((lambda (x) x) node))]
[((out*a:id ...) (node:expr input:expr ...))
#'((out*a ...) (node input ...))]))
(syntax-parse stx
[(_ (in:id ...)
clause:network-clause ...+)
(with-syntax ([((outs ...) ...)
(map ensure-parens (syntax->list #'(clause.outs ...)))]
[(rhs ...)
(map maybe-wrap (syntax->list #'(clause.rhs ...)))])
#'(network/inr (in ...)
[(outs ...) rhs] ...))])
)
(define (signal? f)
(or (and (network/s? f) (= (network/s-ins f) 0))
(and (procedure? f) (procedure-arity-includes? f 0))))
(define-syntax (fixed-inputs stx)
(syntax-parse stx
[(_ net arg ...)
#'(network () [out (net arg ...)])]))
(define (signal-*s los)
(unless (andmap signal? los)
(raise-argument-error 'signal-*s "list of signals" 0 los))
(network/s 0 1 (lambda ()
(define sigfuns (map network-init los))
(lambda ()
(for/product ([fun sigfuns]) (fun))))))
(define (signal-+s los)
(unless (andmap signal? los)
(raise-argument-error 'signal-*s "list of signals" 0 los))
(network/s 0 1 (lambda ()
(define sigfuns (map network-init los))
(lambda ()
(for/sum ([fun sigfuns]) (fun))))))
(define (loop-ctr len skip)
(define (increment p)
(define next-p (+ p skip))
(cond [(< next-p len) next-p]
[else (- next-p len)]))
(network ()
[a (prev b 0)]
[b (increment a)]
[out a]))
(define (simple-ctr init skip)
(network ()
[a (prev b init)]
[b (+ skip a)]
[out a]))
(define (signal-samples signal n)
(define sigfun (network-init signal))
(for/vector ([i n]) (sigfun)))
(define (signal-nth signal n)
(define sigfun (network-init signal))
(for ([i n]) (sigfun))
(sigfun))