changeset 1594:e58bcfa66093

cleanups and a few fixed cases
author drewp@bigasterisk.com
date Sun, 05 Sep 2021 01:15:55 -0700
parents b0df43d5494c
children 413a280828bf
files service/mqtt_to_rdf/inference.py service/mqtt_to_rdf/inference_test.py
diffstat 2 files changed, 165 insertions(+), 148 deletions(-) [+]
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py	Sat Sep 04 23:23:55 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Sun Sep 05 01:15:55 2021 -0700
@@ -2,17 +2,15 @@
 copied from reasoning 2021-08-29. probably same api. should
 be able to lib/ this out
 """
-from collections import defaultdict
 import itertools
 import logging
-from dataclasses import dataclass
+from collections import defaultdict
+from dataclasses import dataclass, field
 from decimal import Decimal
 from typing import Dict, Iterator, List, Set, Tuple, Union, cast
-from urllib.request import OpenerDirector
 
 from prometheus_client import Summary
-from rdflib import BNode, Graph, Literal, Namespace, URIRef, RDF
-from rdflib.collection import Collection
+from rdflib import RDF, BNode, Graph, Literal, Namespace, URIRef
 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
 from rdflib.term import Node, Variable
 
@@ -22,6 +20,7 @@
 Triple = Tuple[Node, Node, Node]
 Rule = Tuple[Graph, Node, Graph]
 BindableTerm = Union[Variable, BNode]
+ReadOnlyWorkingSet = ReadOnlyGraphAggregate
 
 READ_RULES_CALLS = Summary('read_rules_calls', 'calls')
 
@@ -30,23 +29,13 @@
 MATH = Namespace('http://www.w3.org/2000/10/swap/math#')
 
 
-@dataclass
-class _RuleMatch:
-    """one way that a rule can match the working set"""
-    vars: Dict[Variable, Node]
+class EvaluationFailed(ValueError):
+    """e.g. we were given (5 math:greaterThan 6)"""
 
 
-ReadOnlyWorkingSet = ReadOnlyGraphAggregate
-
-filterFuncs = {
-    MATH['greaterThan'],
-}
-
-
+@dataclass
 class CandidateBinding:
-
-    def __init__(self, binding: Dict[BindableTerm, Node]):
-        self.binding = binding  # mutable!
+    binding: Dict[BindableTerm, Node]
 
     def __repr__(self):
         b = " ".join("%s=%s" % (k, v) for k, v in sorted(self.binding.items()))
@@ -54,39 +43,48 @@
 
     def apply(self, g: Graph) -> Iterator[Triple]:
         for stmt in g:
-            stmt = list(stmt)
-            for i, term in enumerate(stmt):
-                if isinstance(term, (Variable, BNode)):
-                    if term in self.binding:
-                        stmt[i] = self.binding[term]
-            else:
-                yield cast(Triple, stmt)
+            yield (self._applyTerm(stmt[0]), self._applyTerm(stmt[1]), self._applyTerm(stmt[2]))
 
-    def applyFunctions(self, lhs):
+    def _applyTerm(self, term: Node):
+        if isinstance(term, (Variable, BNode)):
+            if term in self.binding:
+                return self.binding[term]
+        return term
+
+    def applyFunctions(self, lhs) -> Graph:
         """may grow the binding with some results"""
         usedByFuncs = Graph()
         while True:
-            before = len(self.binding)
-            delta = 0
-            for ev in Evaluation.findEvals(lhs):
-                log.debug(f'{INDENT*3} found Evaluation')
-
-                newBindings, usedGraph = ev.resultBindings(self.binding)
-                usedByFuncs += usedGraph
-                for k, v in newBindings.items():
-                    if k in self.binding and self.binding[k] != v:
-                        raise ValueError(
-                            f'conflict- thought {k} would be {self.binding[k]} but another Evaluation said it should be {v}')
-                    self.binding[k] = v
-                delta = len(self.binding) - before
-                log.debug(f'{INDENT*4} rule {graphDump(usedGraph)} made {delta} new bindings')
+            delta = self._applyFunctionsIteration(lhs, usedByFuncs)
             if delta == 0:
                 break
         return usedByFuncs
 
+    def _applyFunctionsIteration(self, lhs, usedByFuncs: Graph):
+        before = len(self.binding)
+        delta = 0
+        for ev in Evaluation.findEvals(lhs):
+            log.debug(f'{INDENT*3} found Evaluation')
+
+            newBindings, usedGraph = ev.resultBindings(self.binding)
+            usedByFuncs += usedGraph
+            self._addNewBindings(newBindings)
+            delta = len(self.binding) - before
+            dump = "(...)"
+            if log.isEnabledFor(logging.DEBUG) and cast(int, usedGraph.__len__()) < 20:
+                dump = graphDump(usedGraph)
+            log.debug(f'{INDENT*4} rule {dump} made {delta} new bindings')
+        return delta
+
+    def _addNewBindings(self, newBindings):
+        for k, v in newBindings.items():
+            if k in self.binding and self.binding[k] != v:
+                raise ValueError(f'conflict- thought {k} would be {self.binding[k]} but another Evaluation said it should be {v}')
+            self.binding[k] = v
+
     def verify(self, lhs: 'Lhs', workingSet: ReadOnlyWorkingSet, usedByFuncs: Graph) -> bool:
         """Can this lhs be true all at once in workingSet? Does it match with these bindings?"""
-        boundLhs = list(self.apply(lhs._g))
+        boundLhs = list(self.apply(lhs.graph))
         boundUsedByFuncs = list(self.apply(usedByFuncs))
 
         self.logVerifyBanner(boundLhs, workingSet, boundUsedByFuncs)
@@ -94,11 +92,7 @@
         for stmt in boundLhs:
             log.debug(f'{INDENT*4} check for {stmt}')
 
-            if stmt[1] in filterFuncs:
-                if not mathTest(*stmt):
-                    log.debug(f'{INDENT*5} binding was invalid because {stmt}) is not true')
-                    return False
-            elif stmt in boundUsedByFuncs:
+            if stmt in boundUsedByFuncs:
                 pass
             elif stmt in workingSet:
                 pass
@@ -125,26 +119,32 @@
         log.debug(f'{INDENT*4}\\')
 
 
+@dataclass
 class Lhs:
+    graph: Graph
+
+    staticRuleStmts: Graph = field(default_factory=Graph)
+    lhsBindables: Set[BindableTerm] = field(default_factory=set)
+    lhsBnodes: Set[BNode] = field(default_factory=set)
 
-    def __init__(self, existingGraph):
-        self._g = existingGraph
+    def __post_init__(self):
+        for ruleStmt in self.graph:
+            varsAndBnodesInStmt = [term for term in ruleStmt if isinstance(term, (Variable, BNode))]
+            self.lhsBindables.update(varsAndBnodesInStmt)
+            self.lhsBnodes.update(x for x in varsAndBnodesInStmt if isinstance(x, BNode))
+            if not varsAndBnodesInStmt:
+                self.staticRuleStmts.add(ruleStmt)
 
     def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet) -> Iterator[CandidateBinding]:
         """bindings that fit the LHS of a rule, using statements from workingSet and functions
         from LHS"""
-        nodesToBind = self.nodesToBind()
-        log.debug(f'{INDENT*2} nodesToBind: {nodesToBind}')
+        log.debug(f'{INDENT*2} nodesToBind: {self.lhsBindables}')
 
         if not self.allStaticStatementsMatch(workingSet):
             return
 
         candidateTermMatches: Dict[BindableTerm, Set[Node]] = self.allCandidateTermMatches(workingSet)
 
-        # for n in nodesToBind:
-        #     if n not in candidateTermMatches:
-        #         candidateTermMatches[n] = set()
-
         orderedVars, orderedValueSets = organize(candidateTermMatches)
 
         self.logCandidates(orderedVars, orderedValueSets)
@@ -156,35 +156,18 @@
             log.debug('')
             log.debug(f'{INDENT*3}*trying {binding}')
 
-            usedByFuncs = binding.applyFunctions(self)
+            try:
+                usedByFuncs = binding.applyFunctions(self)
+            except EvaluationFailed:
+                continue
 
             if not binding.verify(self, workingSet, usedByFuncs):
                 log.debug(f'{INDENT*3} this binding did not verify')
                 continue
             yield binding
 
-    def nodesToBind(self) -> List[BindableTerm]:
-        nodes: Set[BindableTerm] = set()
-        staticRuleStmts = Graph()
-        for ruleStmt in self._g:
-            varsInStmt = [v for v in ruleStmt if isinstance(v, (Variable, BNode))]
-            nodes.update(varsInStmt)
-            if (not varsInStmt  # ok
-                    #and not any(isinstance(t, BNode) for t in ruleStmt)  # approx
-               ):
-                staticRuleStmts.add(ruleStmt)
-        return sorted(nodes)
-
     def allStaticStatementsMatch(self, workingSet: ReadOnlyWorkingSet) -> bool:
-        staticRuleStmts = Graph()
-        for ruleStmt in self._g:
-            varsInStmt = [v for v in ruleStmt if isinstance(v, (Variable, BNode))]
-            if (not varsInStmt  # ok
-                    #and not any(isinstance(t, BNode) for t in ruleStmt)  # approx
-               ):
-                staticRuleStmts.add(ruleStmt)
-
-        for ruleStmt in staticRuleStmts:
+        for ruleStmt in self.staticRuleStmts:
             if ruleStmt not in workingSet:
                 log.debug(f'{INDENT*3} {ruleStmt} not in working set- skip rule')
                 return False
@@ -194,35 +177,43 @@
         """the total set of terms each variable could possibly match"""
 
         candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set)
-        lhsBnodes: Set[BNode] = set()
-        for lhsStmt in self._g:
+        for lhsStmt in self.graph:
             log.debug(f'{INDENT*3} possibles for this lhs stmt: {lhsStmt}')
             for i, trueStmt in enumerate(sorted(workingSet)):
                 log.debug(f'{INDENT*4} consider this true stmt ({i}): {trueStmt}')
-                bindingsFromStatement: Dict[Variable, Set[Node]] = {}
-                for lhsTerm, trueTerm in zip(lhsStmt, trueStmt):
-                    if isinstance(lhsTerm, BNode):
-                        lhsBnodes.add(lhsTerm)
-                    elif isinstance(lhsTerm, Variable):
-                        bindingsFromStatement.setdefault(lhsTerm, set()).add(trueTerm)
-                    elif lhsTerm != trueTerm:
-                        break
-                else:
-                    for v, vals in bindingsFromStatement.items():
-                        candidateTermMatches[v].update(vals)
 
-        for trueStmt in itertools.chain(workingSet, self._g):
-            for b in lhsBnodes:
+                for v, vals in self._bindingsFromStatement(lhsStmt, trueStmt):
+                    candidateTermMatches[v].update(vals)
+
+        for trueStmt in itertools.chain(workingSet, self.graph):
+            for b in self.lhsBnodes:
                 for t in [trueStmt[0], trueStmt[2]]:
                     if isinstance(t, (URIRef, BNode)):
                         candidateTermMatches[b].add(t)
         return candidateTermMatches
 
+    def _bindingsFromStatement(self, stmt1: Triple, stmt2: Triple) -> Iterator[Tuple[Variable, Set[Node]]]:
+        """if these stmts match otherwise, what BNode or Variable mappings do we learn?
+        
+        e.g. stmt1=(?x B ?y) and stmt2=(A B C), then we yield (?x, {A}) and (?y, {C})
+        or   stmt1=(_:x B C) and stmt2=(A B C), then we yield (_:x, {A})
+        or   stmt1=(?x B C)  and stmt2=(A B D), then we yield nothing
+        """
+        bindingsFromStatement = {}
+        for term1, term2 in zip(stmt1, stmt2):
+            if isinstance(term1, (BNode, Variable)):
+                bindingsFromStatement.setdefault(term1, set()).add(term2)
+            elif term1 != term2:
+                break
+        else:
+            for v, vals in bindingsFromStatement.items():
+                yield v, vals
+
     def graphWithoutEvals(self, binding: CandidateBinding) -> Graph:
         g = Graph()
         usedByFuncs = binding.applyFunctions(self)
 
-        for stmt in self._g:
+        for stmt in self.graph:
             if stmt not in usedByFuncs:
                 g.add(stmt)
         return g
@@ -241,18 +232,25 @@
     """some lhs statements need to be evaluated with a special function 
     (e.g. math) and then not considered for the rest of the rule-firing 
     process. It's like they already 'matched' something, so they don't need
