Coverage for rulekit/survival.py: 82%

131 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-07 11:26 +0000

1"""Module containing classes for survival analysis and prediction. 

2""" 

3from __future__ import annotations 

4 

5from typing import Optional 

6from typing import Tuple 

7from typing import Union 

8 

9import numpy as np 

10import pandas as pd 

11from jpype import JClass 

12from pydantic import BaseModel # pylint: disable=no-name-in-module 

13 

14from rulekit._helpers import ExampleSetFactory 

15from rulekit._helpers import PredictionResultMapper 

16from rulekit._operator import BaseOperator 

17from rulekit._operator import Data 

18from rulekit._operator import ExpertKnowledgeOperator 

19from rulekit._problem_types import ProblemType 

20from rulekit.kaplan_meier import KaplanMeierEstimator 

21from rulekit.params import ContrastSetModelParams 

22from rulekit.params import DEFAULT_PARAMS_VALUE 

23from rulekit.params import ExpertModelParams 

24from rulekit.rules import RuleSet 

25from rulekit.rules import SurvivalRule 

26 

27_DEFAULT_SURVIVAL_TIME_ATTR: str = "survival_time" 

28 

29 

30class _SurvivalModelsParams(BaseModel): 

31 survival_time_attr: Optional[str] 

32 minsupp_new: Optional[float] = DEFAULT_PARAMS_VALUE["minsupp_new"] 

33 max_growing: Optional[float] = DEFAULT_PARAMS_VALUE["max_growing"] 

34 enable_pruning: Optional[bool] = DEFAULT_PARAMS_VALUE["enable_pruning"] 

35 ignore_missing: Optional[bool] = DEFAULT_PARAMS_VALUE["ignore_missing"] 

36 max_uncovered_fraction: Optional[float] = DEFAULT_PARAMS_VALUE[ 

37 "max_uncovered_fraction" 

38 ] 

39 select_best_candidate: Optional[bool] = DEFAULT_PARAMS_VALUE[ 

40 "select_best_candidate" 

41 ] 

42 complementary_conditions: Optional[bool] = DEFAULT_PARAMS_VALUE[ 

43 "complementary_conditions" 

44 ] 

45 max_rule_count: int = DEFAULT_PARAMS_VALUE["max_rule_count"] 

46 

47 

48class _SurvivalExpertModelParams(_SurvivalModelsParams, ExpertModelParams): 

49 pass 

50 

51 

52class _BaseSurvivalRulesModel: 

53 

54 model: RuleSet[SurvivalRule] 

55 

56 def get_train_set_kaplan_meier(self) -> KaplanMeierEstimator: 

57 """Returns train set KaplanMeier estimator 

58 

59 Returns: 

60 KaplanMeierEstimator: estimator 

61 """ 

62 return KaplanMeierEstimator( 

63 self.model._java_object.getTrainingEstimator() # pylint: disable=protected-access 

64 ) 

65 

66 

67class SurvivalRules(BaseOperator, _BaseSurvivalRulesModel): 

68 """Survival model.""" 

69 

70 __params_class__ = _SurvivalModelsParams 

71 

72 def __init__( # pylint: disable=super-init-not-called,too-many-arguments 

73 self, 

74 survival_time_attr: str = None, 

75 minsupp_new: int = DEFAULT_PARAMS_VALUE["minsupp_new"], 

76 max_growing: int = DEFAULT_PARAMS_VALUE["max_growing"], 

77 enable_pruning: bool = DEFAULT_PARAMS_VALUE["enable_pruning"], 

78 ignore_missing: bool = DEFAULT_PARAMS_VALUE["ignore_missing"], 

79 max_uncovered_fraction: float = DEFAULT_PARAMS_VALUE["max_uncovered_fraction"], 

80 select_best_candidate: bool = DEFAULT_PARAMS_VALUE["select_best_candidate"], 

81 complementary_conditions: bool = DEFAULT_PARAMS_VALUE[ 

82 "complementary_conditions" 

83 ], 

84 max_rule_count: int = DEFAULT_PARAMS_VALUE["max_rule_count"], 

85 ): 

