Mercurial > code > home > repos > homeauto
comparison service/mqtt_to_rdf/lhs_evaluation.py @ 1658:7ec2483d61b5
refactor inference_functions
author | drewp@bigasterisk.com |
---|---|
date | Sun, 19 Sep 2021 13:33:10 -0700 |
parents | 20474ad4968e |
children | 31f7dab6a60b |
comparison
equal
deleted
inserted
replaced
1657:274bb6c04627 | 1658:7ec2483d61b5 |
---|---|
17 ROOM = Namespace("http://projects.bigasterisk.com/room/") | 17 ROOM = Namespace("http://projects.bigasterisk.com/room/") |
18 LOG = Namespace('http://www.w3.org/2000/10/swap/log#') | 18 LOG = Namespace('http://www.w3.org/2000/10/swap/log#') |
19 MATH = Namespace('http://www.w3.org/2000/10/swap/math#') | 19 MATH = Namespace('http://www.w3.org/2000/10/swap/math#') |
20 | 20 |
21 | 21 |
22 def numericNode(n: Node): | 22 def _numericNode(n: Node): |
23 if not isinstance(n, Literal): | 23 if not isinstance(n, Literal): |
24 raise TypeError(f'expected Literal, got {n=}') | 24 raise TypeError(f'expected Literal, got {n=}') |
25 val = n.toPython() | 25 val = n.toPython() |
26 if not isinstance(val, (int, float, Decimal)): | 26 if not isinstance(val, (int, float, Decimal)): |
27 raise TypeError(f'expected number, got {val=}') | 27 raise TypeError(f'expected number, got {val=}') |
28 return val | 28 return val |
29 | 29 |
30 | 30 |
31 def parseList(graph: ChunkedGraph, subj: Node) -> Tuple[List[Node], Set[Triple]]: | 31 def _parseList(graph: ChunkedGraph, subj: Node) -> Tuple[List[Node], Set[Triple]]: |
32 """"Do like Collection(g, subj) but also return all the | 32 """"Do like Collection(g, subj) but also return all the |
33 triples that are involved in the list""" | 33 triples that are involved in the list""" |
34 out = [] | 34 out = [] |
35 used = set() | 35 used = set() |
36 cur = subj | 36 cur = subj |
48 | 48 |
49 cur = next | 49 cur = next |
50 return out, used | 50 return out, used |
51 | 51 |
52 | 52 |
53 registeredFunctionTypes: List[Type['Function']] = [] | 53 _registeredFunctionTypes: List[Type['Function']] = [] |
54 | 54 |
55 | 55 |
56 def register(cls: Type['Function']): | 56 def register(cls: Type['Function']): |
57 registeredFunctionTypes.append(cls) | 57 _registeredFunctionTypes.append(cls) |
58 return cls | 58 return cls |
59 | 59 |
60 | 60 |
61 class Function: | 61 class Function: |
62 """any rule stmt that runs a function (not just a statement match)""" | 62 """any rule stmt that runs a function (not just a statement match)""" |
72 raise NotImplementedError | 72 raise NotImplementedError |
73 | 73 |
74 def getNumericOperands(self, existingBinding: CandidateBinding) -> List[Union[int, float, Decimal]]: | 74 def getNumericOperands(self, existingBinding: CandidateBinding) -> List[Union[int, float, Decimal]]: |
75 out = [] | 75 out = [] |
76 for op in self.getOperandNodes(existingBinding): | 76 for op in self.getOperandNodes(existingBinding): |
77 out.append(numericNode(op)) | 77 out.append(_numericNode(op)) |
78 | 78 |
79 return out | 79 return out |
80 | 80 |
81 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: | 81 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: |
82 """either any new bindings this function makes (could be 0), or None if it doesn't match""" | 82 """either any new bindings this function makes (could be 0), or None if it doesn't match""" |
110 | 110 |
111 class ListFunction(Function): | 111 class ListFunction(Function): |
112 """function that takes an rdf list as input""" | 112 """function that takes an rdf list as input""" |
113 | 113 |
114 def usedStatements(self) -> Set[Triple]: | 114 def usedStatements(self) -> Set[Triple]: |
115 _, used = parseList(self.ruleGraph, self.chunk.primary[0]) | 115 _, used = _parseList(self.ruleGraph, self.chunk.primary[0]) |
116 return used | 116 return used |
117 | 117 |
118 def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: | 118 def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: |
119 operands, _ = parseList(self.ruleGraph, self.chunk.primary[0]) | 119 operands, _ = _parseList(self.ruleGraph, self.chunk.primary[0]) |
120 return [existingBinding.applyTerm(x) for x in operands] | 120 return [existingBinding.applyTerm(x) for x in operands] |
121 | 121 |
122 import inference_functions # calls register() on some classes | |
122 | 123 |
123 @register | 124 _byPred: Dict[URIRef, Type[Function]] = dict((cls.pred, cls) for cls in _registeredFunctionTypes) |
124 class Gt(SubjectObjectFunction): | |
125 pred = MATH['greaterThan'] | |
126 | |
127 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: | |
128 [x, y] = self.getNumericOperands(existingBinding) | |
129 if x > y: | |
130 return CandidateBinding({}) # no new values; just allow matching to keep going | |
131 | |
132 | |
133 @register | |
134 class AsFarenheit(SubjectFunction): | |
135 pred = ROOM['asFarenheit'] | |
136 | |
137 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: | |
138 [x] = self.getNumericOperands(existingBinding) | |
139 f = cast(Literal, Literal(Decimal(x) * 9 / 5 + 32)) | |
140 return self.valueInObjectTerm(f) | |
141 | |
142 | |
143 @register | |
144 class Sum(ListFunction): | |
145 pred = MATH['sum'] | |
146 | |
147 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: | |
148 f = Literal(sum(self.getNumericOperands(existingBinding))) | |
149 return self.valueInObjectTerm(f) | |
150 | |
151 | |
152 ### registration is done | |
153 | |
154 _byPred: Dict[URIRef, Type[Function]] = dict((cls.pred, cls) for cls in registeredFunctionTypes) | |
155 | 125 |
156 | 126 |
157 def functionsFor(pred: URIRef) -> Iterator[Type[Function]]: | 127 def functionsFor(pred: URIRef) -> Iterator[Type[Function]]: |
158 try: | 128 try: |
159 yield _byPred[pred] | 129 yield _byPred[pred] |
168 # usedByFuncs.update(cls(s, graph).usedStatements()) | 138 # usedByFuncs.update(cls(s, graph).usedStatements()) |
169 # return usedByFuncs | 139 # return usedByFuncs |
170 | 140 |
171 | 141 |
172 def rulePredicates() -> Set[URIRef]: | 142 def rulePredicates() -> Set[URIRef]: |
173 return set(c.pred for c in registeredFunctionTypes) | 143 return set(c.pred for c in _registeredFunctionTypes) |