comparison service/mqtt_to_rdf/lhs_evaluation.py @ 1658:7ec2483d61b5

refactor inference_functions
author drewp@bigasterisk.com
date Sun, 19 Sep 2021 13:33:10 -0700
parents 20474ad4968e
children 31f7dab6a60b
comparison
equal deleted inserted replaced
1657:274bb6c04627 1658:7ec2483d61b5
17 ROOM = Namespace("http://projects.bigasterisk.com/room/") 17 ROOM = Namespace("http://projects.bigasterisk.com/room/")
18 LOG = Namespace('http://www.w3.org/2000/10/swap/log#') 18 LOG = Namespace('http://www.w3.org/2000/10/swap/log#')
19 MATH = Namespace('http://www.w3.org/2000/10/swap/math#') 19 MATH = Namespace('http://www.w3.org/2000/10/swap/math#')
20 20
21 21
22 def numericNode(n: Node): 22 def _numericNode(n: Node):
23 if not isinstance(n, Literal): 23 if not isinstance(n, Literal):
24 raise TypeError(f'expected Literal, got {n=}') 24 raise TypeError(f'expected Literal, got {n=}')
25 val = n.toPython() 25 val = n.toPython()
26 if not isinstance(val, (int, float, Decimal)): 26 if not isinstance(val, (int, float, Decimal)):
27 raise TypeError(f'expected number, got {val=}') 27 raise TypeError(f'expected number, got {val=}')
28 return val 28 return val
29 29
30 30
31 def parseList(graph: ChunkedGraph, subj: Node) -> Tuple[List[Node], Set[Triple]]: 31 def _parseList(graph: ChunkedGraph, subj: Node) -> Tuple[List[Node], Set[Triple]]:
32 """"Do like Collection(g, subj) but also return all the 32 """"Do like Collection(g, subj) but also return all the
33 triples that are involved in the list""" 33 triples that are involved in the list"""
34 out = [] 34 out = []
35 used = set() 35 used = set()
36 cur = subj 36 cur = subj
48 48
49 cur = next 49 cur = next
50 return out, used 50 return out, used
51 51
52 52
53 registeredFunctionTypes: List[Type['Function']] = [] 53 _registeredFunctionTypes: List[Type['Function']] = []
54 54
55 55
56 def register(cls: Type['Function']): 56 def register(cls: Type['Function']):
57 registeredFunctionTypes.append(cls) 57 _registeredFunctionTypes.append(cls)
58 return cls 58 return cls
59 59
60 60
61 class Function: 61 class Function:
62 """any rule stmt that runs a function (not just a statement match)""" 62 """any rule stmt that runs a function (not just a statement match)"""
72 raise NotImplementedError 72 raise NotImplementedError
73 73
74 def getNumericOperands(self, existingBinding: CandidateBinding) -> List[Union[int, float, Decimal]]: 74 def getNumericOperands(self, existingBinding: CandidateBinding) -> List[Union[int, float, Decimal]]:
75 out = [] 75 out = []
76 for op in self.getOperandNodes(existingBinding): 76 for op in self.getOperandNodes(existingBinding):
77 out.append(numericNode(op)) 77 out.append(_numericNode(op))
78 78
79 return out 79 return out
80 80
81 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: 81 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
82 """either any new bindings this function makes (could be 0), or None if it doesn't match""" 82 """either any new bindings this function makes (could be 0), or None if it doesn't match"""
110 110
111 class ListFunction(Function): 111 class ListFunction(Function):
112 """function that takes an rdf list as input""" 112 """function that takes an rdf list as input"""
113 113
114 def usedStatements(self) -> Set[Triple]: 114 def usedStatements(self) -> Set[Triple]:
115 _, used = parseList(self.ruleGraph, self.chunk.primary[0]) 115 _, used = _parseList(self.ruleGraph, self.chunk.primary[0])
116 return used 116 return used
117 117
118 def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: 118 def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
119 operands, _ = parseList(self.ruleGraph, self.chunk.primary[0]) 119 operands, _ = _parseList(self.ruleGraph, self.chunk.primary[0])
120 return [existingBinding.applyTerm(x) for x in operands] 120 return [existingBinding.applyTerm(x) for x in operands]
121 121
122 import inference_functions # calls register() on some classes
122 123
123 @register 124 _byPred: Dict[URIRef, Type[Function]] = dict((cls.pred, cls) for cls in _registeredFunctionTypes)
124 class Gt(SubjectObjectFunction):
125 pred = MATH['greaterThan']
126
127 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
128 [x, y] = self.getNumericOperands(existingBinding)
129 if x > y:
130 return CandidateBinding({}) # no new values; just allow matching to keep going
131
132
133 @register
134 class AsFarenheit(SubjectFunction):
135 pred = ROOM['asFarenheit']
136
137 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
138 [x] = self.getNumericOperands(existingBinding)
139 f = cast(Literal, Literal(Decimal(x) * 9 / 5 + 32))
140 return self.valueInObjectTerm(f)
141
142
143 @register
144 class Sum(ListFunction):
145 pred = MATH['sum']
146
147 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
148 f = Literal(sum(self.getNumericOperands(existingBinding)))
149 return self.valueInObjectTerm(f)
150
151
152 ### registration is done
153
154 _byPred: Dict[URIRef, Type[Function]] = dict((cls.pred, cls) for cls in registeredFunctionTypes)
155 125
156 126
157 def functionsFor(pred: URIRef) -> Iterator[Type[Function]]: 127 def functionsFor(pred: URIRef) -> Iterator[Type[Function]]:
158 try: 128 try:
159 yield _byPred[pred] 129 yield _byPred[pred]
168 # usedByFuncs.update(cls(s, graph).usedStatements()) 138 # usedByFuncs.update(cls(s, graph).usedStatements())
169 # return usedByFuncs 139 # return usedByFuncs
170 140
171 141
172 def rulePredicates() -> Set[URIRef]: 142 def rulePredicates() -> Set[URIRef]:
173 return set(c.pred for c in registeredFunctionTypes) 143 return set(c.pred for c in _registeredFunctionTypes)