#lang racket
(require "data-structures.rkt")
(require "smt-interface.rkt")
(require "debug.rkt")
(require rackunit)
(provide (all-defined-out))
(define (bcp-clause smt clause)
(if (and (literal-eq? (clause-watched1 clause) (clause-watched2 clause))
(literal-unassigned? (clause-watched1 clause)))
(begin (propagate-assignment smt (clause-watched1 clause) clause))
smt))
(define (initial-bcp smt)
(let ((clauses (SMT-clauses smt)))
(let recur ((smt smt)
(idx 0))
(if (= idx (vector-length clauses))
smt
(recur (bcp-clause smt (vector-ref clauses idx))
(+ 1 idx))))))
(define (learned-bcp smt)
(let ((learned (car (SMT-learned-clauses smt))))
(propagate-assignment smt (clause-watched1 learned) learned)))
(define (propagate-assignment smt lit clause)
(let ((smt (SMT-satisfy-literal! smt lit)))
(set-literal-igraph-node! lit
(node (SMT-decision-level smt) clause))
(let ((smt (propagate-T-implications smt lit))
(falsify (negate-literal lit)))
(let prop-watch ((watchlist (literal-watched falsify))
(smt smt))
(cond [(null? watchlist) smt] [else
(let ((smt (update-watchedness smt (first watchlist) falsify)))
(prop-watch (rest watchlist) smt))])))))
(define (propagate-T-implications smt lit)
(let-values ([(t-state lits)
((T-Propagate) (SMT-T-State smt) (SMT-strength smt) (literal->dimacs lit))])
(let t-propagate ((lits lits)
(smt (new-T-State smt t-state)))
(if (empty? lits) smt
(t-propagate
(cdr lits)
(propagate-assignment
smt
((dimacs-lit->literal (SMT-variables smt)) (car lits))
(lambda (smt) (clause-literals
((dimacs-lits->clause (SMT-variables smt))
((T-Explain) (SMT-T-State smt) (SMT-strength smt) (car lits)))))))))))
(define (backjump! smt absolute-level)
(SMT-slash-all-literals! smt) (let obliterate-loop ((levels (- (SMT-decision-level smt)
absolute-level))
(pa (SMT-partial-assignment smt))
(total-vars-obliterated 0))
(cond [(not (zero? levels)) (begin (for-each obliterate! (first pa))
(obliterate-loop (+ -1 levels) (rest pa)
(+ (length (first pa))
total-vars-obliterated)))]
[else
(let* ((smt (SMT-set-decision-level smt absolute-level))
(smt (SMT-set-partial-assignment smt pa))
(smt (new-T-State smt ((T-Backjump) (SMT-T-State smt) total-vars-obliterated))))
(learned-bcp smt))])))
(define (obliterate! lit)
(let ((var (literal-var lit)))
(begin (set-var-value! var 'unassigned)
(set-var-igraph-node! var #f)
(set-var-timestamp! var #f))))
(define (update-watchedness smt clause decided-literal)
(let ((caseval (bcp-4cases 0 clause #f #f decided-literal)))
(match caseval
['skip smt] ['contradiction
(if (= 0 (SMT-decision-level smt)) (raise (unsat-exn smt))
(resolve-conflict! smt clause))]
[unit-literal (propagate-assignment smt unit-literal clause)])))
(define (choose-latest-literal literals)
(let find-recent ((idx 1)
(candidate (vector-ref literals 0)))
(if (= idx (vector-length literals))
candidate
(let* ((nthlit (vector-ref literals idx)))
(if (and ((literal-timestamp nthlit) . > . (literal-timestamp candidate)))
(find-recent (+ 1 idx) nthlit)
(find-recent (+ 1 idx) candidate))))))
(define (resolve-conflict! smt C)
(let ((clause-lits (lemma->lits smt C)))
(increase-scores! clause-lits) (let ((literals-to-learn
(let first-uip ((resolvent clause-lits))
(cond [(asserting-literals? smt resolvent)
resolvent]
[else (let* ((resolve-lit (choose-latest-literal resolvent))
(resolve-against (literal-explanation smt resolve-lit)))
(increase-scores! resolve-against) (first-uip (resolve-on-lit resolvent resolve-against resolve-lit)))]))))
(let* ((level-to-backjump-to (asserting-level smt literals-to-learn))
(watch1 (choose-latest-literal literals-to-learn))
(watch2 (if (literal-eq? (vector-ref literals-to-learn 0) watch1)
(vector-ref literals-to-learn (+ -1 (vector-length literals-to-learn)))
(vector-ref literals-to-learn 0)))
(learned (clause literals-to-learn watch1 watch2)))
(begin (add-literal-watched! learned watch1)
(add-literal-watched! learned watch2)
(let ((smt (SMT-learn-clause smt learned)))
(if (0 . > . level-to-backjump-to)
(raise (unsat-exn smt)) (raise (bail-exn (backjump! smt level-to-backjump-to))))))))))
(define (asserting-literals? smt lits)
(not
(let not-asserting? ((idx 0)
(found-one? #f))
(and (not (= idx (vector-length lits))) (let* ((nthlit (vector-ref lits idx))
(dec-eq? (= (SMT-decision-level smt)
(literal-dec-lev nthlit))))
(or (and found-one? dec-eq?) (not-asserting? (+ 1 idx)
(or dec-eq? found-one?))))))))
(define (asserting-level smt lits)
(let recur ([idx 0] [candidate -1] [all-same-level? #t])
(if (= idx (vector-length lits))
(if all-same-level?
(+ -1 (literal-dec-lev (vector-ref lits 0)))
candidate)
(let* ((this-declev (literal-dec-lev (vector-ref lits idx)))
(same-level-as-last? (or (0 . > . candidate)
(= this-declev candidate)))
(all-same-level?* (and all-same-level?
same-level-as-last?)))
(if (and (this-declev . > . candidate)
(not (= this-declev (SMT-decision-level smt))))
(recur (+ 1 idx) this-declev all-same-level?*)
(recur (+ 1 idx) candidate all-same-level?*))))))
(define (resolve-on-lit C D res-lit)
(list->vector (list-union (remove* (list (negate-literal res-lit)) (remove* (list res-lit) (vector->list C)))
(remove* (list (negate-literal res-lit)) (remove* (list res-lit) (vector->list D)))
literal-eq?)))
(define (bcp-4cases idx clause nonfalse-literal multiple p)
(cond [(= idx (clause-size clause))
(if multiple (begin (rem-literal-watched! clause p)
(add-literal-watched! clause nonfalse-literal) (clause-watched-swap! clause p nonfalse-literal)
'skip) (if nonfalse-literal
(if (literal-unassigned? nonfalse-literal)
nonfalse-literal 'skip) 'contradiction))] [else
(let* ((literal (nth-literal clause idx))
(litval (literal-valuation literal)))
(if (or (literal-eq? literal p) (false? litval)) (bcp-4cases (+ 1 idx) clause nonfalse-literal multiple p)
(if (and (literal-eq? (clause-other-watched clause p)
literal) nonfalse-literal) (bcp-4cases (+ 1 idx) clause nonfalse-literal #t p)
(bcp-4cases (+ 1 idx) clause literal nonfalse-literal p))))]))
(define (memberf a B [proc equal?])
(and (not (empty? B))
(or (proc a (car B))
(memberf a (cdr B) proc))))
(define (list-union A B [proc equal?])
(cond [(empty? A) B]
[(memberf (first A) B proc)
(list-union (rest A) B proc)]
[else (cons (first A)
(list-union (rest A) B proc))]))
(check equal?
(let ((smt (initialize (list 5 5 '((-1 2)
(-1 3)
(-2 4)
(-3 -4)
(1 -3 5)))
#f
0)))
(equal? (bcp-4cases 0 (vector-ref (SMT-clauses smt) 0) #f #f (literal (vector-ref (SMT-variables smt) 0)
#f))
(vector-ref (clause-literals (vector-ref (SMT-clauses smt) 0)) 1))) #t)
(check equal?
(let ((smt (initialize (list 5 5 '((-1 2)
(-1 3)
(-2 4)
(-3 -4)
(1 -3 5)))
#f
0)))
(bcp-4cases 0 (vector-ref (SMT-clauses smt) 4) #f #f (literal (vector-ref (SMT-variables smt) 0)
#t))) 'skip)