Coverage for tests/test_survival.py: 99%

144 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-07 11:26 +0000

1import os 

2import threading 

3import unittest 

4 

5import numpy as np 

6import pandas as pd 

7 

8from rulekit import survival 

9from rulekit.arff import read_arff 

10from rulekit.events import RuleInductionProgressListener 

11from rulekit.kaplan_meier import KaplanMeierEstimator 

12from rulekit.main import RuleKit 

13from rulekit.rules import SurvivalRule 

14from tests.utils import assert_rules_are_equals 

15from tests.utils import dir_path 

16from tests.utils import get_test_cases 

17 

18 

19class TestKaplanMeierEstimator(unittest.TestCase): 

20 

21 survival_rules: survival.SurvivalRules 

22 

23 def setUp(self): 

24 test_case = get_test_cases("SurvivalLogRankSnCTest")[0] 

25 

26 self.survival_rules = survival.SurvivalRules( 

27 survival_time_attr=test_case.survival_time 

28 ) 

29 example_set = test_case.example_set 

30 self.survival_rules.fit( 

31 example_set.values, 

32 example_set.labels, 

33 ) 

34 self.km: KaplanMeierEstimator = self.survival_rules.model.rules[ 

35 0 

36 ].kaplan_meier_estimator 

37 

38 def test_accessing_probabilities(self): 

39 self.assertTrue( 

40 all([p >= 0.0 and p <= 1.0 for p in self.km.probabilities]), 

41 "All probabilities should be in range [0, 1]", 

42 ) 

43 self.assertTrue( 

44 all([isinstance(p, float) for p in self.km.probabilities]), 

45 "All probabilities should be Pythonic floats", 

46 ) 

47 

48 def test_accessing_events_count(self): 

49 self.assertTrue( 

50 all([isinstance(p, np.int_) for p in self.km.events_count]), 

51 "All event counts should be Pythonic integers", 

52 ) 

53 

54 def test_accessing_at_risk_count(self): 

55 self.assertTrue( 

56 all([isinstance(p, np.int_) for p in self.km.at_risk_count]), 

57 "All risk count should be Pythonic integers", 

58 ) 

59 

60 

61class TestSurvivalRules(unittest.TestCase): 

62 

63 def test_induction_progress_listener(self): 

64 test_case = get_test_cases("SurvivalLogRankSnCTest")[0] 

65 

66 surv = survival.SurvivalRules(survival_time_attr=test_case.survival_time) 

67 example_set = test_case.example_set 

68 

69 class EventListener(RuleInductionProgressListener): 

70 

71 lock = threading.Lock() 

72 induced_rules_count = 0 

73 on_progress_calls_count = 0 

74 

75 def on_new_rule(self, rule: SurvivalRule): 

76 self.lock.acquire() 

77 self.induced_rules_count += 1 

78 self.lock.release() 

79 

80 def on_progress( 

81 self, total_examples_count: int, uncovered_examples_count: int 

82 ): 

83 self.lock.acquire() 

84 self.on_progress_calls_count += 1 

85 self.lock.release() 

86 

87 listener = EventListener() 

88 surv.add_event_listener(listener) 

89 surv.fit( 

90 example_set.values, 

91 example_set.labels, 

92 ) 

93 rules_count = len(surv.model.rules) 

94 self.assertEqual(rules_count, listener.induced_rules_count) 

95 self.assertEqual(rules_count, listener.on_progress_calls_count) 

96 

97 def test_compare_with_java_results(self): 

98 test_cases = get_test_cases("SurvivalLogRankSnCTest") 

99 

100 for test_case in test_cases: 

101 params = test_case.induction_params 

102 tree = survival.SurvivalRules( 

103 **params, survival_time_attr=test_case.survival_time 

104 ) 

105 example_set = test_case.example_set 

106 tree.fit(example_set.values, example_set.labels) 

107 model = tree.model 

108 expected = test_case.reference_report.rules 

109 actual = list(map(str, model.rules)) 

110 assert_rules_are_equals(expected, actual) 

111 