86 """ 

87 Parameters 

88 ---------- 

89 survival_time_attr : str 

90 name of column containing survival time data (use when data passed to model 

91 is padnas dataframe). 

92 minsupp_new : float = 5.0 

93 a minimum number (or fraction, if value < 1.0) of previously uncovered 

94 examples to be covered by a new rule (positive examples for classification 

95 problems); default: 5, 

96 max_growing : int = 0.0 

97 non-negative integer representing maximum number of conditions which can be 

98 added to the rule in the growing phase (use this parameter for large 

99 datasets if execution time is prohibitive); 0 indicates no limit; default: 0 

100 enable_pruning : bool = True 

101 enable or disable pruning, default is True. 

102 ignore_missing : bool = False 

103 boolean telling whether missing values should be ignored (by default, a 

104 missing value of given attribute is always considered as not fulfilling the 

105 condition build upon that attribute); default: False. 

106 max_uncovered_fraction : float = 0.0 

107 Floating-point number from [0,1] interval representing maximum fraction of 

108 examples that may remain uncovered by the rule set, default: 0.0. 

109 select_best_candidate : bool = False 

110 Flag determining if best candidate should be selected from growing phase; 

111 default: False. 

112 complementary_conditions : bool = False 

113 If enabled, complementary conditions in the form a = !{value} for nominal 

114 attributes are supported. 

115 max_rule_count : int = 0 

116 Maximum number of rules to be generated (for classification data sets it 

117 applies to a single class); 0 indicates no limit. 

118 """ 

119 self._params = None 

120 self._rule_generator = None 

121 self._configurator = None 

122 self._initialize_rulekit() 

123 self.set_params( 

124 survival_time_attr=survival_time_attr, 

125 minsupp_new=minsupp_new, 

126 max_growing=max_growing, 

127 enable_pruning=enable_pruning, 

128 ignore_missing=ignore_missing, 

129 max_uncovered_fraction=max_uncovered_fraction, 

130 select_best_candidate=select_best_candidate, 

131 complementary_conditions=complementary_conditions, 

132 max_rule_count=max_rule_count, 

133 ) 

134 self.model: RuleSet[SurvivalRule] = None 

135 

136 def set_params(self, **kwargs) -> object: 

137 """Set models hyperparameters. Parameters are the same as in constructor.""" 

138 self.survival_time_attr = kwargs.get("survival_time_attr") 

139 return BaseOperator.set_params(self, **kwargs) 

140 

141 @staticmethod 

142 def _append_survival_time_columns( 

143 values, survival_time: Union[pd.Series, np.ndarray, list] 

144 ) -> Optional[str]: 

145 survival_time_attr: str = _DEFAULT_SURVIVAL_TIME_ATTR 

146 if isinstance(survival_time, pd.Series): 

147 if survival_time.name is None: 

148 survival_time.name = survival_time_attr 

149 else: 

150 survival_time_attr = survival_time.name 

151 values[survival_time.name] = survival_time 

152 elif isinstance(survival_time, np.ndarray): 

153 np.append(values, survival_time, axis=1) 

154 elif isinstance(survival_time, list): 

155 for index, row in enumerate(values): 

156 row.append(survival_time[index]) 

157 else: 

158 raise ValueError( 

159 "Data values must be instance of either pandas DataFrame, numpy array" 

160 " or list" 

161 ) 

162 return survival_time_attr 

163 

164 def _prepare_survival_attribute( 

165 self, survival_time: Optional[Data], values: Data 

166 ) -> str: 

167 if self.survival_time_attr is None and survival_time is None: 

168 raise ValueError( 

169 'No "survival_time" attribute name was specified. ' 

170 + "Specify it using method set_params" 

171 ) 

172 if survival_time is not None: 

173 return SurvivalRules._append_survival_time_columns(values, survival_time) 

