Mercurial > code > home > repos > homeauto
changeset 1634:ba59cfc3c747
hack math:sum in there. Test suite is passing except some slow performers
author | drewp@bigasterisk.com |
---|---|
date | Sun, 12 Sep 2021 23:48:43 -0700 |
parents | 6107603ed455 |
children | 22d481f0a924 |
files | service/mqtt_to_rdf/candidate_binding.py service/mqtt_to_rdf/inference.py service/mqtt_to_rdf/inference_test.py service/mqtt_to_rdf/lhs_evaluation.py |
diffstat | 4 files changed, 184 insertions(+), 79 deletions(-) [+] |
line wrap: on
line diff
--- a/service/mqtt_to_rdf/candidate_binding.py Sun Sep 12 21:48:36 2021 -0700 +++ b/service/mqtt_to_rdf/candidate_binding.py Sun Sep 12 23:48:43 2021 -0700 @@ -7,7 +7,7 @@ from rdflib.term import Node, Variable from inference_types import BindableTerm, BindingUnknown, Triple -log = logging.getLogger() +log = logging.getLogger('cbind') INDENT = ' ' @dataclass
--- a/service/mqtt_to_rdf/inference.py Sun Sep 12 21:48:36 2021 -0700 +++ b/service/mqtt_to_rdf/inference.py Sun Sep 12 23:48:43 2021 -0700 @@ -7,16 +7,16 @@ import time from collections import defaultdict from dataclasses import dataclass -from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast +from typing import (Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast) from prometheus_client import Histogram, Summary -from rdflib import BNode, Graph, Namespace +from rdflib import RDF, BNode, Graph, Namespace from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate from rdflib.term import Literal, Node, Variable from candidate_binding import CandidateBinding -from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple) -from lhs_evaluation import Decimal, Evaluation, numericNode +from inference_types import (BindableTerm, BindingUnknown, EvaluationFailed, ReadOnlyWorkingSet, Triple) +from lhs_evaluation import Decimal, Evaluation, numericNode, parseList log = logging.getLogger('infer') INDENT = ' ' @@ -58,6 +58,7 @@ lhsStmt: Triple prev: Optional['StmtLooper'] workingSet: ReadOnlyWorkingSet + parent: 'Lhs' # just for lhs.graph, really def __repr__(self): return f'StmtLooper{self._shortId}({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})' @@ -98,9 +99,29 @@ augmentedWorkingSet = list(self.prev.currentBinding().apply(self._myWorkingSetMatches, returnBoundStatementsOnly=False)) - log.debug(f'{INDENT*6} {self} has {self._myWorkingSetMatches=}') + log.debug(f'{INDENT*6} {self}.advance has {augmentedWorkingSet=}') + + if self._advanceWithPlainMatches(augmentedWorkingSet): + return + + if self._advanceWithBoolRules(): + return + + curBind = self.prev.currentBinding() if self.prev else CandidateBinding({}) + [lhsStmtBound] = curBind.apply([self.lhsStmt], returnBoundStatementsOnly=False) - log.debug(f'{INDENT*6} {self} mines {len(augmentedWorkingSet)} matching augmented statements') + fullWorkingSet = self.workingSet + self.parent.graph + boundFullWorkingSet = list(curBind.apply(fullWorkingSet, returnBoundStatementsOnly=False)) + log.debug(f'{fullWorkingSet.__len__()=} {len(boundFullWorkingSet)=}') + + if self._advanceWithFunctions(augmentedWorkingSet, boundFullWorkingSet, lhsStmtBound): + return + + log.debug(f'{INDENT*6} {self} is past end') + self._pastEnd = True + + def _advanceWithPlainMatches(self, augmentedWorkingSet: Sequence[Triple]) -> bool: + log.debug(f'{INDENT*7} {self} mines {len(augmentedWorkingSet)} matching augmented statements') for s in augmentedWorkingSet: log.debug(f'{INDENT*7} {s}') @@ -111,19 +132,38 @@ log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings') continue - log.debug(f'{INDENT*6} {outBinding=} {self._seenBindings=}') + log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}') if outBinding.binding not in self._seenBindings: self._seenBindings.append(outBinding.binding.copy()) self._current = outBinding log.debug(f'{INDENT*7} new binding from {self} -> {outBinding}') - return - log.debug(f'yes we saw') + return True + return False - log.debug(f'{INDENT*6} {self} mines rules') + def _advanceWithBoolRules(self) -> bool: + log.debug(f'{INDENT*7} {self} mines bool rules') + if self.lhsStmt[1] == MATH['greaterThan']: + operands = [self.lhsStmt[0], self.lhsStmt[2]] + try: + boundOperands = self._boundOperands(operands) + except BindingUnknown: + return False + if numericNode(boundOperands[0]) > numericNode(boundOperands[1]): + bindingDict: Dict[BindableTerm, + Node] = self._prevBindings().copy() # no new values; just allow matching to keep going + if bindingDict not in self._seenBindings: + self._seenBindings.append(bindingDict) + self._current = CandidateBinding(bindingDict) + log.debug(f'{INDENT*7} new binding from {self} -> {bindingDict}') + return True + return False + + def _advanceWithFunctions(self, augmentedWorkingSet: Sequence[Triple], boundFullWorkingSet, lhsStmtBound) -> bool: + log.debug(f'{INDENT*7} {self} mines rules') if self.lhsStmt[1] == ROOM['asFarenheit']: pb: Dict[BindableTerm, Node] = self._prevBindings() - log.debug(f'{INDENT*6} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}') + log.debug(f'{INDENT*7} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}') if self.lhsStmt[0] in pb: operands = [pb[cast(BindableTerm, self.lhsStmt[0])]] @@ -136,17 +176,59 @@ if newBindings not in self._seenBindings: self._seenBindings.append(newBindings) self._current = CandidateBinding(newBindings) - return + return True + elif self.lhsStmt[1] == MATH['sum']: + + g = Graph() + for s in boundFullWorkingSet: + g.add(s) + log.debug(f' boundWorkingSet graph: {s}') + log.debug(f'_parseList subj = {lhsStmtBound[0]}') + operands, _ = parseList(g, lhsStmtBound[0]) + log.debug(f'********* {INDENT*7} {self} found list {operands=}') + try: + obj = Literal(sum(map(numericNode, operands))) + except TypeError: + log.debug('typeerr in operands') + pass + else: + objVar = lhsStmtBound[2] + log.debug(f'{objVar=}') - log.debug(f'{INDENT*6} {self} is past end') - self._pastEnd = True + if not isinstance(objVar, Variable): + raise TypeError(f'expected Variable, got {objVar!r}') + newBindings: Dict[BindableTerm, Node] = {objVar: obj} + log.debug(f'{newBindings=}') + + self._current.addNewBindings(CandidateBinding(newBindings)) + log.debug(f'{self._seenBindings=}') + if newBindings not in self._seenBindings: + self._seenBindings.append(newBindings) + self._current = CandidateBinding(newBindings) + return True + + return False + + def _boundOperands(self, operands) -> List[Node]: + pb: Dict[BindableTerm, Node] = self._prevBindings() + + boundOperands: List[Node] = [] + for op in operands: + if isinstance(op, (Variable, BNode)): + if op in pb: + boundOperands.append(pb[op]) + else: + raise BindingUnknown() + else: + boundOperands.append(op) + return boundOperands def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding: outBinding = self._prevBindings().copy() for rt, ct in zip(self.lhsStmt, newStmt): if isinstance(rt, (Variable, BNode)): if rt in outBinding and outBinding[rt] != ct: - raise Inconsistent() + raise Inconsistent(f'{rt=} {ct=} {outBinding=}') outBinding[rt] = ct return CandidateBinding(outBinding) @@ -245,7 +327,21 @@ """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all start out valid (or else raise NoOptions)""" - stmtsToAdd = list(self.graph) + usedByFuncs: Set[Triple] = set() # don't worry about matching these + stmtsToResolve = list(self.graph) + for i, s in enumerate(stmtsToResolve): + if s[1] == MATH['sum']: + _, used = parseList(self.graph, s[0]) + usedByFuncs.update(used) + + stmtsToAdd = [stmt for stmt in stmtsToResolve if not stmt in usedByFuncs] + + # sort them by variable dependencies; don't just try all perms! + def lightSortKey(stmt): # Not this. Though it helps performance on the big rdf list cases. + (s, p, o) = stmt + return p == MATH['sum'], p, s, o + + stmtsToAdd.sort(key=lightSortKey) for perm in itertools.permutations(stmtsToAdd): stmtStack: List[StmtLooper] = [] @@ -254,7 +350,7 @@ for s in perm: try: - elem = StmtLooper(s, prev, knownTrue) + elem = StmtLooper(s, prev, knownTrue, parent=self) except NoOptions: log.debug(f'{INDENT*6} permutation didnt work, try another') break @@ -540,7 +636,7 @@ log.debug('') log.debug(f'{INDENT*2}-applying rule {i}') log.debug(f'{INDENT*3} rule def lhs:') - for stmt in r.lhsGraph: + for stmt in sorted(r.lhsGraph, reverse=True): log.debug(f'{INDENT*4} {stmt}') log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')
--- a/service/mqtt_to_rdf/inference_test.py Sun Sep 12 21:48:36 2021 -0700 +++ b/service/mqtt_to_rdf/inference_test.py Sun Sep 12 23:48:43 2021 -0700 @@ -165,9 +165,9 @@ self.assertGraphEqual(inf.infer(N3("")), N3(":new :stmt :x .")) self.assertGraphEqual(inf.infer(N3(":any :any :any .")), N3(":new :stmt :x .")) -# def test2(self): -# inf = makeInferenceWithRules("{ (2) math:sum ?x } => { :new :stmt ?x } .") -# self.assertGraphEqual(inf.infer(N3("")), N3(":new :stmt 2 .")) + def test2(self): + inf = makeInferenceWithRules("{ (2) math:sum ?x } => { :new :stmt ?x } .") + self.assertGraphEqual(inf.infer(N3("")), N3(":new :stmt 2 .")) # @unittest.skip("too hard for now") # def test3(self): @@ -175,29 +175,33 @@ # self.assertGraphEqual(inf.infer(N3("")), N3(":new :stmt :c .")) -# class TestInferenceWithMathFunctions(WithGraphEqual): +class TestInferenceWithMathFunctions(WithGraphEqual): -# def testBoolFilter(self): -# inf = makeInferenceWithRules("{ :a :b ?x . ?x math:greaterThan 5 } => { :new :stmt ?x } .") -# self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3("")) -# self.assertGraphEqual(inf.infer(N3(":a :b 5 .")), N3("")) -# self.assertGraphEqual(inf.infer(N3(":a :b 6 .")), N3(":new :stmt 6 .")) + def testBoolFilter(self): + inf = makeInferenceWithRules("{ :a :b ?x . ?x math:greaterThan 5 } => { :new :stmt ?x } .") + self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3("")) + self.assertGraphEqual(inf.infer(N3(":a :b 5 .")), N3("")) + self.assertGraphEqual(inf.infer(N3(":a :b 6 .")), N3(":new :stmt 6 .")) -# def testNonFiringMathRule(self): -# inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .") -# self.assertGraphEqual(inf.infer(N3("")), N3("")) + def testNonFiringMathRule(self): + inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .") + self.assertGraphEqual(inf.infer(N3("")), N3("")) -# def testStatementGeneratingRule(self): -# inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .") -# self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 4 .")) + def testStatementGeneratingRule(self): + inf = makeInferenceWithRules("{ :a :b ?x . (?x) math:sum ?y } => { :new :stmt ?y } .") + self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 3 .")) + + def test2Operands(self): + inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .") + self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 4 .")) -# def test3Operands(self): -# inf = makeInferenceWithRules("{ :a :b ?x . (2 ?x 2) math:sum ?y } => { :new :stmt ?y } .") -# self.assertGraphEqual(inf.infer(N3(":a :b 2 .")), N3(":new :stmt 6 .")) + def test3Operands(self): + inf = makeInferenceWithRules("{ :a :b ?x . (2 ?x 2) math:sum ?y } => { :new :stmt ?y } .") + self.assertGraphEqual(inf.infer(N3(":a :b 2 .")), N3(":new :stmt 6 .")) -# def test0Operands(self): -# inf = makeInferenceWithRules("{ :a :b ?x . () math:sum ?y } => { :new :stmt ?y } .") -# self.assertGraphEqual(inf.infer(N3(":a :b 2 .")), N3(":new :stmt 0 .")) + def test0Operands(self): + inf = makeInferenceWithRules("{ :a :b ?x . () math:sum ?y } => { :new :stmt ?y } .") + self.assertGraphEqual(inf.infer(N3(":a :b 2 .")), N3(":new :stmt 0 .")) class TestInferenceWithCustomFunctions(WithGraphEqual): @@ -241,11 +245,35 @@ # out = inf.infer(N3('[] a :MqttMessage ; :body "online" ; :topic ( "frontdoorlock" "status" ) .')) # self.assertIn((EX['frontDoorLockStatus'], EX['connectedStatus'], EX['Online']), out) - # def testPerformance0(self): + def testPerformance0(self): + inf = makeInferenceWithRules(''' + { + ?msg a :MqttMessage; + :topic :topic1; + :bodyFloat ?valueC . + ?valueC math:greaterThan -999 . + ?valueC room:asFarenheit ?valueF . + } => { + :airQualityIndoorTemperature :temperatureF ?valueF . + } . + ''') + out = inf.infer( + N3(''' + <urn:uuid:c6e1d92c-0ee1-11ec-bdbd-2a42c4691e9a> a :MqttMessage ; + :body "23.9" ; + :bodyFloat 2.39e+01 ; + :topic :topic1 . + ''')) + + vlit = cast(Literal, out.value(EX['airQualityIndoorTemperature'], EX['temperatureF'])) + valueF = cast(Decimal, vlit.toPython()) + self.assertAlmostEqual(float(valueF), 75.02) + + # def testPerformance1(self): # inf = makeInferenceWithRules(''' # { # ?msg a :MqttMessage; - # :topic :topic1; + # :topic ( "air_quality_indoor" "sensor" "bme280_temperature" "state" ); # :bodyFloat ?valueC . # ?valueC math:greaterThan -999 . # ?valueC room:asFarenheit ?valueF . @@ -258,36 +286,12 @@ # <urn:uuid:c6e1d92c-0ee1-11ec-bdbd-2a42c4691e9a> a :MqttMessage ; # :body "23.9" ; # :bodyFloat 2.39e+01 ; - # :topic :topic1 . - # ''')) - + # :topic ( "air_quality_indoor" "sensor" "bme280_temperature" "state" ) . + # ''')) # vlit = cast(Literal, out.value(EX['airQualityIndoorTemperature'], EX['temperatureF'])) # valueF = cast(Decimal, vlit.toPython()) # self.assertAlmostEqual(float(valueF), 75.02) -# def testPerformance1(self): -# inf = makeInferenceWithRules(''' -# { -# ?msg a :MqttMessage; -# :topic ( "air_quality_indoor" "sensor" "bme280_temperature" "state" ); -# :bodyFloat ?valueC . -# ?valueC math:greaterThan -999 . -# ?valueC room:asFarenheit ?valueF . -# } => { -# :airQualityIndoorTemperature :temperatureF ?valueF . -# } . -# ''') -# out = inf.infer( -# N3(''' -# <urn:uuid:c6e1d92c-0ee1-11ec-bdbd-2a42c4691e9a> a :MqttMessage ; -# :body "23.9" ; -# :bodyFloat 2.39e+01 ; -# :topic ( "air_quality_indoor" "sensor" "bme280_temperature" "state" ) . -# ''')) -# vlit = cast(Literal, out.value(EX['airQualityIndoorTemperature'], EX['temperatureF'])) -# valueF = cast(Decimal, vlit.toPython()) -# self.assertAlmostEqual(float(valueF), 75.02) - def testEmitBnodes(self): inf = makeInferenceWithRules(''' { ?s a :AirQualitySensor; :label ?name . } => { @@ -322,15 +326,15 @@ implied = inf.infer(N3(":a :b (:e0 :e1) .")) self.assertGraphEqual(implied, N3(":new :stmt :here .")) - # def testList3(self): - # inf = makeInferenceWithRules("{ :a :b (:e0 :e1 :e2) . } => { :new :stmt :here } .") - # implied = inf.infer(N3(":a :b (:e0 :e1 :e2) .")) - # self.assertGraphEqual(implied, N3(":new :stmt :here .")) + def testList3(self): + inf = makeInferenceWithRules("{ :a :b (:e0 :e1 :e2) . } => { :new :stmt :here } .") + implied = inf.infer(N3(":a :b (:e0 :e1 :e2) .")) + self.assertGraphEqual(implied, N3(":new :stmt :here .")) - # def testList4(self): - # inf = makeInferenceWithRules("{ :a :b (:e0 :e1 :e2 :e3) . } => { :new :stmt :here } .") - # implied = inf.infer(N3(":a :b (:e0 :e1 :e2 :e3) .")) - # self.assertGraphEqual(implied, N3(":new :stmt :here .")) + def testList4(self): + inf = makeInferenceWithRules("{ :a :b (:e0 :e1 :e2 :e3) . } => { :new :stmt :here } .") + implied = inf.infer(N3(":a :b (:e0 :e1 :e2 :e3) .")) + self.assertGraphEqual(implied, N3(":new :stmt :here .")) # def fakeStats():
--- a/service/mqtt_to_rdf/lhs_evaluation.py Sun Sep 12 21:48:36 2021 -0700 +++ b/service/mqtt_to_rdf/lhs_evaluation.py Sun Sep 12 23:48:43 2021 -0700 @@ -36,7 +36,7 @@ @staticmethod def findEvals(graph: Graph) -> Iterator['Evaluation']: for stmt in graph.triples((None, MATH['sum'], None)): - operands, operandsStmts = _parseList(graph, stmt[0]) + operands, operandsStmts = parseList(graph, stmt[0]) yield Evaluation(operands, stmt, operandsStmts) for stmt in graph.triples((None, MATH['greaterThan'], None)): @@ -98,17 +98,22 @@ return val -def _parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]: +def parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]: """"Do like Collection(g, subj) but also return all the triples that are involved in the list""" out = [] used = set() cur = subj while cur != RDF.nil: - out.append(graph.value(cur, RDF.first)) + elem = graph.value(cur, RDF.first) + if elem is None: + raise ValueError('bad list') + out.append(elem) used.add((cur, RDF.first, out[-1])) next = graph.value(cur, RDF.rest) + if next is None: + raise ValueError('bad list') used.add((cur, RDF.rest, next)) cur = next