Survival analysis

This notebook presents example usage of package for solving survival problem on bmt dataset. You can access dataset here

This tutorial will cover topics such as:
- training model
- changing model hyperparameters
- hyperparameters tuning
- calculating metrics for model

Install dependencies

[ ]:
%pip install matplotlib

Summary of the dataset

[ ]:
import pandas as pd
import numpy as np
from rulekit.arff import read_arff

DATASET_URL: str = (
    'https://raw.githubusercontent.com/'
    'adaa-polsl/RuleKit/master/data/bmt/'
    'bmt.arff'
)
data_df: pd.DataFrame = read_arff(DATASET_URL)
data_df['survival_status'] = data_df['survival_status'].astype(int).astype(str)
data_df

Recipientgender Stemcellsource Donorage Donorage35 IIIV Gendermatch DonorABO RecipientABO RecipientRh ABOmatch ... extcGvHD CD34kgx10d6 CD3dCD34 CD3dkgx10d8 Rbodymass ANCrecovery PLTrecovery time_to_aGvHD_III_IV survival_time survival_status
0 1 1 22.830137 0 1 0 1 1 1 0 ... 1 7.20 1.338760 5.38 35.0 19.0 51.0 32.0 999.0 0
1 1 0 23.342466 0 1 0 -1 -1 1 0 ... 1 4.50 11.078295 0.41 20.6 16.0 37.0 1000000.0 163.0 1
2 1 0 26.394521 0 1 0 -1 -1 1 0 ... 1 7.94 19.013230 0.42 23.4 23.0 20.0 1000000.0 435.0 1
3 0 0 39.684932 1 1 0 1 2 1 1 ... None 4.25 29.481647 0.14 50.0 23.0 29.0 19.0 53.0 1
4 0 1 33.358904 0 0 0 1 2 0 1 ... 1 51.85 3.972255 13.05 9.0 14.0 14.0 1000000.0 2043.0 0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
182 1 1 37.575342 1 1 0 1 1 0 0 ... 1 11.08 2.522750 4.39 44.0 15.0 22.0 16.0 385.0 1
183 0 1 22.895890 0 0 0 1 0 1 1 ... 1 4.64 1.038858 4.47 44.5 12.0 30.0 1000000.0 634.0 1
184 0 1 27.347945 0 1 0 1 -1 1 1 ... 1 7.73 1.635559 4.73 33.0 16.0 16.0 1000000.0 1895.0 0
185 1 1 27.780822 0 1 0 1 0 1 1 ... 0 15.41 8.077770 1.91 24.0 13.0 14.0 54.0 382.0 1
186 1 1 55.553425 1 1 0 1 2 1 1 ... 1 9.91 0.948135 10.45 37.0 18.0 20.0 1000000.0 1109.0 0

187 rows × 37 columns

[3]:
print("Dataset overview:")
print(f"Name: bmt")
print(f"Objects number: {data_df.shape[0]}; Attributes number: {data_df.shape[1]}")
print("Basic attribute statistics:")
data_df.describe()
Dataset overview:
Name: bmt
Objects number: 187; Attributes number: 37
Basic attribute statistics:
[3]:
Donorage Recipientage CD34kgx10d6 CD3dCD34 CD3dkgx10d8 Rbodymass ANCrecovery PLTrecovery time_to_aGvHD_III_IV survival_time
count 187.000000 187.000000 187.000000 182.000000 182.000000 185.000000 187.000000 187.000000 187.000000 187.000000
mean 33.472068 9.931551 11.891781 5.385096 4.745714 35.801081 26752.866310 90937.919786 775408.042781 938.743316
std 8.271826 5.305639 9.914386 9.598716 3.859128 19.650922 161747.200525 288242.407688 418425.252689 849.589495
min 18.646575 0.600000 0.790000 0.204132 0.040000 6.000000 9.000000 9.000000 10.000000 6.000000
25% 27.039726 5.050000 5.350000 1.786683 1.687500 19.000000 13.000000 16.000000 1000000.000000 168.500000
50% 33.550685 9.600000 9.720000 2.734462 4.325000 33.000000 15.000000 21.000000 1000000.000000 676.000000
75% 40.117809 14.050000 15.415000 5.823565 6.785000 50.600000 17.000000 37.000000 1000000.000000 1604.000000
max 55.553425 20.200000 57.780000 99.560970 20.020000 103.400000 1000000.000000 1000000.000000 1000000.000000 3364.000000

