Coverage for rulekit/survival.py: 82%
131 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
1"""Module containing classes for survival analysis and prediction.
2"""
3from __future__ import annotations
5from typing import Optional
6from typing import Tuple
7from typing import Union
9import numpy as np
10import pandas as pd
11from jpype import JClass
12from pydantic import BaseModel # pylint: disable=no-name-in-module
14from rulekit._helpers import ExampleSetFactory
15from rulekit._helpers import PredictionResultMapper
16from rulekit._operator import BaseOperator
17from rulekit._operator import Data
18from rulekit._operator import ExpertKnowledgeOperator
19from rulekit._problem_types import ProblemType
20from rulekit.kaplan_meier import KaplanMeierEstimator
21from rulekit.params import ContrastSetModelParams
22from rulekit.params import DEFAULT_PARAMS_VALUE
23from rulekit.params import ExpertModelParams
24from rulekit.rules import RuleSet
25from rulekit.rules import SurvivalRule
27_DEFAULT_SURVIVAL_TIME_ATTR: str = "survival_time"
30class _SurvivalModelsParams(BaseModel):
31 survival_time_attr: Optional[str]
32 minsupp_new: Optional[float] = DEFAULT_PARAMS_VALUE["minsupp_new"]
33 max_growing: Optional[float] = DEFAULT_PARAMS_VALUE["max_growing"]
34 enable_pruning: Optional[bool] = DEFAULT_PARAMS_VALUE["enable_pruning"]
35 ignore_missing: Optional[bool] = DEFAULT_PARAMS_VALUE["ignore_missing"]
36 max_uncovered_fraction: Optional[float] = DEFAULT_PARAMS_VALUE[
37 "max_uncovered_fraction"
38 ]
39 select_best_candidate: Optional[bool] = DEFAULT_PARAMS_VALUE[
40 "select_best_candidate"
41 ]
42 complementary_conditions: Optional[bool] = DEFAULT_PARAMS_VALUE[
43 "complementary_conditions"
44 ]
45 max_rule_count: int = DEFAULT_PARAMS_VALUE["max_rule_count"]
48class _SurvivalExpertModelParams(_SurvivalModelsParams, ExpertModelParams):
49 pass
52class _BaseSurvivalRulesModel:
54 model: RuleSet[SurvivalRule]
56 def get_train_set_kaplan_meier(self) -> KaplanMeierEstimator:
57 """Returns train set KaplanMeier estimator
59 Returns:
60 KaplanMeierEstimator: estimator
61 """
62 return KaplanMeierEstimator(
63 self.model._java_object.getTrainingEstimator() # pylint: disable=protected-access
64 )
67class SurvivalRules(BaseOperator, _BaseSurvivalRulesModel):
68 """Survival model."""
70 __params_class__ = _SurvivalModelsParams
72 def __init__( # pylint: disable=super-init-not-called,too-many-arguments
73 self,
74 survival_time_attr: str = None,
75 minsupp_new: int = DEFAULT_PARAMS_VALUE["minsupp_new"],
76 max_growing: int = DEFAULT_PARAMS_VALUE["max_growing"],
77 enable_pruning: bool = DEFAULT_PARAMS_VALUE["enable_pruning"],
78 ignore_missing: bool = DEFAULT_PARAMS_VALUE["ignore_missing"],
79 max_uncovered_fraction: float = DEFAULT_PARAMS_VALUE["max_uncovered_fraction"],
80 select_best_candidate: bool = DEFAULT_PARAMS_VALUE["select_best_candidate"],
81 complementary_conditions: bool = DEFAULT_PARAMS_VALUE[
82 "complementary_conditions"
83 ],
84 max_rule_count: int = DEFAULT_PARAMS_VALUE["max_rule_count"],
85 ):
86 """
87 Parameters
88 ----------
89 survival_time_attr : str
90 name of column containing survival time data (use when data passed to model
91 is padnas dataframe).
92 minsupp_new : float = 5.0
93 a minimum number (or fraction, if value < 1.0) of previously uncovered
94 examples to be covered by a new rule (positive examples for classification
95 problems); default: 5,
96 max_growing : int = 0.0
97 non-negative integer representing maximum number of conditions which can be
98 added to the rule in the growing phase (use this parameter for large
99 datasets if execution time is prohibitive); 0 indicates no limit; default: 0
100 enable_pruning : bool = True
101 enable or disable pruning, default is True.
102 ignore_missing : bool = False
103 boolean telling whether missing values should be ignored (by default, a
104 missing value of given attribute is always considered as not fulfilling the
105 condition build upon that attribute); default: False.
106 max_uncovered_fraction : float = 0.0
107 Floating-point number from [0,1] interval representing maximum fraction of
108 examples that may remain uncovered by the rule set, default: 0.0.
109 select_best_candidate : bool = False
110 Flag determining if best candidate should be selected from growing phase;
111 default: False.
112 complementary_conditions : bool = False
113 If enabled, complementary conditions in the form a = !{value} for nominal
114 attributes are supported.
115 max_rule_count : int = 0
116 Maximum number of rules to be generated (for classification data sets it
117 applies to a single class); 0 indicates no limit.
118 """
119 self._params = None
120 self._rule_generator = None
121 self._configurator = None
122 self._initialize_rulekit()
123 self.set_params(
124 survival_time_attr=survival_time_attr,
125 minsupp_new=minsupp_new,
126 max_growing=max_growing,
127 enable_pruning=enable_pruning,
128 ignore_missing=ignore_missing,
129 max_uncovered_fraction=max_uncovered_fraction,
130 select_best_candidate=select_best_candidate,
131 complementary_conditions=complementary_conditions,
132 max_rule_count=max_rule_count,
133 )
134 self.model: RuleSet[SurvivalRule] = None
136 def set_params(self, **kwargs) -> object:
137 """Set models hyperparameters. Parameters are the same as in constructor."""
138 self.survival_time_attr = kwargs.get("survival_time_attr")
139 return BaseOperator.set_params(self, **kwargs)
141 @staticmethod
142 def _append_survival_time_columns(
143 values, survival_time: Union[pd.Series, np.ndarray, list]
144 ) -> Optional[str]:
145 survival_time_attr: str = _DEFAULT_SURVIVAL_TIME_ATTR
146 if isinstance(survival_time, pd.Series):
147 if survival_time.name is None:
148 survival_time.name = survival_time_attr
149 else:
150 survival_time_attr = survival_time.name
151 values[survival_time.name] = survival_time
152 elif isinstance(survival_time, np.ndarray):
153 np.append(values, survival_time, axis=1)
154 elif isinstance(survival_time, list):
155 for index, row in enumerate(values):
156 row.append(survival_time[index])
157 else:
158 raise ValueError(
159 "Data values must be instance of either pandas DataFrame, numpy array"
160 " or list"
161 )
162 return survival_time_attr
164 def _prepare_survival_attribute(
165 self, survival_time: Optional[Data], values: Data
166 ) -> str:
167 if self.survival_time_attr is None and survival_time is None:
168 raise ValueError(
169 'No "survival_time" attribute name was specified. '
170 + "Specify it using method set_params"
171 )
172 if survival_time is not None:
173 return SurvivalRules._append_survival_time_columns(values, survival_time)
174 return self.survival_time_attr
176 def fit(
177 self, values: Data, labels: Data, survival_time: Data = None
178 ) -> SurvivalRules: # pylint: disable=arguments-differ
179 """Train model on given dataset.
181 Parameters
182 ----------
183 values : :class:`rulekit.operator.Data`
184 attributes
185 labels : :class:`rulekit.operator.Data`
186 survival status
187 survival_time: :class:`rulekit.operator.Data`
188 data about survival time. Could be omitted when *survival_time_attr*
189 parameter was specified.
191 Returns
192 -------
193 self : SurvivalRules
194 """
195 survival_time_attribute = self._prepare_survival_attribute(
196 survival_time, values
197 )
198 super().fit(values, labels, survival_time_attribute)
199 return self
201 def predict(self, values: Data) -> np.ndarray:
202 """Perform prediction and return estimated survival function for each example.
204 Parameters
205 ----------
206 values : :class:`rulekit.operator.Data`
207 attributes
209 Returns
210 -------
211 result : np.ndarray
212 Each row represent single example from dataset and contains estimated
213 survival function for that example. Estimated survival function is returned
214 as a dictionary containing times and corresponding probabilities.
215 """
216 return PredictionResultMapper.map_survival(super().predict(values))
218 def score(self, values: Data, labels: Data, survival_time: Data = None) -> float:
219 """Return the Integrated Brier Score on the given dataset and labels
220 (event status indicator).
222 Integrated Brier Score (IBS) - the Brier score (BS) represents the squared
223 difference between true event status at time T and predicted event status at
224 that time; the Integrated Brier score summarizes the prediction error over all
225 observations and over all times in a test set.
227 Parameters
228 ----------
229 values : :class:`rulekit.operator.Data`
230 attributes
231 labels : :class:`rulekit.operator.Data`
232 survival status
233 survival_time: :class:`rulekit.operator.Data`
234 data about survival time. Could be omitted when *survival_time_attr*
235 parameter was specified
237 Returns
238 -------
239 score : float
240 Integrated Brier Score of self.predict(values) wrt. labels.
241 """
243 survival_time_attribute = self._prepare_survival_attribute(
244 survival_time, values
245 )
246 example_set = ExampleSetFactory(self._get_problem_type()).make(
247 values, labels, survival_time_attribute=survival_time_attribute
248 )
250 predicted_example_set = (
251 self.model._java_object.apply( # pylint: disable=protected-access
252 example_set
253 )
254 )
256 IntegratedBrierScore = JClass( # pylint: disable=invalid-name
257 "adaa.analytics.rules.logic.performance.IntegratedBrierScore"
258 )
259 integrated_brier_score = IntegratedBrierScore()
260 ibs = integrated_brier_score.countExample(predicted_example_set).getValue()
261 return float(ibs)
263 def _get_problem_type(self) -> ProblemType:
264 return ProblemType.SURVIVAL
267class ExpertSurvivalRules(ExpertKnowledgeOperator, SurvivalRules):
268 """Expert Survival model."""
270 __params_class__ = _SurvivalExpertModelParams
272 def __init__( # pylint: disable=super-init-not-called,too-many-arguments,too-many-locals
273 self,
274 survival_time_attr: str = None,
275 minsupp_new: float = DEFAULT_PARAMS_VALUE["minsupp_new"],
276 max_growing: int = DEFAULT_PARAMS_VALUE["max_growing"],
277 enable_pruning: bool = DEFAULT_PARAMS_VALUE["enable_pruning"],
278 ignore_missing: bool = DEFAULT_PARAMS_VALUE["ignore_missing"],
279 max_uncovered_fraction: float = DEFAULT_PARAMS_VALUE["max_uncovered_fraction"],
280 select_best_candidate: bool = DEFAULT_PARAMS_VALUE["select_best_candidate"],
281 complementary_conditions: bool = DEFAULT_PARAMS_VALUE[
282 "complementary_conditions"
283 ],
284 extend_using_preferred: bool = DEFAULT_PARAMS_VALUE["extend_using_preferred"],
285 extend_using_automatic: bool = DEFAULT_PARAMS_VALUE["extend_using_automatic"],
286 induce_using_preferred: bool = DEFAULT_PARAMS_VALUE["induce_using_preferred"],
287 induce_using_automatic: bool = DEFAULT_PARAMS_VALUE["induce_using_automatic"],
288 preferred_conditions_per_rule: int = DEFAULT_PARAMS_VALUE[
289 "preferred_conditions_per_rule"
290 ],
291 preferred_attributes_per_rule: int = DEFAULT_PARAMS_VALUE[
292 "preferred_attributes_per_rule"
293 ],
294 max_rule_count: int = DEFAULT_PARAMS_VALUE["max_rule_count"],
295 ):
296 """
297 Parameters
298 ----------
299 minsupp_new : float = 5.0
300 a minimum number (or fraction, if value < 1.0) of previously uncovered
301 examples to be covered by a new rule (positive examples for classification
302 problems); default: 5,
303 survival_time_attr : str
304 name of column containing survival time data (use when data passed to model
305 is pandas dataframe).
306 max_growing : int = 0.0
307 non-negative integer representing maximum number of conditions which can be
308 added to the rule in the growing phase (use this parameter for large
309 datasets if execution time is prohibitive); 0 indicates no limit; default: 0
310 enable_pruning : bool = True
311 enable or disable pruning, default is True.
312 ignore_missing : bool = False
313 boolean telling whether missing values should be ignored (by default, a
314 missing value of given attribute is always considered as not fulfilling the
315 condition build upon that attribute); default: False.
316 max_uncovered_fraction : float = 0.0
317 Floating-point number from [0,1] interval representing maximum fraction of
318 examples that may remain uncovered by the rule set, default: 0.0.
319 select_best_candidate : bool = False
320 Flag determining if best candidate should be selected from growing phase;
321 default: False.
322 complementary_conditions : bool = False
323 If enabled, complementary conditions in the form a = !{value} for nominal
324 attributes are supported.
325 max_rule_count : int = 0
326 Maximum number of rules to be generated (for classification data sets it
327 applies to a single class); 0 indicates no limit.
329 extend_using_preferred : bool = False
330 boolean indicating whether initial rules should be extended with a use of
331 preferred conditions and attributes; default is False
332 extend_using_automatic : bool = False
333 boolean indicating whether initial rules should be extended with a use of
334 automatic conditions and attributes; default is False
335 induce_using_preferred : bool = False
336 boolean indicating whether new rules should be induced with a use of
337 preferred conditions and attributes; default is False
338 induce_using_automatic : bool = False
339 boolean indicating whether new rules should be induced with a use of
340 automatic conditions and attributes; default is False
341 preferred_conditions_per_rule : int = None
342 maximum number of preferred conditions per rule; default: unlimited,
343 preferred_attributes_per_rule : int = None
344 maximum number of preferred attributes per rule; default: unlimited.
345 """
346 self._params = None
347 self._rule_generator = None
348 self._configurator = None
349 self._initialize_rulekit()
350 self.set_params(
351 survival_time_attr=survival_time_attr,
352 minsupp_new=minsupp_new,
353 max_growing=max_growing,
354 enable_pruning=enable_pruning,
355 ignore_missing=ignore_missing,
356 max_uncovered_fraction=max_uncovered_fraction,
357 select_best_candidate=select_best_candidate,
358 extend_using_preferred=extend_using_preferred,
359 extend_using_automatic=extend_using_automatic,
360 induce_using_preferred=induce_using_preferred,
361 induce_using_automatic=induce_using_automatic,
362 preferred_conditions_per_rule=preferred_conditions_per_rule,
363 preferred_attributes_per_rule=preferred_attributes_per_rule,
364 complementary_conditions=complementary_conditions,
365 max_rule_count=max_rule_count,
366 )
367 self.model: RuleSet[SurvivalRule] = None
369 def set_params(self, **kwargs) -> object: # pylint: disable=arguments-differ
370 self.survival_time_attr = kwargs["survival_time_attr"]
371 return ExpertKnowledgeOperator.set_params(self, **kwargs)
373 def fit( # pylint: disable=arguments-differ,too-many-arguments
374 self,
375 values: Data,
376 labels: Data,
377 survival_time: Data = None,
378 expert_rules: list[Union[str, tuple[str, str]]] = None,
379 expert_preferred_conditions: list[Union[str, tuple[str, str]]] = None,
380 expert_forbidden_conditions: list[Union[str, tuple[str, str]]] = None,
381 ) -> ExpertSurvivalRules:
382 """Train model on given dataset.
384 Parameters
385 ----------
386 values : :class:`rulekit.operator.Data`
387 attributes
388 labels : Data
389 survival status
390 survival_time: :class:`rulekit.operator.Data`
391 data about survival time. Could be omitted when *survival_time_attr*
392 parameter was specified.
393 expert_rules : List[Union[str, Tuple[str, str]]]
394 set of initial rules, either passed as a list of strings representing rules
395 or as list of tuples where first element is name of the rule and second one
396 is rule string.
397 expert_preferred_conditions : List[Union[str, Tuple[str, str]]]
398 multiset of preferred conditions (used also for specifying preferred
399 attributes by using special value Any). Either passed as a list of strings
400 representing rules or as list of tuples where first element is name of the
401 rule and second one is rule string.
402 expert_forbidden_conditions : List[Union[str, Tuple[str, str]]]
403 set of forbidden conditions (used also for specifying forbidden attributes
404 by using special valye Any). Either passed as a list of strings representing
405 rules or as list of tuples where first element is name of the rule and
406 second one is rule string.
408 Returns
409 -------
410 self : ExpertSurvivalRules
411 """
412 survival_time_attribute = SurvivalRules._prepare_survival_attribute(
413 self, survival_time, values
414 )
415 return ExpertKnowledgeOperator.fit(
416 self,
417 values=values,
418 labels=labels,
419 survival_time_attribute=survival_time_attribute,
420 expert_rules=expert_rules,
421 expert_preferred_conditions=expert_preferred_conditions,
422 expert_forbidden_conditions=expert_forbidden_conditions,
423 )
425 def predict(self, values: Data) -> np.ndarray:
426 return PredictionResultMapper.map_survival(
427 ExpertKnowledgeOperator.predict(self, values)
428 )
430 def _get_problem_type(self) -> ProblemType:
431 return ProblemType.SURVIVAL
434class _SurvivalContrastSetModelParams(ContrastSetModelParams, _SurvivalModelsParams):
435 pass
438class ContrastSetSurvivalRules(BaseOperator, _BaseSurvivalRulesModel):
439 """Contrast set survival model."""
441 __params_class__ = _SurvivalContrastSetModelParams
443 def __init__( # pylint: disable=super-init-not-called,too-many-arguments
444 self,
445 minsupp_all: Tuple[float, float, float, float] = DEFAULT_PARAMS_VALUE[
446 "minsupp_all"
447 ],
448 max_neg2pos: float = DEFAULT_PARAMS_VALUE["max_neg2pos"],
449 max_passes_count: int = DEFAULT_PARAMS_VALUE["max_passes_count"],
450 penalty_strength: float = DEFAULT_PARAMS_VALUE["penalty_strength"],
451 penalty_saturation: float = DEFAULT_PARAMS_VALUE["penalty_saturation"],
452 survival_time_attr: str = None,
453 minsupp_new: float = DEFAULT_PARAMS_VALUE["minsupp_new"],
454 max_growing: int = DEFAULT_PARAMS_VALUE["max_growing"],
455 enable_pruning: bool = DEFAULT_PARAMS_VALUE["enable_pruning"],
456 ignore_missing: bool = DEFAULT_PARAMS_VALUE["ignore_missing"],
457 max_uncovered_fraction: float = DEFAULT_PARAMS_VALUE["max_uncovered_fraction"],
458 select_best_candidate: bool = DEFAULT_PARAMS_VALUE["select_best_candidate"],
459 complementary_conditions: bool = DEFAULT_PARAMS_VALUE[
460 "complementary_conditions"
461 ],
462 max_rule_count: int = DEFAULT_PARAMS_VALUE["max_rule_count"],
463 ):
464 """
465 Parameters
466 ----------
467 minsupp_all: Tuple[float, float, float, float]
468 a minimum positive support of a contrast set (p/P). When multiple values are
469 specified, a metainduction is performed; Default and recommended sequence
470 is: 0.8, 0.5, 0.2, 0.1
471 max_neg2pos: float
472 a maximum ratio of negative to positive supports (nP/pN); Default is 0.5
473 max_passes_count: int
474 a maximum number of sequential covering passes for a single minsupp-all;
475 Default is 5
476 penalty_strength: float
477 (s) - penalty strength; Default is 0.5
478 penalty_saturation: float
479 the value of p_new / P at which penalty reward saturates; Default is 0.2.
480 survival_time_attr : str
481 name of column containing survival time data (use when data passed to model
482 is pandas dataframe).
483 minsupp_new : float = 5.0
484 a minimum number (or fraction, if value < 1.0) of previously uncovered
485 examples to be covered by a new rule (positive examples for classification
486 problems); default: 5,
487 max_growing : int = 0.0
488 non-negative integer representing maximum number of conditions which can be
489 added to the rule in the growing phase (use this parameter for large
490 datasets if execution time is prohibitive); 0 indicates no limit; default: 0
491 enable_pruning : bool = True
492 enable or disable pruning, default is True.
493 ignore_missing : bool = False
494 boolean telling whether missing values should be ignored (by default, a
495 missing value of given attribute is always considered as not fulfilling the
496 condition build upon that attribute); default: False.
497 max_uncovered_fraction : float = 0.0
498 Floating-point number from [0,1] interval representing maximum fraction of
499 examples that may remain uncovered by the rule set, default: 0.0.
500 select_best_candidate : bool = False
501 Flag determining if best candidate should be selected from growing phase;
502 default: False.
503 complementary_conditions : bool = False
504 If enabled, complementary conditions in the form a = !{value} for nominal
505 attributes are supported.
506 max_rule_count : int = 0
507 Maximum number of rules to be generated (for classification data sets it
508 applies to a single class); 0 indicates no limit.
509 """
510 self._params = None
511 self._rule_generator = None
512 self._configurator = None
513 self.contrast_attribute: str = None
514 self._initialize_rulekit()
515 self.set_params(
516 minsupp_all=minsupp_all,
517 max_neg2pos=max_neg2pos,
518 max_passes_count=max_passes_count,
519 penalty_strength=penalty_strength,
520 penalty_saturation=penalty_saturation,
521 survival_time_attr=survival_time_attr,
522 minsupp_new=minsupp_new,
523 max_growing=max_growing,
524 enable_pruning=enable_pruning,
525 ignore_missing=ignore_missing,
526 max_uncovered_fraction=max_uncovered_fraction,
527 select_best_candidate=select_best_candidate,
528 complementary_conditions=complementary_conditions,
529 max_rule_count=max_rule_count,
530 )
531 self.model: RuleSet[SurvivalRule] = None
533 def set_params(self, **kwargs) -> object:
534 """Set models hyperparameters. Parameters are the same as in constructor."""
535 # params validation
536 self.survival_time_attr = kwargs["survival_time_attr"]
537 return BaseOperator.set_params(self, **kwargs)
539 def fit( # pylint: disable=arguments-renamed
540 self,
541 values: Data,
542 labels: Data,
543 contrast_attribute: str,
544 survival_time: Data = None,
545 ) -> ContrastSetSurvivalRules:
546 """Train model on given dataset.
548 Parameters
549 ----------
550 values : :class:`rulekit.operator.Data`
551 attributes
552 labels : :class:`rulekit.operator.Data`
553 survival status
554 contrast_attribute: str
555 group attribute
556 survival_time: :class:`rulekit.operator.Data`
557 data about survival time. Could be omitted when *survival_time_attr*
558 parameter was specified.
560 Returns
561 -------
562 self : ContrastSetSurvivalRules
563 """
564 survival_time_attribute = SurvivalRules._prepare_survival_attribute( # pylint: disable=protected-access
565 self, survival_time, values
566 )
567 super().fit(
568 values,
569 labels,
570 survival_time_attribute=survival_time_attribute,
571 contrast_attribute=contrast_attribute,
572 )
573 self.contrast_attribute = contrast_attribute
574 return self
576 def predict(self, values: Data) -> np.ndarray:
577 """Perform prediction and return estimated survival function for each example.
579 Parameters
580 ----------
581 values : :class:`rulekit.operator.Data`
582 attributes
584 Returns
585 -------
586 result : np.ndarray
587 Each row represent single example from dataset and contains estimated
588 survival function for that example. Estimated survival function is returned
589 as a dictionary containing times and corresponding probabilities.
590 """
591 return PredictionResultMapper.map_survival(super().predict(values))
593 def score(self, values: Data, labels: Data, survival_time: Data = None) -> float:
594 """Return the Integrated Brier Score on the given dataset and
595 labels(event status indicator).
597 Integrated Brier Score (IBS) - the Brier score (BS) represents the squared
598 differencebetween true event status at time T and predicted event status at that
599 time; the Integrated Brier score summarizes the prediction error over all
600 observations and over all times in a test set.
602 Parameters
603 ----------
604 values : :class:`rulekit.operator.Data`
605 attributes
606 labels : :class:`rulekit.operator.Data`
607 survival status
608 survival_time: :class:`rulekit.operator.Data`
609 data about survival time. Could be omitted when *survival_time_attr*
610 parameter was specified
612 Returns
613 -------
614 score : float
615 Integrated Brier Score of self.predict(values) wrt. labels.
616 """
617 return SurvivalRules.score(self, values, labels, survival_time=survival_time)
619 def _get_problem_type(self) -> ProblemType:
620 return ProblemType.CONTRAST_SURVIVAL