{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Raport_przezyciowy.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "name": "rulekit", "display_name": "rulekit", "language": "python" }, "language_info": { "name": "python", "version": "3.8.6" }, "metadata": { "interpreter": { "hash": "62266c16fff41e971c13e9cb2ad3d47e4ef45d0678714c255381eb9fdcbd7032" } } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "-Uy-yBGsd9W1" }, "source": [ "# Survival analysis" ] }, { "source": [ "This notebook presents example usage of package for solving survival problem on `bmt` dataset. You can download dataset [here](https://raw.githubusercontent.com/adaa-polsl/RuleKit/master/data/bmt/bmt.arff) \n", "\n", "This tutorial will cover topics such as: \n", "- training model \n", "- changing model hyperparameters \n", "- hyperparameters tuning \n", "- calculating metrics for model \n", "- getting RuleKit inbuilt " ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "markdown", "metadata": { "id": "KjtU7PA8eOTr" }, "source": [ "## Summary of the dataset" ] }, { "cell_type": "code", "metadata": { "id": "Tp1TpfCkd58n" }, "source": [ "from scipy.io import arff\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "datasets_path = \"\" \n", "\n", "file_name = 'bmt.arff'\n", "\n", "data_df = pd.DataFrame(arff.loadarff(open(datasets_path + file_name, 'r', encoding=\"cp1252\"))[0])\n", "\n", "# code to fix the problem with encoding of the file\n", "tmp_df = data_df.select_dtypes([object]) \n", "tmp_df = tmp_df.stack().str.decode(\"cp1252\").unstack()\n", "for col in tmp_df:\n", " data_df[col] = tmp_df[col]\n", " \n", "data_df = data_df.replace({'?': None})" ], "execution_count": 1, "outputs": [] }, { "cell_type": "code", "execution_count": 97, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " Recipientgender Stemcellsource Donorage Donorage35 IIIV Gendermatch \\\n", "0 1 1 22.830137 0 1 0 \n", "1 1 0 23.342466 0 1 0 \n", "2 1 0 26.394521 0 1 0 \n", "3 0 0 39.684932 1 1 0 \n", "4 0 1 33.358904 0 0 0 \n", ".. ... ... ... ... ... ... \n", "182 1 1 37.575342 1 1 0 \n", "183 0 1 22.895890 0 0 0 \n", "184 0 1 27.347945 0 1 0 \n", "185 1 1 27.780822 0 1 0 \n", "186 1 1 55.553425 1 1 0 \n", "\n", " DonorABO RecipientABO RecipientRh ABOmatch ... extcGvHD CD34kgx10d6 \\\n", "0 1 1 1 0 ... 1 7.20 \n", "1 -1 -1 1 0 ... 1 4.50 \n", "2 -1 -1 1 0 ... 1 7.94 \n", "3 1 2 1 1 ... None 4.25 \n", "4 1 2 0 1 ... 1 51.85 \n", ".. ... ... ... ... ... ... ... \n", "182 1 1 0 0 ... 1 11.08 \n", "183 1 0 1 1 ... 1 4.64 \n", "184 1 -1 1 1 ... 1 7.73 \n", "185 1 0 1 1 ... 0 15.41 \n", "186 1 2 1 1 ... 1 9.91 \n", "\n", " CD3dCD34 CD3dkgx10d8 Rbodymass ANCrecovery PLTrecovery \\\n", "0 1.338760 5.38 35.0 19.0 51.0 \n", "1 11.078295 0.41 20.6 16.0 37.0 \n", "2 19.013230 0.42 23.4 23.0 20.0 \n", "3 29.481647 0.14 50.0 23.0 29.0 \n", "4 3.972255 13.05 9.0 14.0 14.0 \n", ".. ... ... ... ... ... \n", "182 2.522750 4.39 44.0 15.0 22.0 \n", "183 1.038858 4.47 44.5 12.0 30.0 \n", "184 1.635559 4.73 33.0 16.0 16.0 \n", "185 8.077770 1.91 24.0 13.0 14.0 \n", "186 0.948135 10.45 37.0 18.0 20.0 \n", "\n", " time_to_aGvHD_III_IV survival_time survival_status \n", "0 32.0 999.0 0.0 \n", "1 1000000.0 163.0 1.0 \n", "2 1000000.0 435.0 1.0 \n", "3 19.0 53.0 1.0 \n", "4 1000000.0 2043.0 0.0 \n", ".. ... ... ... \n", "182 16.0 385.0 1.0 \n", "183 1000000.0 634.0 1.0 \n", "184 1000000.0 1895.0 0.0 \n", "185 54.0 382.0 1.0 \n", "186 1000000.0 1109.0 0.0 \n", "\n", "[187 rows x 37 columns]" ], "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
RecipientgenderStemcellsourceDonorageDonorage35IIIVGendermatchDonorABORecipientABORecipientRhABOmatch...extcGvHDCD34kgx10d6CD3dCD34CD3dkgx10d8RbodymassANCrecoveryPLTrecoverytime_to_aGvHD_III_IVsurvival_timesurvival_status
01122.8301370101110...17.201.3387605.3835.019.051.032.0999.00.0
11023.342466010-1-110...14.5011.0782950.4120.616.037.01000000.0163.01.0
21026.394521010-1-110...17.9419.0132300.4223.423.020.01000000.0435.01.0
30039.6849321101211...None4.2529.4816470.1450.023.029.019.053.01.0
40133.3589040001201...151.853.97225513.059.014.014.01000000.02043.00.0
..................................................................
1821137.5753421101100...111.082.5227504.3944.015.022.016.0385.01.0
1830122.8958900001011...14.641.0388584.4744.512.030.01000000.0634.01.0
1840127.3479450101-111...17.731.6355594.7333.016.016.01000000.01895.00.0
1851127.7808220101011...015.418.0777701.9124.013.014.054.0382.01.0
1861155.5534251101211...19.910.94813510.4537.018.020.01000000.01109.00.0
\n

187 rows × 37 columns

\n
" }, "metadata": {}, "execution_count": 97 } ], "source": [ "data_df" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 378 }, "id": "y9uVi9SFeZSa", "outputId": "6809c06d-5d8c-48a0-9b6d-3c433574f7f7" }, "source": [ "print(\"Dataset overview:\")\n", "print(f\"Name: {file_name}\")\n", "print(f\"Objects number: {data_df.shape[0]}; Attributes number: {data_df.shape[1]}\")\n", "print(\"Basic attribute statistics:\")\n", "data_df.describe()" ], "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Dataset overview:\nName: bmt.arff\nObjects number: 187; Attributes number: 37\nBasic attribute statistics:\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ " Donorage Recipientage CD34kgx10d6 CD3dCD34 CD3dkgx10d8 \\\n", "count 187.000000 187.000000 187.000000 182.000000 182.000000 \n", "mean 33.472068 9.931551 11.891781 5.385096 4.745714 \n", "std 8.271826 5.305639 9.914386 9.598716 3.859128 \n", "min 18.646575 0.600000 0.790000 0.204132 0.040000 \n", "25% 27.039726 5.050000 5.350000 1.786683 1.687500 \n", "50% 33.550685 9.600000 9.720000 2.734462 4.325000 \n", "75% 40.117809 14.050000 15.415000 5.823565 6.785000 \n", "max 55.553425 20.200000 57.780000 99.560970 20.020000 \n", "\n", " Rbodymass ANCrecovery PLTrecovery time_to_aGvHD_III_IV \\\n", "count 185.000000 187.000000 187.000000 187.000000 \n", "mean 35.801081 26752.866310 90937.919786 775408.042781 \n", "std 19.650922 161747.200525 288242.407688 418425.252689 \n", "min 6.000000 9.000000 9.000000 10.000000 \n", "25% 19.000000 13.000000 16.000000 1000000.000000 \n", "50% 33.000000 15.000000 21.000000 1000000.000000 \n", "75% 50.600000 17.000000 37.000000 1000000.000000 \n", "max 103.400000 1000000.000000 1000000.000000 1000000.000000 \n", "\n", " survival_time survival_status \n", "count 187.000000 187.000000 \n", "mean 938.743316 0.454545 \n", "std 849.589495 0.499266 \n", "min 6.000000 0.000000 \n", "25% 168.500000 0.000000 \n", "50% 676.000000 0.000000 \n", "75% 1604.000000 1.000000 \n", "max 3364.000000 1.000000 " ], "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
DonorageRecipientageCD34kgx10d6CD3dCD34CD3dkgx10d8RbodymassANCrecoveryPLTrecoverytime_to_aGvHD_III_IVsurvival_timesurvival_status
count187.000000187.000000187.000000182.000000182.000000185.000000187.000000187.000000187.000000187.000000187.000000
mean33.4720689.93155111.8917815.3850964.74571435.80108126752.86631090937.919786775408.042781938.7433160.454545
std8.2718265.3056399.9143869.5987163.85912819.650922161747.200525288242.407688418425.252689849.5894950.499266
min18.6465750.6000000.7900000.2041320.0400006.0000009.0000009.00000010.0000006.0000000.000000
25%27.0397265.0500005.3500001.7866831.68750019.00000013.00000016.0000001000000.000000168.5000000.000000
50%33.5506859.6000009.7200002.7344624.32500033.00000015.00000021.0000001000000.000000676.0000000.000000
75%40.11780914.05000015.4150005.8235656.78500050.60000017.00000037.0000001000000.0000001604.0000001.000000
max55.55342520.20000057.78000099.56097020.020000103.4000001000000.0000001000000.0000001000000.0000003364.0000001.000000
\n
" }, "metadata": {}, "execution_count": 2 } ] }, { "cell_type": "markdown", "metadata": { "id": "al2J-WKIesF7" }, "source": [ "### Survival curve for the entire set (Kaplan Meier curve)" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 296 }, "id": "vQSEvAcRfES0", "outputId": "acd03cb0-a7a8-4d7c-dee9-7f0bb0bddd6f" }, "source": [ "from lifelines import KaplanMeierFitter\n", "\n", "# create a kmf object\n", "kmf = KaplanMeierFitter() \n", "\n", "# Fit the data into the model\n", "kmf.fit(data_df['survival_time'], data_df['survival_status'],label='Kaplan Meier Estimate')\n", "\n", "# Create an estimate\n", "kmf.plot(ci_show=False) " ], "execution_count": 3, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 3 }, { "output_type": "display_data", "data": { "text/plain": "
", "image/svg+xml": "\r\n\r\n\r\n \r\n \r\n \r\n \r\n 2021-04-22T08:04:41.255915\r\n image/svg+xml\r\n \r\n \r\n Matplotlib v3.4.1, https://matplotlib.org/\r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n\r\n", "image/png": "\n" }, "metadata": { "needs_background": "light" } } ] }, { "cell_type": "markdown", "metadata": { "id": "r8TgXKGmmSJf" }, "source": [ "## Import and init RuleKit" ] }, { "cell_type": "code", "metadata": { "id": "w0hYM-8Ele2j" }, "source": [ "from rulekit import RuleKit\n", "from rulekit.survival import SurvivalRules\n", "from rulekit.params import Measures\n", "\n", "\n", "RuleKit.init()" ], "execution_count": 4, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "dattexxGmaqJ" }, "source": [ "## Helper function for creating ruleset characteristics dataframe" ] }, { "cell_type": "code", "metadata": { "id": "aLCZkT_SmU4a" }, "source": [ "def get_ruleset_stats(model) -> pd.DataFrame:\n", " tmp = model.parameters.__dict__\n", " del tmp['_java_object']\n", " return pd.DataFrame.from_records([{**tmp, **model.stats.__dict__}])" ], "execution_count": 14, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "u4wOfecjme_d" }, "source": [ "## Rule induction on full dataset" ] }, { "cell_type": "code", "metadata": { "id": "TrO-LyN2mpiP" }, "source": [ "x = data_df.drop(['survival_status'], axis=1)\n", "y = data_df['survival_status']" ], "execution_count": 6, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "c5tmU4IHnFjw" }, "source": [ "srv = SurvivalRules(\n", " survival_time_attr = 'survival_time'\n", ")\n", "srv.fit(x, y)\n", "ruleset = srv.model\n", "predictions = srv.predict(x)\n", "\n", "ruleset_stats = get_ruleset_stats(ruleset)\n", "\n", "display(ruleset_stats)" ], "execution_count": 10, "outputs": [] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[]" ] }, "metadata": {}, "execution_count": 13 }, { "output_type": "display_data", "data": { "text/plain": "
", "image/svg+xml": "\r\n\r\n\r\n \r\n \r\n \r\n \r\n 2021-04-21T13:07:16.829628\r\n image/svg+xml\r\n \r\n \r\n Matplotlib v3.4.1, https://matplotlib.org/\r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n\r\n", "image/png": "\n" }, "metadata": { "needs_background": "light" } } ], "source": [ "plt.plot(predictions[0][\"times\"], predictions[0][\"probabilities\"])" ] }, { "cell_type": "markdown", "metadata": { "id": "abrtDQOtpVoL" }, "source": [ "### Generated rules" ] }, { "cell_type": "code", "metadata": { "id": "FskFiB6PpVI_" }, "source": [ "for rule in ruleset.rules:\n", " print(rule)" ], "execution_count": 14, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "IF Relapse = {0} AND Donorage = (-inf, 45.16) AND Recipientage = (-inf, 17.45) THEN survival_status = {NaN}\nIF HLAmismatch = {0} AND Donorage = <33.34, 42.14) AND Gendermatch = {0} AND RecipientRh = {1} AND Recipientage = <3.30, inf) THEN survival_status = {NaN}\nIF Relapse = {1} AND PLTrecovery = <15.50, inf) THEN survival_status = {NaN}\nIF PLTrecovery = (-inf, 266) THEN survival_status = {NaN}\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "GkNBc5iBpwmj" }, "source": [ "### Rules evaluation on full set" ] }, { "cell_type": "code", "metadata": { "id": "9UjrC8r-p59d" }, "source": [ "integrated_brier_score = srv.score(x, y)\n", "print(f'Integrated Brier Score: {integrated_brier_score}')" ], "execution_count": 16, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Integrated Brier Score: 0.2154545054362314\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "FpoSoaKdqAGQ" }, "source": [ "## Stratified K-Folds cross-validation" ] }, { "cell_type": "code", "metadata": { "id": "0nNv6a84qTsq" }, "source": [ "from sklearn.model_selection import StratifiedKFold\n", "\n", "skf = StratifiedKFold(n_splits=10)\n", "\n", "ruleset_stats = pd.DataFrame()\n", "survival_metrics = []\n", "\n", "\n", "for train_index, test_index in skf.split(x, y):\n", " x_train, x_test = x.iloc[train_index], x.iloc[test_index]\n", " y_train, y_test = y.iloc[train_index], y.iloc[test_index]\n", "\n", " srv = SurvivalRules(\n", " survival_time_attr = 'survival_time'\n", " )\n", " srv.fit(x_train, y_train)\n", " ruleset = srv.model\n", "\n", " ibs = srv.score(x_test, y_test)\n", "\n", " survival_metrics.append(ibs)\n", " ruleset_stats = pd.concat([ruleset_stats, get_ruleset_stats(ruleset)])\n" ], "execution_count": 9, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "MfCOH_f3sICm" }, "source": [ "Ruleset characteristics (average)" ] }, { "cell_type": "code", "metadata": { "id": "xzbazr51sRd3" }, "source": [ "display(ruleset_stats.mean())" ], "execution_count": 10, "outputs": [ { "output_type": "display_data", "data": { "text/plain": "minimum_covered 5.000000\nmaximum_uncovered_fraction 0.000000\nignore_missing 0.000000\npruning_enabled 1.000000\nmax_growing_condition 0.000000\ntime_total_s 39.927018\ntime_growing_s 35.729990\ntime_pruning_s 4.192022\nrules_count 5.700000\nconditions_per_rule 3.329405\ninduced_conditions_per_rule 68.789167\navg_rule_coverage 0.389676\navg_rule_precision 1.000000\navg_rule_quality 0.998029\npvalue 0.001971\nFDR_pvalue 0.002043\nFWER_pvalue 0.002328\nfraction_significant 1.000000\nfraction_FDR_significant 1.000000\nfraction_FWER_significant 1.000000\ndtype: float64" }, "metadata": {} } ] }, { "cell_type": "markdown", "metadata": { "id": "_SmDJho4sVEO" }, "source": [ "Rules evaluation on dataset (average)" ] }, { "cell_type": "code", "metadata": { "id": "Co-fNd9nshWB" }, "source": [ "print(f'Integrated Brier Score: {np.mean(survival_metrics)}')" ], "execution_count": 15, "outputs": [ { "output_type": "display_data", "data": { "text/plain": "0.24666778331912664" }, "metadata": {} } ] }, { "cell_type": "markdown", "metadata": { "id": "d-GdQ-wUtzW9" }, "source": [ "## Hyperparameters tuning\n", "\n", "This one gonna take a while..." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import StratifiedKFold\n", "from sklearn.model_selection import GridSearchCV\n", "from rulekit.params import Measures" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def scorer(estimator, X, y):\n", " return (-1 * estimator.score(X,y))" ] }, { "cell_type": "code", "metadata": { "id": "xNUji8U7t2wd" }, "source": [ "# define models and parameters\n", "model = SurvivalRules()\n", "min_rule_covered = range(3, 15)\n", "\n", "# define grid search\n", "grid = {\n", " 'survival_time_attr': ['survival_time'],\n", " 'min_rule_covered': min_rule_covered, \n", "}\n", "\n", "cv = StratifiedKFold(n_splits=3)\n", "grid_search = GridSearchCV(estimator=model, param_grid=grid, cv=cv, scoring=scorer) \n", "grid_result = grid_search.fit(x, y)\n", "\n", "# summarize results\n", "print(\"Best Integrated Brier Score: %f using %s\" % ( (-1)*grid_result.best_score_, grid_result.best_params_))" ], "execution_count": 18, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Best Integrated Brier Score: 0.228297 using {'min_rule_covered': 5, 'survival_time_attr': 'survival_time'}\n" ] } ] }, { "source": [ "## Building model with tuned hyperparameters\n", "\n", "### Split dataset to train and test (80%/20%)" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, shuffle=True, stratify=y)\n", "\n", "srv = SurvivalRules(\n", " survival_time_attr = 'survival_time',\n", " min_rule_covered = 5\n", ")\n", "srv.fit(x_train, y_train)\n", "ruleset = srv.model\n", "ruleset_stats = get_ruleset_stats(ruleset)" ] }, { "source": [ "Rules evaluation" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "display(ruleset_stats.iloc[0])" ] }, { "source": [ "### Validate model on test dataset" ], "cell_type": "markdown", "metadata": {} }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Integrated Brier Score: 0.19112015424542664\n" ] } ], "source": [ "integrated_brier_score = srv.score(x_test, y_test)\n", "print(f'Integrated Brier Score: {integrated_brier_score}')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "predictions = srv.predict(x_test)" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[]" ] }, "metadata": {}, "execution_count": 80 }, { "output_type": "display_data", "data": { "text/plain": "
", "image/svg+xml": "\r\n\r\n\r\n \r\n \r\n \r\n \r\n 2021-04-22T09:20:30.747190\r\n image/svg+xml\r\n \r\n \r\n Matplotlib v3.4.1, https://matplotlib.org/\r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n\r\n", "image/png": "\n" }, "metadata": { "needs_background": "light" } } ], "source": [ "plt.plot(predictions[0][\"times\"], predictions[0][\"probabilities\"])" ] } ] }