{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook presents example usage of package for solving regression problem on `methane` dataset. You can download training dataset [here](https://raw.githubusercontent.com/adaa-polsl/RuleKit/master/data/methane/methane-train.arff) and test dataset [here](https://raw.githubusercontent.com/adaa-polsl/RuleKit/master/data/methane/methane-test.arff)\n", "\n", "This tutorial will cover topics such as: \n", "- training model \n", "- changing model hyperparameters \n", "- hyperparameters tuning \n", "- calculating metrics for model \n", "- getting RuleKit inbuilt " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary of the dataset" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from scipy.io import arff\n", "import pandas as pd\n", "\n", "train_file_name = \"methane-train.arff\"\n", "test_file_name = \"methane-test.arff\"\n", "\n", "train_df = pd.DataFrame(arff.loadarff(train_file_name)[0])\n", "test_df = pd.DataFrame(arff.loadarff(test_file_name)[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train file" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train file overview:\n", "Name: methane-train.arff\n", "Objects number: 13368; Attributes number: 8\n", "Basic attribute statistics:\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
MM31MM116AS038PG072PDBA13DMM116MM116_pred
count13368.00000013368.00000013368.00000013368.00000013368.00000013368.00000013368.00000013368.00000
mean0.3639600.7750072.2947341.8356000.3085731073.443372-0.0000070.79825
std0.1171050.2693660.1425040.1066810.4619223.1628110.0435660.28649
min0.1700000.2000001.4000001.1000000.0000001067.000000-1.8000000.20000
25%0.2600000.5000002.3000001.8000000.0000001070.0000000.0000000.50000
50%0.3600000.8000002.3000001.8000000.0000001075.0000000.0000000.80000
75%0.4500001.0000002.4000001.9000001.0000001076.0000000.0000001.00000
max0.8200002.2000002.7000002.6000001.0000001078.0000000.8000002.20000
\n", "
" ], "text/plain": [ " MM31 MM116 AS038 PG072 PD \\\n", "count 13368.000000 13368.000000 13368.000000 13368.000000 13368.000000 \n", "mean 0.363960 0.775007 2.294734 1.835600 0.308573 \n", "std 0.117105 0.269366 0.142504 0.106681 0.461922 \n", "min 0.170000 0.200000 1.400000 1.100000 0.000000 \n", "25% 0.260000 0.500000 2.300000 1.800000 0.000000 \n", "50% 0.360000 0.800000 2.300000 1.800000 0.000000 \n", "75% 0.450000 1.000000 2.400000 1.900000 1.000000 \n", "max 0.820000 2.200000 2.700000 2.600000 1.000000 \n", "\n", " BA13 DMM116 MM116_pred \n", "count 13368.000000 13368.000000 13368.00000 \n", "mean 1073.443372 -0.000007 0.79825 \n", "std 3.162811 0.043566 0.28649 \n", "min 1067.000000 -1.800000 0.20000 \n", "25% 1070.000000 0.000000 0.50000 \n", "50% 1075.000000 0.000000 0.80000 \n", "75% 1076.000000 0.000000 1.00000 \n", "max 1078.000000 0.800000 2.20000 " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(\"Train file overview:\")\n", "print(f\"Name: {train_file_name}\")\n", "print(f\"Objects number: {train_df.shape[0]}; Attributes number: {train_df.shape[1]}\")\n", "print(\"Basic attribute statistics:\")\n", "train_df.describe()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test file" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Test file overview:\n", "Name: methane-test.arff\n", "Objects number: 5728; Attributes number: 8\n", "Basic attribute statistics:\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
MM31MM116AS038PG072PDBA13DMM116MM116_pred
count5728.0000005728.0000005728.0000005728.0000005728.0000005728.0000005728.0000005728.000000
mean0.5566521.0069132.2366271.8192390.5384081072.691690-0.0000171.042458
std0.1146820.1679830.1049130.0788650.4985662.7995590.0468490.171393
min0.3500000.5000001.8000001.6000000.0000001067.000000-0.4000000.600000
25%0.4600000.9000002.2000001.8000000.0000001071.0000000.0000000.900000
50%0.5500001.0000002.2000001.8000001.0000001073.0000000.0000001.000000
75%0.6400001.1000002.3000001.9000001.0000001075.0000000.0000001.200000
max0.9800001.6000002.7000002.1000001.0000001078.0000000.3000001.600000
\n", "
" ], "text/plain": [ " MM31 MM116 AS038 PG072 PD \\\n", "count 5728.000000 5728.000000 5728.000000 5728.000000 5728.000000 \n", "mean 0.556652 1.006913 2.236627 1.819239 0.538408 \n", "std 0.114682 0.167983 0.104913 0.078865 0.498566 \n", "min 0.350000 0.500000 1.800000 1.600000 0.000000 \n", "25% 0.460000 0.900000 2.200000 1.800000 0.000000 \n", "50% 0.550000 1.000000 2.200000 1.800000 1.000000 \n", "75% 0.640000 1.100000 2.300000 1.900000 1.000000 \n", "max 0.980000 1.600000 2.700000 2.100000 1.000000 \n", "\n", " BA13 DMM116 MM116_pred \n", "count 5728.000000 5728.000000 5728.000000 \n", "mean 1072.691690 -0.000017 1.042458 \n", "std 2.799559 0.046849 0.171393 \n", "min 1067.000000 -0.400000 0.600000 \n", "25% 1071.000000 0.000000 0.900000 \n", "50% 1073.000000 0.000000 1.000000 \n", "75% 1075.000000 0.000000 1.200000 \n", "max 1078.000000 0.300000 1.600000 " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# test file\n", "print(\"\\nTest file overview:\")\n", "print(f\"Name: {test_file_name}\")\n", "print(f\"Objects number: {test_df.shape[0]}; Attributes number: {test_df.shape[1]}\")\n", "print(\"Basic attribute statistics:\")\n", "test_df.describe()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import RuleKit" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from rulekit import RuleKit\n", "from rulekit.regression import RuleRegressor\n", "from rulekit.params import Measures" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Helper function for calculating metrics" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import sklearn.tree as scikit\n", "import math\n", "from sklearn.preprocessing import MultiLabelBinarizer\n", "from sklearn import metrics\n", "import pandas as pd\n", "import numpy as np\n", "from typing import Tuple\n", "from math import sqrt\n", "\n", "\n", "def get_regression_metrics(measure: str, y_pred, y_true) -> pd.DataFrame:\n", " relative_error = 0\n", " squared_relative_error = 0\n", " relative_error_lenient = 0\n", " relative_error_strict = 0\n", " nae_denominator = 0\n", " avg = sum(y_true) / len(y_pred)\n", "\n", " for i in range(0, len(y_pred)):\n", " true = y_true[i]\n", " predicted = y_pred[i]\n", "\n", " relative_error += abs((true - predicted) / true)\n", " squared_relative_error += abs((true - predicted) / true) * abs((true - predicted) / true)\n", " relative_error_lenient += abs((true - predicted) / max(true, predicted))\n", " relative_error_strict += abs((true - predicted) / min(true, predicted))\n", " nae_denominator += abs(avg - true)\n", " relative_error /= len(y_pred)\n", " squared_relative_error /= len(y_pred)\n", " relative_error_lenient /= len(y_pred)\n", " relative_error_strict /= len(y_pred)\n", " nae_denominator /= len(y_pred)\n", " correlation = np.mean(np.corrcoef(y_true, y_pred))\n", "\n", " dictionary = {\n", " 'Measure': measure,\n", " 'absolute_error': metrics.mean_absolute_error(y_true, y_pred),\n", " 'relative_error': relative_error,\n", " 'relative_error_lenient': relative_error_lenient,\n", " 'relative_error_strict': relative_error_strict,\n", " 'normalized_absolute_error': metrics.mean_absolute_error(y_true, y_pred) / nae_denominator,\n", " 'squared_error': metrics.mean_squared_error(y_true, y_pred),\n", " 'root_mean_squared_error': metrics.mean_squared_error(y_true, y_pred, squared=False),\n", " 'root_relative_squared_error': sqrt(squared_relative_error),\n", " 'correlation': correlation,\n", " 'squared_correlation': np.power(correlation, 2),\n", " }\n", " return pd.DataFrame.from_records([dictionary], index='Measure')\n", "\n", "def get_ruleset_stats(measure: str, model) -> pd.DataFrame:\n", " tmp = model.parameters.__dict__\n", " del tmp['_java_object']\n", " return pd.DataFrame.from_records([{'Measure': measure, **tmp, **model.stats.__dict__}], index='Measure')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Rule induction on training dataset" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "X_train = train_df.drop(['MM116_pred'], axis=1)\n", "y_train = train_df['MM116_pred']" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
minimum_coveredmaximum_uncovered_fractionignore_missingpruning_enabledmax_growing_conditiontime_total_stime_growing_stime_pruning_srules_countconditions_per_ruleinduced_conditions_per_ruleavg_rule_coverageavg_rule_precisionavg_rule_qualitypvalueFDR_pvalueFWER_pvaluefraction_significantfraction_FDR_significantfraction_FWER_significant
Measure
C25.00.0FalseTrue0.010.7143001.5810148.724361283.10714324.2857140.1692340.9432780.8089540.0015810.0015810.0015811.0000001.0000001.000000
Correlation5.00.0FalseTrue0.050.2735512.40494647.560259212.38095236.3333330.2380060.853824NaN0.0474860.0474860.0474870.9523810.9523810.952381
RSS5.00.0FalseTrue0.03.5787590.2408473.04861283.50000022.3750000.2521130.738030NaN0.0000040.0000040.0000041.0000001.0000001.000000
\n", "
" ], "text/plain": [ " minimum_covered maximum_uncovered_fraction ignore_missing \\\n", "Measure \n", "C2 5.0 0.0 False \n", "Correlation 5.0 0.0 False \n", "RSS 5.0 0.0 False \n", "\n", " pruning_enabled max_growing_condition time_total_s \\\n", "Measure \n", "C2 True 0.0 10.714300 \n", "Correlation True 0.0 50.273551 \n", "RSS True 0.0 3.578759 \n", "\n", " time_growing_s time_pruning_s rules_count conditions_per_rule \\\n", "Measure \n", "C2 1.581014 8.724361 28 3.107143 \n", "Correlation 2.404946 47.560259 21 2.380952 \n", "RSS 0.240847 3.048612 8 3.500000 \n", "\n", " induced_conditions_per_rule avg_rule_coverage \\\n", "Measure \n", "C2 24.285714 0.169234 \n", "Correlation 36.333333 0.238006 \n", "RSS 22.375000 0.252113 \n", "\n", " avg_rule_precision avg_rule_quality pvalue FDR_pvalue \\\n", "Measure \n", "C2 0.943278 0.808954 0.001581 0.001581 \n", "Correlation 0.853824 NaN 0.047486 0.047486 \n", "RSS 0.738030 NaN 0.000004 0.000004 \n", "\n", " FWER_pvalue fraction_significant fraction_FDR_significant \\\n", "Measure \n", "C2 0.001581 1.000000 1.000000 \n", "Correlation 0.047487 0.952381 0.952381 \n", "RSS 0.000004 1.000000 1.000000 \n", "\n", " fraction_FWER_significant \n", "Measure \n", "C2 1.000000 \n", "Correlation 0.952381 \n", "RSS 1.000000 " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
absolute_errorrelative_errorrelative_error_lenientrelative_error_strictnormalized_absolute_errorsquared_errorroot_mean_squared_errorroot_relative_squared_errorcorrelationsquared_correlation
Measure
C20.0852930.1000460.0930690.1113290.3629670.0160600.1267300.1322210.9495590.901663
Correlation0.0721640.0836870.0783680.0919830.3070950.0131190.1145380.1240880.9645930.930439
RSS0.1426180.1777150.1506290.2052150.6069160.0408550.2021250.2481990.8626620.744186
\n", "
" ], "text/plain": [ " absolute_error relative_error relative_error_lenient \\\n", "Measure \n", "C2 0.085293 0.100046 0.093069 \n", "Correlation 0.072164 0.083687 0.078368 \n", "RSS 0.142618 0.177715 0.150629 \n", "\n", " relative_error_strict normalized_absolute_error squared_error \\\n", "Measure \n", "C2 0.111329 0.362967 0.016060 \n", "Correlation 0.091983 0.307095 0.013119 \n", "RSS 0.205215 0.606916 0.040855 \n", "\n", " root_mean_squared_error root_relative_squared_error \\\n", "Measure \n", "C2 0.126730 0.132221 \n", "Correlation 0.114538 0.124088 \n", "RSS 0.202125 0.248199 \n", "\n", " correlation squared_correlation \n", "Measure \n", "C2 0.949559 0.901663 \n", "Correlation 0.964593 0.930439 \n", "RSS 0.862662 0.744186 " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# C2\n", "c2_reg = RuleRegressor(\n", " induction_measure=Measures.C2,\n", " pruning_measure=Measures.C2,\n", " voting_measure=Measures.C2,\n", ")\n", "c2_reg.fit(X_train, y_train)\n", "c2_ruleset = c2_reg.model\n", "predictions = c2_reg.predict(X_train)\n", "\n", "regression_metrics = get_regression_metrics('C2', predictions, y_train)\n", "ruleset_stats = get_ruleset_stats('C2', c2_ruleset)\n", "\n", "\n", "# Correlation\n", "corr_reg = RuleRegressor(\n", " induction_measure=Measures.Correlation,\n", " pruning_measure=Measures.Correlation,\n", " voting_measure=Measures.Correlation,\n", " mean_based_regression=True\n", ")\n", "corr_reg.fit(X_train, y_train)\n", "corr_ruleset = corr_reg.model\n", "predictions = corr_reg.predict(X_train)\n", "\n", "tmp = get_regression_metrics('Correlation', predictions, y_train)\n", "regression_metrics = pd.concat([regression_metrics, tmp])\n", "ruleset_stats = pd.concat([ruleset_stats, get_ruleset_stats('Correlation', corr_ruleset)])\n", "\n", "\n", "# RSS\n", "rss_reg = RuleRegressor(\n", " induction_measure=Measures.RSS,\n", " pruning_measure=Measures.RSS,\n", " voting_measure=Measures.RSS,\n", " mean_based_regression=True\n", ")\n", "rss_reg.fit(X_train, y_train)\n", "rss_ruleset = rss_reg.model\n", "predictions = rss_reg.predict(X_train)\n", "\n", "tmp = get_regression_metrics('RSS', predictions, y_train)\n", "regression_metrics = pd.concat([regression_metrics, tmp])\n", "ruleset_stats = pd.concat([ruleset_stats, get_ruleset_stats('RSS', rss_ruleset)])\n", "\n", "\n", "display(ruleset_stats)\n", "display(regression_metrics)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### C2 Measure generated rules" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "IF MM31 = (-inf, 0.23) THEN MM116_pred = {0.40} [0.39,0.41]\n", "IF MM116 = <0.35, 0.45) AND DMM116 = <-0.05, inf) AND MM31 = (-inf, 0.24) THEN MM116_pred = {0.40} [0.39,0.42]\n", "IF MM116 = <0.35, 0.45) AND MM31 = (-inf, 0.24) THEN MM116_pred = {0.40} [0.38,0.42]\n", "IF MM31 = <0.24, 0.25) AND BA13 = (-inf, 1076.50) THEN MM116_pred = {0.50} [0.50,0.50]\n", "IF MM116 = (-inf, 0.45) AND DMM116 = <-0.05, inf) AND AS038 = (-inf, 2.45) AND MM31 = <0.19, 0.25) AND PG072 = (-inf, 2.05) THEN MM116_pred = {0.40} [0.38,0.42]\n", "IF MM116 = (-inf, 0.45) THEN MM116_pred = {0.40} [0.37,0.44]\n", "IF MM116 = (-inf, 0.55) AND MM31 = <0.19, 0.29) AND BA13 = <1072.50, inf) THEN MM116_pred = {0.45} [0.39,0.50]\n", "IF PD = (-inf, 0.50) AND MM116 = <0.45, 0.55) AND MM31 = <0.23, inf) AND PG072 = <1.65, inf) AND BA13 = (-inf, 1077.50) THEN MM116_pred = {0.50} [0.48,0.53]\n", "IF MM116 = <0.45, inf) AND DMM116 = <-0.05, inf) AND MM31 = <0.23, 0.30) AND BA13 = <1073.50, 1076.50) THEN MM116_pred = {0.50} [0.48,0.53]\n", "IF MM116 = (-inf, 0.55) AND AS038 = <2.25, inf) AND MM31 = <0.29, inf) AND BA13 = <1076.50, inf) THEN MM116_pred = {0.55} [0.49,0.61]\n", "IF MM116 = (-inf, 0.55) AND DMM116 = <-0.05, inf) AND MM31 = <0.19, 0.31) THEN MM116_pred = {0.45} [0.40,0.51]\n", "IF MM116 = (-inf, 0.55) AND DMM116 = <-0.15, inf) AND MM31 = (-inf, 0.31) THEN MM116_pred = {0.45} [0.39,0.51]\n", "IF MM116 = (-inf, 0.55) AND DMM116 = <-0.05, inf) AND MM31 = (-inf, 0.32) AND BA13 = <1070.50, inf) THEN MM116_pred = {0.45} [0.39,0.51]\n", "IF MM116 = (-inf, 0.55) AND MM31 = <0.32, 0.33) AND PG072 = (-inf, 1.95) AND BA13 = <1074.50, 1076.50) THEN MM116_pred = {0.60} [0.60,0.60]\n", "IF MM116 = (-inf, 0.55) AND MM31 = (-inf, 0.34) THEN MM116_pred = {0.45} [0.39,0.52]\n", "IF MM116 = (-inf, 0.55) AND DMM116 = <-0.05, inf) THEN MM116_pred = {0.45} [0.39,0.52]\n", "IF MM116 = <0.55, 0.70) THEN MM116_pred = {0.61} [0.56,0.65]\n", "IF PD = (-inf, 0.50) AND MM116 = <0.35, inf) AND DMM116 = <-0.05, 0.05) AND MM31 = (-inf, 0.24) THEN MM116_pred = {0.41} [0.37,0.45]\n", "IF MM116 = <0.35, inf) AND DMM116 = <-0.05, 0.15) AND AS038 = <2.05, 2.45) AND MM31 = (-inf, 0.24) THEN MM116_pred = {0.41} [0.36,0.46]\n", "IF PD = (-inf, 0.50) AND MM116 = <0.35, 0.75) AND AS038 = (-inf, 2.45) AND MM31 = (-inf, 0.26) AND PG072 = (-inf, 2.05) THEN MM116_pred = {0.45} [0.38,0.51]\n", "IF PD = <0.50, inf) AND MM116 = <0.55, 0.75) AND DMM116 = <-0.05, inf) AND AS038 = (-inf, 2.35) AND BA13 = <1074.50, inf) THEN MM116_pred = {0.70} [0.60,0.80]\n", "IF MM116 = <0.75, 0.85) THEN MM116_pred = {0.83} [0.76,0.90]\n", "IF MM116 = (-inf, 0.95) AND DMM116 = <-0.05, 0.05) AND AS038 = (-inf, 2.45) AND MM31 = <0.19, 0.26) AND PG072 = (-inf, 2.05) THEN MM116_pred = {0.45} [0.37,0.53]\n", "IF MM116 = <0.95, inf) THEN MM116_pred = {1.15} [0.97,1.32]\n", "IF MM116 = <0.45, 0.75) AND MM31 = <0.23, inf) THEN MM116_pred = {0.60} [0.49,0.71]\n", "IF PD = (-inf, 0.50) AND MM116 = (-inf, 0.95) AND DMM116 = (-inf, 0.05) AND AS038 = (-inf, 2.45) AND MM31 = (-inf, 0.27) AND PG072 = (-inf, 2.05) THEN MM116_pred = {0.46} [0.38,0.54]\n", "IF PD = <0.50, inf) AND MM116 = <0.45, 0.95) AND AS038 = (-inf, 2.35) AND MM31 = <0.26, 0.27) THEN MM116_pred = {0.84} [0.68,1.01]\n", "IF MM116 = <0.85, inf) THEN MM116_pred = {1.06} [0.88,1.24]\n" ] } ], "source": [ "for rule in c2_ruleset.rules:\n", " print(rule)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Correlation Measure generated rules" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "IF MM31 = (-inf, 0.23) THEN MM116_pred = {0.40} [0.39,0.41]\n", "IF MM116 = (-inf, 0.45) AND DMM116 = <-0.05, inf) AND MM31 = <0.18, 0.24) THEN MM116_pred = {0.40} [0.38,0.42]\n", "IF MM116 = (-inf, 0.45) AND MM31 = <0.18, 0.24) THEN MM116_pred = {0.40} [0.38,0.42]\n", "IF MM31 = (-inf, 0.25) THEN MM116_pred = {0.44} [0.37,0.51]\n", "IF MM31 = (-inf, 0.26) THEN MM116_pred = {0.46} [0.36,0.55]\n", "IF MM31 = (-inf, 0.28) THEN MM116_pred = {0.49} [0.37,0.61]\n", "IF PD = (-inf, 0.50) AND MM116 = <0.25, inf) AND DMM116 = <-0.05, 0.05) AND AS038 = <2, 2.45) AND MM31 = <0.23, inf) AND BA13 = (-inf, 1075.50) THEN MM116_pred = {0.71} [0.50,0.92]\n", "IF MM116 = <0.25, 0.45) AND MM31 = <0.18, inf) AND PG072 = (-inf, 2.05) THEN MM116_pred = {0.40} [0.38,0.43]\n", "IF PD = (-inf, 0.50) AND MM116 = (-inf, 0.25) AND DMM116 = <-0.05, 0.05) AND AS038 = <2.35, 2.45) AND MM31 = <0.19, inf) AND PG072 = <1.75, 1.95) AND BA13 = (-inf, 1076.50) THEN MM116_pred = {0.25} [0.20,0.30]\n", "IF MM116 = (-inf, 0.45) AND DMM116 = <-0.05, inf) AND MM31 = <0.18, inf) AND BA13 = (-inf, 1077.50) THEN MM116_pred = {0.40} [0.37,0.43]\n", "IF MM116 = (-inf, 0.55) AND MM31 = (-inf, 0.32) THEN MM116_pred = {0.45} [0.39,0.51]\n", "IF MM116 = (-inf, 0.55) AND DMM116 = <-0.15, inf) THEN MM116_pred = {0.45} [0.39,0.52]\n", "IF MM116 = <0.45, 0.65) THEN MM116_pred = {0.55} [0.49,0.61]\n", "IF MM116 = <0.45, 0.75) AND DMM116 = <-0.15, inf) THEN MM116_pred = {0.60} [0.49,0.71]\n", "IF MM116 = <0.45, 0.85) AND DMM116 = <-0.15, inf) AND MM31 = <0.25, inf) THEN MM116_pred = {0.70} [0.56,0.84]\n", "IF MM116 = <0.70, inf) AND DMM116 = <-0.30, 0.15) THEN MM116_pred = {0.97} [0.77,1.17]\n", "IF MM116 = <1.05, 1.35) THEN MM116_pred = {1.19} [1.08,1.31]\n", "IF MM116 = <1.35, 1.65) AND MM31 = <0.35, inf) THEN MM116_pred = {1.48} [1.35,1.61]\n", "IF MM116 = <1.65, inf) THEN MM116_pred = {1.84} [1.44,2.24]\n", "IF MM116 = <0.85, 1.15) AND DMM116 = <-0.35, inf) THEN MM116_pred = {1.00} [0.89,1.12]\n", "IF MM116 = <0.65, 1.55) AND DMM116 = <-0.50, inf) AND PG072 = (-inf, 2.35) THEN MM116_pred = {0.97} [0.78,1.16]\n" ] } ], "source": [ "for rule in corr_ruleset.rules:\n", " print(rule)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### RSS Measure generated rules" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "IF MM31 = (-inf, 0.23) THEN MM116_pred = {0.40} [0.39,0.41]\n", "IF MM116 = (-inf, 0.45) AND MM31 = <0.18, 0.25) AND PG072 = (-inf, 2.05) THEN MM116_pred = {0.40} [0.38,0.43]\n", "IF MM31 = (-inf, 0.26) THEN MM116_pred = {0.46} [0.36,0.55]\n", "IF MM116 = <0.35, inf) AND MM31 = <0.26, inf) THEN MM116_pred = {0.91} [0.67,1.14]\n", "IF PD = (-inf, 0.50) AND MM116 = <0.25, inf) AND DMM116 = <-0.95, 0.05) AND AS038 = <2, 2.45) AND MM31 = <0.23, inf) AND PG072 = <1.65, 2.05) AND BA13 = (-inf, 1075.50) THEN MM116_pred = {0.71} [0.50,0.93]\n", "IF PD = (-inf, 0.50) AND MM116 = (-inf, 0.25) AND DMM116 = <-0.05, 0.05) AND AS038 = <2.35, 2.45) AND MM31 = <0.19, inf) AND PG072 = <1.75, 1.95) AND BA13 = (-inf, 1077.50) THEN MM116_pred = {0.25} [0.20,0.30]\n", "IF MM116 = (-inf, 0.25) THEN MM116_pred = {0.23} [0.19,0.28]\n", "IF PD = (-inf, 0.50) AND MM116 = <0.25, inf) AND AS038 = <2, 2.45) AND MM31 = <0.23, inf) AND PG072 = (-inf, 1.95) AND BA13 = (-inf, 1075.50) THEN MM116_pred = {0.71} [0.50,0.93]\n" ] } ], "source": [ "for rule in rss_ruleset.rules:\n", " print(rule)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation on a test set" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "X_test = test_df.drop(['MM116_pred'], axis=1)\n", "y_test = test_df['MM116_pred']" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# C2\n", "c2_predictions = c2_reg.predict(X_test)\n", "c2_regression_metrics = get_regression_metrics('C2', c2_predictions, y_test)\n", "\n", "# Correlation\n", "corr_predictions = corr_reg.predict(X_test)\n", "corr_regression_metrics = get_regression_metrics('Correlation', corr_predictions, y_test)\n", "\n", "# RSS\n", "rss_predictions = rss_reg.predict(X_test)\n", "rss_regression_metrics = get_regression_metrics('RSS', rss_predictions, y_test)\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
absolute_errorrelative_errorrelative_error_lenientrelative_error_strictnormalized_absolute_errorsquared_errorroot_mean_squared_errorroot_relative_squared_errorcorrelationsquared_correlation
Measure
C20.1013330.0967540.0901870.1055760.6986900.0167460.1294050.1216160.8296470.688314
Correlation0.1079440.0970080.0950770.1107300.7442760.0188600.1373330.1155500.9072980.823190
RSS0.1949680.1731260.1713320.2267071.3443020.0573230.2394220.2019520.6287340.395306
\n", "
" ], "text/plain": [ " absolute_error relative_error relative_error_lenient \\\n", "Measure \n", "C2 0.101333 0.096754 0.090187 \n", "Correlation 0.107944 0.097008 0.095077 \n", "RSS 0.194968 0.173126 0.171332 \n", "\n", " relative_error_strict normalized_absolute_error squared_error \\\n", "Measure \n", "C2 0.105576 0.698690 0.016746 \n", "Correlation 0.110730 0.744276 0.018860 \n", "RSS 0.226707 1.344302 0.057323 \n", "\n", " root_mean_squared_error root_relative_squared_error \\\n", "Measure \n", "C2 0.129405 0.121616 \n", "Correlation 0.137333 0.115550 \n", "RSS 0.239422 0.201952 \n", "\n", " correlation squared_correlation \n", "Measure \n", "C2 0.829647 0.688314 \n", "Correlation 0.907298 0.823190 \n", "RSS 0.628734 0.395306 " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(pd.concat([c2_regression_metrics, corr_regression_metrics, rss_regression_metrics]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyperparameters tuning\n", "\n", "This one gonna take a while..." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting 3 folds for each of 54 candidates, totalling 162 fits\n", "Best RMSE: -0.128827 using {'induction_measure': , 'minsupp_new': 6, 'pruning_measure': , 'voting_measure': }\n" ] } ], "source": [ "from sklearn.model_selection import KFold\n", "from sklearn.model_selection import GridSearchCV\n", "from rulekit.params import Measures\n", "\n", "\n", "# define models and parameters\n", "model = RuleRegressor(mean_based_regression=True)\n", "minsupp_new = range(5, 7)\n", "measures_choice = [Measures.C2, Measures.Correlation, Measures.RSS]\n", "\n", "# define grid search\n", "grid = {\n", " 'minsupp_new': minsupp_new, \n", " 'induction_measure': measures_choice, \n", " 'pruning_measure': measures_choice, \n", " 'voting_measure': measures_choice\n", "}\n", "cv = KFold(n_splits=3)\n", "grid_search = GridSearchCV(estimator=model, param_grid=grid, cv=cv, scoring='neg_root_mean_squared_error', verbose=True)\n", "grid_result = grid_search.fit(X_train, y_train)\n", "\n", "# summarize results\n", "print(\"Best RMSE: %f using %s\" % (grid_result.best_score_, grid_result.best_params_))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prediction using the model selected from the tuning" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "reg = grid_result.best_estimator_" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "ruleset = reg.model\n", "ruleset_stats = get_ruleset_stats('', ruleset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generated rules" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "IF MM31 = (-inf, 0.23) THEN MM116_pred = {0.40} [0.39,0.41]\n", "IF MM116 = <0.35, 0.45) AND DMM116 = <-0.05, inf) AND MM31 = (-inf, 0.24) THEN MM116_pred = {0.40} [0.39,0.42]\n", "IF MM116 = <0.35, 0.45) AND MM31 = (-inf, 0.24) THEN MM116_pred = {0.40} [0.38,0.42]\n", "IF PD = (-inf, 0.50) AND DMM116 = <-0.05, inf) AND AS038 = (-inf, 2.45) AND MM31 = <0.24, 0.25) THEN MM116_pred = {0.50} [0.47,0.54]\n", "IF PD = <0.50, inf) AND MM116 = (-inf, 0.45) AND AS038 = (-inf, 2.45) AND MM31 = <0.24, 0.25) AND PG072 = (-inf, 2.05) THEN MM116_pred = {0.41} [0.38,0.44]\n", "IF PD = (-inf, 0.50) AND MM31 = <0.24, 0.25) THEN MM116_pred = {0.51} [0.47,0.54]\n", "IF DMM116 = <-0.05, 0.05) AND MM31 = (-inf, 0.26) THEN MM116_pred = {0.46} [0.36,0.55]\n", "IF MM116 = (-inf, 0.45) THEN MM116_pred = {0.40} [0.37,0.44]\n", "IF MM116 = <0.45, inf) AND MM31 = <0.23, 0.24) AND BA13 = (-inf, 1075.50) THEN MM116_pred = {0.50} [0.48,0.52]\n", "IF PD = (-inf, 0.50) AND MM116 = <0.45, 0.55) AND DMM116 = <-0.05, inf) AND MM31 = <0.23, inf) AND PG072 = <1.65, inf) THEN MM116_pred = {0.51} [0.48,0.53]\n", "IF MM116 = <0.45, 0.55) AND DMM116 = <-0.05, inf) AND MM31 = <0.23, 0.29) AND PG072 = <1.65, inf) THEN MM116_pred = {0.51} [0.48,0.53]\n", "IF MM116 = <0.35, 0.55) AND DMM116 = (-inf, -0.05) AND MM31 = (-inf, 0.26) AND BA13 = <1077.50, inf) THEN MM116_pred = {0.54} [0.48,0.60]\n", "IF PD = (-inf, 0.50) AND MM116 = <0.45, 0.55) AND AS038 = (-inf, 2.45) AND MM31 = <0.23, inf) AND PG072 = <1.65, inf) AND BA13 = (-inf, 1077.50) THEN MM116_pred = {0.50} [0.48,0.53]\n", "IF PD = <0.50, inf) AND MM116 = <0.45, 0.55) AND DMM116 = <-0.05, 0.05) AND AS038 = <2.25, 2.35) AND MM31 = <0.28, 0.30) AND PG072 = <1.75, 1.95) AND BA13 = <1075.50, 1076.50) THEN MM116_pred = {0.55} [0.50,0.60]\n", "IF PD = <0.50, inf) AND MM116 = (-inf, 0.55) AND MM31 = <0.29, 0.30) AND PG072 = (-inf, 1.95) AND BA13 = (-inf, 1076.50) THEN MM116_pred = {0.55} [0.50,0.60]\n", "IF MM116 = (-inf, 0.55) THEN MM116_pred = {0.45} [0.39,0.52]\n", "IF PD = (-inf, 0.50) AND MM116 = <0.55, 0.65) AND DMM116 = <-0.05, 0.05) AND AS038 = <2.25, 2.45) AND MM31 = <0.26, 0.27) AND PG072 = <1.75, 1.85) AND BA13 = <1074.50, 1077.50) THEN MM116_pred = {0.60} [NaN,NaN]\n", "IF MM116 = <0.45, 0.65) AND MM31 = <0.23, inf) THEN MM116_pred = {0.55} [0.49,0.61]\n", "IF MM116 = <0.55, 0.75) THEN MM116_pred = {0.67} [0.58,0.77]\n", "IF MM116 = <0.75, 0.85) THEN MM116_pred = {0.83} [0.76,0.90]\n", "IF MM116 = <0.85, inf) THEN MM116_pred = {1.06} [0.88,1.24]\n" ] } ], "source": [ "for rule in ruleset.rules:\n", " print(rule)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ruleset evaluation" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
minimum_coveredmaximum_uncovered_fractionignore_missingpruning_enabledmax_growing_conditiontime_total_stime_growing_stime_pruning_srules_countconditions_per_ruleinduced_conditions_per_ruleavg_rule_coverageavg_rule_precisionavg_rule_qualitypvalueFDR_pvalueFWER_pvaluefraction_significantfraction_FDR_significantfraction_FWER_significant
Measure
6.00.0FalseTrue0.012.8566381.25023311.273084213.19047629.8095240.1161520.849723NaNNaNNaNNaN0.9523810.9523810.952381
\n", "
" ], "text/plain": [ " minimum_covered maximum_uncovered_fraction ignore_missing \\\n", "Measure \n", " 6.0 0.0 False \n", "\n", " pruning_enabled max_growing_condition time_total_s time_growing_s \\\n", "Measure \n", " True 0.0 12.856638 1.250233 \n", "\n", " time_pruning_s rules_count conditions_per_rule \\\n", "Measure \n", " 11.273084 21 3.190476 \n", "\n", " induced_conditions_per_rule avg_rule_coverage avg_rule_precision \\\n", "Measure \n", " 29.809524 0.116152 0.849723 \n", "\n", " avg_rule_quality pvalue FDR_pvalue FWER_pvalue \\\n", "Measure \n", " NaN NaN NaN NaN \n", "\n", " fraction_significant fraction_FDR_significant \\\n", "Measure \n", " 0.952381 0.952381 \n", "\n", " fraction_FWER_significant \n", "Measure \n", " 0.952381 " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(ruleset_stats)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Validate model on test dataset" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "absolute_error 0.111355\n", "relative_error 0.103524\n", "relative_error_lenient 0.097884\n", "relative_error_strict 0.114888\n", "normalized_absolute_error 0.767792\n", "squared_error 0.019642\n", "root_mean_squared_error 0.140148\n", "root_relative_squared_error 0.125609\n", "correlation 0.801204\n", "squared_correlation 0.641927\n", "Name: , dtype: float64" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "predictions = reg.predict(X_test)\n", "regression_metrics = get_regression_metrics('', predictions, y_test)\n", "display(regression_metrics.iloc[0])" ] } ], "metadata": { "kernelspec": { "display_name": "env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.10" }, "metadata": { "interpreter": { "hash": "62266c16fff41e971c13e9cb2ad3d47e4ef45d0678714c255381eb9fdcbd7032" } }, "orig_nbformat": 2 }, "nbformat": 4, "nbformat_minor": 2 }