-    to match a statement from the known-true working set."""
+    to match a statement from the known-true working set.
+    
+    One Evaluation instance is for one function call.
+    """
 
     @staticmethod
     def findEvals(lhs: Lhs) -> Iterator['Evaluation']:
-        for stmt in lhs._g.triples((None, MATH['sum'], None)):
-            # shouldn't be redoing this here
-            operands, operandsStmts = parseList(lhs._g, stmt[0])
+        for stmt in lhs.graph.triples((None, MATH['sum'], None)):
+            operands, operandsStmts = parseList(lhs.graph, stmt[0])
             g = Graph()
             g += operandsStmts
             yield Evaluation(operands, g, stmt)
 
-        for stmt in lhs._g.triples((None, ROOM['asFarenheit'], None)):
+        for stmt in lhs.graph.triples((None, MATH['greaterThan'], None)):
+            g = Graph()
+            g.add(stmt)
+            yield Evaluation([stmt[0], stmt[2]], g, stmt)
+
+        for stmt in lhs.graph.triples((None, ROOM['asFarenheit'], None)):
             g = Graph()
             g.add(stmt)
             yield Evaluation([stmt[0]], g, stmt)
@@ -260,13 +258,13 @@
     # internal, use findEvals
     def __init__(self, operands: List[Node], operandsStmts: Graph, stmt: Triple) -> None:
         self.operands = operands
