diff service/mqtt_to_rdf/lhs_evaluation.py @ 1651:20474ad4968e

WIP - functions are broken as i move most layers to work in Chunks not Triples A Chunk is a Triple plus any rdf lists.
author drewp@bigasterisk.com
date Sat, 18 Sep 2021 23:57:20 -0700
parents 3059f31b2dfa
children 7ec2483d61b5
line wrap: on
line diff
--- a/service/mqtt_to_rdf/lhs_evaluation.py	Sat Sep 18 23:53:59 2021 -0700
+++ b/service/mqtt_to_rdf/lhs_evaluation.py	Sat Sep 18 23:57:20 2021 -0700
@@ -1,15 +1,14 @@
-from dataclasses import dataclass
 import logging
 from decimal import Decimal
-from candidate_binding import CandidateBinding
-from typing import Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast
+from typing import (Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast)
 
 from prometheus_client import Summary
 from rdflib import RDF, Literal, Namespace, URIRef
-from rdflib.graph import Graph
-from rdflib.term import BNode, Node, Variable
+from rdflib.term import Node, Variable
 
+from candidate_binding import CandidateBinding
 from inference_types import BindableTerm, Triple
+from stmt_chunk import Chunk, ChunkedGraph
 
 log = logging.getLogger('infer')
 
@@ -29,7 +28,7 @@
     return val
 
 
-def parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]:
+def parseList(graph: ChunkedGraph, subj: Node) -> Tuple[List[Node], Set[Triple]]:
     """"Do like Collection(g, subj) but also return all the 
     triples that are involved in the list"""
     out = []
@@ -63,9 +62,9 @@
     """any rule stmt that runs a function (not just a statement match)"""
     pred: URIRef
 
-    def __init__(self, stmt: Triple, ruleGraph: Graph):
-        self.stmt = stmt
-        if stmt[1] != self.pred:
+    def __init__(self, chunk: Chunk, ruleGraph: ChunkedGraph):
+        self.chunk = chunk
+        if chunk.predicate != self.pred:
             raise TypeError
         self.ruleGraph = ruleGraph
 
@@ -84,7 +83,7 @@
         raise NotImplementedError
 
     def valueInObjectTerm(self, value: Node) -> Optional[CandidateBinding]:
-        objVar = self.stmt[2]
+        objVar = self.chunk.primary[2]
         if not isinstance(objVar, Variable):
             raise TypeError(f'expected Variable, got {objVar!r}')
         return CandidateBinding({cast(BindableTerm, objVar): value})
@@ -93,31 +92,31 @@
         '''stmts in self.graph (not including self.stmt, oddly) that are part of
         this function setup and aren't to be matched literally'''
         return set()
-    
+
 
 class SubjectFunction(Function):
     """function that depends only on the subject term"""
 
     def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
-        return [existingBinding.applyTerm(self.stmt[0])]
+        return [existingBinding.applyTerm(self.chunk.primary[0])]
 
 
 class SubjectObjectFunction(Function):
     """a filter function that depends on the subject and object terms"""
 
     def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
-        return [existingBinding.applyTerm(self.stmt[0]), existingBinding.applyTerm(self.stmt[2])]
+        return [existingBinding.applyTerm(self.chunk.primary[0]), existingBinding.applyTerm(self.chunk.primary[2])]
 
 
 class ListFunction(Function):
     """function that takes an rdf list as input"""
 
     def usedStatements(self) -> Set[Triple]:
-        _, used = parseList(self.ruleGraph, self.stmt[0])
+        _, used = parseList(self.ruleGraph, self.chunk.primary[0])
         return used
 
     def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
-        operands, _ = parseList(self.ruleGraph, self.stmt[0])
+        operands, _ = parseList(self.ruleGraph, self.chunk.primary[0])
         return [existingBinding.applyTerm(x) for x in operands]
 
 
@@ -149,9 +148,12 @@
         f = Literal(sum(self.getNumericOperands(existingBinding)))
         return self.valueInObjectTerm(f)
 
+
 ### registration is done
 
 _byPred: Dict[URIRef, Type[Function]] = dict((cls.pred, cls) for cls in registeredFunctionTypes)
+
+
 def functionsFor(pred: URIRef) -> Iterator[Type[Function]]:
     try:
         yield _byPred[pred]
@@ -159,13 +161,13 @@
         return
 
 
-def lhsStmtsUsedByFuncs(graph: Graph) -> Set[Triple]:
-    usedByFuncs: Set[Triple] = set()  # don't worry about matching these
-    for s in graph:
-        for cls in functionsFor(pred=s[1]):
-            usedByFuncs.update(cls(s, graph).usedStatements())
-    return usedByFuncs
+# def lhsStmtsUsedByFuncs(graph: ChunkedGraph) -> Set[Chunk]:
+#     usedByFuncs: Set[Triple] = set()  # don't worry about matching these
+#     for s in graph:
+#         for cls in functionsFor(pred=s[1]):
+#             usedByFuncs.update(cls(s, graph).usedStatements())
+#     return usedByFuncs
 
 
 def rulePredicates() -> Set[URIRef]:
-    return set(c.pred for c in registeredFunctionTypes)
\ No newline at end of file
+    return set(c.pred for c in registeredFunctionTypes)