174 return self.survival_time_attr 

175 

176 def fit( 

177 self, values: Data, labels: Data, survival_time: Data = None 

178 ) -> SurvivalRules: # pylint: disable=arguments-differ 

179 """Train model on given dataset. 

180 

181 Parameters 

182 ---------- 

183 values : :class:`rulekit.operator.Data` 

184 attributes 

185 labels : :class:`rulekit.operator.Data` 

186 survival status 

187 survival_time: :class:`rulekit.operator.Data` 

188 data about survival time. Could be omitted when *survival_time_attr* 

189 parameter was specified. 

190 

191 Returns 

192 ------- 

193 self : SurvivalRules 

194 """ 

195 survival_time_attribute = self._prepare_survival_attribute( 

196 survival_time, values 

197 ) 

198 super().fit(values, labels, survival_time_attribute) 

199 return self 

200 

201 def predict(self, values: Data) -> np.ndarray: 

202 """Perform prediction and return estimated survival function for each example. 

203 

204 Parameters 

205 ---------- 

206 values : :class:`rulekit.operator.Data` 

207 attributes 

208 

209 Returns 

210 ------- 

211 result : np.ndarray 

212 Each row represent single example from dataset and contains estimated 

213 survival function for that example. Estimated survival function is returned 

214 as a dictionary containing times and corresponding probabilities. 

215 """ 

216 return PredictionResultMapper.map_survival(super().predict(values)) 

217 

218 def score(self, values: Data, labels: Data, survival_time: Data = None) -> float: 

219 """Return the Integrated Brier Score on the given dataset and labels 

220 (event status indicator). 

221 

222 Integrated Brier Score (IBS) - the Brier score (BS) represents the squared 

223 difference between true event status at time T and predicted event status at 

224 that time; the Integrated Brier score summarizes the prediction error over all 

225 observations and over all times in a test set. 

226 

227 Parameters 

228 ---------- 

229 values : :class:`rulekit.operator.Data` 

230 attributes 

231 labels : :class:`rulekit.operator.Data` 

232 survival status 

233 survival_time: :class:`rulekit.operator.Data` 

234 data about survival time. Could be omitted when *survival_time_attr* 

235 parameter was specified 

236 

237 Returns 

238 ------- 

239 score : float 

240 Integrated Brier Score of self.predict(values) wrt. labels. 

241 """ 

242 

243 survival_time_attribute = self._prepare_survival_attribute( 

244 survival_time, values 

245 ) 

246 example_set = ExampleSetFactory(self._get_problem_type()).make( 

247 values, labels, survival_time_attribute=survival_time_attribute 

248 ) 

249 

250 predicted_example_set = ( 

251 self.model._java_object.apply( # pylint: disable=protected-access 

252 example_set 

253 ) 

254 ) 

255 

256 IntegratedBrierScore = JClass( # pylint: disable=invalid-name 

257 "adaa.analytics.rules.logic.performance.IntegratedBrierScore" 

258 ) 

259 integrated_brier_score = IntegratedBrierScore() 

260 ibs = integrated_brier_score.countExample(predicted_example_set).getValue() 

261 return float(ibs) 

262 

263 def _get_problem_type(self) -> ProblemType: 

264 return ProblemType.SURVIVAL 

265 

266 

267class ExpertSurvivalRules(ExpertKnowledgeOperator, SurvivalRules): 

268 """Expert Survival model.""" 

269 

270 __params_class__ = _SurvivalExpertModelParams 

271 

