Mercurial > code > home > repos > homeauto
diff service/mqtt_to_rdf/inference.py @ 1648:3059f31b2dfa
more performance work
author | drewp@bigasterisk.com |
---|---|
date | Fri, 17 Sep 2021 11:10:18 -0700 |
parents | 5403c6343fa4 |
children | bb5d2b5370ac |
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py Fri Sep 17 11:07:21 2021 -0700 +++ b/service/mqtt_to_rdf/inference.py Fri Sep 17 11:10:18 2021 -0700 @@ -61,7 +61,7 @@ parent: 'Lhs' # just for lhs.graph, really def __repr__(self): - return f'StmtLooper{self._shortId}({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})' + return f'StmtLooper{self._shortId}{"<pastEnd>" if self.pastEnd() else ""})' def __post_init__(self): self._shortId = next(_stmtLooperShortId) @@ -70,6 +70,9 @@ self._current = CandidateBinding({}) self._pastEnd = False self._seenBindings: List[CandidateBinding] = [] + + log.debug(f'introducing {self!r}({graphDump([self.lhsStmt])})') + self.restart() def _myMatches(self, g: Graph) -> List[Triple]: @@ -119,7 +122,7 @@ try: outBinding = self._totalBindingIfThisStmtWereTrue(stmt) except Inconsistent: - log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings') + log.debug(f'{INDENT*7} StmtLooper{self._shortId} - {stmt} would be inconsistent with prev bindings') continue log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}') @@ -169,7 +172,8 @@ for rt, ct in zip(self.lhsStmt, newStmt): if isinstance(rt, (Variable, BNode)): if outBinding.contains(rt) and outBinding.applyTerm(rt) != ct: - raise Inconsistent(f'{rt=} {ct=} {outBinding=}') + msg = f'{rt=} {ct=} {outBinding=}' if log.isEnabledFor(logging.DEBUG) else '' + raise Inconsistent(msg) outBinding.addNewBindings(CandidateBinding({rt: ct})) return outBinding @@ -191,15 +195,37 @@ @dataclass class Lhs: - graph: Graph + graph: Graph # our full LHS graph, as input. See below for the statements partitioned into groups. def __post_init__(self): - pass + + usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph) + + stmtsToMatch = list(self.graph - usedByFuncs) + self.staticStmts = [] + self.patternStmts = [] + for st in stmtsToMatch: + if all(isinstance(term, (URIRef, Literal)) for term in st): + self.staticStmts.append(st) + else: + self.patternStmts.append(st) + + # sort them by variable dependencies; don't just try all perms! + def lightSortKey(stmt): # Not this. + (s, p, o) = stmt + return p in rulePredicates(), p, s, o + + self.patternStmts.sort(key=lightSortKey) + + self.myPreds = set(p for s, p, o in self.graph if isinstance(p, URIRef)) + self.myPreds -= rulePredicates() + self.myPreds -= {RDF.first, RDF.rest} + self.myPreds = set(self.myPreds) def __repr__(self): return f"Lhs({graphDump(self.graph)})" - def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: + def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats, ruleStatementsIterationLimit) -> Iterator['BoundLhs']: """bindings that fit the LHS of a rule, using statements from workingSet and functions from LHS""" if self.graph.__len__() == 0: @@ -211,10 +237,19 @@ stats['_checkPredicateCountsCulls'] += 1 return + if not all(st in knownTrue for st in self.staticStmts): + stats['staticStmtCulls'] += 1 + return + + if len(self.patternStmts) == 0: + # static only + yield BoundLhs(self, CandidateBinding({})) + return + log.debug(f'{INDENT*4} build new StmtLooper stack') try: - stmtStack = self._assembleRings(knownTrue) + stmtStack = self._assembleRings(knownTrue, stats) except NoOptions: log.debug(f'{INDENT*5} start up with no options; 0 bindings') return @@ -225,8 +260,8 @@ iterCount = 0 while True: iterCount += 1 - if iterCount > 10: - raise ValueError('stuck') + if iterCount > ruleStatementsIterationLimit: + raise ValueError('rule too complex') log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}') @@ -249,32 +284,23 @@ def _checkPredicateCounts(self, knownTrue): """raise NoOptions quickly in some cases""" - myPreds = set(p for s, p, o in self.graph if isinstance(p, URIRef)) - myPreds -= rulePredicates() - myPreds -= {RDF.first, RDF.rest} - if any((None, p, None) not in knownTrue for p in set(myPreds)): + + if any((None, p, None) not in knownTrue for p in self.myPreds): return True + log.info(f'{INDENT*2} checkPredicateCounts does not cull because all {self.myPreds=} are in knownTrue') return False - def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]: + def _assembleRings(self, knownTrue: ReadOnlyWorkingSet, stats) -> List[StmtLooper]: """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all start out valid (or else raise NoOptions)""" - usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph) - - stmtsToAdd = list(self.graph - usedByFuncs) - - # sort them by variable dependencies; don't just try all perms! - def lightSortKey(stmt): # Not this. Though it helps performance on the big rdf list cases. - (s, p, o) = stmt - return p == MATH['sum'], p, s, o - - stmtsToAdd.sort(key=lightSortKey) - - for i, perm in enumerate(itertools.permutations(stmtsToAdd)): + log.info(f'{INDENT*2} stats={dict(stats)}') + log.info(f'{INDENT*2} taking permutations of {len(self.patternStmts)=}') + for i, perm in enumerate(itertools.permutations(self.patternStmts)): stmtStack: List[StmtLooper] = [] prev: Optional[StmtLooper] = None - log.debug(f'{INDENT*5} [perm {i}] try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}') + if log.isEnabledFor(logging.DEBUG): + log.debug(f'{INDENT*5} [perm {i}] try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}') for s in perm: try: @@ -286,8 +312,10 @@ prev = stmtStack[-1] else: return stmtStack + if i > 5000: + raise NotImplementedError(f'trying too many permutations {len(self.patternStmts)=}') + log.debug(f'{INDENT*6} no perms worked- rule cannot match anything') - raise NoOptions() def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool: @@ -329,8 +357,8 @@ # self.rhsBnodeMap = {} - def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict): - for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats): + def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict, ruleStatementsIterationLimit): + for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats, ruleStatementsIterationLimit): log.debug(f'{INDENT*5} +rule has a working binding: {bound}') # rhs could have more bnodes, and they just need to be distinct per rule-firing that we do @@ -362,13 +390,18 @@ implied.add(newStmt) +@dataclass class Inference: + rulesIterationLimit = 3 + ruleStatementsIterationLimit = 3 def __init__(self) -> None: - self.rules = [] + self.rules: List[Rule] = [] + self._nonRuleStmts: List[Triple] = [] def setRules(self, g: ConjunctiveGraph): - self.rules: List[Rule] = [] + self.rules = [] + self._nonRuleStmts = [] for stmt in g: if stmt[1] == LOG['implies']: self.rules.append(Rule(stmt[0], stmt[2])) @@ -391,28 +424,30 @@ # just the statements that came from RHS's of rules that fired. implied = ConjunctiveGraph() - bailout_iterations = 100 + rulesIterations = 0 delta = 1 stats['initWorkingSet'] = cast(int, workingSet.__len__()) - while delta > 0 and bailout_iterations > 0: + while delta > 0 and rulesIterations <= self.rulesIterationLimit: log.debug('') - log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)') - bailout_iterations -= 1 + log.info(f'{INDENT*1}*iteration {rulesIterations}') + delta = -len(implied) self._iterateAllRules(workingSet, implied, stats) delta += len(implied) - stats['iterations'] += 1 + rulesIterations += 1 log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts') + stats['iterations'] = rulesIterations stats['timeSpent'] = round(time.time() - startTime, 3) stats['impliedStmts'] = len(implied) - log.info(f'{INDENT*0} Inference done {dict(stats)}. Implied:') - log.info(graphDump(implied)) + log.info(f'{INDENT*0} Inference done {dict(stats)}.') + log.debug('Implied:') + log.debug(graphDump(implied)) return implied def _iterateAllRules(self, workingSet: Graph, implied: Graph, stats): for i, rule in enumerate(self.rules): self._logRuleApplicationHeader(workingSet, i, rule) - rule.applyRule(workingSet, implied, stats) + rule.applyRule(workingSet, implied, stats, self.ruleStatementsIterationLimit) def _logRuleApplicationHeader(self, workingSet, i, r: Rule): if not log.isEnabledFor(logging.DEBUG): @@ -420,8 +455,9 @@ log.debug('') log.debug(f'{INDENT*2} workingSet:') - for j, stmt in enumerate(sorted(workingSet)): - log.debug(f'{INDENT*3} ({j}) {stmt}') + # for j, stmt in enumerate(sorted(workingSet)): + # log.debug(f'{INDENT*3} ({j}) {stmt}') + log.debug(f'{INDENT*3} {graphDump(workingSet, oneLine=False)}') log.debug('') log.debug(f'{INDENT*2}-applying rule {i}') @@ -431,7 +467,7 @@ log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}') -def graphDump(g: Union[Graph, List[Triple]]): +def graphDump(g: Union[Graph, List[Triple]], oneLine=True): # this is very slow- debug only! if not log.isEnabledFor(logging.DEBUG): return "(skipped dump)" @@ -442,5 +478,7 @@ g.bind('', ROOM) g.bind('ex', Namespace('http://example.com/')) lines = cast(bytes, g.serialize(format='n3')).decode('utf8').splitlines() - lines = [line.strip() for line in lines if not line.startswith('@prefix')] + lines = [line for line in lines if not line.startswith('@prefix')] + if oneLine: + lines = [line.strip() for line in lines] return ' '.join(lines)