#lang scheme/base
(require (for-syntax scheme/base) "stx.ss"
srfi/1)
(provide (all-defined-out))
(define (minus stx)
(syntax-case stx (-)
((- x) #'x)
(x #'(- x))))
(define (r/mul stx)
(define (sum a b) #`(+ #,(r/mul a) #,(r/mul b)))
(define sum? (op? +))
(define (prod . args)
(let ((args (map r/mul args)))
(let ((stx #`(* #,@args)))
(if (ormap sum? args)
(r/mul stx) stx))))
(syntax-case stx (* - +)
((* (+ a b) c) (sum #'(* a b) #'(* b c)))
((* a (+ b c)) (sum #'(* a b) #'(* a c)))
((* a b) (prod #'a #'b))
((+ a b) (sum #'a #'b))
((- a b) (sum #'a #'(- b)))
((- a) (prod #'-1 #'a))
(_ stx)))
(define (u/b stx)
(define (ub op . args) #`(#,op #,@(map u/b args)))
(syntax-case stx ()
((op a) (ub #'op #'a))
((op a b) (ub #'op #'a #'b))
((op a b ...) (ub #'op #'a #'(op b ...)))
(a #'a)
))
(define (flatten op? [sub (lambda (x) x)])
(lambda (stx)
(let flatten_ ((stx stx))
(if (op? stx)
(syntax-case stx ()
((op a b) (append (flatten_ #'a) (flatten_ #'b))))
(list (sub stx))))))
(define (flatten/* stx) #`(* #,@((flatten (op? *)) stx)))
(define (flatten/+ stx) #`(+ #,@((flatten (op? +) flatten/*) stx)))
(define (identifier-cons stx lst)
(if (identifier? stx) (cons stx lst) lst))
(define (term-variables stx)
(syntax-case stx (*)
((* . factors)
(foldl identifier-cons '() (syntax->list #'factors)))))
(define (term-variable-lset stx)
(apply lset-union bound-identifier=?
(map list (term-variables stx))))
(define (term-order stx) (length (term-variables stx)))
(define (sort-order sum-stx)
(syntax-case sum-stx (+)
((+ . terms)
#`(+ #,@(sort
(syntax->list #'terms)
> #:key term-order)))))
(define (sop stx)
(sort-order (flatten/+ (r/mul (u/b stx)))))
(define (nf stx)
(syntax-case stx (= < > <= >= +)
((< a b) (nf #'(> b a)))
((<= a b) (nf #'(>= b a)))
((= 0 a) (nf #'(= a 0)))
((= a 0) #`(constr:= #,(sop #'a)))
((> a 0) #`(constr:> #,(sop #'a)))
((>= a 0) #`(constr:>= #,(sop #'a)))
((op a b) (nf #'(op (- a b) 0)))
))
(define (forms->matrix_FIXME lst)
(define variables
(for/fold ((vars '())) ((stx lst))
(syntax-case stx (+)
((+ . terms)
(for/fold ((vars vars)) ((term (syntax->list #'terms)))
(lset-union
bound-identifier=?
vars (term-variable-lset term)))))))
variables)