272 def __init__( # pylint: disable=super-init-not-called,too-many-arguments,too-many-locals 

273 self, 

274 survival_time_attr: str = None, 

275 minsupp_new: float = DEFAULT_PARAMS_VALUE["minsupp_new"], 

276 max_growing: int = DEFAULT_PARAMS_VALUE["max_growing"], 

277 enable_pruning: bool = DEFAULT_PARAMS_VALUE["enable_pruning"], 

278 ignore_missing: bool = DEFAULT_PARAMS_VALUE["ignore_missing"], 

279 max_uncovered_fraction: float = DEFAULT_PARAMS_VALUE["max_uncovered_fraction"], 

280 select_best_candidate: bool = DEFAULT_PARAMS_VALUE["select_best_candidate"], 

281 complementary_conditions: bool = DEFAULT_PARAMS_VALUE[ 

282 "complementary_conditions" 

283 ], 

284 extend_using_preferred: bool = DEFAULT_PARAMS_VALUE["extend_using_preferred"], 

285 extend_using_automatic: bool = DEFAULT_PARAMS_VALUE["extend_using_automatic"], 

286 induce_using_preferred: bool = DEFAULT_PARAMS_VALUE["induce_using_preferred"], 

287 induce_using_automatic: bool = DEFAULT_PARAMS_VALUE["induce_using_automatic"], 

288 preferred_conditions_per_rule: int = DEFAULT_PARAMS_VALUE[ 

289 "preferred_conditions_per_rule" 

290 ], 

291 preferred_attributes_per_rule: int = DEFAULT_PARAMS_VALUE[ 

292 "preferred_attributes_per_rule" 

293 ], 

294 max_rule_count: int = DEFAULT_PARAMS_VALUE["max_rule_count"], 

295 ): 

296 """ 

297 Parameters 

298 ---------- 

299 minsupp_new : float = 5.0 

300 a minimum number (or fraction, if value < 1.0) of previously uncovered 

301 examples to be covered by a new rule (positive examples for classification 

302 problems); default: 5, 

303 survival_time_attr : str 

304 name of column containing survival time data (use when data passed to model 

305 is pandas dataframe). 

306 max_growing : int = 0.0 

307 non-negative integer representing maximum number of conditions which can be 

308 added to the rule in the growing phase (use this parameter for large 

309 datasets if execution time is prohibitive); 0 indicates no limit; default: 0 

310 enable_pruning : bool = True 

311 enable or disable pruning, default is True. 

312 ignore_missing : bool = False 

313 boolean telling whether missing values should be ignored (by default, a 

314 missing value of given attribute is always considered as not fulfilling the 

315 condition build upon that attribute); default: False. 

316 max_uncovered_fraction : float = 0.0 

317 Floating-point number from [0,1] interval representing maximum fraction of 

318 examples that may remain uncovered by the rule set, default: 0.0. 

319 select_best_candidate : bool = False 

320 Flag determining if best candidate should be selected from growing phase; 

321 default: False. 

322 complementary_conditions : bool = False 

323 If enabled, complementary conditions in the form a = !{value} for nominal 

324 attributes are supported. 

325 max_rule_count : int = 0 

326 Maximum number of rules to be generated (for classification data sets it 

327 applies to a single class); 0 indicates no limit. 

328 

329 extend_using_preferred : bool = False 

330 boolean indicating whether initial rules should be extended with a use of 

331 preferred conditions and attributes; default is False 

332 extend_using_automatic : bool = False 

333 boolean indicating whether initial rules should be extended with a use of 

334 automatic conditions and attributes; default is False 

335 induce_using_preferred : bool = False 

336 boolean indicating whether new rules should be induced with a use of 

337 preferred conditions and attributes; default is False 

338 induce_using_automatic : bool = False 

339 boolean indicating whether new rules should be induced with a use of 

340 automatic conditions and attributes; default is False 

341 preferred_conditions_per_rule : int = None 

342 maximum number of preferred conditions per rule; default: unlimited, 

343 preferred_attributes_per_rule : int = None 

344 maximum number of preferred attributes per rule; default: unlimited. 

345 """ 

346 self._params = None 

347 self._rule_generator = None 

348 self._configurator = None 

349 self._initialize_rulekit() 

