{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "-Uy-yBGsd9W1"
},
"source": [
"# Survival analysis"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook presents example usage of package for solving survival problem on `bmt` dataset. You can download dataset [here](https://raw.githubusercontent.com/adaa-polsl/RuleKit/master/data/bmt/bmt.arff) \n",
"\n",
"This tutorial will cover topics such as: \n",
"- training model \n",
"- changing model hyperparameters \n",
"- hyperparameters tuning \n",
"- calculating metrics for model \n",
"- getting RuleKit inbuilt "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KjtU7PA8eOTr"
},
"source": [
"## Summary of the dataset"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "Tp1TpfCkd58n"
},
"outputs": [],
"source": [
"from scipy.io import arff\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"datasets_path = \"\" \n",
"\n",
"file_name = 'bmt.arff'\n",
"\n",
"data_df = pd.DataFrame(arff.loadarff(open(datasets_path + file_name, 'r', encoding=\"cp1252\"))[0])\n",
"\n",
"# code to fix the problem with encoding of the file\n",
"tmp_df = data_df.select_dtypes([object]) \n",
"tmp_df = tmp_df.stack().str.decode(\"cp1252\").unstack()\n",
"for col in tmp_df:\n",
" data_df[col] = tmp_df[col]\n",
" \n",
"data_df = data_df.replace({'?': None})"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Recipientgender | \n",
" Stemcellsource | \n",
" Donorage | \n",
" Donorage35 | \n",
" IIIV | \n",
" Gendermatch | \n",
" DonorABO | \n",
" RecipientABO | \n",
" RecipientRh | \n",
" ABOmatch | \n",
" ... | \n",
" extcGvHD | \n",
" CD34kgx10d6 | \n",
" CD3dCD34 | \n",
" CD3dkgx10d8 | \n",
" Rbodymass | \n",
" ANCrecovery | \n",
" PLTrecovery | \n",
" time_to_aGvHD_III_IV | \n",
" survival_time | \n",
" survival_status | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1 | \n",
" 1 | \n",
" 22.830137 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" ... | \n",
" 1 | \n",
" 7.20 | \n",
" 1.338760 | \n",
" 5.38 | \n",
" 35.0 | \n",
" 19.0 | \n",
" 51.0 | \n",
" 32.0 | \n",
" 999.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1 | \n",
" 0 | \n",
" 23.342466 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" -1 | \n",
" -1 | \n",
" 1 | \n",
" 0 | \n",
" ... | \n",
" 1 | \n",
" 4.50 | \n",
" 11.078295 | \n",
" 0.41 | \n",
" 20.6 | \n",
" 16.0 | \n",
" 37.0 | \n",
" 1000000.0 | \n",
" 163.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1 | \n",
" 0 | \n",
" 26.394521 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" -1 | \n",
" -1 | \n",
" 1 | \n",
" 0 | \n",
" ... | \n",
" 1 | \n",
" 7.94 | \n",
" 19.013230 | \n",
" 0.42 | \n",
" 23.4 | \n",
" 23.0 | \n",
" 20.0 | \n",
" 1000000.0 | \n",
" 435.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0 | \n",
" 0 | \n",
" 39.684932 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
" 1 | \n",
" 1 | \n",
" ... | \n",
" None | \n",
" 4.25 | \n",
" 29.481647 | \n",
" 0.14 | \n",
" 50.0 | \n",
" 23.0 | \n",
" 29.0 | \n",
" 19.0 | \n",
" 53.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0 | \n",
" 1 | \n",
" 33.358904 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
" 0 | \n",
" 1 | \n",
" ... | \n",
" 1 | \n",
" 51.85 | \n",
" 3.972255 | \n",
" 13.05 | \n",
" 9.0 | \n",
" 14.0 | \n",
" 14.0 | \n",
" 1000000.0 | \n",
" 2043.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" | 182 | \n",
" 1 | \n",
" 1 | \n",
" 37.575342 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 0 | \n",
" ... | \n",
" 1 | \n",
" 11.08 | \n",
" 2.522750 | \n",
" 4.39 | \n",
" 44.0 | \n",
" 15.0 | \n",
" 22.0 | \n",
" 16.0 | \n",
" 385.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
" | 183 | \n",
" 0 | \n",
" 1 | \n",
" 22.895890 | \n",
" 0 | \n",
" 0 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" ... | \n",
" 1 | \n",
" 4.64 | \n",
" 1.038858 | \n",
" 4.47 | \n",
" 44.5 | \n",
" 12.0 | \n",
" 30.0 | \n",
" 1000000.0 | \n",
" 634.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
" | 184 | \n",
" 0 | \n",
" 1 | \n",
" 27.347945 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" -1 | \n",
" 1 | \n",
" 1 | \n",
" ... | \n",
" 1 | \n",
" 7.73 | \n",
" 1.635559 | \n",
" 4.73 | \n",
" 33.0 | \n",
" 16.0 | \n",
" 16.0 | \n",
" 1000000.0 | \n",
" 1895.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 185 | \n",
" 1 | \n",
" 1 | \n",
" 27.780822 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 1 | \n",
" ... | \n",
" 0 | \n",
" 15.41 | \n",
" 8.077770 | \n",
" 1.91 | \n",
" 24.0 | \n",
" 13.0 | \n",
" 14.0 | \n",
" 54.0 | \n",
" 382.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
" | 186 | \n",
" 1 | \n",
" 1 | \n",
" 55.553425 | \n",
" 1 | \n",
" 1 | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
" 1 | \n",
" 1 | \n",
" ... | \n",
" 1 | \n",
" 9.91 | \n",
" 0.948135 | \n",
" 10.45 | \n",
" 37.0 | \n",
" 18.0 | \n",
" 20.0 | \n",
" 1000000.0 | \n",
" 1109.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
187 rows × 37 columns
\n",
"
"
],
"text/plain": [
" Recipientgender Stemcellsource Donorage Donorage35 IIIV Gendermatch \\\n",
"0 1 1 22.830137 0 1 0 \n",
"1 1 0 23.342466 0 1 0 \n",
"2 1 0 26.394521 0 1 0 \n",
"3 0 0 39.684932 1 1 0 \n",
"4 0 1 33.358904 0 0 0 \n",
".. ... ... ... ... ... ... \n",
"182 1 1 37.575342 1 1 0 \n",
"183 0 1 22.895890 0 0 0 \n",
"184 0 1 27.347945 0 1 0 \n",
"185 1 1 27.780822 0 1 0 \n",
"186 1 1 55.553425 1 1 0 \n",
"\n",
" DonorABO RecipientABO RecipientRh ABOmatch ... extcGvHD CD34kgx10d6 \\\n",
"0 1 1 1 0 ... 1 7.20 \n",
"1 -1 -1 1 0 ... 1 4.50 \n",
"2 -1 -1 1 0 ... 1 7.94 \n",
"3 1 2 1 1 ... None 4.25 \n",
"4 1 2 0 1 ... 1 51.85 \n",
".. ... ... ... ... ... ... ... \n",
"182 1 1 0 0 ... 1 11.08 \n",
"183 1 0 1 1 ... 1 4.64 \n",
"184 1 -1 1 1 ... 1 7.73 \n",
"185 1 0 1 1 ... 0 15.41 \n",
"186 1 2 1 1 ... 1 9.91 \n",
"\n",
" CD3dCD34 CD3dkgx10d8 Rbodymass ANCrecovery PLTrecovery \\\n",
"0 1.338760 5.38 35.0 19.0 51.0 \n",
"1 11.078295 0.41 20.6 16.0 37.0 \n",
"2 19.013230 0.42 23.4 23.0 20.0 \n",
"3 29.481647 0.14 50.0 23.0 29.0 \n",
"4 3.972255 13.05 9.0 14.0 14.0 \n",
".. ... ... ... ... ... \n",
"182 2.522750 4.39 44.0 15.0 22.0 \n",
"183 1.038858 4.47 44.5 12.0 30.0 \n",
"184 1.635559 4.73 33.0 16.0 16.0 \n",
"185 8.077770 1.91 24.0 13.0 14.0 \n",
"186 0.948135 10.45 37.0 18.0 20.0 \n",
"\n",
" time_to_aGvHD_III_IV survival_time survival_status \n",
"0 32.0 999.0 0.0 \n",
"1 1000000.0 163.0 1.0 \n",
"2 1000000.0 435.0 1.0 \n",
"3 19.0 53.0 1.0 \n",
"4 1000000.0 2043.0 0.0 \n",
".. ... ... ... \n",
"182 16.0 385.0 1.0 \n",
"183 1000000.0 634.0 1.0 \n",
"184 1000000.0 1895.0 0.0 \n",
"185 54.0 382.0 1.0 \n",
"186 1000000.0 1109.0 0.0 \n",
"\n",
"[187 rows x 37 columns]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_df"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 378
},
"id": "y9uVi9SFeZSa",
"outputId": "6809c06d-5d8c-48a0-9b6d-3c433574f7f7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset overview:\n",
"Name: bmt.arff\n",
"Objects number: 187; Attributes number: 37\n",
"Basic attribute statistics:\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Donorage | \n",
" Recipientage | \n",
" CD34kgx10d6 | \n",
" CD3dCD34 | \n",
" CD3dkgx10d8 | \n",
" Rbodymass | \n",
" ANCrecovery | \n",
" PLTrecovery | \n",
" time_to_aGvHD_III_IV | \n",
" survival_time | \n",
" survival_status | \n",
"
\n",
" \n",
" \n",
" \n",
" | count | \n",
" 187.000000 | \n",
" 187.000000 | \n",
" 187.000000 | \n",
" 182.000000 | \n",
" 182.000000 | \n",
" 185.000000 | \n",
" 187.000000 | \n",
" 187.000000 | \n",
" 187.000000 | \n",
" 187.000000 | \n",
" 187.000000 | \n",
"
\n",
" \n",
" | mean | \n",
" 33.472068 | \n",
" 9.931551 | \n",
" 11.891781 | \n",
" 5.385096 | \n",
" 4.745714 | \n",
" 35.801081 | \n",
" 26752.866310 | \n",
" 90937.919786 | \n",
" 775408.042781 | \n",
" 938.743316 | \n",
" 0.454545 | \n",
"
\n",
" \n",
" | std | \n",
" 8.271826 | \n",
" 5.305639 | \n",
" 9.914386 | \n",
" 9.598716 | \n",
" 3.859128 | \n",
" 19.650922 | \n",
" 161747.200525 | \n",
" 288242.407688 | \n",
" 418425.252689 | \n",
" 849.589495 | \n",
" 0.499266 | \n",
"
\n",
" \n",
" | min | \n",
" 18.646575 | \n",
" 0.600000 | \n",
" 0.790000 | \n",
" 0.204132 | \n",
" 0.040000 | \n",
" 6.000000 | \n",
" 9.000000 | \n",
" 9.000000 | \n",
" 10.000000 | \n",
" 6.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 25% | \n",
" 27.039726 | \n",
" 5.050000 | \n",
" 5.350000 | \n",
" 1.786683 | \n",
" 1.687500 | \n",
" 19.000000 | \n",
" 13.000000 | \n",
" 16.000000 | \n",
" 1000000.000000 | \n",
" 168.500000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 50% | \n",
" 33.550685 | \n",
" 9.600000 | \n",
" 9.720000 | \n",
" 2.734462 | \n",
" 4.325000 | \n",
" 33.000000 | \n",
" 15.000000 | \n",
" 21.000000 | \n",
" 1000000.000000 | \n",
" 676.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" | 75% | \n",
" 40.117809 | \n",
" 14.050000 | \n",
" 15.415000 | \n",
" 5.823565 | \n",
" 6.785000 | \n",
" 50.600000 | \n",
" 17.000000 | \n",
" 37.000000 | \n",
" 1000000.000000 | \n",
" 1604.000000 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" | max | \n",
" 55.553425 | \n",
" 20.200000 | \n",
" 57.780000 | \n",
" 99.560970 | \n",
" 20.020000 | \n",
" 103.400000 | \n",
" 1000000.000000 | \n",
" 1000000.000000 | \n",
" 1000000.000000 | \n",
" 3364.000000 | \n",
" 1.000000 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Donorage Recipientage CD34kgx10d6 CD3dCD34 CD3dkgx10d8 \\\n",
"count 187.000000 187.000000 187.000000 182.000000 182.000000 \n",
"mean 33.472068 9.931551 11.891781 5.385096 4.745714 \n",
"std 8.271826 5.305639 9.914386 9.598716 3.859128 \n",
"min 18.646575 0.600000 0.790000 0.204132 0.040000 \n",
"25% 27.039726 5.050000 5.350000 1.786683 1.687500 \n",
"50% 33.550685 9.600000 9.720000 2.734462 4.325000 \n",
"75% 40.117809 14.050000 15.415000 5.823565 6.785000 \n",
"max 55.553425 20.200000 57.780000 99.560970 20.020000 \n",
"\n",
" Rbodymass ANCrecovery PLTrecovery time_to_aGvHD_III_IV \\\n",
"count 185.000000 187.000000 187.000000 187.000000 \n",
"mean 35.801081 26752.866310 90937.919786 775408.042781 \n",
"std 19.650922 161747.200525 288242.407688 418425.252689 \n",
"min 6.000000 9.000000 9.000000 10.000000 \n",
"25% 19.000000 13.000000 16.000000 1000000.000000 \n",
"50% 33.000000 15.000000 21.000000 1000000.000000 \n",
"75% 50.600000 17.000000 37.000000 1000000.000000 \n",
"max 103.400000 1000000.000000 1000000.000000 1000000.000000 \n",
"\n",
" survival_time survival_status \n",
"count 187.000000 187.000000 \n",
"mean 938.743316 0.454545 \n",
"std 849.589495 0.499266 \n",
"min 6.000000 0.000000 \n",
"25% 168.500000 0.000000 \n",
"50% 676.000000 0.000000 \n",
"75% 1604.000000 1.000000 \n",
"max 3364.000000 1.000000 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(\"Dataset overview:\")\n",
"print(f\"Name: {file_name}\")\n",
"print(f\"Objects number: {data_df.shape[0]}; Attributes number: {data_df.shape[1]}\")\n",
"print(\"Basic attribute statistics:\")\n",
"data_df.describe()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "al2J-WKIesF7"
},
"source": [
"### Survival curve for the entire set (Kaplan Meier curve)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 296
},
"id": "vQSEvAcRfES0",
"outputId": "acd03cb0-a7a8-4d7c-dee9-7f0bb0bddd6f"
},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from lifelines import KaplanMeierFitter\n",
"\n",
"# create a kmf object\n",
"kmf = KaplanMeierFitter() \n",
"\n",
"# Fit the data into the model\n",
"kmf.fit(data_df['survival_time'], data_df['survival_status'],label='Kaplan Meier Estimate')\n",
"\n",
"# Create an estimate\n",
"kmf.plot(ci_show=False) "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r8TgXKGmmSJf"
},
"source": [
"## Import RuleKit"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "w0hYM-8Ele2j"
},
"outputs": [],
"source": [
"from rulekit.survival import SurvivalRules\n",
"from rulekit.params import Measures"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dattexxGmaqJ"
},
"source": [
"## Helper function for creating ruleset characteristics dataframe"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "aLCZkT_SmU4a"
},
"outputs": [],
"source": [
"def get_ruleset_stats(model) -> pd.DataFrame:\n",
" tmp = model.parameters.__dict__\n",
" del tmp['_java_object']\n",
" return pd.DataFrame.from_records([{**tmp, **model.stats.__dict__}])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u4wOfecjme_d"
},
"source": [
"## Rule induction on full dataset"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "TrO-LyN2mpiP"
},
"outputs": [],
"source": [
"X = data_df.drop(['survival_status'], axis=1)\n",
"y = data_df['survival_status']"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "c5tmU4IHnFjw"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" minimum_covered | \n",
" maximum_uncovered_fraction | \n",
" ignore_missing | \n",
" pruning_enabled | \n",
" max_growing_condition | \n",
" time_total_s | \n",
" time_growing_s | \n",
" time_pruning_s | \n",
" rules_count | \n",
" conditions_per_rule | \n",
" induced_conditions_per_rule | \n",
" avg_rule_coverage | \n",
" avg_rule_precision | \n",
" avg_rule_quality | \n",
" pvalue | \n",
" FDR_pvalue | \n",
" FWER_pvalue | \n",
" fraction_significant | \n",
" fraction_FDR_significant | \n",
" fraction_FWER_significant | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 5.0 | \n",
" 0.0 | \n",
" False | \n",
" True | \n",
" 0.0 | \n",
" 3.486387 | \n",
" 1.575069 | \n",
" 1.873982 | \n",
" 4 | \n",
" 3.0 | \n",
" 86.25 | \n",
" 0.485294 | \n",
" 1.0 | \n",
" 0.999678 | \n",
" 0.000322 | \n",
" 0.000375 | \n",
" 0.00048 | \n",
" 1.0 | \n",
" 1.0 | \n",
" 1.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" minimum_covered maximum_uncovered_fraction ignore_missing \\\n",
"0 5.0 0.0 False \n",
"\n",
" pruning_enabled max_growing_condition time_total_s time_growing_s \\\n",
"0 True 0.0 3.486387 1.575069 \n",
"\n",
" time_pruning_s rules_count conditions_per_rule \\\n",
"0 1.873982 4 3.0 \n",
"\n",
" induced_conditions_per_rule avg_rule_coverage avg_rule_precision \\\n",
"0 86.25 0.485294 1.0 \n",
"\n",
" avg_rule_quality pvalue FDR_pvalue FWER_pvalue fraction_significant \\\n",
"0 0.999678 0.000322 0.000375 0.00048 1.0 \n",
"\n",
" fraction_FDR_significant fraction_FWER_significant \n",
"0 1.0 1.0 "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"srv = SurvivalRules(\n",
" survival_time_attr = 'survival_time'\n",
")\n",
"srv.fit(X, y)\n",
"ruleset = srv.model\n",
"predictions = srv.predict(X)\n",
"\n",
"ruleset_stats = get_ruleset_stats(ruleset)\n",
"\n",
"display(ruleset_stats)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(predictions[0][\"times\"], predictions[0][\"probabilities\"])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "abrtDQOtpVoL"
},
"source": [
"### Generated rules"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "FskFiB6PpVI_"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"IF Relapse = {0} AND Donorage = (-inf, 45.16) AND Recipientage = (-inf, 17.45) THEN \n",
"IF HLAmismatch = {0} AND Donorage = <33.34, 42.14) AND Gendermatch = {0} AND RecipientRh = {1} AND Recipientage = <3.30, inf) THEN \n",
"IF Relapse = {1} AND PLTrecovery = <15.50, inf) THEN \n",
"IF PLTrecovery = (-inf, 266) THEN \n"
]
}
],
"source": [
"for rule in ruleset.rules:\n",
" print(rule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GkNBc5iBpwmj"
},
"source": [
"### Rules evaluation on full set"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "9UjrC8r-p59d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Integrated Brier Score: 0.2154545054362314\n"
]
}
],
"source": [
"integrated_brier_score = srv.score(X, y)\n",
"print(f'Integrated Brier Score: {integrated_brier_score}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FpoSoaKdqAGQ"
},
"source": [
"## Stratified K-Folds cross-validation"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "0nNv6a84qTsq"
},
"outputs": [],
"source": [
"from sklearn.model_selection import StratifiedKFold\n",
"\n",
"skf = StratifiedKFold(n_splits=10)\n",
"\n",
"ruleset_stats = pd.DataFrame()\n",
"survival_metrics = []\n",
"\n",
"\n",
"for train_index, test_index in skf.split(X, y):\n",
" x_train, x_test = X.iloc[train_index], X.iloc[test_index]\n",
" y_train, y_test = y.iloc[train_index], y.iloc[test_index]\n",
"\n",
" srv = SurvivalRules(\n",
" survival_time_attr = 'survival_time'\n",
" )\n",
" srv.fit(x_train, y_train)\n",
" ruleset = srv.model\n",
"\n",
" ibs = srv.score(x_test, y_test)\n",
"\n",
" survival_metrics.append(ibs)\n",
" ruleset_stats = pd.concat([ruleset_stats, get_ruleset_stats(ruleset)])\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MfCOH_f3sICm"
},
"source": [
"Ruleset characteristics (average)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "xzbazr51sRd3"
},
"outputs": [
{
"data": {
"text/plain": [
"minimum_covered 5.000000\n",
"maximum_uncovered_fraction 0.000000\n",
"ignore_missing 0.000000\n",
"pruning_enabled 1.000000\n",
"max_growing_condition 0.000000\n",
"time_total_s 2.550110\n",
"time_growing_s 1.032773\n",
"time_pruning_s 1.515146\n",
"rules_count 5.700000\n",
"conditions_per_rule 3.329405\n",
"induced_conditions_per_rule 68.789167\n",
"avg_rule_coverage 0.389676\n",
"avg_rule_precision 1.000000\n",
"avg_rule_quality 0.998029\n",
"pvalue 0.001971\n",
"FDR_pvalue 0.002043\n",
"FWER_pvalue 0.002328\n",
"fraction_significant 1.000000\n",
"fraction_FDR_significant 1.000000\n",
"fraction_FWER_significant 1.000000\n",
"dtype: float64"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display(ruleset_stats.mean())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_SmDJho4sVEO"
},
"source": [
"Rules evaluation on dataset (average)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "Co-fNd9nshWB"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Integrated Brier Score: 0.24666778331912664\n"
]
}
],
"source": [
"print(f'Integrated Brier Score: {np.mean(survival_metrics)}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d-GdQ-wUtzW9"
},
"source": [
"## Hyperparameters tuning\n",
"\n",
"This one gonna take a while..."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import StratifiedKFold\n",
"from sklearn.model_selection import GridSearchCV\n",
"from rulekit.params import Measures"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def scorer(estimator, X, y):\n",
" return (-1 * estimator.score(X,y))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "xNUji8U7t2wd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best Integrated Brier Score: 0.215827 using {'minsupp_new': 3, 'survival_time_attr': 'survival_time'}\n"
]
}
],
"source": [
"# define models and parameters\n",
"model = SurvivalRules()\n",
"\n",
"# define grid search\n",
"grid = {\n",
" 'survival_time_attr': ['survival_time'],\n",
" 'minsupp_new': range(3, 10),\n",
"}\n",
"\n",
"cv = StratifiedKFold(n_splits=3)\n",
"grid_search = GridSearchCV(estimator=model, param_grid=grid, cv=cv, scoring=scorer) \n",
"grid_result = grid_search.fit(X, y)\n",
"\n",
"# summarize results\n",
"print(\"Best Integrated Brier Score: %f using %s\" % ( (-1)*grid_result.best_score_, grid_result.best_params_))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Building model with tuned hyperparameters\n",
"\n",
"### Split dataset to train and test (80%/20%)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=True, stratify=y)\n",
"\n",
"srv = SurvivalRules(\n",
" survival_time_attr='survival_time',\n",
" minsupp_new=5\n",
")\n",
"srv.fit(X_train, y_train)\n",
"ruleset = srv.model\n",
"ruleset_stats = get_ruleset_stats(ruleset)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Rules evaluation"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"minimum_covered 5.0\n",
"maximum_uncovered_fraction 0.0\n",
"ignore_missing False\n",
"pruning_enabled True\n",
"max_growing_condition 0.0\n",
"time_total_s 2.74258\n",
"time_growing_s 0.960948\n",
"time_pruning_s 1.78038\n",
"rules_count 7\n",
"conditions_per_rule 3.428571\n",
"induced_conditions_per_rule 71.142857\n",
"avg_rule_coverage 0.299137\n",
"avg_rule_precision 1.0\n",
"avg_rule_quality 0.994714\n",
"pvalue 0.005286\n",
"FDR_pvalue 0.005376\n",
"FWER_pvalue 0.005743\n",
"fraction_significant 1.0\n",
"fraction_FDR_significant 1.0\n",
"fraction_FWER_significant 1.0\n",
"Name: 0, dtype: object"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display(ruleset_stats.iloc[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Validate model on test dataset"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Integrated Brier Score: 0.22534751550121987\n"
]
}
],
"source": [
"integrated_brier_score = srv.score(X_test, y_test)\n",
"print(f'Integrated Brier Score: {integrated_brier_score}')"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"predictions = srv.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(predictions[0][\"times\"], predictions[0][\"probabilities\"])"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Raport_przezyciowy.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
},
"metadata": {
"interpreter": {
"hash": "62266c16fff41e971c13e9cb2ad3d47e4ef45d0678714c255381eb9fdcbd7032"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}