Helper function for creating ruleset characteristics dataframe

[4]:
def get_ruleset_stats(model) -> pd.DataFrame:
    tmp = model.parameters.__dict__
    del tmp['_java_object']
    return pd.DataFrame.from_records([{**tmp, **model.stats.__dict__}])

Rule induction on full dataset

[5]:
X: pd.DataFrame = data_df.drop(['survival_status'], axis=1)
y: pd.Series = data_df['survival_status']
[6]:
from rulekit.survival import SurvivalRules
from rulekit.rules import RuleSet, SurvivalRule

srv = SurvivalRules(survival_time_attr='survival_time')

srv.fit(X, y)

ruleset: RuleSet[SurvivalRule] = srv.model

predictions: np.ndarray = srv.predict(X)


ruleset_stats = get_ruleset_stats(ruleset)


display(ruleset_stats)
minimum_covered maximum_uncovered_fraction ignore_missing pruning_enabled max_growing_condition time_total_s time_growing_s time_pruning_s rules_count conditions_per_rule induced_conditions_per_rule avg_rule_coverage avg_rule_precision avg_rule_quality pvalue FDR_pvalue FWER_pvalue fraction_significant fraction_FDR_significant fraction_FWER_significant
0 0.05 0.0 False True 0.0 1.771417 0.797513 0.902853 5 3.6 65.2 0.308021 1.0 0.999865 0.000135 0.000147 0.000184 1.0 1.0 1.0

Plot predicted estimators for the first five examples

[7]:
import matplotlib.pyplot as plt

for i in range(5):
    plt.step(
        predictions[i]["times"],
        predictions[i]["probabilities"],
        label=f'Example {i}'
    )
plt.legend(title='Example index:')
[7]:
<matplotlib.legend.Legend at 0x289deb3a180>
../../_images/rst_tutorials_survival_13_1.png

Plot rules Kaplan-Meier’s estimators on top of the training dataset estimator

[8]:
from rulekit.kaplan_meier import KaplanMeierEstimator

# plot rules kaplan-meier curves
for i, rule in enumerate(ruleset.rules):
    rule_label: str = f'r{i + 1}'
    rule_km: KaplanMeierEstimator = rule.kaplan_meier_estimator
    plt.step(
        rule_km.times,
        rule_km.probabilities,
        label=rule_label
    )
    print(f'{rule_label}: {rule}')

# plot whole dataset kaplan-meier curve
train_km: KaplanMeierEstimator = srv.get_train_set_kaplan_meier()
plt.step(
    train_km.times,
    train_km.probabilities,
    label='Training set estimator'
)

plt.legend()
plt.show()
r1: IF Donorage = (-inf, 45.16) AND Relapse = {0} AND Recipientage = (-inf, 17.45) THEN
r2: IF Donorage = (-inf, 43.63) AND HLAmismatch = {0} AND Relapse = {1} THEN
r3: IF PLTrecovery = (-inf, 266) AND time_to_aGvHD_III_IV = <12.50, inf) AND ANCrecovery = <10.50, 19.50) AND Rbodymass = (-inf, 69) AND Donorage = (-inf, 44.06) AND Recipientage = <4.60, inf) AND CD34kgx10d6 = (-inf, 16.98) THEN
r4: IF Donorage = <37.16, inf) AND Recipientage = <5.15, inf) AND time_to_aGvHD_III_IV = <23.50, inf) AND CD3dCD34 = <0.90, 73.72) THEN
r5: IF Recipientage = <17.85, 18.85) THEN
../../_images/rst_tutorials_survival_15_1.png

Rules evaluation on full set

[9]:
integrated_brier_score = srv.score(X, y)
print(f'Integrated Brier Score: {integrated_brier_score}')
Integrated Brier Score: 0.19651358709002972

Stratified K-Folds cross-validation

[10]:
from sklearn.model_selection import StratifiedKFold
from rulekit.exceptions import RuleKitJavaException

skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)

ruleset_stats = pd.DataFrame()
survival_metrics = []

