Coverage for rulekit/rules.py: 86%

152 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-07 11:26 +0000

1"""Contains classes representing rules and rulesets. 

2""" 

3from typing import Generic 

4from typing import TypeVar 

5from typing import Union 

6 

7import numpy as np 

8from jpype import JObject 

9 

10from rulekit.kaplan_meier import KaplanMeierEstimator 

11from rulekit.params import Measures 

12from rulekit.stats import RuleSetStatistics 

13from rulekit.stats import RuleStatistics 

14 

15 

16class InductionParameters: 

17 """Induction parameters.""" 

18 

19 def __init__(self, java_object): 

20 self._java_object = java_object 

21 

22 self.minimum_covered: float = self._java_object.getMinimumCovered() 

23 self.maximum_uncovered_fraction: float = ( 

24 self._java_object.getMaximumUncoveredFraction() 

25 ) 

26 self.ignore_missing: bool = self._java_object.isIgnoreMissing() 

27 self.pruning_enabled: bool = self._java_object.isPruningEnabled() 

28 self.max_growing_condition: float = self._java_object.getMaxGrowingConditions() 

29 

30 @property 

31 def induction_measure(self) -> Union[Measures, str]: 

32 """ 

33 Returns: 

34 Union[Measures, str]: Measure used for induction 

35 """ 

36 return InductionParameters._get_measure_str( 

37 self._java_object.getInductionMeasure() 

38 ) 

39 

40 @property 

41 def pruning_measure(self) -> Union[Measures, str]: 

42 """ 

43 Returns: 

44 Union[Measures, str]: Measure used for pruning 

45 """ 

46 return InductionParameters._get_measure_str( 

47 self._java_object.getPruningMeasure() 

48 ) 

49 

50 @property 

51 def voting_measure(self) -> Union[Measures, str]: 

52 """ 

53 Returns: 

54 Union[Measures, str]: Measure used for voting 

55 """ 

56 return InductionParameters._get_measure_str( 

57 self._java_object.getVotingMeasure() 

58 ) 

59 

60 @staticmethod 

61 def _get_measure_str(measure) -> Union[Measures, str]: 

62 name: str = measure.getName() 

63 if name == "UserDefined": 

64 return "UserDefined" 

65 return Measures[name] 

66 

67 def __str__(self): 

68 return str(self._java_object.toString()) 

69 

70 

71class BaseRule: 

72 """Base class representing single rule.""" 

73 

74 def __init__(self, java_object): 

75 """:meta private:""" 

76 self._java_object = java_object 

77 self._stats: RuleStatistics = None 

78 

79 @property 

80 def weight(self) -> float: 

81 """Rule weight""" 

82 return self._java_object.getWeight() 

83 

84 @property 

85 def weighted_p(self) -> float: 

86 """Number of positives covered by the rule (accounting weights).""" 

87 return self._java_object.getWeighted_p() 

88 

89 @property 

90 def weighted_n(self) -> float: 

91 """Number of negatives covered by the rule (accounting weights).""" 

92 return self._java_object.getWeighted_n() 

93 

94 @property 

95 def weighted_P(self) -> float: # pylint: disable=invalid-name 

96 """Number of positives in the training set (accounting weights).""" 

97 return self._java_object.getWeighted_P() 

98 

99 @property 

100 def weighted_N(self) -> float: # pylint: disable=invalid-name 

101 """Number of negatives in the training set (accounting weights).""" 

102 return self._java_object.getWeighted_N() 

103 

104 @property 

105 def pvalue(self) -> float: 

106 """Rule significance.""" 

107 return self._java_object.getPValue() 

108 

109 @property 

110 def stats(self) -> RuleStatistics: 

111 """Rule statistics.""" 

112 if self._stats is None: 

113 self._stats = RuleStatistics(self) 

114 return self._stats 

115 

116 def get_covering_information(self) -> dict: 

117 """Returns information about rule covering 

118 

119 Returns 

120 ------- 

121 covering_data : dict 

122 Dictionary containing covering information. 

123 """ 

124 return { 

125 "weighted_n": self.weighted_n, 

126 "weighted_p": self.weighted_p, 

127 "weighted_N": self.weighted_N, 

128 "weighted_P": self.weighted_P, 

129 } 

130 

131 def print_stats(self): 

132 """Prints rule statistics as formatted text.""" 

133 print(self.stats) 

134 

135 def __str__(self): 

136 """Returns string representation of the rule.""" 

137 return str(self._java_object.toString()) 

138 

139 

140class ClassificationRule(BaseRule): 

141 """Class representing classification rule""" 

142 

143 def __init__(self, java_object): 

144 super().__init__(java_object) 

145 

146 self._decision_class: str = str(self._java_object.getClassLabel()) 

147 

148 @property 

149 def decision_class(self) -> str: 

150 """Decision class of the rule""" 

151 return self._decision_class 

152 

153 

154class RegressionRule(BaseRule): 

155 """Class representing regression rule""" 

156 

157 def __init__(self, java_object): 

158 super().__init__(java_object) 

159 

160 self._conclusion_value: str = float(self._java_object.getConsequenceValue()) 

161 

162 @property 

163 def conclusion_value(self) -> float: 

164 """Value from the rule's conclusion""" 

165 return self._conclusion_value 

166 

167 

168class SurvivalRule(BaseRule): 

169 """Class representing survival rule""" 

170 

171 def __init__(self, java_object): 

