(module inference-control mzscheme
(provide (all-defined))
(require "inference-environments.ss")
(require "bindings.ss")
(require "patterns.ss")
(require "facts.ss")
(require "rulesets.ss")
(require "matches.ss")
(require "assertions.ss")
(require (lib "list.ss" "srfi" "1"))
(define assert
(case-lambda
((fact reason)
(when (current-inference-trace)
(printf "assert: ~a~n" fact))
(let ((assertion (make-assertion fact reason)))
(hash-table-put! (current-inference-assertion-index)
(fact-first fact)
(cons assertion
(hash-table-get
(current-inference-assertion-index)
(fact-first fact) (lambda () '()))))
(for-each
(lambda (match-node)
(match-node-assert match-node assertion))
(hash-table-get (current-inference-data-index)
(fact-first fact) (lambda () '())))
assertion))
((fact)
(assert fact #t))))
(define (retract assertion)
(when (current-inference-trace)
(printf "retract: ~a~n" (assertion-fact assertion)))
(hash-table-put! (current-inference-assertion-index)
(fact-first (assertion-fact assertion))
(delete! assertion
(hash-table-get
(current-inference-assertion-index)
(fact-first (assertion-fact assertion)))))
(for-each
(lambda (match-node)
(match-node-retract match-node assertion))
(hash-table-get (current-inference-data-index)
(fact-first (assertion-fact assertion)))))
(define modify
(case-lambda
((assertion fact reason)
(retract assertion)
(assert fact reason))
((assertion fact)
(modify assertion fact #t))))
(define (check fact)
(when (current-inference-trace)
(printf "check: ~a~n" fact))
(let/ec return
(let ((assertion (make-assertion fact #f)))
(for-each
(lambda (match-node)
(let ((match (match-node-check match-node assertion return)))
(when (not (null? match))
(return match))))
(hash-table-get (current-inference-goal-index)
(fact-first fact) (lambda () '())))
'())))
(define (query pattern)
(let ((matches '()))
(for-each
(lambda (assertion)
(let* ((fact (assertion-fact assertion))
(bindings (pattern-unify fact pattern '())))
(if bindings
(set! matches
(cons (cons assertion bindings)
matches)))))
(hash-table-get (current-inference-assertion-index)
(pattern-first pattern) (lambda () '())))
matches))
(define-struct node
(successors matches)
(make-inspector))
(define-struct (match-node node)
(assertion-variable pattern match-constraint-predicate)
(make-inspector))
(define-struct (join-node node)
(left right join-constraint-predicate existential? match-counts)
(make-inspector))
(define-struct (rule-node node)
(rule join action)
(make-inspector))
(define (link-nodes predecessor successor)
(set-node-successors!
predecessor
(cons successor (node-successors predecessor))))
(define (add-match-to-node-matches match node)
(when (node-matches node)
(set-node-matches!
node (cons match (node-matches node)))))
(define (remove-match-from-node-matches match node)
(let/cc exit
(let loop ((previous #f)
(matches (node-matches node)))
(when (not (null? matches))
(when (eq? match (car matches))
(if previous
(set-cdr! previous (cdr matches))
(set-node-matches! node (cdr matches)))
(exit))
(loop matches (cdr matches))))))
(define (get-node-matches node bindings)
(let ((matches (node-matches node)))
(if matches
matches
(if (match-node? node)
(check (pattern-substitute
(match-node-pattern node) bindings))
#f))))
(define (activate ruleset)
(let ((initial-join-node
(make-join-node
'() '((())) #f #f #f #f '()))) (for-each
(lambda (rule)
(if (null? (rule-goals rule))
(activate-data-rule rule initial-join-node)
(activate-goal-rule rule)))
(ruleset-rules ruleset)))
(fix-goal-matches))
(define (activate-data-rule rule initial-join-node)
(let ((match-node #f)
(join-node #f)
(rule-node #f)
(previous-join-node initial-join-node)
(previous-variable-list '()))
(for-each
(lambda (clause)
(let ((existential? #f)
(assertion-variable #f)
(pattern #f)
(variable-list '()))
(cond ((and (pair? clause)
(variable? (car clause)))
(set! assertion-variable (car clause))
(set! pattern (caddr clause)))
((and (pair? clause)
(memq (car clause) '(no notany any notall all)))
(set! existential? (car clause))
(set! pattern (cadr clause)))
(else
(set! pattern clause)))
(set! variable-list
(merge-variable-lists
previous-variable-list
(if assertion-variable
(cons assertion-variable (pattern-variables pattern))
(pattern-variables pattern))))
(let ((match-constraints
(pattern-match-constraints
pattern (pattern-variables pattern))))
(set! match-node
(make-match-node
'() '() assertion-variable (pattern-base-pattern pattern) (if (null? match-constraints) #f
(eval `(lambda ,(pattern-variables pattern)
(and ,@match-constraints))))))
(hash-table-put! (current-inference-data-index)
(pattern-first pattern)
(cons match-node
(hash-table-get (current-inference-data-index)
(pattern-first pattern)
(lambda () '())))))
(let ((join-constraints
(pattern-join-constraints
pattern (pattern-variables pattern))))
(set! join-node
(make-join-node
'() (if (or (memq existential? '(no notany)) (eq? existential? 'all))
(node-matches previous-join-node)
'())
previous-join-node match-node (if (null? join-constraints) #f
(eval `(lambda ,variable-list
(and ,@join-constraints))))
existential? '()))) (link-nodes match-node join-node)
(link-nodes previous-join-node join-node)
(set! previous-join-node join-node)
(if (not existential?)
(set! previous-variable-list variable-list))))
(rule-preconditions rule))
(set! rule-node
(make-rule-node
'() (node-matches previous-join-node) rule join-node (if (not (rule-actions rule))
#f
(eval #`(lambda #,previous-variable-list
(and #,@(rule-actions rule)))))))
(link-nodes join-node rule-node)
(current-inference-rule-nodes
(append! (current-inference-rule-nodes)
(list rule-node))))
(void))
(define (activate-goal-rule rule)
(let ((match-node #f)
(join-node #f)
(rule-node #f)
(previous-node #f)
(previous-variable-list #f))
(let ((goal-pattern (car (rule-goals rule))))
(set! previous-node
(make-match-node
'() #f #f goal-pattern #f)) (hash-table-put! (current-inference-goal-index)
(car goal-pattern)
(cons previous-node
(hash-table-get (current-inference-goal-index)
(car goal-pattern)
(lambda () '()))))
(set! previous-variable-list
(pattern-variables goal-pattern)))
(for-each
(lambda (clause)
(let ((existential? #f)
(assertion-variable #f)
(pattern #f)
(variable-list '()))
(cond ((variable? (car clause))
(set! assertion-variable (car clause))
(set! pattern (caddr clause)))
((memq (car clause) '(no notany any notall all))
(set! existential? (car clause))
(set! pattern (cadr clause)))
(else
(set! pattern clause)))
(set! variable-list
(merge-variable-lists
previous-variable-list
(if assertion-variable
(cons assertion-variable (pattern-variables pattern))
(pattern-variables pattern))))
(let ((match-constraints
(pattern-match-constraints
pattern (pattern-variables pattern))))
(set! match-node
(make-match-node
'() '() assertion-variable (pattern-base-pattern pattern)
(if (null? match-constraints)
#f
(eval `(lambda ,(pattern-variables pattern)
(and ,@match-constraints))))))
(hash-table-put! (current-inference-data-index)
(car pattern)
(cons match-node
(hash-table-get (current-inference-data-index)
(car pattern)
(lambda () '())))))
(let ((join-constraints
(pattern-join-constraints
pattern (pattern-variables pattern))))
(set! join-node
(make-join-node
'() #f previous-node match-node (if (null? join-constraints)
#f
(eval `(lambda ,variable-list
,@join-constraints)))
existential? '()))) (link-nodes match-node join-node)
(link-nodes previous-node join-node)
(set! previous-node join-node)
(if (not existential?)
(set! previous-variable-list variable-list))))
(rule-preconditions rule))
(set! rule-node
(make-rule-node
'() #f rule previous-node
(if (not (rule-actions rule))
#f
(eval #`(lambda #,previous-variable-list
(and #,@(rule-actions rule)))))))
(link-nodes previous-node rule-node)
(current-inference-rule-nodes
(append! (current-inference-rule-nodes)
(list rule-node))))
(void))
(define (fix-goal-matches)
(hash-table-for-each (current-inference-data-index)
(lambda (key value)
(when (hash-table-get (current-inference-goal-index) key
(lambda () #f))
(for-each
(lambda (match-node)
(set-node-matches! match-node #f))
value)
(hash-table-remove! (current-inference-data-index) key)))))
(define (merge-variable-lists list1 list2)
(cond ((null? list2)
list1)
((memq (car list2) list1)
(merge-variable-lists list1 (cdr list2)))
(else
(merge-variable-lists
(append list1 (list (car list2))) (cdr list2)))))
(define (match-node-assert match-node assertion)
(let ((bindings (pattern-unify
(assertion-fact assertion)
(match-node-pattern match-node) '())))
(when (current-inference-trace)
(printf "match-node-assert: fact = ~a; pattern = ~a; bindings = ~a~n"
(assertion-fact assertion) (match-node-pattern match-node) bindings))
(when bindings
(when (match-node-assertion-variable match-node)
(set! bindings
(cons
(cons (match-node-assertion-variable match-node)
assertion)
bindings)))
(when (and (match-node-match-constraint-predicate match-node)
(not (apply (match-node-match-constraint-predicate match-node)
(bindings-data bindings))))
(set! bindings #f)))
(if bindings
(let ((match (cons (list assertion) bindings)))
(when (current-inference-trace)
(printf "match-node-assert: match = ~a~n" match))
(add-match-to-node-matches match match-node)
(for-each
(lambda (successor)
(propagate-match-from-match-node match successor))
(node-successors match-node)))
(for-each
(lambda (successor)
(propagate-nonmatch-from-match-node successor))
(node-successors match-node)))))
(define (match-node-retract match-node assertion)
(when (current-inference-trace)
(printf "match-node-retract: fact = ~a~n"
(assertion-fact assertion)))
(let/cc exit
(let loop ((previous #f)
(matches (node-matches match-node)))
(when (not (null? matches))
(let ((match (car matches)))
(when (eq? assertion (caar match))
(if previous
(set-cdr! previous (cdr matches))
(set-node-matches! match-node (cdr matches)))
(for-each
(lambda (successor)
(unpropagate-match-from-match-node match successor))
(node-successors match-node))
(exit)))
(loop matches (cdr matches))))
(for-each
(lambda (successor)
(unpropagate-nonmatch-from-match-node successor))
(node-successors match-node))))
(define (match-node-check match-node assertion continuation)
(let/ec exit
(let ((bindings (pattern-unify
(assertion-fact assertion)
(match-node-pattern match-node) '())))
(when (current-inference-trace)
(printf "match-node-check: fact = ~a; pattern = ~a; bindings = ~a~n"
(assertion-fact assertion) (match-node-pattern match-node) bindings))
(when bindings
(if (match-node-assertion-variable match-node)
(set! bindings
(cons
(cons (match-node-assertion-variable match-node)
assertion)
bindings)))
(when (and (match-node-match-constraint-predicate match-node)
(not (apply (match-node-match-constraint-predicate match-node)
(bindings-data bindings))))
(exit #f))
(let ((match (cons (list assertion) bindings)))
(when (current-inference-trace)
(printf "match-node-check: match = ~a~n" match))
(add-match-to-node-matches match match-node)
(for-each
(lambda (successor)
(propagate-match-from-join-node match successor continuation))
(node-successors match-node)))))
'()))
(define (propagate-match-from-match-node match join-node)
(when (node-matches join-node)
(for-each
(lambda (left-match)
(let ((joined-match (join left-match match join-node)))
(if (join-node-existential? join-node)
(let* ((count-association
(assq left-match (join-node-match-counts join-node)))
(count (if count-association
(cdr count-association)
0)))
(when joined-match
(set! count (+ count 1))
(if count-association
(set-cdr! count-association count)
(set-join-node-match-counts!
join-node (cons (cons left-match count)
(join-node-match-counts join-node)))))
(case (join-node-existential? join-node)
((no notany)
(when (and joined-match
(= count 1)) (unpropagate-match-to-successors left-match join-node)))
((any)
(when (and joined-match
(= count 1)) (propagate-match-to-successors left-match join-node #f)))
((notall)
(let ((n (length (hash-table-get
(current-inference-assertion-index)
(pattern-first
(match-node-pattern
(join-node-right join-node)))))))
(when (and (not joined-match)
(= count (- n 1)))
(propagate-match-to-successors left-match join-node #f))))
((all)
(let ((n (length (hash-table-get
(current-inference-assertion-index)
(pattern-first
(match-node-pattern
(join-node-right join-node)))))))
(when (and (not joined-match)
(= count (- n 1)))
(unpropagate-match-to-successors left-match join-node #f))))))
(when joined-match
(propagate-match-to-successors joined-match join-node #f)))))
(get-node-matches (join-node-left join-node) (car match)))))
(define (propagate-nonmatch-from-match-node join-node)
(when (node-matches join-node)
(for-each
(lambda (left-match)
(when (join-node-existential? join-node)
(let* ((count-association
(assq left-match (join-node-match-counts join-node)))
(count (if count-association
(cdr count-association)
0)))
(case (join-node-existential? join-node)
((notall)
(let ((n (length (hash-table-get
(current-inference-assertion-index)
(pattern-first
(match-node-pattern
(join-node-right join-node)))))))
(when (= count (- n 1))
(propagate-match-to-successors left-match join-node #f))))
((all)
(let ((n (length (hash-table-get
(current-inference-assertion-index)
(pattern-first
(match-node-pattern
(join-node-right join-node)))))))
(when (= count (- n 1))
(unpropagate-match-to-successors left-match join-node))))))))
(get-node-matches (join-node-left join-node) '()))))
(define (propagate-match-from-join-node match join-node continuation)
(when (current-inference-trace)
(printf "propagate-match-from-join-node: ~a~n" match))
(if (join-node-existential? join-node)
(let ((count 0))
(for-each
(lambda (right-match)
(let ((joined-match (join match right-match join-node)))
(when joined-match
(set! count (+ count 1)))))
(node-matches (join-node-right join-node)))
(set-join-node-match-counts!
join-node (cons (cons match count)
(join-node-match-counts join-node)))
(case (join-node-existential? join-node)
((no notany)
(when (= count 0)
(propagate-match-to-successors match join-node continuation)))
((any)
(when (> count 1)
(propagate-match-to-successors match join-node continuation)))
((notall)
(let ((n (length (hash-table-get
(current-inference-assertion-index)
(pattern-first
(match-node-pattern
(join-node-right join-node)))))))
(when (< count n)
(propagate-match-to-successors match join-node continuation))))
((all)
(let ((n (length (hash-table-get
(current-inference-assertion-index)
(pattern-first
(match-node-pattern
(join-node-right join-node)))))))
(when (= count n)
(propagate-match-to-successors match join-node continuation))))))
(for-each
(lambda (right-match)
(let ((joined-match (join match right-match join-node)))
(when joined-match
(propagate-match-to-successors joined-match join-node continuation))))
(get-node-matches (join-node-right join-node) (cdr match)))))
(define (propagate-match-to-successors match join-node continuation)
(add-match-to-node-matches match join-node)
(for-each
(lambda (successor)
(if (join-node? successor)
(propagate-match-from-join-node match successor continuation)
(propagate-match-to-rule match successor continuation)))
(node-successors join-node)))
(define (propagate-match-to-rule match rule-node continuation)
(when (current-inference-trace)
(printf "Match propagated to rule instance ~a with bindings ~a~n"
(rule-name (rule-node-rule rule-node)) (cdr match)))
(add-match-to-node-matches match rule-node)
(when (not (node-matches rule-node))
(when (current-inference-trace)
(printf "Executing rule instance ~a with bindings ~a.~n"
(rule-name (rule-node-rule rule-node))
(cdr match)))
(when (rule-node-action rule-node)
(apply (rule-node-action rule-node)
(bindings-data match)))
(continuation (list (cons (list (caar match)) (cdr match))))))
(define (unpropagate-match-from-match-node match join-node)
(if (join-node-existential? join-node)
(for-each
(lambda (left-match)
(let* ((count-association
(assq left-match (join-node-match-counts join-node)))
(count (cdr count-association))
(joined-match (join left-match match join-node)))
(when joined-match
(set! count (- count 1))
(set-cdr! count-association count))
(case (join-node-existential? join-node)
((no notany)
(when (and joined-match
(= count 0)) (propagate-match-to-successors left-match join-node #f)))
((any)
(when (and join-node
(= count 0)) (unpropagate-match-to-successors left-match join-node)))
((notall)
(let ((n (length (hash-table-get
(current-inference-assertion-index)
(pattern-first
(match-node-pattern
(join-node-right join-node)))))))
(when (and (not joined-match)
(= count n))
(unpropagate-match-to-successors left-match join-node))))
((all)
(let ((n (length (hash-table-get
(current-inference-assertion-index)
(pattern-first
(match-node-pattern
(join-node-right join-node)))))))
(when (and (not joined-match)
(= count n))
(propagate-match-to-successors left-match join-node #f)))))))
(node-matches (join-node-left join-node)))
(let ((assertion (caar match)))
(for-each
(lambda (match)
(when (eq? assertion
(last (car match)))
(unpropagate-match-to-successors match join-node)))
(node-matches join-node)))))
(define (unpropagate-nonmatch-from-match-node join-node)
(when (join-node-existential? join-node)
(for-each
(lambda (left-match)
(let* ((count-association
(assq left-match (join-node-match-counts join-node)))
(count (if count-association
(cdr count-association)
0)))
(case (join-node-existential? join-node)
((notall)
(let ((n (length (hash-table-get
(current-inference-assertion-index)
(pattern-first
(match-node-pattern
(join-node-right join-node)))))))
(when (= count n)
(unpropagate-match-to-successors left-match join-node))))
((all)
(let ((n (length (hash-table-get
(current-inference-assertion-index)
(pattern-first
(match-node-pattern
(join-node-right join-node)))))))
(when (= count n)
(propagate-match-to-successors left-match join-node #f)))))))
(node-matches (join-node-left join-node)))))
(define (unpropagate-match-from-join-node match join-node)
(let/ec exit
(let loop ((previous #f)
(alist (join-node-match-counts join-node)))
(when (not (null? alist))
(let ((association (car alist)))
(when (eq? match (car association))
(if previous
(set-cdr! previous (cdr alist))
(set-join-node-match-counts! join-node (cdr alist)))
(exit)))
(loop alist (cdr alist)))))
(for-each
(lambda (node-match)
(when (match-subset? match node-match)
(unpropagate-match-to-successors node-match join-node)))
(node-matches join-node)))
(define (unpropagate-match-to-successors match join-node)
(remove-match-from-node-matches match join-node)
(for-each
(lambda (successor)
(if (join-node? successor)
(unpropagate-match-from-join-node match successor)
(unpropagate-match-to-rule match successor)))
(node-successors join-node)))
(define (unpropagate-match-to-rule match rule-node)
(when (current-inference-trace)
(printf "Match unpropagated to rule instance ~a with bindings ~a~n"
(rule-name (rule-node-rule rule-node)) (cdr match)))
(remove-match-from-node-matches match rule-node))
(define (join left-match right-match join-node)
(let ((left-assertions (car left-match))
(left-bindings (cdr left-match))
(right-assertions (car right-match))
(right-bindings (cdr right-match)))
(let/cc return
(for-each
(lambda (right-binding)
(if (assq (car right-binding) left-bindings)
(if (not (equal? (cdr right-binding)
(cdr (assq (car right-binding) left-bindings))))
(return #f))))
right-bindings)
(let ((bindings left-bindings))
(for-each
(lambda (right-binding)
(if (not (assq (car right-binding) left-bindings))
(set! bindings (append bindings (list right-binding)))))
right-bindings)
(if (and (join-node-join-constraint-predicate join-node)
(not (apply (join-node-join-constraint-predicate join-node)
(bindings-data bindings))))
(return #f))
(cons (append left-assertions right-assertions) bindings)))))
(define (start-inference)
(assert '(start))
(let/cc exit
(current-inference-exit exit)
(let loop ()
(let/cc break
(for-each
(lambda (rule-node)
(if (not (null? (node-matches rule-node)))
(let* ((rule (rule-node-rule rule-node))
(match (car (node-matches rule-node)))
(arguments (bindings-data (cdr match))))
(when (current-inference-trace)
(printf "Executing rule instance ~a with bindings ~a~n"
(rule-name rule) (cdr match)))
(set-node-matches!
rule-node (cdr (node-matches rule-node)))
(apply (rule-node-action rule-node) arguments)
(loop))))
(current-inference-rule-nodes))
(exit)))))
(define stop-inference
(case-lambda
((return-value)
((current-inference-exit) return-value))
(()
((current-inference-exit)))))
(define (succeed)
(stop-inference #t))
(define (fail)
(stop-inference #f))
(define (print-rule-network)
(for-each
(lambda (rule-node)
(printf "----------~n")
(printf "Rule: ~a~n~n" (rule-name (rule-node-rule rule-node)))
(print-join-node (rule-node-join rule-node)))
(current-inference-rule-nodes)))
(define (print-join-node join-node)
(when (join-node-left join-node)
(if (join-node? (join-node-left join-node))
(print-join-node (join-node-left join-node))
(print-match-node (join-node-left join-node))))
(when (join-node-right join-node)
(print-match-node (join-node-right join-node)))
(printf "join node: existential? = ~a~n" (join-node-existential? join-node))
(printf "join-node: match-counts = ~a~n" (join-node-match-counts join-node))
(printf "join node: matches = ~a~n~n" (node-matches join-node)))
(define (print-match-node match-node)
(printf "match-node: pattern = ~a~n" (match-node-pattern match-node))
(printf "match-node: matches = ~a~n~n" (node-matches match-node)))
)