-        self.operandsStmts = operandsStmts
+        self.operandsStmts = operandsStmts  # may grow
         self.stmt = stmt
 
     def resultBindings(self, inputBindings) -> Tuple[Dict[BindableTerm, Node], Graph]:
         """under the bindings so far, what would this evaluation tell us, and which stmts would be consumed from doing so?"""
         pred = self.stmt[1]
-        objVar = self.stmt[2]
+        objVar: Node = self.stmt[2]
         boundOperands = []
         for o in self.operands:
             if isinstance(o, Variable):
@@ -277,43 +275,34 @@
 
             boundOperands.append(o)
 
-        if not isinstance(objVar, Variable):
-            raise TypeError(f'expected Variable, got {objVar!r}')
-
         if pred == MATH['sum']:
-            log.debug(f'{INDENT*4} sum {list(map(self.numericNode, boundOperands))}')
-            obj = cast(Literal, Literal(sum(map(self.numericNode, boundOperands))))
+            obj = Literal(sum(map(numericNode, boundOperands)))
             self.operandsStmts.add(self.stmt)
+            if not isinstance(objVar, Variable):
+                raise TypeError(f'expected Variable, got {objVar!r}')
             return {objVar: obj}, self.operandsStmts
         elif pred == ROOM['asFarenheit']:
             if len(boundOperands) != 1:
                 raise ValueError(":asFarenheit takes 1 subject operand")