172 super().__init__(java_object) 

173 

174 self._kaplan_meier_estimator: KaplanMeierEstimator = KaplanMeierEstimator( 

175 java_object.getEstimator() 

176 ) 

177 

178 @property 

179 def kaplan_meier_estimator(self) -> KaplanMeierEstimator: 

180 """Kaplan-Meier estimator from the rule concslusion""" 

181 return self._kaplan_meier_estimator 

182 

183 

184def _rule_factory(java_object: JObject) -> BaseRule: 

185 class_name: str = str(java_object.getClass().getName()).lower() 

186 if "regression" in class_name: 

187 return RegressionRule(java_object) 

188 elif "survival" in class_name: 

189 return SurvivalRule(java_object) 

190 return ClassificationRule(java_object) 

191 

192 

193T = TypeVar("T") 

194 

195 

196class RuleSet(Generic[T]): 

197 """Class representing ruleset.""" 

198 

199 def __init__(self, java_object): 

200 """:meta private:""" 

201 self._java_object = java_object 

202 self._stats: RuleSetStatistics = None 

203 

204 @property 

205 def total_time(self) -> float: 

206 """Time of constructing the rule set in seconds""" 

207 return self._java_object.getTotalTime() 

208 

209 @property 

210 def growing_time(self) -> float: 

211 """Time of growing in seconds""" 

212 return self._java_object.getGrowingTime() 

213 

214 @property 

215 def pruning_time(self) -> float: 

216 """Time of pruning in seconds""" 

217 return self._java_object.getPruningTime() 

218 

219 @property 

220 def is_voting(self) -> bool: 

221 """Value indicating whether rules are voting.""" 

222 return self._java_object.getIsVoting() 

223 

224 @property 

225 def parameters(self) -> object: 

226 """Parameters used during rule set induction.""" 

227 return InductionParameters(self._java_object.getParams()) 

228 

229 @property 

230 def stats(self) -> RuleSetStatistics: 

231 """Rule set statistics.""" 

232 if self._stats is None: 

233 self._stats = RuleSetStatistics(self) 

234 return self._stats 

235 

236 def covering(self, example_set) -> np.ndarray: 

237 """:meta private:""" 

238 res = [] 

239 for rule in self.rules: 

240 covering_info = ( 

241 rule._java_object.coversUnlabelled( # pylint: disable=protected-access 

242 example_set 

243 ) 

244 ) 

245 covered_examples_indexes = [] 

246 covered_examples_indexes += covering_info 

247 res.append(covered_examples_indexes) 

248 return np.array(res, dtype=object) 

249 

250 @property 

251 def rules(self) -> list[T]: 

252 """List of rules objects.""" 

253 return [_rule_factory(java_rule) for java_rule in self._java_object.getRules()] 

254 

255 def calculate_conditions_count(self) -> float: 

256 """ 

257 Returns 

258 ------- 

259 count: float 

260 Number of conditions. 

261 """ 

262 return self._java_object.calculateConditionsCount() 

263 

264 def calculate_induced_conditions_count(self) -> float: 

265 """ 

266 Returns 

267 ------- 

268 count: float 

269 Number of induced conditions. 

270 """ 

271 return self._java_object.calculateInducedCondtionsCount() 

272 

273 def calculate_avg_rule_coverage(self) -> float: 

274 """ 

275 Returns 

276 ------- 

277 count: float 

278 Average rule coverage. 

279 """ 

280 return self._java_object.calculateAvgRuleCoverage() 

281 

282 def calculate_avg_rule_precision(self) -> float: 

283 """ 

284 Returns 

285 ------- 

286 count: float 

287 Average rule precision. 

288 """ 

289 return self._java_object.calculateAvgRulePrecision() 

290 

291 def calculate_avg_rule_quality(self) -> float: 

292 """ 

293 Returns 

294 ------- 

295 count: float 

296 Average rule quality. 

297 """ 

298 return self._java_object.calculateAvgRuleQuality() 

299 

300 def calculate_significance(self, alpha: float) -> dict: 

301 """ 

302 Parameters 

303 ---------- 

304 alpha : float 

305 

306 Returns 

307 ------- 

308 count: float 

309 Significance of the rule set. 

310 """ 

311 significance = self._java_object.calculateSignificance(alpha) 

312 return {"p": significance.p, "fraction": significance.fraction} 

313 

314 def calculate_significance_fdr(self, alpha: float) -> dict: 

315 """ 

316 Returns 

317 ------- 

318 count: dict 

319 Significance of the rule set with false discovery rate correction. 

320 Dictionary contains two fields: *fraction* (fraction of rules significant 

321 at assumed level) and *p* (average p-value of all rules). 

322 """ 

323 significance = self._java_object.calculateSignificanceFDR(alpha) 

324 return {"p": significance.p, "fraction": significance.fraction} 

325 

326 def calculate_significance_fwer(self, alpha: float) -> dict: 

327 """ 

328 Returns 

329 ------- 

330 count: dict 

331 Significance of the rule set with familiy-wise error rate correction. 

332 Dictionary contains two fields: *fraction* (fraction of rules significant 

333 at assumed level) and *p* (average p-value of all rules). 

334 """ 

335 significance = self._java_object.calculateSignificanceFWER(alpha) 

336 return {"p": significance.p, "fraction": significance.fraction} 

337 

338 def __str__(self): 

339 """Returns string representation of the object.""" 

340 return str(self._java_object.toString())