{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook presents example usage of package for solving classification problem on `seismic-bumps` dataset. You can access dataset [here](https://raw.githubusercontent.com/adaa-polsl/RuleKit/master/data/seismic-bumps/seismic-bumps.arff).\n", "\n", "This tutorial will cover topics such as: \n", "- training model \n", "- changing model hyperparameters \n", "- hyperparameters tuning \n", "- calculating metrics for model \n", "- getting RuleKit inbuilt " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install matplotlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary of the dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
genergygimpulsgoenergygoimpulsnbumpsnbumps2nbumps3nbumps4nbumps5nbumps6nbumps7nbumps89senergymaxenergyclass
count2.584000e+032584.0000002584.0000002584.0000002584.0000002584.0000002584.0000002584.0000002584.0000002584.02584.02584.02584.0000002584.0000002584.000000
mean9.024252e+04538.57933412.3757744.5089010.8595200.3935760.3928020.0677240.0046440.00.00.04975.2708984278.8506190.065789
std2.292005e+05562.65253680.31905163.1665561.3646160.7837720.7697100.2790590.0680010.00.00.020450.83322219357.4548820.247962
min1.000000e+022.000000-96.000000-96.0000000.0000000.0000000.0000000.0000000.0000000.00.00.00.0000000.0000000.000000
25%1.166000e+04190.000000-37.000000-36.0000000.0000000.0000000.0000000.0000000.0000000.00.00.00.0000000.0000000.000000
50%2.548500e+04379.000000-6.000000-6.0000000.0000000.0000000.0000000.0000000.0000000.00.00.00.0000000.0000000.000000
75%5.283250e+04669.00000038.00000030.2500001.0000001.0000001.0000000.0000000.0000000.00.00.02600.0000002000.0000000.000000
max2.595650e+064518.0000001245.000000838.0000009.0000008.0000007.0000003.0000001.0000000.00.00.0402000.000000400000.0000001.000000
\n", "
" ], "text/plain": [ " genergy gimpuls goenergy goimpuls nbumps \\\n", "count 2.584000e+03 2584.000000 2584.000000 2584.000000 2584.000000 \n", "mean 9.024252e+04 538.579334 12.375774 4.508901 0.859520 \n", "std 2.292005e+05 562.652536 80.319051 63.166556 1.364616 \n", "min 1.000000e+02 2.000000 -96.000000 -96.000000 0.000000 \n", "25% 1.166000e+04 190.000000 -37.000000 -36.000000 0.000000 \n", "50% 2.548500e+04 379.000000 -6.000000 -6.000000 0.000000 \n", "75% 5.283250e+04 669.000000 38.000000 30.250000 1.000000 \n", "max 2.595650e+06 4518.000000 1245.000000 838.000000 9.000000 \n", "\n", " nbumps2 nbumps3 nbumps4 nbumps5 nbumps6 nbumps7 \\\n", "count 2584.000000 2584.000000 2584.000000 2584.000000 2584.0 2584.0 \n", "mean 0.393576 0.392802 0.067724 0.004644 0.0 0.0 \n", "std 0.783772 0.769710 0.279059 0.068001 0.0 0.0 \n", "min 0.000000 0.000000 0.000000 0.000000 0.0 0.0 \n", "25% 0.000000 0.000000 0.000000 0.000000 0.0 0.0 \n", "50% 0.000000 0.000000 0.000000 0.000000 0.0 0.0 \n", "75% 1.000000 1.000000 0.000000 0.000000 0.0 0.0 \n", "max 8.000000 7.000000 3.000000 1.000000 0.0 0.0 \n", "\n", " nbumps89 senergy maxenergy class \n", "count 2584.0 2584.000000 2584.000000 2584.000000 \n", "mean 0.0 4975.270898 4278.850619 0.065789 \n", "std 0.0 20450.833222 19357.454882 0.247962 \n", "min 0.0 0.000000 0.000000 0.000000 \n", "25% 0.0 0.000000 0.000000 0.000000 \n", "50% 0.0 0.000000 0.000000 0.000000 \n", "75% 0.0 2600.000000 2000.000000 0.000000 \n", "max 0.0 402000.000000 400000.000000 1.000000 " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "from rulekit.arff import read_arff\n", "\n", "DATASET_URL: str = (\n", " 'https://raw.githubusercontent.com/'\n", " 'adaa-polsl/RuleKit/refs/heads/master/data/seismic-bumps/'\n", " 'seismic-bumps.arff'\n", ")\n", "\n", "df_full: pd.DataFrame = read_arff(DATASET_URL)\n", "df_full['class'] = df_full['class'].astype(int)\n", "df_full.describe()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Decision class distribution" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "groups = df_full['class'].value_counts()\n", "sizes = [groups[0], groups[1]]\n", "labels = [str(e) for e in groups.index]\n", "\n", "fig1, ax1 = plt.subplots()\n", "ax1.pie(sizes, labels=labels, autopct='%1.1f%%', shadow=True, startangle=90)\n", "ax1.axis('equal')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Helper function for calculating metrics" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import math\n", "from sklearn import metrics\n", "import numpy as np\n", "from rulekit.classification import RuleClassifier\n", "\n", "\n", "X: pd.DataFrame = df_full.drop(['class'], axis=1)\n", "y: pd.Series = df_full['class']\n", "\n", "\n", "def get_prediction_metrics(\n", " measure: str,\n", " y_pred: np.ndarray,\n", " y_true: pd.Series,\n", " classification_metrics: dict\n", ") -> tuple[pd.DataFrame, np.ndarray]:\n", " confusion_matrix: np.ndarray = metrics.confusion_matrix(y_true, y_pred)\n", " tn, fp, fn, tp = confusion_matrix.ravel()\n", " sensitivity: float = tp / (tp + fn)\n", " specificity: float = tn / (tn + fp)\n", " npv: float = tn / (tn + fn)\n", " ppv: float = tp / (tp + fp)\n", "\n", " dictionary = {\n", " 'Measure': measure,\n", " 'Accuracy': metrics.accuracy_score(y_true, y_pred),\n", " 'MAE': metrics.mean_absolute_error(y_true, y_pred),\n", " 'Kappa': metrics.cohen_kappa_score(y_true, y_pred),\n", " 'Balanced accuracy': metrics.balanced_accuracy_score(y_true, y_pred),\n", " 'Logistic loss': metrics.log_loss(y_true, y_pred),\n", " 'Precision': metrics.log_loss(y_true, y_pred),\n", " 'Sensitivity': sensitivity,\n", " 'Specificity': specificity,\n", " 'NPV': npv,\n", " 'PPV': ppv,\n", " 'psep': ppv + npv - 1,\n", " 'Fall-out': fp / (fp + tn),\n", " \"Youden's J statistic\": sensitivity + specificity - 1,\n", " 'Lift': (tp / (tp + fp)) / ((tp + fn) / (tp + tn + fp + fn)),\n", " 'F-measure': 2 * tp / (2 * tp + fp + fn),\n", " 'Fowlkes-Mallows index': metrics.fowlkes_mallows_score(y_true, y_pred),\n", " 'False positive': fp,\n", " 'False negative': fn,\n", " 'True positive': tp,\n", " 'True negative': tn,\n", " 'Rules per example': classification_metrics['rules_per_example'],\n", " 'Voting conflicts': classification_metrics['voting_conflicts'],\n", " 'Geometric mean': math.sqrt(specificity * sensitivity),\n", " 'Geometric mean': math.sqrt(specificity * sensitivity),\n", " }\n", " return pd.DataFrame.from_records([dictionary], index='Measure'), confusion_matrix\n", "\n", "\n", "def get_ruleset_stats(\n", " measure: str,\n", " model: RuleClassifier\n", ") -> pd.DataFrame:\n", " return pd.DataFrame.from_records(\n", " [{'Measure': measure, **model.stats.__dict__}],\n", " index='Measure'\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Rule induction on full dataset" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
time_total_stime_growing_stime_pruning_srules_countconditions_per_ruleinduced_conditions_per_ruleavg_rule_coverageavg_rule_precisionavg_rule_qualitypvalueFDR_pvalueFWER_pvaluefraction_significantfraction_FDR_significantfraction_FWER_significant
Measure
C21.2185501.0362290.111458354.74285722.1428570.2594100.6707930.3221250.0057290.0058790.0106400.9714290.9714290.885714
Correlation0.4714750.3397090.123223215.14285751.6666670.3066120.4691570.2017720.0168410.0173450.0266550.9047620.9047620.857143
RSS1.5140441.3272090.171176143.71428664.4285710.4737950.4845640.2532490.0418920.0440680.0680630.7857140.7857140.714286
\n", "
" ], "text/plain": [ " time_total_s time_growing_s time_pruning_s rules_count \\\n", "Measure \n", "C2 1.218550 1.036229 0.111458 35 \n", "Correlation 0.471475 0.339709 0.123223 21 \n", "RSS 1.514044 1.327209 0.171176 14 \n", "\n", " conditions_per_rule induced_conditions_per_rule \\\n", "Measure \n", "C2 4.742857 22.142857 \n", "Correlation 5.142857 51.666667 \n", "RSS 3.714286 64.428571 \n", "\n", " avg_rule_coverage avg_rule_precision avg_rule_quality \\\n", "Measure \n", "C2 0.259410 0.670793 0.322125 \n", "Correlation 0.306612 0.469157 0.201772 \n", "RSS 0.473795 0.484564 0.253249 \n", "\n", " pvalue FDR_pvalue FWER_pvalue fraction_significant \\\n", "Measure \n", "C2 0.005729 0.005879 0.010640 0.971429 \n", "Correlation 0.016841 0.017345 0.026655 0.904762 \n", "RSS 0.041892 0.044068 0.068063 0.785714 \n", "\n", " fraction_FDR_significant fraction_FWER_significant \n", "Measure \n", "C2 0.971429 0.885714 \n", "Correlation 0.904762 0.857143 \n", "RSS 0.785714 0.714286 " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AccuracyMAEKappaBalanced accuracyLogistic lossPrecisionSensitivitySpecificityNPVPPV...LiftF-measureFowlkes-Mallows indexFalse positiveFalse negativeTrue positiveTrue negativeRules per exampleVoting conflictsGeometric mean
Measure
C20.9326630.0673370.4336130.7096932.4270882.4270880.4529410.9664460.9616650.487342...7.4075950.4695120.92870381937723339.0793342146.00.661622
Correlation0.8273990.1726010.2466890.7299096.2211576.2211570.6176470.8421710.9690180.216049...3.2839510.3201220.8237573816510520336.4388541993.00.721224
RSS0.7883130.2116870.2074580.7253947.6299847.6299840.6529410.7978460.9702770.185309...2.8166940.2886870.7898004885911119266.6331272085.00.721766
\n", "

3 rows × 23 columns

\n", "
" ], "text/plain": [ " Accuracy MAE Kappa Balanced accuracy Logistic loss \\\n", "Measure \n", "C2 0.932663 0.067337 0.433613 0.709693 2.427088 \n", "Correlation 0.827399 0.172601 0.246689 0.729909 6.221157 \n", "RSS 0.788313 0.211687 0.207458 0.725394 7.629984 \n", "\n", " Precision Sensitivity Specificity NPV PPV ... \\\n", "Measure ... \n", "C2 2.427088 0.452941 0.966446 0.961665 0.487342 ... \n", "Correlation 6.221157 0.617647 0.842171 0.969018 0.216049 ... \n", "RSS 7.629984 0.652941 0.797846 0.970277 0.185309 ... \n", "\n", " Lift F-measure Fowlkes-Mallows index False positive \\\n", "Measure \n", "C2 7.407595 0.469512 0.928703 81 \n", "Correlation 3.283951 0.320122 0.823757 381 \n", "RSS 2.816694 0.288687 0.789800 488 \n", "\n", " False negative True positive True negative Rules per example \\\n", "Measure \n", "C2 93 77 2333 9.079334 \n", "Correlation 65 105 2033 6.438854 \n", "RSS 59 111 1926 6.633127 \n", "\n", " Voting conflicts Geometric mean \n", "Measure \n", "C2 2146.0 0.661622 \n", "Correlation 1993.0 0.721224 \n", "RSS 2085.0 0.721766 \n", "\n", "[3 rows x 23 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Confusion matrix - C2\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01
0233381
19377
\n", "
" ], "text/plain": [ " 0 1\n", "0 2333 81\n", "1 93 77" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Confusion matrix - Correlation\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01
02033381
165105
\n", "
" ], "text/plain": [ " 0 1\n", "0 2033 381\n", "1 65 105" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Confusion matrix - RSS\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01
01926488
159111
\n", "
" ], "text/plain": [ " 0 1\n", "0 1926 488\n", "1 59 111" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from rulekit.classification import RuleClassifier\n", "from rulekit.rules import RuleSet, ClassificationRule\n", "from rulekit.params import Measures\n", "from IPython.display import display\n", "\n", "# C2\n", "clf = RuleClassifier(\n", " induction_measure=Measures.C2,\n", " pruning_measure=Measures.C2,\n", " voting_measure=Measures.C2,\n", ")\n", "clf.fit(X, y)\n", "c2_ruleset: RuleSet[ClassificationRule] = clf.model\n", "prediction, classification_metrics = clf.predict(X, return_metrics=True)\n", "\n", "prediction_metric, c2_confusion_matrix = get_prediction_metrics('C2', prediction, y, classification_metrics)\n", "model_stats = get_ruleset_stats('C2', clf.model)\n", "\n", "# Correlation\n", "clf = RuleClassifier(\n", " induction_measure=Measures.Correlation,\n", " pruning_measure=Measures.Correlation,\n", " voting_measure=Measures.Correlation,\n", ")\n", "clf.fit(X, y)\n", "corr_ruleset: RuleSet[ClassificationRule] = clf.model\n", "prediction, classification_metrics = clf.predict(X, return_metrics=True)\n", "\n", "tmp, corr_confusion_matrix = get_prediction_metrics('Correlation', prediction, y, classification_metrics)\n", "prediction_metric = pd.concat([prediction_metric, tmp])\n", "model_stats = pd.concat([model_stats, get_ruleset_stats('Correlation', clf.model)])\n", "\n", "# RSS\n", "clf = RuleClassifier(\n", " induction_measure=Measures.RSS,\n", " pruning_measure=Measures.RSS,\n", " voting_measure=Measures.RSS,\n", ")\n", "clf.fit(X, y)\n", "rss_ruleset: RuleSet[ClassificationRule] = clf.model\n", "prediction, classification_metrics = clf.predict(X, return_metrics=True)\n", "tmp, rss_confusion_matrix = get_prediction_metrics('RSS', prediction, y, classification_metrics)\n", "prediction_metric = pd.concat([prediction_metric, tmp])\n", "model_stats = pd.concat([model_stats, get_ruleset_stats('RSS', clf.model)])\n", "\n", "display(model_stats)\n", "display(prediction_metric)\n", "\n", "print('Confusion matrix - C2')\n", "display(pd.DataFrame(c2_confusion_matrix))\n", "\n", "print('Confusion matrix - Correlation')\n", "display(pd.DataFrame(corr_confusion_matrix))\n", "\n", "print('Confusion matrix - RSS')\n", "display(pd.DataFrame(rss_confusion_matrix))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### C2 Measure generated rules" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "IF gimpuls = (-inf, 32.50) THEN class = {0}\n", "IF nbumps = (-inf, 0.50) AND goenergy = <-84.50, inf) AND goimpuls = (-inf, -0.50) AND genergy = (-inf, 13675) THEN class = {0}\n", "IF genergy = (-inf, 17640) AND nbumps = (-inf, 0.50) AND goenergy = <-84.50, inf) THEN class = {0}\n", "IF genergy = <1625, 17640) AND maxenergy = (-inf, 3500) AND gimpuls = (-inf, 772.50) THEN class = {0}\n", "IF shift = {N} AND nbumps = (-inf, 0.50) AND goenergy = <-73.50, inf) THEN class = {0}\n", "IF shift = {N} AND senergy = (-inf, 6150) AND genergy = <1865, inf) AND goimpuls = (-inf, 230.50) THEN class = {0}\n", "IF senergy = (-inf, 550) AND gimpuls = (-inf, 380.50) AND goimpuls = (-inf, 96.50) AND goenergy = (-inf, 118) THEN class = {0}\n", "IF senergy = (-inf, 550) AND genergy = (-inf, 31790) AND goenergy = <-84.50, 114.50) THEN class = {0}\n", "IF senergy = (-inf, 550) AND goenergy = <-84.50, 87.50) AND gimpuls = (-inf, 1342.50) AND goimpuls = (-inf, 96) THEN class = {0}\n", "IF senergy = (-inf, 550) AND goimpuls = (-inf, 233.50) THEN class = {0}\n", "IF genergy = <1865, 28515) AND goenergy = (-inf, 105.50) AND nbumps = (-inf, 4.50) THEN class = {0}\n", "IF nbumps = <0.50, 1.50) AND gimpuls = (-inf, 1210) AND goimpuls = (-inf, 233.50) AND goenergy = <-72.50, inf) AND genergy = <12550, inf) THEN class = {0}\n", "IF gimpuls = (-inf, 514.50) AND nbumps = (-inf, 6.50) AND goimpuls = (-inf, 96.50) AND goenergy = <-84.50, inf) THEN class = {0}\n", "IF nbumps = (-inf, 2.50) AND gimpuls = (-inf, 1832.50) AND goimpuls = (-inf, 312) THEN class = {0}\n", "IF nbumps3 = (-inf, 2.50) AND nbumps = (-inf, 6.50) AND goenergy = <-88.50, inf) THEN class = {0}\n", "IF genergy = (-inf, 748755) AND goimpuls = (-inf, 95.50) AND maxenergy = (-inf, 55000) AND gimpuls = (-inf, 3096) THEN class = {0}\n", "IF nbumps3 = <3.50, inf) AND gimpuls = <364, 1459.50) AND senergy = <10150, inf) THEN class = {1}\n", "IF gimpuls = <2208.50, inf) AND genergy = <513615, 1005720) AND nbumps2 = <0.50, inf) AND nbumps = (-inf, 3.50) THEN class = {1}\n", "IF gimpuls = <1328, inf) AND goenergy = (-inf, -29.50) AND goimpuls = <-31.50, -14.50) AND nbumps4 = (-inf, 1.50) THEN class = {1}\n", "IF gimpuls = <1328, 2109) AND goimpuls = (-inf, -5.50) AND senergy = <350, 36350) AND genergy = <159155, 586025) AND nbumps2 = (-inf, 3.50) AND goenergy = <-41.50, inf) THEN class = {1}\n", "IF nbumps3 = <0.50, 1.50) AND gimpuls = <1408, 1959) AND goimpuls = <-20.50, 13.50) AND senergy = (-inf, 54950) THEN class = {1}\n", "IF senergy = <750, 38250) AND genergy = <254130, 1133675) AND goenergy = <-16.50, inf) AND gimpuls = <1438.50, inf) AND goimpuls = <-5, inf) THEN class = {1}\n", "IF nbumps = <4.50, inf) AND nbumps3 = <1.50, 4.50) AND gimpuls = <203.50, inf) AND senergy = <4300, 131700) AND goenergy = <-41.50, inf) THEN class = {1}\n", "IF nbumps = <2.50, 4.50) AND gimpuls = <740.50, inf) AND genergy = <38935, 127440) AND senergy = (-inf, 14750) AND goimpuls = (-inf, 68.50) THEN class = {1}\n", "IF nbumps = <2.50, inf) AND gimpuls = <379, 1742) AND senergy = (-inf, 31100) AND genergy = (-inf, 211170) AND goenergy = (-inf, 123.50) AND goimpuls = (-inf, 19.50) THEN class = {1}\n", "IF gimpuls = <1139.50, inf) AND goimpuls = <-46, 116.50) AND senergy = (-inf, 38250) AND genergy = <46580, 1877915) AND goenergy = (-inf, 183) AND shift = {W} AND nbumps3 = (-inf, 1.50) AND nbumps2 = (-inf, 2.50) THEN class = {1}\n", "IF nbumps = <1.50, 3.50) AND gimpuls = <521.50, 2344.50) AND nbumps2 = <0.50, inf) AND genergy = <34605, 656965) AND goenergy = (-inf, 137) AND goimpuls = <-39, 41.50) AND maxenergy = <450, inf) THEN class = {1}\n", "IF nbumps = <1.50, 3.50) AND nbumps2 = <0.50, inf) AND genergy = <18870, inf) AND nbumps3 = (-inf, 1.50) AND gimpuls = <160, inf) AND senergy = <550, inf) AND goimpuls = <-62.50, 8.50) AND goenergy = (-inf, -1.50) THEN class = {1}\n", "IF nbumps = <1.50, inf) AND gimpuls = <95, 1603.50) AND goenergy = (-inf, 131) AND goimpuls = <-70.50, 119) AND nbumps2 = (-inf, 4.50) AND nbumps3 = <0.50, inf) AND genergy = (-inf, 614380) AND maxenergy = (-inf, 25000) AND senergy = <2250, inf) THEN class = {1}\n", "IF goenergy = <-59.50, 186) AND genergy = <12415, 129940) AND gimpuls = <121.50, 793) AND senergy = <150, 1350) AND ghazard = {a} AND goimpuls = <-53.50, inf) THEN class = {1}\n", "IF genergy = <42215, 94300) AND gimpuls = <133.50, 813.50) AND ghazard = {a} AND goenergy = <-74.50, 160) AND senergy = (-inf, 11100) AND nbumps = (-inf, 3.50) AND nbumps3 = (-inf, 0.50) THEN class = {1}\n", "IF gimpuls = <537.50, 796) AND shift = {W} AND genergy = <17635, 36470) AND goimpuls = <-36.50, inf) AND goenergy = <-37.50, inf) AND nbumps = (-inf, 0.50) THEN class = {1}\n", "IF genergy = <18800, 52205) AND shift = {W} AND ghazard = {a} AND goimpuls = <-28.50, inf) AND goenergy = (-inf, 181) AND gimpuls = <380.50, 524.50) AND nbumps = (-inf, 0.50) THEN class = {1}\n", "IF gimpuls = <184.50, inf) AND goenergy = <-55.50, 128.50) AND genergy = <7265, inf) AND goimpuls = <-60.50, 37.50) AND nbumps = (-inf, 7.50) AND nbumps2 = (-inf, 4.50) AND maxenergy = (-inf, 25000) AND senergy = (-inf, 31350) THEN class = {1}\n", "IF gimpuls = <32.50, inf) AND goimpuls = <-74.50, inf) AND ghazard = {a} AND genergy = <1510, inf) AND goenergy = <-89.50, 124.50) THEN class = {1}\n" ] } ], "source": [ "for rule in c2_ruleset.rules:\n", " print(rule)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Correlation Measure generated rules" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "IF nbumps = (-inf, 1.50) AND gimpuls = (-inf, 1252.50) THEN class = {0}\n", "IF nbumps = (-inf, 2.50) AND gimpuls = (-inf, 1331) AND goimpuls = (-inf, 312) AND nbumps5 = (-inf, 0.50) THEN class = {0}\n", "IF nbumps = (-inf, 2.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1253.50) AND nbumps3 = (-inf, 2.50) AND nbumps = (-inf, 7) THEN class = {0}\n", "IF nbumps4 = (-inf, 2.50) THEN class = {0}\n", "IF nbumps2 = <0.50, 2.50) AND maxenergy = <1500, inf) AND senergy = (-inf, 36050) AND nbumps3 = <0.50, 4.50) AND goimpuls = <-34, 95) AND genergy = (-inf, 662435) AND gimpuls = <994.50, 1959) THEN class = {1}\n", "IF nbumps2 = <0.50, inf) AND maxenergy = <1500, inf) AND goimpuls = <-55, 95) AND nbumps = (-inf, 6.50) AND genergy = <61250, 662435) AND goenergy = (-inf, 96) AND nbumps3 = <0.50, inf) AND gimpuls = <712, 2257.50) AND senergy = (-inf, 31100) THEN class = {1}\n", "IF nbumps2 = <0.50, inf) AND genergy = <58310, 934630) AND goenergy = (-inf, 186) AND senergy = (-inf, 40650) AND maxenergy = <1500, inf) AND gimpuls = <538.50, inf) AND goimpuls = <-55, inf) THEN class = {1}\n", "IF nbumps = <1.50, inf) AND nbumps2 = <0.50, inf) AND gimpuls = <521.50, 2374) AND genergy = <58310, 799855) AND senergy = <650, 36050) AND goimpuls = <-71, 58.50) AND ghazard = {a} THEN class = {1}\n", "IF nbumps = <1.50, 6.50) AND nbumps2 = <0.50, inf) AND gimpuls = <521.50, 2374) AND genergy = <34360, inf) AND maxenergy = <350, inf) AND goimpuls = <-55, 95) AND senergy = <550, inf) AND nbumps4 = (-inf, 1.50) THEN class = {1}\n", "IF nbumps = <1.50, inf) AND gimpuls = <306, inf) AND genergy = <28325, inf) AND goimpuls = (-inf, 19.50) THEN class = {1}\n", "IF nbumps = <1.50, inf) AND nbumps2 = <0.50, inf) AND gimpuls = <153.50, 321) AND genergy = <14295, 36250) AND goimpuls = <-60.50, inf) AND senergy = (-inf, 40500) AND nbumps3 = (-inf, 3.50) THEN class = {1}\n", "IF genergy = <96260, 1062020) AND goimpuls = <-29, inf) AND senergy = <850, 7500) AND nbumps3 = (-inf, 1.50) AND gimpuls = <1404, 2965.50) AND nbumps = (-inf, 3.50) AND goenergy = (-inf, 69.50) THEN class = {1}\n", "IF gimpuls = <1253.50, inf) AND goenergy = <-50.50, 131.50) AND genergy = <46580, 1789250) AND nbumps = (-inf, 7.50) AND shift = {W} AND goimpuls = <-60.50, 118) AND senergy = (-inf, 95850) AND ghazard = {a} THEN class = {1}\n", "IF senergy = <550, inf) AND shift = {W} AND genergy = <10495, inf) AND gimpuls = <160, inf) AND goenergy = (-inf, 126) THEN class = {1}\n", "IF senergy = <350, inf) AND goimpuls = <-74.50, inf) AND gimpuls = <32.50, inf) AND goenergy = <-78.50, inf) AND maxenergy = <250, inf) THEN class = {1}\n", "IF genergy = <43150, inf) AND gimpuls = <133.50, inf) AND goenergy = (-inf, 176.50) THEN class = {1}\n", "IF shift = {W} AND genergy = <31760, 49585) AND gimpuls = <362.50, 771) AND goimpuls = <-27.50, inf) AND goenergy = <-3.50, inf) AND maxenergy = (-inf, 650) THEN class = {1}\n", "IF shift = {W} AND genergy = <20485, 43280) AND gimpuls = <380.50, 796) AND goimpuls = <-37, 142.50) AND goenergy = <-37.50, 181) AND nbumps = (-inf, 0.50) THEN class = {1}\n", "IF gimpuls = <177.50, inf) AND genergy = <7265, inf) AND goimpuls = (-inf, 241.50) AND goenergy = (-inf, 124.50) THEN class = {1}\n", "IF gimpuls = <54.50, 90) AND genergy = <1510, 4905) AND goimpuls = <-72.50, 28.50) AND seismoacoustic = {a} THEN class = {1}\n" ] } ], "source": [ "for rule in corr_ruleset.rules:\n", " print(rule)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### RSS Measure generated rules" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "IF nbumps = (-inf, 1.50) AND genergy = (-inf, 126350) THEN class = {0}\n", "IF nbumps = (-inf, 1.50) AND gimpuls = (-inf, 2168) AND goimpuls = (-inf, 96.50) THEN class = {0}\n", "IF genergy = (-inf, 44750) AND nbumps3 = (-inf, 2.50) AND goenergy = (-inf, 105.50) THEN class = {0}\n", "IF gimpuls = (-inf, 725.50) AND nbumps3 = (-inf, 3.50) AND nbumps4 = (-inf, 2.50) AND goimpuls = (-inf, 117) AND goenergy = <-88.50, inf) THEN class = {0}\n", "IF nbumps2 = (-inf, 1.50) AND nbumps = (-inf, 4.50) THEN class = {0}\n", "IF goimpuls = (-inf, 312) AND nbumps5 = (-inf, 0.50) AND goenergy = <-88.50, inf) THEN class = {0}\n", "IF gimpuls = <521.50, inf) AND genergy = <57680, inf) THEN class = {1}\n", "IF nbumps = <1.50, inf) THEN class = {1}\n", "IF senergy = <550, inf) AND shift = {W} AND genergy = <10495, inf) THEN class = {1}\n", "IF nbumps = <0.50, inf) AND goimpuls = <-74.50, inf) AND gimpuls = <32.50, inf) AND goenergy = <-78.50, 124.50) THEN class = {1}\n", "IF genergy = <34315, 49585) AND ghazard = {a} AND gimpuls = <396, 1445.50) AND goenergy = <7, inf) AND goimpuls = <-19, inf) AND senergy = (-inf, 350) THEN class = {1}\n", "IF genergy = <26200, 78890) AND gimpuls = <133.50, 813.50) AND goenergy = <-74.50, 297.50) AND goimpuls = <-71, inf) AND nbumps = (-inf, 3.50) AND senergy = (-inf, 1850) THEN class = {1}\n", "IF genergy = <18585, 25305) AND shift = {W} AND gimpuls = <240, 588.50) AND goimpuls = <-42.50, 133) AND goenergy = <-45.50, inf) AND senergy = (-inf, 450) THEN class = {1}\n", "IF gimpuls = <54.50, inf) AND goimpuls = <-74.50, 28.50) AND genergy = <1510, inf) AND ghazard = {a} AND nbumps4 = (-inf, 1.50) AND senergy = (-inf, 92850) THEN class = {1}\n" ] } ], "source": [ "for rule in rss_ruleset.rules:\n", " print(rule)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Stratified K-Folds cross-validation" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from IPython.display import display\n", "from sklearn.model_selection import StratifiedKFold\n", "\n", "N_SPLITS: int = 10\n", "\n", "skf = StratifiedKFold(n_splits=10)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\cezar\\AppData\\Local\\Temp\\ipykernel_13196\\4002598548.py:22: RuntimeWarning: invalid value encountered in scalar divide\n", " ppv: float = tp / (tp + fp)\n", "C:\\Users\\cezar\\AppData\\Local\\Temp\\ipykernel_13196\\4002598548.py:39: RuntimeWarning: invalid value encountered in scalar divide\n", " 'Lift': (tp / (tp + fp)) / ((tp + fn) / (tp + tn + fp + fn)),\n", "C:\\Users\\cezar\\AppData\\Local\\Temp\\ipykernel_13196\\4002598548.py:22: RuntimeWarning: invalid value encountered in scalar divide\n", " ppv: float = tp / (tp + fp)\n", "C:\\Users\\cezar\\AppData\\Local\\Temp\\ipykernel_13196\\4002598548.py:39: RuntimeWarning: invalid value encountered in scalar divide\n", " 'Lift': (tp / (tp + fp)) / ((tp + fn) / (tp + tn + fp + fn)),\n" ] } ], "source": [ "c2_ruleset_stats = pd.DataFrame()\n", "c2_prediction_metrics = pd.DataFrame()\n", "c2_confusion_matrix = np.array([[0.0, 0.0], [0.0, 0.0]])\n", "\n", "for train_index, test_index in skf.split(X, y):\n", " x_train, x_test = X.iloc[train_index], X.iloc[test_index]\n", " y_train, y_test = y.iloc[train_index], y.iloc[test_index]\n", "\n", " clf = RuleClassifier(\n", " induction_measure=Measures.C2,\n", " pruning_measure=Measures.C2,\n", " voting_measure=Measures.C2,\n", " )\n", " clf.fit(x_train, y_train)\n", " c2_ruleset = clf.model\n", " prediction, classification_metrics = clf.predict(x_test, return_metrics=True)\n", " tmp, confusion_matrix = get_prediction_metrics('C2', prediction, y_test, classification_metrics)\n", " \n", " c2_prediction_metrics = pd.concat([c2_prediction_metrics, tmp])\n", " c2_ruleset_stats = pd.concat([c2_ruleset_stats, get_ruleset_stats('C2', c2_ruleset)])\n", " c2_confusion_matrix += confusion_matrix\n", "\n", "c2_confusion_matrix /= N_SPLITS" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Rules characteristics " ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "time_total_s 0.292132\n", "time_growing_s 0.227262\n", "time_pruning_s 0.047455\n", "rules_count 34.200000\n", "conditions_per_rule 4.720378\n", "induced_conditions_per_rule 21.124536\n", "avg_rule_coverage 0.239541\n", "avg_rule_precision 0.690010\n", "avg_rule_quality 0.337021\n", "pvalue 0.014757\n", "FDR_pvalue 0.015234\n", "FWER_pvalue 0.030014\n", "fraction_significant 0.909792\n", "fraction_FDR_significant 0.909792\n", "fraction_FWER_significant 0.872710\n", "dtype: float64" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(c2_ruleset_stats.mean())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Rules evaluation (average)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Accuracy 0.855063\n", "MAE 0.144937\n", "Kappa 0.107556\n", "Balanced accuracy 0.567020\n", "Logistic loss 5.224070\n", "Precision 5.224070\n", "Sensitivity 0.235294\n", "Specificity 0.898747\n", "NPV 0.945594\n", "PPV 0.495355\n", "psep 0.443788\n", "Fall-out 0.101253\n", "Youden's J statistic 0.134041\n", "Lift 7.520255\n", "F-measure 0.145015\n", "Fowlkes-Mallows index 0.870605\n", "False positive 24.500000\n", "False negative 13.000000\n", "True positive 4.000000\n", "True negative 216.900000\n", "Rules per example 7.871479\n", "Voting conflicts 179.000000\n", "Geometric mean 0.344832\n", "dtype: float64" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(c2_prediction_metrics.mean())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Confusion matrix (average)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01
0216.924.5
113.04.0
\n", "
" ], "text/plain": [ " 0 1\n", "0 216.9 24.5\n", "1 13.0 4.0" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(pd.DataFrame(c2_confusion_matrix))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyperparameters tuning\n", "\n", "This one gonna take a while..." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best BAC: 0.626780 using {'induction_measure': , 'minsupp_new': 5, 'pruning_measure': , 'voting_measure': }\n" ] } ], "source": [ "from sklearn.model_selection import StratifiedKFold\n", "from sklearn.model_selection import GridSearchCV\n", "from rulekit.params import Measures\n", "\n", "N_SPLITS: int = 3\n", "\n", "# define models and parameters\n", "model = RuleClassifier()\n", "minsupp_new = range(3, 15, 2)\n", "measures_choice = [Measures.C2, Measures.RSS, Measures.WeightedLaplace, Measures.Correlation]\n", "\n", "# define grid search\n", "grid = {\n", " 'minsupp_new': minsupp_new, \n", " 'induction_measure': measures_choice, \n", " 'pruning_measure': measures_choice, \n", " 'voting_measure': measures_choice\n", "}\n", "cv = StratifiedKFold(n_splits=N_SPLITS)\n", "grid_search = GridSearchCV(\n", " estimator=model, \n", " param_grid=grid, \n", " cv=cv, \n", " scoring='balanced_accuracy', \n", " n_jobs=3\n", ")\n", "grid_result = grid_search.fit(X, y)\n", "\n", "# summarize results\n", "print(\"Best BAC: %f using %s\" % (grid_result.best_score_, grid_result.best_params_))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Building model with tuned hyperparameters\n", "\n", "### Split dataset to train and test (80%/20%)." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "from IPython.display import display\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=True)\n", "\n", "\n", "clf = RuleClassifier(**grid_result.best_params_)\n", "clf.fit(X_train, y_train)\n", "ruleset: RuleSet[ClassificationRule] = clf.model\n", "ruleset_stats = get_ruleset_stats('Best', ruleset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Rules evaluation" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "time_total_s 0.173054\n", "time_growing_s 0.120460\n", "time_pruning_s 0.029986\n", "rules_count 29.000000\n", "conditions_per_rule 2.689655\n", "induced_conditions_per_rule 15.310345\n", "avg_rule_coverage 0.491183\n", "avg_rule_precision 0.736226\n", "avg_rule_quality 1.309334\n", "pvalue 0.019831\n", "FDR_pvalue 0.019993\n", "FWER_pvalue 0.024284\n", "fraction_significant 0.931034\n", "fraction_FDR_significant 0.931034\n", "fraction_FWER_significant 0.931034\n", "dtype: float64" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(ruleset_stats.mean())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Validate model on test dataset" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Accuracy 0.808511\n", "MAE 0.191489\n", "Kappa 0.170010\n", "Balanced accuracy 0.679398\n", "Logistic loss 6.901976\n", "Precision 6.901976\n", "Sensitivity 0.533333\n", "Specificity 0.825462\n", "NPV 0.966346\n", "PPV 0.158416\n", "psep 0.124762\n", "Fall-out 0.174538\n", "Youden's J statistic 0.358795\n", "Lift 2.730033\n", "F-measure 0.244275\n", "Fowlkes-Mallows index 0.809997\n", "False positive 85.000000\n", "False negative 14.000000\n", "True positive 16.000000\n", "True negative 402.000000\n", "Rules per example 14.034816\n", "Voting conflicts 360.000000\n", "Geometric mean 0.663511\n", "dtype: float64" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01
040285
11416
\n", "
" ], "text/plain": [ " 0 1\n", "0 402 85\n", "1 14 16" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "prediction, classification_metrics = clf.predict(X_test, return_metrics=True)\n", "prediction_metrics, confusion_matrix = get_prediction_metrics('Best', prediction, y_test, classification_metrics)\n", "\n", "display(prediction_metrics.mean())\n", "display(pd.DataFrame(confusion_matrix))" ] } ], "metadata": { "kernelspec": { "display_name": "tutorials_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.0" }, "orig_nbformat": 2 }, "nbformat": 4, "nbformat_minor": 2 }