-            f = Literal(Decimal(self.numericNode(boundOperands[0])) * 9 / 5 + 32)
-            g = Graph()
-            g.add(self.stmt)
-
-            log.debug('made 1 st graph')
-            return {objVar: f}, g
+            f = Literal(Decimal(numericNode(boundOperands[0])) * 9 / 5 + 32)
+            if not isinstance(objVar, Variable):
+                raise TypeError(f'expected Variable, got {objVar!r}')
+            return {objVar: f}, self.operandsStmts
+        elif pred == MATH['greaterThan']:
+            if not (numericNode(boundOperands[0]) > numericNode(boundOperands[1])):
+                raise EvaluationFailed()
+            return {}, self.operandsStmts
         else:
-            raise NotImplementedError()
-
-    def numericNode(self, n: Node):
-        if not isinstance(n, Literal):
-            raise TypeError(f'expected Literal, got {n=}')
-        val = n.toPython()
-        if not isinstance(val, (int, float, Decimal)):
-            raise TypeError(f'expected number, got {val=}')
-        return val
+            raise NotImplementedError(repr(pred))
 
 
-# merge into evaluation, raising a Invalid for impossible stmts
-def mathTest(subj, pred, obj):
-    x = subj.toPython()
-    y = obj.toPython()
-    if pred == MATH['greaterThan']:
-        return x > y
-    else:
-        raise NotImplementedError(pred)
+def numericNode(n: Node):
+    if not isinstance(n, Literal):
+        raise TypeError(f'expected Literal, got {n=}')
+    val = n.toPython()
+    if not isinstance(val, (int, float, Decimal)):
+        raise TypeError(f'expected number, got {val=}')
+    return val
 
 
 class Inference:
@@ -334,7 +323,7 @@
         workingSet = Graph()
         workingSet += graph
 
-        # just the statements that came from rule RHS's.
+        # just the statements that came from RHS's of rules that fired.
         implied = ConjunctiveGraph()
 
         bailout_iterations = 100
@@ -353,24 +342,29 @@
 
     def _iterateAllRules(self, workingSet: Graph, implied: Graph):
         for i, r in enumerate(self.rules):
-            log.debug('')
-            log.debug(f'{INDENT*2} workingSet:')
-            for i, stmt in enumerate(sorted(workingSet)):
-                log.debug(f'{INDENT*3} ({i}) {stmt}')
-
-            log.debug('')
-            log.debug(f'{INDENT*2}-applying rule {i}')
-            log.debug(f'{INDENT*3} rule def lhs: {graphDump(r[0])}')
-            log.debug(f'{INDENT*3} rule def rhs: {graphDump(r[2])}')
+            self.logRuleApplicationHeader(workingSet, i, r)
             if r[1] == LOG['implies']:
                 applyRule(Lhs(r[0]), r[2], workingSet, implied)
             else:
                 log.info(f'{INDENT*2} {r} not a rule?')
 
