Coverage for rulekit/rules.py: 86%
152 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
1"""Contains classes representing rules and rulesets.
2"""
3from typing import Generic
4from typing import TypeVar
5from typing import Union
7import numpy as np
8from jpype import JObject
10from rulekit.kaplan_meier import KaplanMeierEstimator
11from rulekit.params import Measures
12from rulekit.stats import RuleSetStatistics
13from rulekit.stats import RuleStatistics
16class InductionParameters:
17 """Induction parameters."""
19 def __init__(self, java_object):
20 self._java_object = java_object
22 self.minimum_covered: float = self._java_object.getMinimumCovered()
23 self.maximum_uncovered_fraction: float = (
24 self._java_object.getMaximumUncoveredFraction()
25 )
26 self.ignore_missing: bool = self._java_object.isIgnoreMissing()
27 self.pruning_enabled: bool = self._java_object.isPruningEnabled()
28 self.max_growing_condition: float = self._java_object.getMaxGrowingConditions()
30 @property
31 def induction_measure(self) -> Union[Measures, str]:
32 """
33 Returns:
34 Union[Measures, str]: Measure used for induction
35 """
36 return InductionParameters._get_measure_str(
37 self._java_object.getInductionMeasure()
38 )
40 @property
41 def pruning_measure(self) -> Union[Measures, str]:
42 """
43 Returns:
44 Union[Measures, str]: Measure used for pruning
45 """
46 return InductionParameters._get_measure_str(
47 self._java_object.getPruningMeasure()
48 )
50 @property
51 def voting_measure(self) -> Union[Measures, str]:
52 """
53 Returns:
54 Union[Measures, str]: Measure used for voting
55 """
56 return InductionParameters._get_measure_str(
57 self._java_object.getVotingMeasure()
58 )
60 @staticmethod
61 def _get_measure_str(measure) -> Union[Measures, str]:
62 name: str = measure.getName()
63 if name == "UserDefined":
64 return "UserDefined"
65 return Measures[name]
67 def __str__(self):
68 return str(self._java_object.toString())
71class BaseRule:
72 """Base class representing single rule."""
74 def __init__(self, java_object):
75 """:meta private:"""
76 self._java_object = java_object
77 self._stats: RuleStatistics = None
79 @property
80 def weight(self) -> float:
81 """Rule weight"""
82 return self._java_object.getWeight()
84 @property
85 def weighted_p(self) -> float:
86 """Number of positives covered by the rule (accounting weights)."""
87 return self._java_object.getWeighted_p()
89 @property
90 def weighted_n(self) -> float:
91 """Number of negatives covered by the rule (accounting weights)."""
92 return self._java_object.getWeighted_n()
94 @property
95 def weighted_P(self) -> float: # pylint: disable=invalid-name
96 """Number of positives in the training set (accounting weights)."""
97 return self._java_object.getWeighted_P()
99 @property
100 def weighted_N(self) -> float: # pylint: disable=invalid-name
101 """Number of negatives in the training set (accounting weights)."""
102 return self._java_object.getWeighted_N()
104 @property
105 def pvalue(self) -> float:
106 """Rule significance."""
107 return self._java_object.getPValue()
109 @property
110 def stats(self) -> RuleStatistics:
111 """Rule statistics."""
112 if self._stats is None:
113 self._stats = RuleStatistics(self)
114 return self._stats
116 def get_covering_information(self) -> dict:
117 """Returns information about rule covering
119 Returns
120 -------
121 covering_data : dict
122 Dictionary containing covering information.
123 """
124 return {
125 "weighted_n": self.weighted_n,
126 "weighted_p": self.weighted_p,
127 "weighted_N": self.weighted_N,
128 "weighted_P": self.weighted_P,
129 }
131 def print_stats(self):
132 """Prints rule statistics as formatted text."""
133 print(self.stats)
135 def __str__(self):
136 """Returns string representation of the rule."""
137 return str(self._java_object.toString())
140class ClassificationRule(BaseRule):
141 """Class representing classification rule"""
143 def __init__(self, java_object):
144 super().__init__(java_object)
146 self._decision_class: str = str(self._java_object.getClassLabel())
148 @property
149 def decision_class(self) -> str:
150 """Decision class of the rule"""
151 return self._decision_class
154class RegressionRule(BaseRule):
155 """Class representing regression rule"""
157 def __init__(self, java_object):
158 super().__init__(java_object)
160 self._conclusion_value: str = float(self._java_object.getConsequenceValue())
162 @property
163 def conclusion_value(self) -> float:
164 """Value from the rule's conclusion"""
165 return self._conclusion_value
168class SurvivalRule(BaseRule):
169 """Class representing survival rule"""
171 def __init__(self, java_object):
172 super().__init__(java_object)
174 self._kaplan_meier_estimator: KaplanMeierEstimator = KaplanMeierEstimator(
175 java_object.getEstimator()
176 )
178 @property
179 def kaplan_meier_estimator(self) -> KaplanMeierEstimator:
180 """Kaplan-Meier estimator from the rule concslusion"""
181 return self._kaplan_meier_estimator
184def _rule_factory(java_object: JObject) -> BaseRule:
185 class_name: str = str(java_object.getClass().getName()).lower()
186 if "regression" in class_name:
187 return RegressionRule(java_object)
188 elif "survival" in class_name:
189 return SurvivalRule(java_object)
190 return ClassificationRule(java_object)
193T = TypeVar("T")
196class RuleSet(Generic[T]):
197 """Class representing ruleset."""
199 def __init__(self, java_object):
200 """:meta private:"""
201 self._java_object = java_object
202 self._stats: RuleSetStatistics = None
204 @property
205 def total_time(self) -> float:
206 """Time of constructing the rule set in seconds"""
207 return self._java_object.getTotalTime()
209 @property
210 def growing_time(self) -> float:
211 """Time of growing in seconds"""
212 return self._java_object.getGrowingTime()
214 @property
215 def pruning_time(self) -> float:
216 """Time of pruning in seconds"""
217 return self._java_object.getPruningTime()
219 @property
220 def is_voting(self) -> bool:
221 """Value indicating whether rules are voting."""
222 return self._java_object.getIsVoting()
224 @property
225 def parameters(self) -> object:
226 """Parameters used during rule set induction."""
227 return InductionParameters(self._java_object.getParams())
229 @property
230 def stats(self) -> RuleSetStatistics:
231 """Rule set statistics."""
232 if self._stats is None:
233 self._stats = RuleSetStatistics(self)
234 return self._stats
236 def covering(self, example_set) -> np.ndarray:
237 """:meta private:"""
238 res = []
239 for rule in self.rules:
240 covering_info = (
241 rule._java_object.coversUnlabelled( # pylint: disable=protected-access
242 example_set
243 )
244 )
245 covered_examples_indexes = []
246 covered_examples_indexes += covering_info
247 res.append(covered_examples_indexes)
248 return np.array(res, dtype=object)
250 @property
251 def rules(self) -> list[T]:
252 """List of rules objects."""
253 return [_rule_factory(java_rule) for java_rule in self._java_object.getRules()]
255 def calculate_conditions_count(self) -> float:
256 """
257 Returns
258 -------
259 count: float
260 Number of conditions.
261 """
262 return self._java_object.calculateConditionsCount()
264 def calculate_induced_conditions_count(self) -> float:
265 """
266 Returns
267 -------
268 count: float
269 Number of induced conditions.
270 """
271 return self._java_object.calculateInducedCondtionsCount()
273 def calculate_avg_rule_coverage(self) -> float:
274 """
275 Returns
276 -------
277 count: float
278 Average rule coverage.
279 """
280 return self._java_object.calculateAvgRuleCoverage()
282 def calculate_avg_rule_precision(self) -> float:
283 """
284 Returns
285 -------
286 count: float
287 Average rule precision.
288 """
289 return self._java_object.calculateAvgRulePrecision()
291 def calculate_avg_rule_quality(self) -> float:
292 """
293 Returns
294 -------
295 count: float
296 Average rule quality.
297 """
298 return self._java_object.calculateAvgRuleQuality()
300 def calculate_significance(self, alpha: float) -> dict:
301 """
302 Parameters
303 ----------
304 alpha : float
306 Returns
307 -------
308 count: float
309 Significance of the rule set.
310 """
311 significance = self._java_object.calculateSignificance(alpha)
312 return {"p": significance.p, "fraction": significance.fraction}
314 def calculate_significance_fdr(self, alpha: float) -> dict:
315 """
316 Returns
317 -------
318 count: dict
319 Significance of the rule set with false discovery rate correction.
320 Dictionary contains two fields: *fraction* (fraction of rules significant
321 at assumed level) and *p* (average p-value of all rules).
322 """
323 significance = self._java_object.calculateSignificanceFDR(alpha)
324 return {"p": significance.p, "fraction": significance.fraction}
326 def calculate_significance_fwer(self, alpha: float) -> dict:
327 """
328 Returns
329 -------
330 count: dict
331 Significance of the rule set with familiy-wise error rate correction.
332 Dictionary contains two fields: *fraction* (fraction of rules significant
333 at assumed level) and *p* (average p-value of all rules).
334 """
335 significance = self._java_object.calculateSignificanceFWER(alpha)
336 return {"p": significance.p, "fraction": significance.fraction}
338 def __str__(self):
339 """Returns string representation of the object."""
340 return str(self._java_object.toString())