for train_index, test_index in skf.split(X, y):
    X_train, X_test = X.iloc[train_index], X.iloc[test_index]
    y_train, y_test = y.iloc[train_index], y.iloc[test_index]

    srv = SurvivalRules(
        survival_time_attr='survival_time'
    )
    srv.fit(X_train, y_train)

    ruleset = srv.model

    ibs: float = srv.score(X_test, y_test)

    survival_metrics.append(ibs)
    ruleset_stats = pd.concat([ruleset_stats, get_ruleset_stats(ruleset)])

Ruleset characteristics (average)

[11]:
display(ruleset_stats.mean())
minimum_covered                 0.050000
maximum_uncovered_fraction      0.000000
ignore_missing                  0.000000
pruning_enabled                 1.000000
max_growing_condition           0.000000
time_total_s                    0.799019
time_growing_s                  0.296248
time_pruning_s                  0.477474
rules_count                     4.000000
conditions_per_rule             2.581667
induced_conditions_per_rule    59.825000
avg_rule_coverage               0.486613
avg_rule_precision              1.000000
avg_rule_quality                0.995955
pvalue                          0.004045
FDR_pvalue                      0.004061
FWER_pvalue                     0.004104
fraction_significant            0.980000
fraction_FDR_significant        0.980000
fraction_FWER_significant       0.980000
dtype: float64

Rules evaluation on dataset (average)

[12]:
print(f'Integrated Brier Score: {np.mean(survival_metrics)}')
Integrated Brier Score: 0.20178456199764142

Hyperparameters tuning

This one gonna take a while…

[13]:
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import GridSearchCV
[14]:
def scorer(estimator: SurvivalRules, X: pd.DataFrame, y: pd.Series) -> float:
    return -1 * estimator.score(X, y)
[18]:
# define models and parameters
model = SurvivalRules(survival_time_attr='survival_time')

# define grid search
grid = {
    'survival_time_attr': ['survival_time'],
    'minsupp_new': range(1, 10),
}

cv = StratifiedKFold(n_splits=3)
grid_search = GridSearchCV(estimator=model, param_grid=grid, cv=cv, scoring=scorer)
grid_result = grid_search.fit(X, y)

# summarize results
print(
    'Best Integrated Brier Score: '
    f'{grid_result.best_score_} using {grid_result.best_params_}'
)
Best Integrated Brier Score: -0.21437408819868886 using {'minsupp_new': 3, 'survival_time_attr': 'survival_time'}

Building model with tuned hyperparameters

Split dataset to train and test (80%/20%)

[19]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=True, stratify=y)

srv = SurvivalRules(
    survival_time_attr='survival_time',
    minsupp_new=5
)
srv.fit(X_train, y_train)
ruleset: RuleSet[SurvivalRule] = srv.model
ruleset_stats: pd.DataFrame = get_ruleset_stats(ruleset)

Rules evaluation

[20]:
display(ruleset_stats.iloc[0])
minimum_covered                     5.0
maximum_uncovered_fraction          0.0
ignore_missing                    False
pruning_enabled                    True
max_growing_condition               0.0
time_total_s                   0.594173
time_growing_s                 0.244234
time_pruning_s                 0.312523
rules_count                           4
conditions_per_rule                2.25
induced_conditions_per_rule       55.25
avg_rule_coverage              0.389262
avg_rule_precision                  1.0
avg_rule_quality                    1.0
pvalue                              0.0
FDR_pvalue                          0.0
FWER_pvalue                         0.0
fraction_significant                1.0
fraction_FDR_significant            1.0
fraction_FWER_significant           1.0
Name: 0, dtype: object

Validate model on test dataset

[21]:
integrated_brier_score = srv.score(X_test, y_test)
print(f'Integrated Brier Score: {integrated_brier_score}')
Integrated Brier Score: 0.14054870564224475
[22]:
predictions = srv.predict(X_test)
[29]:
for i in range(5):
    plt.step(
        predictions[i]["times"],
        predictions[i]["probabilities"],
        label=f'Example {i}'
    )
plt.legend(title='Examples indices:')
plt.title('Predicted Kaplan-Meier curves for 5 examples from test set')
[29]:
Text(0.5, 1.0, 'Predicted Kaplan-Meier curves for 5 examples from test set')
../../_images/rst_tutorials_survival_35_1.png