350 self.set_params( 

351 survival_time_attr=survival_time_attr, 

352 minsupp_new=minsupp_new, 

353 max_growing=max_growing, 

354 enable_pruning=enable_pruning, 

355 ignore_missing=ignore_missing, 

356 max_uncovered_fraction=max_uncovered_fraction, 

357 select_best_candidate=select_best_candidate, 

358 extend_using_preferred=extend_using_preferred, 

359 extend_using_automatic=extend_using_automatic, 

360 induce_using_preferred=induce_using_preferred, 

361 induce_using_automatic=induce_using_automatic, 

362 preferred_conditions_per_rule=preferred_conditions_per_rule, 

363 preferred_attributes_per_rule=preferred_attributes_per_rule, 

364 complementary_conditions=complementary_conditions, 

365 max_rule_count=max_rule_count, 

366 ) 

367 self.model: RuleSet[SurvivalRule] = None 

368 

369 def set_params(self, **kwargs) -> object: # pylint: disable=arguments-differ 

370 self.survival_time_attr = kwargs["survival_time_attr"] 

371 return ExpertKnowledgeOperator.set_params(self, **kwargs) 

372 

373 def fit( # pylint: disable=arguments-differ,too-many-arguments 

374 self, 

375 values: Data, 

376 labels: Data, 

377 survival_time: Data = None, 

378 expert_rules: list[Union[str, tuple[str, str]]] = None, 

379 expert_preferred_conditions: list[Union[str, tuple[str, str]]] = None, 

380 expert_forbidden_conditions: list[Union[str, tuple[str, str]]] = None, 

381 ) -> ExpertSurvivalRules: 

382 """Train model on given dataset. 

383 

384 Parameters 

385 ---------- 

386 values : :class:`rulekit.operator.Data` 

387 attributes 

388 labels : Data 

389 survival status 

390 survival_time: :class:`rulekit.operator.Data` 

391 data about survival time. Could be omitted when *survival_time_attr* 

392 parameter was specified. 

393 expert_rules : List[Union[str, Tuple[str, str]]] 

394 set of initial rules, either passed as a list of strings representing rules 

395 or as list of tuples where first element is name of the rule and second one 

396 is rule string. 

397 expert_preferred_conditions : List[Union[str, Tuple[str, str]]] 

398 multiset of preferred conditions (used also for specifying preferred 

399 attributes by using special value Any). Either passed as a list of strings 

400 representing rules or as list of tuples where first element is name of the 

401 rule and second one is rule string. 

402 expert_forbidden_conditions : List[Union[str, Tuple[str, str]]] 

403 set of forbidden conditions (used also for specifying forbidden attributes 

404 by using special valye Any). Either passed as a list of strings representing 

405 rules or as list of tuples where first element is name of the rule and 

406 second one is rule string. 

407 

408 Returns 

409 ------- 

410 self : ExpertSurvivalRules 

411 """ 

412 survival_time_attribute = SurvivalRules._prepare_survival_attribute( 

413 self, survival_time, values 

414 ) 

415 return ExpertKnowledgeOperator.fit( 

416 self, 

417 values=values, 

418 labels=labels, 

419 survival_time_attribute=survival_time_attribute, 

420 expert_rules=expert_rules, 

421 expert_preferred_conditions=expert_preferred_conditions, 

422 expert_forbidden_conditions=expert_forbidden_conditions, 

423 ) 

424 

425 def predict(self, values: Data) -> np.ndarray: 

426 return PredictionResultMapper.map_survival( 

427 ExpertKnowledgeOperator.predict(self, values) 

428 ) 

429 

430 def _get_problem_type(self) -> ProblemType: 

431 return ProblemType.SURVIVAL 

432 

433 

434class _SurvivalContrastSetModelParams(ContrastSetModelParams, _SurvivalModelsParams): 

435 pass 

436 

437 

438class ContrastSetSurvivalRules(BaseOperator, _BaseSurvivalRulesModel): 

439 """Contrast set survival model.""" 

440 

441 __params_class__ = _SurvivalContrastSetModelParams 

442 