112 def test_fit_and_predict_on_boolean_columns(self): 

113 test_case = get_test_cases("SurvivalLogRankSnCTest")[0] 

114 params = test_case.induction_params 

115 clf = survival.SurvivalRules( 

116 **params, survival_time_attr=test_case.survival_time 

117 ) 

118 X, y = test_case.example_set.values, test_case.example_set.labels 

119 X["boolean_column"] = np.random.randint(low=0, high=2, size=X.shape[0]).astype( 

120 bool 

121 ) 

122 clf.fit(X, y) 

123 clf.predict(X) 

124 

125 y = pd.Series(y) 

126 clf.fit(X, y) 

127 clf.predict(X) 

128 

129 def test_passing_survival_time_column_to_fit_method(self): 

130 test_case = get_test_cases("SurvivalLogRankSnCTest")[0] 

131 params = test_case.induction_params 

132 surv1 = survival.SurvivalRules(**params) 

133 surv2 = survival.SurvivalRules( 

134 **params, survival_time_attr=test_case.survival_time 

135 ) 

136 X, y = test_case.example_set.values, test_case.example_set.labels 

137 survival_time_col: pd.Series = X[test_case.survival_time] 

138 X_without_time_col: pd.DataFrame = X.drop( 

139 columns=[test_case.survival_time], axis=1 

140 ) 

141 surv1.fit(X_without_time_col, y, survival_time=survival_time_col) 

142 surv2.fit(X, y) 

143 

144 assert_rules_are_equals( 

145 [str(r) for r in surv1.model.rules], 

146 [str(r) for r in surv2.model.rules], 

147 ) 

148 

149 def test_ibs_calculation(self): 

150 test_case = get_test_cases("SurvivalLogRankSnCTest")[0] 

151 params = test_case.induction_params 

152 surv = survival.SurvivalRules( 

153 **params, survival_time_attr=test_case.survival_time 

154 ) 

155 X, y = test_case.example_set.values, test_case.example_set.labels 

156 survival_time_col: pd.Series = X[test_case.survival_time] 

157 X_without_time_col: pd.DataFrame = X.drop( 

158 columns=[test_case.survival_time], axis=1 

159 ) 

160 surv.fit(X, y) 

161 

162 ibs: float = surv.score(X, y) 

163 ibs2: float = surv.score(X_without_time_col, y, survival_time=survival_time_col) 

164 

165 self.assertEqual(ibs, ibs2) 

166 

167 def test_getting_training_dataset_kaplan_meier_estimator(self): 

168 test_case = get_test_cases("SurvivalLogRankSnCTest")[0] 

169 params = test_case.induction_params 

170 surv = survival.SurvivalRules( 

171 **params, survival_time_attr=test_case.survival_time 

172 ) 

173 example_set = test_case.example_set 

174 unique_times_counts: int = example_set.values[test_case.survival_time].nunique() 

175 surv.fit(example_set.values, example_set.labels) 

176 

177 training_km: KaplanMeierEstimator = surv.get_train_set_kaplan_meier() 

178 self.assertTrue( 

179 training_km is not None and isinstance(training_km, KaplanMeierEstimator), 

180 "Should return KaplanMeierEstimator instance fitted on whole training set", 

181 ) 

182 self.assertEqual( 

183 len(training_km.times), 

184 unique_times_counts, 

185 "Estimator should contain probabilities for each unique time from the dataset", 

186 ) 

187 

188 def test_max_rule_count(self): 

189 MAX_RULE_COUNT = 3 

190 df: pd.DataFrame = read_arff( 

191 os.path.join(dir_path, "resources", "data", "bmt-train-0.arff") 

192 ) 

193 X, y = df.drop("survival_status", axis=1), df["survival_status"] 

194 clf = survival.SurvivalRules( 

195 survival_time_attr="survival_time", 

196 max_rule_count=MAX_RULE_COUNT, 

197 ) 

198 clf.fit(X, y) 

