Coverage for tests/test_survival.py: 99%
144 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
1import os
2import threading
3import unittest
5import numpy as np
6import pandas as pd
8from rulekit import survival
9from rulekit.arff import read_arff
10from rulekit.events import RuleInductionProgressListener
11from rulekit.kaplan_meier import KaplanMeierEstimator
12from rulekit.main import RuleKit
13from rulekit.rules import SurvivalRule
14from tests.utils import assert_rules_are_equals
15from tests.utils import dir_path
16from tests.utils import get_test_cases
19class TestKaplanMeierEstimator(unittest.TestCase):
21 survival_rules: survival.SurvivalRules
23 def setUp(self):
24 test_case = get_test_cases("SurvivalLogRankSnCTest")[0]
26 self.survival_rules = survival.SurvivalRules(
27 survival_time_attr=test_case.survival_time
28 )
29 example_set = test_case.example_set
30 self.survival_rules.fit(
31 example_set.values,
32 example_set.labels,
33 )
34 self.km: KaplanMeierEstimator = self.survival_rules.model.rules[
35 0
36 ].kaplan_meier_estimator
38 def test_accessing_probabilities(self):
39 self.assertTrue(
40 all([p >= 0.0 and p <= 1.0 for p in self.km.probabilities]),
41 "All probabilities should be in range [0, 1]",
42 )
43 self.assertTrue(
44 all([isinstance(p, float) for p in self.km.probabilities]),
45 "All probabilities should be Pythonic floats",
46 )
48 def test_accessing_events_count(self):
49 self.assertTrue(
50 all([isinstance(p, np.int_) for p in self.km.events_count]),
51 "All event counts should be Pythonic integers",
52 )
54 def test_accessing_at_risk_count(self):
55 self.assertTrue(
56 all([isinstance(p, np.int_) for p in self.km.at_risk_count]),
57 "All risk count should be Pythonic integers",
58 )
61class TestSurvivalRules(unittest.TestCase):
63 def test_induction_progress_listener(self):
64 test_case = get_test_cases("SurvivalLogRankSnCTest")[0]
66 surv = survival.SurvivalRules(survival_time_attr=test_case.survival_time)
67 example_set = test_case.example_set
69 class EventListener(RuleInductionProgressListener):
71 lock = threading.Lock()
72 induced_rules_count = 0
73 on_progress_calls_count = 0
75 def on_new_rule(self, rule: SurvivalRule):
76 self.lock.acquire()
77 self.induced_rules_count += 1
78 self.lock.release()
80 def on_progress(
81 self, total_examples_count: int, uncovered_examples_count: int
82 ):
83 self.lock.acquire()
84 self.on_progress_calls_count += 1
85 self.lock.release()
87 listener = EventListener()
88 surv.add_event_listener(listener)
89 surv.fit(
90 example_set.values,
91 example_set.labels,
92 )
93 rules_count = len(surv.model.rules)
94 self.assertEqual(rules_count, listener.induced_rules_count)
95 self.assertEqual(rules_count, listener.on_progress_calls_count)
97 def test_compare_with_java_results(self):
98 test_cases = get_test_cases("SurvivalLogRankSnCTest")
100 for test_case in test_cases:
101 params = test_case.induction_params
102 tree = survival.SurvivalRules(
103 **params, survival_time_attr=test_case.survival_time
104 )
105 example_set = test_case.example_set
106 tree.fit(example_set.values, example_set.labels)
107 model = tree.model
108 expected = test_case.reference_report.rules
109 actual = list(map(str, model.rules))
110 assert_rules_are_equals(expected, actual)
112 def test_fit_and_predict_on_boolean_columns(self):
113 test_case = get_test_cases("SurvivalLogRankSnCTest")[0]
114 params = test_case.induction_params
115 clf = survival.SurvivalRules(
116 **params, survival_time_attr=test_case.survival_time
117 )
118 X, y = test_case.example_set.values, test_case.example_set.labels
119 X["boolean_column"] = np.random.randint(low=0, high=2, size=X.shape[0]).astype(
120 bool
121 )
122 clf.fit(X, y)
123 clf.predict(X)
125 y = pd.Series(y)
126 clf.fit(X, y)
127 clf.predict(X)
129 def test_passing_survival_time_column_to_fit_method(self):
130 test_case = get_test_cases("SurvivalLogRankSnCTest")[0]
131 params = test_case.induction_params
132 surv1 = survival.SurvivalRules(**params)
133 surv2 = survival.SurvivalRules(
134 **params, survival_time_attr=test_case.survival_time
135 )
136 X, y = test_case.example_set.values, test_case.example_set.labels
137 survival_time_col: pd.Series = X[test_case.survival_time]
138 X_without_time_col: pd.DataFrame = X.drop(
139 columns=[test_case.survival_time], axis=1
140 )
141 surv1.fit(X_without_time_col, y, survival_time=survival_time_col)
142 surv2.fit(X, y)
144 assert_rules_are_equals(
145 [str(r) for r in surv1.model.rules],
146 [str(r) for r in surv2.model.rules],
147 )
149 def test_ibs_calculation(self):
150 test_case = get_test_cases("SurvivalLogRankSnCTest")[0]
151 params = test_case.induction_params
152 surv = survival.SurvivalRules(
153 **params, survival_time_attr=test_case.survival_time
154 )
155 X, y = test_case.example_set.values, test_case.example_set.labels
156 survival_time_col: pd.Series = X[test_case.survival_time]
157 X_without_time_col: pd.DataFrame = X.drop(
158 columns=[test_case.survival_time], axis=1
159 )
160 surv.fit(X, y)
162 ibs: float = surv.score(X, y)
163 ibs2: float = surv.score(X_without_time_col, y, survival_time=survival_time_col)
165 self.assertEqual(ibs, ibs2)
167 def test_getting_training_dataset_kaplan_meier_estimator(self):
168 test_case = get_test_cases("SurvivalLogRankSnCTest")[0]
169 params = test_case.induction_params
170 surv = survival.SurvivalRules(
171 **params, survival_time_attr=test_case.survival_time
172 )
173 example_set = test_case.example_set
174 unique_times_counts: int = example_set.values[test_case.survival_time].nunique()
175 surv.fit(example_set.values, example_set.labels)
177 training_km: KaplanMeierEstimator = surv.get_train_set_kaplan_meier()
178 self.assertTrue(
179 training_km is not None and isinstance(training_km, KaplanMeierEstimator),
180 "Should return KaplanMeierEstimator instance fitted on whole training set",
181 )
182 self.assertEqual(
183 len(training_km.times),
184 unique_times_counts,
185 "Estimator should contain probabilities for each unique time from the dataset",
186 )
188 def test_max_rule_count(self):
189 MAX_RULE_COUNT = 3
190 df: pd.DataFrame = read_arff(
191 os.path.join(dir_path, "resources", "data", "bmt-train-0.arff")
192 )
193 X, y = df.drop("survival_status", axis=1), df["survival_status"]
194 clf = survival.SurvivalRules(
195 survival_time_attr="survival_time",
196 max_rule_count=MAX_RULE_COUNT,
197 )
198 clf.fit(X, y)
199 self.assertLessEqual(
200 len(clf.model.rules),
201 MAX_RULE_COUNT,
202 f"Ruleset should contain no more than {MAX_RULE_COUNT} rules according to max_rule_count parameter",
203 )
206class TestExpertSurvivalRules(unittest.TestCase):
208 @classmethod
209 def setUpClass(cls):
210 RuleKit.init()
212 def test_compare_with_java_results(self):
213 test_cases = get_test_cases("SurvivalLogRankExpertSnCTest")
215 for test_case in test_cases:
216 params = test_case.induction_params
217 surv = survival.ExpertSurvivalRules(
218 **params,
219 ignore_missing=True,
220 survival_time_attr=test_case.survival_time
221 )
222 example_set = test_case.example_set
223 surv.fit(
224 example_set.values,
225 example_set.labels,
226 expert_rules=test_case.knowledge.expert_rules,
227 expert_preferred_conditions=test_case.knowledge.expert_preferred_conditions,
228 expert_forbidden_conditions=test_case.knowledge.expert_forbidden_conditions,
229 )
230 model = surv.model
231 expected = test_case.reference_report.rules
232 actual = list(map(str, model.rules))
233 assert_rules_are_equals(expected, actual)
235 def test_refining_conditions_for_nominal_attributes(self):
236 df: pd.DataFrame = read_arff(
237 os.path.join(dir_path, "resources", "data", "bmt-train-0.arff")
238 )
239 X, y = df.drop("survival_status", axis=1), df["survival_status"]
241 # Run experiment using python API
242 clf = survival.ExpertSurvivalRules(
243 complementary_conditions=True,
244 extend_using_preferred=False,
245 extend_using_automatic=False,
246 induce_using_preferred=False,
247 induce_using_automatic=False,
248 preferred_conditions_per_rule=0,
249 preferred_attributes_per_rule=0,
250 survival_time_attr="survival_time",
251 )
252 clf.fit(X, y, expert_rules=[("expert_rules-1", "IF CMVstatus @= {1} THEN")])
254 self.assertEqual(
255 ["IF [[CMVstatus = {1}]] THEN "],
256 [str(r) for r in clf.model.rules],
257 "Ruleset should contain only a single rule configured by expert",
258 )
260 clf.fit(X, y, expert_rules=[("expert_rules-1", "IF CMVstatus @= Any THEN")])
261 self.assertEqual(
262 ["IF [[CMVstatus = !{1}]] THEN "],
263 [str(r) for r in clf.model.rules],
264 (
265 "Ruleset should contain only a single rule configured by expert with "
266 "a refined condition"
267 ),
268 )
270 def test_refining_conditions_for_numerical_attributes(self):
271 df: pd.DataFrame = read_arff(
272 os.path.join(dir_path, "resources", "data", "bmt-train-0.arff")
273 )
274 X, y = df.drop("survival_status", axis=1), df["survival_status"]
276 # Run experiment using python API
277 clf = survival.ExpertSurvivalRules(
278 complementary_conditions=True,
279 extend_using_preferred=False,
280 extend_using_automatic=False,
281 induce_using_preferred=False,
282 induce_using_automatic=False,
283 preferred_conditions_per_rule=0,
284 preferred_attributes_per_rule=0,
285 survival_time_attr="survival_time",
286 )
287 clf.fit(X, y, expert_rules=[("expert_rules-1", "IF CD34kgx10d6 @= Any THEN")])
288 self.assertEqual(
289 ["IF [[CD34kgx10d6 = (-inf, 11.86)]] THEN "],
290 [str(r) for r in clf.model.rules],
291 (
292 "Ruleset should contain only a single rule configured by expert with "
293 "a refined condition"
294 ),
295 )
298if __name__ == "__main__":
299 unittest.main()