443 def __init__( # pylint: disable=super-init-not-called,too-many-arguments 

444 self, 

445 minsupp_all: Tuple[float, float, float, float] = DEFAULT_PARAMS_VALUE[ 

446 "minsupp_all" 

447 ], 

448 max_neg2pos: float = DEFAULT_PARAMS_VALUE["max_neg2pos"], 

449 max_passes_count: int = DEFAULT_PARAMS_VALUE["max_passes_count"], 

450 penalty_strength: float = DEFAULT_PARAMS_VALUE["penalty_strength"], 

451 penalty_saturation: float = DEFAULT_PARAMS_VALUE["penalty_saturation"], 

452 survival_time_attr: str = None, 

453 minsupp_new: float = DEFAULT_PARAMS_VALUE["minsupp_new"], 

454 max_growing: int = DEFAULT_PARAMS_VALUE["max_growing"], 

455 enable_pruning: bool = DEFAULT_PARAMS_VALUE["enable_pruning"], 

456 ignore_missing: bool = DEFAULT_PARAMS_VALUE["ignore_missing"], 

457 max_uncovered_fraction: float = DEFAULT_PARAMS_VALUE["max_uncovered_fraction"], 

458 select_best_candidate: bool = DEFAULT_PARAMS_VALUE["select_best_candidate"], 

459 complementary_conditions: bool = DEFAULT_PARAMS_VALUE[ 

460 "complementary_conditions" 

461 ], 

462 max_rule_count: int = DEFAULT_PARAMS_VALUE["max_rule_count"], 

463 ): 

464 """ 

465 Parameters 

466 ---------- 

467 minsupp_all: Tuple[float, float, float, float] 

468 a minimum positive support of a contrast set (p/P). When multiple values are 

469 specified, a metainduction is performed; Default and recommended sequence 

470 is: 0.8, 0.5, 0.2, 0.1 

471 max_neg2pos: float 

472 a maximum ratio of negative to positive supports (nP/pN); Default is 0.5 

473 max_passes_count: int 

474 a maximum number of sequential covering passes for a single minsupp-all; 

475 Default is 5 

476 penalty_strength: float 

477 (s) - penalty strength; Default is 0.5 

478 penalty_saturation: float 

479 the value of p_new / P at which penalty reward saturates; Default is 0.2. 

480 survival_time_attr : str 

481 name of column containing survival time data (use when data passed to model 

482 is pandas dataframe). 

483 minsupp_new : float = 5.0 

484 a minimum number (or fraction, if value < 1.0) of previously uncovered 

485 examples to be covered by a new rule (positive examples for classification 

486 problems); default: 5, 

487 max_growing : int = 0.0 

488 non-negative integer representing maximum number of conditions which can be 

489 added to the rule in the growing phase (use this parameter for large 

490 datasets if execution time is prohibitive); 0 indicates no limit; default: 0 

491 enable_pruning : bool = True 

492 enable or disable pruning, default is True. 

493 ignore_missing : bool = False 

494 boolean telling whether missing values should be ignored (by default, a 

495 missing value of given attribute is always considered as not fulfilling the 

496 condition build upon that attribute); default: False. 

497 max_uncovered_fraction : float = 0.0 

498 Floating-point number from [0,1] interval representing maximum fraction of 

499 examples that may remain uncovered by the rule set, default: 0.0. 

500 select_best_candidate : bool = False 

501 Flag determining if best candidate should be selected from growing phase; 

502 default: False. 

503 complementary_conditions : bool = False 

504 If enabled, complementary conditions in the form a = !{value} for nominal 

505 attributes are supported. 

506 max_rule_count : int = 0 

507 Maximum number of rules to be generated (for classification data sets it 

508 applies to a single class); 0 indicates no limit. 

509 """ 

510 self._params = None 

511 self._rule_generator = None 

512 self._configurator = None 

513 self.contrast_attribute: str = None 

514 self._initialize_rulekit() 