199 self.assertLessEqual( 

200 len(clf.model.rules), 

201 MAX_RULE_COUNT, 

202 f"Ruleset should contain no more than {MAX_RULE_COUNT} rules according to max_rule_count parameter", 

203 ) 

204 

205 

206class TestExpertSurvivalRules(unittest.TestCase): 

207 

208 @classmethod 

209 def setUpClass(cls): 

210 RuleKit.init() 

211 

212 def test_compare_with_java_results(self): 

213 test_cases = get_test_cases("SurvivalLogRankExpertSnCTest") 

214 

215 for test_case in test_cases: 

216 params = test_case.induction_params 

217 surv = survival.ExpertSurvivalRules( 

218 **params, 

219 ignore_missing=True, 

220 survival_time_attr=test_case.survival_time 

221 ) 

222 example_set = test_case.example_set 

223 surv.fit( 

224 example_set.values, 

225 example_set.labels, 

226 expert_rules=test_case.knowledge.expert_rules, 

227 expert_preferred_conditions=test_case.knowledge.expert_preferred_conditions, 

228 expert_forbidden_conditions=test_case.knowledge.expert_forbidden_conditions, 

229 ) 

230 model = surv.model 

231 expected = test_case.reference_report.rules 

232 actual = list(map(str, model.rules)) 

233 assert_rules_are_equals(expected, actual) 

234 

235 def test_refining_conditions_for_nominal_attributes(self): 

236 df: pd.DataFrame = read_arff( 

237 os.path.join(dir_path, "resources", "data", "bmt-train-0.arff") 

238 ) 

239 X, y = df.drop("survival_status", axis=1), df["survival_status"] 

240 

241 # Run experiment using python API 

242 clf = survival.ExpertSurvivalRules( 

243 complementary_conditions=True, 

244 extend_using_preferred=False, 

245 extend_using_automatic=False, 

246 induce_using_preferred=False, 

247 induce_using_automatic=False, 

248 preferred_conditions_per_rule=0, 

249 preferred_attributes_per_rule=0, 

250 survival_time_attr="survival_time", 

251 ) 

252 clf.fit(X, y, expert_rules=[("expert_rules-1", "IF CMVstatus @= {1} THEN")]) 

253 

254 self.assertEqual( 

255 ["IF [[CMVstatus = {1}]] THEN "], 

256 [str(r) for r in clf.model.rules], 

257 "Ruleset should contain only a single rule configured by expert", 

258 ) 

259 

260 clf.fit(X, y, expert_rules=[("expert_rules-1", "IF CMVstatus @= Any THEN")]) 

261 self.assertEqual( 

262 ["IF [[CMVstatus = !{1}]] THEN "], 

263 [str(r) for r in clf.model.rules], 

264 ( 

265 "Ruleset should contain only a single rule configured by expert with " 

266 "a refined condition" 

267 ), 

268 ) 

269 

270 def test_refining_conditions_for_numerical_attributes(self): 

271 df: pd.DataFrame = read_arff( 

272 os.path.join(dir_path, "resources", "data", "bmt-train-0.arff") 

273 ) 

274 X, y = df.drop("survival_status", axis=1), df["survival_status"] 

275 

276 # Run experiment using python API 

277 clf = survival.ExpertSurvivalRules( 

278 complementary_conditions=True, 

279 extend_using_preferred=False, 

280 extend_using_automatic=False, 

281 induce_using_preferred=False, 

282 induce_using_automatic=False, 

283 preferred_conditions_per_rule=0, 

284 preferred_attributes_per_rule=0, 

285 survival_time_attr="survival_time", 

286 ) 

287 clf.fit(X, y, expert_rules=[("expert_rules-1", "IF CD34kgx10d6 @= Any THEN")]) 

288 self.assertEqual( 

289 ["IF [[CD34kgx10d6 = (-inf, 11.86)]] THEN "], 

290 [str(r) for r in clf.model.rules], 

291 ( 

292 "Ruleset should contain only a single rule configured by expert with " 

293 "a refined condition" 

294 ), 

295 ) 

296 

297 

298if __name__ == "__main__": 

299 unittest.main()