+    def logRuleApplicationHeader(self, workingSet, i, r):
+        if not log.isEnabledFor(logging.DEBUG):
+            return
+
+        log.debug('')
+        log.debug(f'{INDENT*2} workingSet:')
+        for i, stmt in enumerate(sorted(workingSet)):
+            log.debug(f'{INDENT*3} ({i}) {stmt}')
+
+        log.debug('')
+        log.debug(f'{INDENT*2}-applying rule {i}')
+        log.debug(f'{INDENT*3} rule def lhs: {graphDump(r[0])}')
+        log.debug(f'{INDENT*3} rule def rhs: {graphDump(r[2])}')
+
 
 def applyRule(lhs: Lhs, rhs: Graph, workingSet: Graph, implied: Graph):
     for binding in lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet])):
-        # log.debug(f'        rule gave {binding=}')
         for lhsBoundStmt in binding.apply(lhs.graphWithoutEvals(binding)):
             workingSet.add(lhsBoundStmt)
         for newStmt in binding.apply(rhs):
@@ -384,8 +378,7 @@
     out = []
     used = set()
     cur = subj
-    while True:
-        # bug: mishandles empty list
+    while cur != RDF.nil:
         out.append(graph.value(cur, RDF.first))
         used.add((cur, RDF.first, out[-1]))
 
@@ -393,16 +386,13 @@
         used.add((cur, RDF.rest, next))
 
         cur = next
-        if cur == RDF.nil:
-            break
     return out, used
 
 
 def graphDump(g: Union[Graph, List[Triple]]):
     if not isinstance(g, Graph):
         g2 = Graph()
-        for stmt in g:
-            g2.add(stmt)
+        g2 += g
         g = g2
     g.bind('', ROOM)
     g.bind('ex', Namespace('http://example.com/'))
--- a/service/mqtt_to_rdf/inference_test.py	Sat Sep 04 23:23:55 2021 -0700
+++ b/service/mqtt_to_rdf/inference_test.py	Sun Sep 05 01:15:55 2021 -0700
@@ -1,12 +1,13 @@
 """
 also see https://github.com/w3c/N3/tree/master/tests/N3Tests
 """
+import itertools
 import unittest
-import itertools
-from rdflib import ConjunctiveGraph, Namespace, Graph, BNode
+
+from rdflib import RDF, BNode, ConjunctiveGraph, Graph, Literal, Namespace
 from rdflib.parser import StringInputSource
 
-from inference import Inference
+from inference import Inference, parseList
 
 
 def patchSlimReprs():
@@ -60,6 +61,7 @@
 
 patchBnodeCounter()
 
+EX = Namespace('http://example.com/')
 ROOM = Namespace('http://projects.bigasterisk.com/room/')
 
 
@@ -213,9 +215,34 @@
         inf = makeInferenceWithRules("{ :a :b ?x . (2 ?x 2) math:sum ?y } => { :new :stmt ?y } .")
         self.assertGraphEqual(inf.infer(N3(":a :b 2 .")), N3(":new :stmt 6 ."))
 
+    def test0Operands(self):
+        inf = makeInferenceWithRules("{ :a :b ?x . () math:sum ?y } => { :new :stmt ?y } .")
+        self.assertGraphEqual(inf.infer(N3(":a :b 2 .")), N3(":new :stmt 0 ."))
+
 
 class TestInferenceWithCustomFunctions(WithGraphEqual):
 
     def testAsFarenheit(self):
         inf = makeInferenceWithRules("{ :a :b ?x . ?x room:asFarenheit ?f } => { :new :stmt ?f } .")
         self.assertGraphEqual(inf.infer(N3(":a :b 12 .")), N3(":new :stmt 53.6 ."))
+
+
+class TestParseList(unittest.TestCase):
+
+    def test0Elements(self):
+        g = N3(":a :b () .")
+        bn = g.value(EX['a'], EX['b'])
+        elems, used = parseList(g, bn)
+        self.assertEqual(elems, [])
+        self.assertFalse(used)
+
+    def test1Element(self):
+        g = N3(":a :b (0) .")
+        bn = g.value(EX['a'], EX['b'])
+        elems, used = parseList(g, bn)
+        self.assertEqual(elems, [Literal(0)])
+        used = sorted(used)
+        self.assertEqual(used, [
+            (bn, RDF.first, Literal(0)),
+            (bn, RDF.rest, RDF.nil),
+        ])