515 self.set_params( 

516 minsupp_all=minsupp_all, 

517 max_neg2pos=max_neg2pos, 

518 max_passes_count=max_passes_count, 

519 penalty_strength=penalty_strength, 

520 penalty_saturation=penalty_saturation, 

521 survival_time_attr=survival_time_attr, 

522 minsupp_new=minsupp_new, 

523 max_growing=max_growing, 

524 enable_pruning=enable_pruning, 

525 ignore_missing=ignore_missing, 

526 max_uncovered_fraction=max_uncovered_fraction, 

527 select_best_candidate=select_best_candidate, 

528 complementary_conditions=complementary_conditions, 

529 max_rule_count=max_rule_count, 

530 ) 

531 self.model: RuleSet[SurvivalRule] = None 

532 

533 def set_params(self, **kwargs) -> object: 

534 """Set models hyperparameters. Parameters are the same as in constructor.""" 

535 # params validation 

536 self.survival_time_attr = kwargs["survival_time_attr"] 

537 return BaseOperator.set_params(self, **kwargs) 

538 

539 def fit( # pylint: disable=arguments-renamed 

540 self, 

541 values: Data, 

542 labels: Data, 

543 contrast_attribute: str, 

544 survival_time: Data = None, 

545 ) -> ContrastSetSurvivalRules: 

546 """Train model on given dataset. 

547 

548 Parameters 

549 ---------- 

550 values : :class:`rulekit.operator.Data` 

551 attributes 

552 labels : :class:`rulekit.operator.Data` 

553 survival status 

554 contrast_attribute: str 

555 group attribute 

556 survival_time: :class:`rulekit.operator.Data` 

557 data about survival time. Could be omitted when *survival_time_attr* 

558 parameter was specified. 

559 

560 Returns 

561 ------- 

562 self : ContrastSetSurvivalRules 

563 """ 

564 survival_time_attribute = SurvivalRules._prepare_survival_attribute( # pylint: disable=protected-access 

565 self, survival_time, values 

566 ) 

567 super().fit( 

568 values, 

569 labels, 

570 survival_time_attribute=survival_time_attribute, 

571 contrast_attribute=contrast_attribute, 

572 ) 

573 self.contrast_attribute = contrast_attribute 

574 return self 

575 

576 def predict(self, values: Data) -> np.ndarray: 

577 """Perform prediction and return estimated survival function for each example. 

578 

579 Parameters 

580 ---------- 

581 values : :class:`rulekit.operator.Data` 

582 attributes 

583 

584 Returns 

585 ------- 

586 result : np.ndarray 

587 Each row represent single example from dataset and contains estimated 

588 survival function for that example. Estimated survival function is returned 

589 as a dictionary containing times and corresponding probabilities. 

590 """ 

591 return PredictionResultMapper.map_survival(super().predict(values)) 

592 

593 def score(self, values: Data, labels: Data, survival_time: Data = None) -> float: 

594 """Return the Integrated Brier Score on the given dataset and 

595 labels(event status indicator). 

596 

597 Integrated Brier Score (IBS) - the Brier score (BS) represents the squared 

598 differencebetween true event status at time T and predicted event status at that 

599 time; the Integrated Brier score summarizes the prediction error over all 

600 observations and over all times in a test set. 

601 

602 Parameters 

603 ---------- 

604 values : :class:`rulekit.operator.Data` 

605 attributes 

606 labels : :class:`rulekit.operator.Data` 

607 survival status 

608 survival_time: :class:`rulekit.operator.Data` 

609 data about survival time. Could be omitted when *survival_time_attr* 

610 parameter was specified 

611 

612 Returns 

613 ------- 

614 score : float 

615 Integrated Brier Score of self.predict(values) wrt. labels. 

616 """ 

617 return SurvivalRules.score(self, values, labels, survival_time=survival_time) 

618 

619 def _get_problem_type(self) -> ProblemType: 

620 return ProblemType.CONTRAST_SURVIVAL