{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "rulekit",
"display_name": "rulekit",
"language": "python"
},
"metadata": {
"interpreter": {
"hash": "62266c16fff41e971c13e9cb2ad3d47e4ef45d0678714c255381eb9fdcbd7032"
}
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"source": [
"# Regression"
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"This notebook presents example usage of package for solving regression problem on `methane` dataset. You can download training dataset [here](https://raw.githubusercontent.com/adaa-polsl/RuleKit/master/data/methane/methane-train.arff) and test dataset [here](https://raw.githubusercontent.com/adaa-polsl/RuleKit/master/data/methane/methane-test.arff)\n",
"\n",
"This tutorial will cover topics such as: \n",
"- training model \n",
"- changing model hyperparameters \n",
"- hyperparameters tuning \n",
"- calculating metrics for model \n",
"- getting RuleKit inbuilt "
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"## Summary of the dataset"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from scipy.io import arff\n",
"import pandas as pd\n",
"\n",
"datasets_path = \"\" \n",
"\n",
"train_file_name = \"methane-train.arff\"\n",
"test_file_name = \"methane-test.arff\"\n",
"\n",
"train_df = pd.DataFrame(arff.loadarff(datasets_path + train_file_name)[0])\n",
"test_df = pd.DataFrame(arff.loadarff(datasets_path + test_file_name)[0])"
]
},
{
"source": [
"### Train file"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Train file overview:\nName: methane-train.arff\nObjects number: 13368; Attributes number: 8\nBasic attribute statistics:\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" MM31 MM116 AS038 PG072 PD \\\n",
"count 13368.000000 13368.000000 13368.000000 13368.000000 13368.000000 \n",
"mean 0.363960 0.775007 2.294734 1.835600 0.308573 \n",
"std 0.117105 0.269366 0.142504 0.106681 0.461922 \n",
"min 0.170000 0.200000 1.400000 1.100000 0.000000 \n",
"25% 0.260000 0.500000 2.300000 1.800000 0.000000 \n",
"50% 0.360000 0.800000 2.300000 1.800000 0.000000 \n",
"75% 0.450000 1.000000 2.400000 1.900000 1.000000 \n",
"max 0.820000 2.200000 2.700000 2.600000 1.000000 \n",
"\n",
" BA13 DMM116 MM116_pred \n",
"count 13368.000000 13368.000000 13368.00000 \n",
"mean 1073.443372 -0.000007 0.79825 \n",
"std 3.162811 0.043566 0.28649 \n",
"min 1067.000000 -1.800000 0.20000 \n",
"25% 1070.000000 0.000000 0.50000 \n",
"50% 1075.000000 0.000000 0.80000 \n",
"75% 1076.000000 0.000000 1.00000 \n",
"max 1078.000000 0.800000 2.20000 "
],
"text/html": "
\n\n
\n \n \n | \n MM31 | \n MM116 | \n AS038 | \n PG072 | \n PD | \n BA13 | \n DMM116 | \n MM116_pred | \n
\n \n \n \n | count | \n 13368.000000 | \n 13368.000000 | \n 13368.000000 | \n 13368.000000 | \n 13368.000000 | \n 13368.000000 | \n 13368.000000 | \n 13368.00000 | \n
\n \n | mean | \n 0.363960 | \n 0.775007 | \n 2.294734 | \n 1.835600 | \n 0.308573 | \n 1073.443372 | \n -0.000007 | \n 0.79825 | \n
\n \n | std | \n 0.117105 | \n 0.269366 | \n 0.142504 | \n 0.106681 | \n 0.461922 | \n 3.162811 | \n 0.043566 | \n 0.28649 | \n
\n \n | min | \n 0.170000 | \n 0.200000 | \n 1.400000 | \n 1.100000 | \n 0.000000 | \n 1067.000000 | \n -1.800000 | \n 0.20000 | \n
\n \n | 25% | \n 0.260000 | \n 0.500000 | \n 2.300000 | \n 1.800000 | \n 0.000000 | \n 1070.000000 | \n 0.000000 | \n 0.50000 | \n
\n \n | 50% | \n 0.360000 | \n 0.800000 | \n 2.300000 | \n 1.800000 | \n 0.000000 | \n 1075.000000 | \n 0.000000 | \n 0.80000 | \n
\n \n | 75% | \n 0.450000 | \n 1.000000 | \n 2.400000 | \n 1.900000 | \n 1.000000 | \n 1076.000000 | \n 0.000000 | \n 1.00000 | \n
\n \n | max | \n 0.820000 | \n 2.200000 | \n 2.700000 | \n 2.600000 | \n 1.000000 | \n 1078.000000 | \n 0.800000 | \n 2.20000 | \n
\n \n
\n
"
},
"metadata": {},
"execution_count": 3
}
],
"source": [
"print(\"Train file overview:\")\n",
"print(f\"Name: {train_file_name}\")\n",
"print(f\"Objects number: {train_df.shape[0]}; Attributes number: {train_df.shape[1]}\")\n",
"print(\"Basic attribute statistics:\")\n",
"train_df.describe()"
]
},
{
"source": [
"### Test file"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\nTest file overview:\nName: methane-test.arff\nObjects number: 5728; Attributes number: 8\nBasic attribute statistics:\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" MM31 MM116 AS038 PG072 PD \\\n",
"count 5728.000000 5728.000000 5728.000000 5728.000000 5728.000000 \n",
"mean 0.556652 1.006913 2.236627 1.819239 0.538408 \n",
"std 0.114682 0.167983 0.104913 0.078865 0.498566 \n",
"min 0.350000 0.500000 1.800000 1.600000 0.000000 \n",
"25% 0.460000 0.900000 2.200000 1.800000 0.000000 \n",
"50% 0.550000 1.000000 2.200000 1.800000 1.000000 \n",
"75% 0.640000 1.100000 2.300000 1.900000 1.000000 \n",
"max 0.980000 1.600000 2.700000 2.100000 1.000000 \n",
"\n",
" BA13 DMM116 MM116_pred \n",
"count 5728.000000 5728.000000 5728.000000 \n",
"mean 1072.691690 -0.000017 1.042458 \n",
"std 2.799559 0.046849 0.171393 \n",
"min 1067.000000 -0.400000 0.600000 \n",
"25% 1071.000000 0.000000 0.900000 \n",
"50% 1073.000000 0.000000 1.000000 \n",
"75% 1075.000000 0.000000 1.200000 \n",
"max 1078.000000 0.300000 1.600000 "
],
"text/html": "\n\n
\n \n \n | \n MM31 | \n MM116 | \n AS038 | \n PG072 | \n PD | \n BA13 | \n DMM116 | \n MM116_pred | \n
\n \n \n \n | count | \n 5728.000000 | \n 5728.000000 | \n 5728.000000 | \n 5728.000000 | \n 5728.000000 | \n 5728.000000 | \n 5728.000000 | \n 5728.000000 | \n
\n \n | mean | \n 0.556652 | \n 1.006913 | \n 2.236627 | \n 1.819239 | \n 0.538408 | \n 1072.691690 | \n -0.000017 | \n 1.042458 | \n
\n \n | std | \n 0.114682 | \n 0.167983 | \n 0.104913 | \n 0.078865 | \n 0.498566 | \n 2.799559 | \n 0.046849 | \n 0.171393 | \n
\n \n | min | \n 0.350000 | \n 0.500000 | \n 1.800000 | \n 1.600000 | \n 0.000000 | \n 1067.000000 | \n -0.400000 | \n 0.600000 | \n
\n \n | 25% | \n 0.460000 | \n 0.900000 | \n 2.200000 | \n 1.800000 | \n 0.000000 | \n 1071.000000 | \n 0.000000 | \n 0.900000 | \n
\n \n | 50% | \n 0.550000 | \n 1.000000 | \n 2.200000 | \n 1.800000 | \n 1.000000 | \n 1073.000000 | \n 0.000000 | \n 1.000000 | \n
\n \n | 75% | \n 0.640000 | \n 1.100000 | \n 2.300000 | \n 1.900000 | \n 1.000000 | \n 1075.000000 | \n 0.000000 | \n 1.200000 | \n
\n \n | max | \n 0.980000 | \n 1.600000 | \n 2.700000 | \n 2.100000 | \n 1.000000 | \n 1078.000000 | \n 0.300000 | \n 1.600000 | \n
\n \n
\n
"
},
"metadata": {},
"execution_count": 4
}
],
"source": [
"# test file\n",
"print(\"\\nTest file overview:\")\n",
"print(f\"Name: {test_file_name}\")\n",
"print(f\"Objects number: {test_df.shape[0]}; Attributes number: {test_df.shape[1]}\")\n",
"print(\"Basic attribute statistics:\")\n",
"test_df.describe()"
]
},
{
"source": [
"## Import and init RuleKit"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from rulekit import RuleKit\n",
"from rulekit.regression import RuleRegressor\n",
"from rulekit.params import Measures\n",
"\n",
"\n",
"RuleKit.init()"
]
},
{
"source": [
"## Helper function for calculating metrics"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import sklearn.tree as scikit\n",
"import math\n",
"from sklearn.preprocessing import MultiLabelBinarizer\n",
"from sklearn import metrics\n",
"import pandas as pd\n",
"import numpy as np\n",
"from typing import Tuple\n",
"from math import sqrt\n",
"\n",
"\n",
"def get_regression_metrics(measure: str, y_pred, y_true) -> pd.DataFrame:\n",
" relative_error = 0\n",
" squared_relative_error = 0\n",
" relative_error_lenient = 0\n",
" relative_error_strict = 0\n",
" nae_denominator = 0\n",
" avg = sum(y_true) / len(y_pred)\n",
"\n",
" for i in range(0, len(y_pred)):\n",
" true = y_true[i]\n",
" predicted = y_pred[i]\n",
"\n",
" relative_error += abs((true - predicted) / true)\n",
" squared_relative_error += abs((true - predicted) / true) * abs((true - predicted) / true)\n",
" relative_error_lenient += abs((true - predicted) / max(true, predicted))\n",
" relative_error_strict += abs((true - predicted) / min(true, predicted))\n",
" nae_denominator += abs(avg - true)\n",
" relative_error /= len(y_pred)\n",
" squared_relative_error /= len(y_pred)\n",
" relative_error_lenient /= len(y_pred)\n",
" relative_error_strict /= len(y_pred)\n",
" nae_denominator /= len(y_pred)\n",
" correlation = np.mean(np.corrcoef(y_true, y_pred))\n",
"\n",
" dictionary = {\n",
" 'Measure': measure,\n",
" 'absolute_error': metrics.mean_absolute_error(y_true, y_pred),\n",
" 'relative_error': relative_error,\n",
" 'relative_error_lenient': relative_error_lenient,\n",
" 'relative_error_strict': relative_error_strict,\n",
" 'normalized_absolute_error': metrics.mean_absolute_error(y_true, y_pred) / nae_denominator,\n",
" 'squared_error': metrics.mean_squared_error(y_true, y_pred),\n",
" 'root_mean_squared_error': metrics.mean_squared_error(y_true, y_pred, squared=False),\n",
" 'root_relative_squared_error': sqrt(squared_relative_error),\n",
" 'correlation': correlation,\n",
" 'squared_correlation': np.power(correlation, 2),\n",
" }\n",
" return pd.DataFrame.from_records([dictionary], index='Measure')\n",
"\n",
"def get_ruleset_stats(measure: str, model) -> pd.DataFrame:\n",
" tmp = model.parameters.__dict__\n",
" del tmp['_java_object']\n",
" return pd.DataFrame.from_records([{'Measure': measure, **tmp, **model.stats.__dict__}], index='Measure')"
]
},
{
"source": [
"## Rule induction on training dataset"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"x_train = train_df.drop(['MM116_pred'], axis=1)\n",
"y_train = train_df['MM116_pred']"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": " minimum_covered maximum_uncovered_fraction ignore_missing \\\nMeasure \nC2 5.0 0.0 False \nCorrelation 5.0 0.0 False \nRSS 5.0 0.0 False \n\n pruning_enabled max_growing_condition time_total_s \\\nMeasure \nC2 True 0.0 162.647903 \nCorrelation True 0.0 89.092752 \nRSS True 0.0 145.849569 \n\n time_growing_s time_pruning_s rules_count conditions_per_rule \\\nMeasure \nC2 109.968198 52.479965 30 4.800000 \nCorrelation 52.286917 36.736225 14 3.857143 \nRSS 82.719695 63.063596 14 3.071429 \n\n induced_conditions_per_rule avg_rule_coverage \\\nMeasure \nC2 25.266667 0.145382 \nCorrelation 31.142857 0.200041 \nRSS 33.785714 0.268429 \n\n avg_rule_precision avg_rule_quality pvalue FDR_pvalue \\\nMeasure \nC2 0.910635 0.724619 4.967428e-03 4.967460e-03 \nCorrelation 0.868585 0.850067 0.000000e+00 0.000000e+00 \nRSS 0.778758 0.835207 1.840021e-12 1.840021e-12 \n\n FWER_pvalue fraction_significant fraction_FDR_significant \\\nMeasure \nC2 4.968345e-03 0.966667 0.966667 \nCorrelation 0.000000e+00 1.000000 1.000000 \nRSS 1.840021e-12 1.000000 1.000000 \n\n fraction_FWER_significant \nMeasure \nC2 0.966667 \nCorrelation 1.000000 \nRSS 1.000000 ",
"text/html": "\n\n
\n \n \n | \n minimum_covered | \n maximum_uncovered_fraction | \n ignore_missing | \n pruning_enabled | \n max_growing_condition | \n time_total_s | \n time_growing_s | \n time_pruning_s | \n rules_count | \n conditions_per_rule | \n induced_conditions_per_rule | \n avg_rule_coverage | \n avg_rule_precision | \n avg_rule_quality | \n pvalue | \n FDR_pvalue | \n FWER_pvalue | \n fraction_significant | \n fraction_FDR_significant | \n fraction_FWER_significant | \n
\n \n | Measure | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n
\n \n \n \n | C2 | \n 5.0 | \n 0.0 | \n False | \n True | \n 0.0 | \n 162.647903 | \n 109.968198 | \n 52.479965 | \n 30 | \n 4.800000 | \n 25.266667 | \n 0.145382 | \n 0.910635 | \n 0.724619 | \n 4.967428e-03 | \n 4.967460e-03 | \n 4.968345e-03 | \n 0.966667 | \n 0.966667 | \n 0.966667 | \n
\n \n | Correlation | \n 5.0 | \n 0.0 | \n False | \n True | \n 0.0 | \n 89.092752 | \n 52.286917 | \n 36.736225 | \n 14 | \n 3.857143 | \n 31.142857 | \n 0.200041 | \n 0.868585 | \n 0.850067 | \n 0.000000e+00 | \n 0.000000e+00 | \n 0.000000e+00 | \n 1.000000 | \n 1.000000 | \n 1.000000 | \n
\n \n | RSS | \n 5.0 | \n 0.0 | \n False | \n True | \n 0.0 | \n 145.849569 | \n 82.719695 | \n 63.063596 | \n 14 | \n 3.071429 | \n 33.785714 | \n 0.268429 | \n 0.778758 | \n 0.835207 | \n 1.840021e-12 | \n 1.840021e-12 | \n 1.840021e-12 | \n 1.000000 | \n 1.000000 | \n 1.000000 | \n
\n \n
\n
"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": " absolute_error relative_error relative_error_lenient \\\nMeasure \nC2 0.113257 0.119912 0.112692 \nCorrelation 0.096903 0.099532 0.096365 \nRSS 0.095808 0.111826 0.106694 \n\n relative_error_strict normalized_absolute_error squared_error \\\nMeasure \nC2 0.146711 0.481968 0.031913 \nCorrelation 0.124060 0.412373 0.028386 \nRSS 0.133913 0.407715 0.023018 \n\n root_mean_squared_error root_relative_squared_error \\\nMeasure \nC2 0.178643 0.167711 \nCorrelation 0.168481 0.147825 \nRSS 0.151715 0.151703 \n\n correlation squared_correlation \nMeasure \nC2 0.910077 0.828241 \nCorrelation 0.935456 0.875078 \nRSS 0.944984 0.892996 ",
"text/html": "\n\n
\n \n \n | \n absolute_error | \n relative_error | \n relative_error_lenient | \n relative_error_strict | \n normalized_absolute_error | \n squared_error | \n root_mean_squared_error | \n root_relative_squared_error | \n correlation | \n squared_correlation | \n
\n \n | Measure | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n
\n \n \n \n | C2 | \n 0.113257 | \n 0.119912 | \n 0.112692 | \n 0.146711 | \n 0.481968 | \n 0.031913 | \n 0.178643 | \n 0.167711 | \n 0.910077 | \n 0.828241 | \n
\n \n | Correlation | \n 0.096903 | \n 0.099532 | \n 0.096365 | \n 0.124060 | \n 0.412373 | \n 0.028386 | \n 0.168481 | \n 0.147825 | \n 0.935456 | \n 0.875078 | \n
\n \n | RSS | \n 0.095808 | \n 0.111826 | \n 0.106694 | \n 0.133913 | \n 0.407715 | \n 0.023018 | \n 0.151715 | \n 0.151703 | \n 0.944984 | \n 0.892996 | \n
\n \n
\n
"
},
"metadata": {}
}
],
"source": [
"# C2\n",
"c2_reg = RuleRegressor(\n",
" induction_measure=Measures.C2,\n",
" pruning_measure=Measures.C2,\n",
" voting_measure=Measures.C2,\n",
")\n",
"c2_reg.fit(x_train, y_train)\n",
"c2_ruleset = c2_reg.model\n",
"predictions = c2_reg.predict(x_train)\n",
"\n",
"regression_metrics = get_regression_metrics('C2', predictions, y_train)\n",
"ruleset_stats = get_ruleset_stats('C2', c2_ruleset)\n",
"\n",
"\n",
"# Correlation\n",
"corr_reg = RuleRegressor(\n",
" induction_measure=Measures.Correlation,\n",
" pruning_measure=Measures.Correlation,\n",
" voting_measure=Measures.Correlation,\n",
")\n",
"corr_reg.fit(x_train, y_train)\n",
"corr_ruleset = corr_reg.model\n",
"predictions = corr_reg.predict(x_train)\n",
"\n",
"tmp = get_regression_metrics('Correlation', predictions, y_train)\n",
"regression_metrics = pd.concat([regression_metrics, tmp])\n",
"ruleset_stats = pd.concat([ruleset_stats, get_ruleset_stats('Correlation', corr_ruleset)])\n",
"\n",
"\n",
"# RSS\n",
"rss_reg = RuleRegressor(\n",
" induction_measure=Measures.RSS,\n",
" pruning_measure=Measures.RSS,\n",
" voting_measure=Measures.RSS,\n",
")\n",
"rss_reg.fit(x_train, y_train)\n",
"rss_ruleset = rss_reg.model\n",
"predictions = rss_reg.predict(x_train)\n",
"\n",
"tmp = get_regression_metrics('RSS', predictions, y_train)\n",
"regression_metrics = pd.concat([regression_metrics, tmp])\n",
"ruleset_stats = pd.concat([ruleset_stats, get_ruleset_stats('RSS', rss_ruleset)])\n",
"\n",
"\n",
"display(ruleset_stats)\n",
"display(regression_metrics)"
]
},
{
"source": [
"### C2 Measure generated rules"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"IF PD = (-inf, 0.50) AND AS038 = (-inf, 2.35) AND MM31 = <0.21, 0.22) AND BA13 = <1075.50, inf) THEN MM116_pred = {0.40} [0.40,0.40]\nIF MM116 = <0.35, 0.45) AND DMM116 = <-0.05, inf) AND MM31 = (-inf, 0.24) THEN MM116_pred = {0.40} [0.38,0.42]\nIF MM116 = (-inf, 0.45) AND MM31 = <0.18, 0.24) THEN MM116_pred = {0.40} [0.38,0.42]\nIF MM31 = <0.24, 0.25) AND BA13 = (-inf, 1076.50) THEN MM116_pred = {0.50} [0.50,0.50]\nIF MM116 = (-inf, 0.45) AND DMM116 = <-0.05, inf) AND AS038 = (-inf, 2.45) AND MM31 = <0.19, 0.25) AND PG072 = (-inf, 2.05) THEN MM116_pred = {0.40} [0.38,0.42]\nIF MM116 = (-inf, 0.45) AND DMM116 = (-inf, 0.05) THEN MM116_pred = {0.40} [0.37,0.43]\nIF MM116 = <0.35, inf) AND MM31 = (-inf, 0.23) THEN MM116_pred = {0.40} [0.39,0.41]\nIF MM116 = <0.35, inf) AND DMM116 = <-0.05, 0.05) AND MM31 = (-inf, 0.24) AND BA13 = (-inf, 1077.50) THEN MM116_pred = {0.40} [0.37,0.43]\nIF MM116 = <0.35, inf) AND DMM116 = <-0.05, inf) AND AS038 = <2.05, inf) AND MM31 = (-inf, 0.24) AND BA13 = (-inf, 1077.50) THEN MM116_pred = {0.40} [0.37,0.43]\nIF MM116 = <0.35, 0.70) AND DMM116 = <-0.05, 0.05) AND MM31 = (-inf, 0.24) THEN MM116_pred = {0.40} [0.36,0.44]\nIF PD = (-inf, 0.50) AND MM116 = <0.35, inf) AND DMM116 = <-0.05, 0.05) AND MM31 = (-inf, 0.24) THEN MM116_pred = {0.40} [0.36,0.44]\nIF MM116 = <0.55, inf) AND DMM116 = (-inf, 0.05) THEN MM116_pred = {0.90} [0.69,1.11]\nIF MM116 = <0.45, inf) AND MM31 = <0.23, 0.27) AND PG072 = <1.65, inf) AND BA13 = (-inf, 1075.50) THEN MM116_pred = {0.50} [0.49,0.51]\nIF PD = (-inf, 0.50) AND MM116 = <0.45, 0.55) AND DMM116 = <-0.05, inf) AND MM31 = <0.23, inf) AND PG072 = <1.65, inf) THEN MM116_pred = {0.50} [0.47,0.53]\nIF MM116 = <0.45, 0.55) AND DMM116 = <-0.05, inf) AND PG072 = <1.65, inf) THEN MM116_pred = {0.50} [0.47,0.53]\nIF PD = (-inf, 0.50) AND MM116 = <0.45, 0.55) AND AS038 = (-inf, 2.45) AND PG072 = <1.65, inf) THEN MM116_pred = {0.50} [0.47,0.53]\nIF MM116 = <0.55, 0.65) AND DMM116 = <0.05, inf) AND MM31 = (-inf, 0.26) AND PG072 = (-inf, 1.85) THEN MM116_pred = {0.60} [0.60,0.60]\nIF MM116 = <0.55, 0.95) AND DMM116 = <0.05, inf) AND MM31 = (-inf, 0.27) AND BA13 = <1075.50, inf) THEN MM116_pred = {0.70} [0.59,0.81]\nIF MM116 = (-inf, 1.05) AND DMM116 = (-inf, 0.15) AND AS038 = (-inf, 2.45) AND MM31 = (-inf, 0.27) AND PG072 = (-inf, 2.05) THEN MM116_pred = {0.40} [0.30,0.50]\nIF PD = (-inf, 0.50) AND DMM116 = <0.05, inf) AND AS038 = <2.35, inf) AND MM31 = <0.27, 0.28) THEN MM116_pred = {0.60} [0.60,0.60]\nIF PD = (-inf, 0.50) AND MM116 = <0.55, 0.75) AND DMM116 = <-0.05, inf) AND MM31 = <0.27, 0.30) THEN MM116_pred = {0.60} [0.57,0.63]\nIF MM116 = <0.55, 0.85) AND MM31 = <0.27, 0.30) THEN MM116_pred = {0.60} [0.51,0.69]\nIF MM116 = <0.45, 0.55) AND DMM116 = <-0.15, inf) AND MM31 = (-inf, 0.30) AND PG072 = <1.65, inf) THEN MM116_pred = {0.50} [0.47,0.53]\nIF DMM116 = (-inf, 0.15) AND AS038 = (-inf, 2.55) AND MM31 = <0.19, 0.30) AND PG072 = <1.55, inf) THEN MM116_pred = {0.50} [0.37,0.63]\nIF MM116 = (-inf, 0.95) AND DMM116 = <-0.30, inf) AND AS038 = <2.25, 2.45) AND MM31 = <0.28, 0.31) AND PG072 = <1.75, 1.95) AND BA13 = (-inf, 1077.50) THEN MM116_pred = {0.60} [0.50,0.70]\nIF MM116 = <0.45, 1.10) AND DMM116 = <-0.15, inf) AND AS038 = <2.15, 2.45) AND MM31 = (-inf, 0.31) AND BA13 = <1072.50, 1077.50) THEN MM116_pred = {0.50} [0.40,0.60]\nIF PD = (-inf, 0.50) AND MM116 = <0.45, 0.95) AND MM31 = <0.30, inf) AND BA13 = (-inf, 1076.50) THEN MM116_pred = {0.80} [0.68,0.92]\nIF MM116 = <0.35, 0.65) AND AS038 = (-inf, 2.45) AND MM31 = <0.29, inf) AND BA13 = (-inf, 1076.50) THEN MM116_pred = {0.60} [0.56,0.64]\nIF MM116 = <0.65, inf) AND DMM116 = <0.05, inf) AND AS038 = <2.15, inf) AND MM31 = <0.30, 0.32) AND PG072 = (-inf, 1.95) AND BA13 = <1074.50, inf) THEN MM116_pred = {1.20} [1.00,1.40]\nIF MM116 = <0.45, inf) AND DMM116 = <-0.15, inf) AND MM31 = <0.32, inf) THEN MM116_pred = {0.90} [0.69,1.11]\n"
]
}
],
"source": [
"for rule in c2_ruleset.rules:\n",
" print(rule)"
]
},
{
"source": [
"### Correlation Measure generated rules"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"IF MM116 = (-inf, 0.45) AND DMM116 = <-0.05, inf) AND MM31 = <0.18, 0.24) THEN MM116_pred = {0.40} [0.38,0.42]\nIF MM116 = (-inf, 0.45) AND MM31 = <0.18, 0.24) THEN MM116_pred = {0.40} [0.38,0.42]\nIF MM31 = (-inf, 0.25) THEN MM116_pred = {0.40} [0.33,0.47]\nIF DMM116 = (-inf, 0.05) AND AS038 = (-inf, 2.45) AND MM31 = (-inf, 0.26) THEN MM116_pred = {0.40} [0.31,0.49]\nIF MM31 = <0.18, 0.28) THEN MM116_pred = {0.50} [0.38,0.62]\nIF MM116 = <0.25, 0.45) AND DMM116 = <-0.05, inf) AND MM31 = <0.18, inf) AND BA13 = (-inf, 1077.50) THEN MM116_pred = {0.40} [0.38,0.42]\nIF MM116 = (-inf, 0.45) AND DMM116 = <-0.05, inf) AND MM31 = <0.18, inf) AND BA13 = (-inf, 1077.50) THEN MM116_pred = {0.40} [0.37,0.43]\nIF MM116 = <0.45, 0.55) AND DMM116 = <-0.05, inf) AND PG072 = <1.65, inf) THEN MM116_pred = {0.50} [0.47,0.53]\nIF MM116 = <0.55, 0.65) AND DMM116 = <-0.15, inf) AND MM31 = <0.24, inf) THEN MM116_pred = {0.60} [0.56,0.64]\nIF MM116 = <0.45, 0.75) AND DMM116 = <-0.05, inf) AND MM31 = (-inf, 0.29) AND BA13 = <1072.50, 1077.50) THEN MM116_pred = {0.50} [0.47,0.53]\nIF MM116 = <0.45, 0.75) AND DMM116 = <-0.05, inf) AND AS038 = (-inf, 2.45) AND MM31 = (-inf, 0.30) AND BA13 = <1072.50, 1077.50) THEN MM116_pred = {0.50} [0.46,0.54]\nIF PD = (-inf, 0.50) AND MM116 = <0.45, 0.85) AND AS038 = (-inf, 2.45) AND MM31 = (-inf, 0.29) AND BA13 = <1072.50, 1077.50) THEN MM116_pred = {0.50} [0.44,0.56]\nIF MM116 = <0.45, 0.85) AND MM31 = <0.26, inf) THEN MM116_pred = {0.70} [0.57,0.83]\nIF MM116 = <0.70, inf) THEN MM116_pred = {0.90} [0.70,1.10]\n"
]
}
],
"source": [
"for rule in corr_ruleset.rules:\n",
" print(rule)"
]
},
{
"source": [
"### RSS Measure generated rules"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"IF MM31 = (-inf, 0.23) THEN MM116_pred = {0.40} [0.39,0.41]\nIF MM116 = (-inf, 0.45) AND MM31 = <0.18, 0.25) AND PG072 = (-inf, 2.05) THEN MM116_pred = {0.40} [0.38,0.42]\nIF MM31 = (-inf, 0.26) THEN MM116_pred = {0.40} [0.30,0.50]\nIF MM31 = <0.18, 0.28) THEN MM116_pred = {0.50} [0.38,0.62]\nIF PD = (-inf, 0.50) AND MM116 = <0.25, inf) AND DMM116 = <-0.95, inf) AND MM31 = <0.23, inf) AND BA13 = (-inf, 1075.50) THEN MM116_pred = {0.70} [0.48,0.92]\nIF MM116 = (-inf, 0.25) THEN MM116_pred = {0.20} [0.15,0.25]\nIF MM116 = (-inf, 0.55) AND DMM116 = <-0.15, inf) AND MM31 = <0.23, inf) AND PG072 = <1.65, inf) THEN MM116_pred = {0.50} [0.44,0.56]\nIF MM116 = (-inf, 0.65) AND DMM116 = <-0.15, 0.15) AND MM31 = <0.23, 0.40) AND PG072 = <1.65, inf) AND BA13 = <1070.50, inf) THEN MM116_pred = {0.50} [0.43,0.57]\nIF DMM116 = <-0.25, 0.15) AND MM31 = (-inf, 0.32) AND PG072 = <1.55, inf) THEN MM116_pred = {0.50} [0.36,0.64]\nIF MM116 = (-inf, 0.75) AND DMM116 = <-0.15, inf) AND AS038 = <2.15, inf) AND BA13 = <1069.50, inf) THEN MM116_pred = {0.50} [0.37,0.63]\nIF MM116 = (-inf, 0.75) AND MM31 = <0.23, 0.46) THEN MM116_pred = {0.60} [0.48,0.72]\nIF MM116 = (-inf, 0.85) AND DMM116 = (-inf, 0.15) AND AS038 = (-inf, 2.55) AND MM31 = (-inf, 0.33) THEN MM116_pred = {0.50} [0.37,0.63]\nIF MM116 = <0.85, inf) THEN MM116_pred = {1} [0.82,1.18]\nIF MM116 = <0.65, 0.85) THEN MM116_pred = {0.80} [0.71,0.89]\n"
]
}
],
"source": [
"for rule in rss_ruleset.rules:\n",
" print(rule)"
]
},
{
"source": [
"## Evaluation on a test set"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"x_test = test_df.drop(['MM116_pred'], axis=1)\n",
"y_test = test_df['MM116_pred']"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# C2\n",
"c2_predictions = c2_reg.predict(x_test)\n",
"c2_regression_metrics = get_regression_metrics('C2', c2_predictions, y_test)\n",
"\n",
"# Correlation\n",
"corr_predictions = corr_reg.predict(x_test)\n",
"corr_regression_metrics = get_regression_metrics('Correlation', corr_predictions, y_test)\n",
"\n",
"# RSS\n",
"rss_predictions = rss_reg.predict(x_test)\n",
"rss_regression_metrics = get_regression_metrics('RSS', rss_predictions, y_test)\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": " absolute_error relative_error relative_error_lenient \\\nMeasure \nC2 0.175348 0.153675 0.152205 \nCorrelation 0.167494 0.143461 0.143185 \nRSS 0.138080 0.121742 0.120687 \n\n relative_error_strict normalized_absolute_error squared_error \\\nMeasure \nC2 0.197185 1.209023 0.049460 \nCorrelation 0.187964 1.154868 0.049314 \nRSS 0.150155 0.952062 0.032734 \n\n root_mean_squared_error root_relative_squared_error \\\nMeasure \nC2 0.222395 0.183108 \nCorrelation 0.222068 0.181535 \nRSS 0.180926 0.151750 \n\n correlation squared_correlation \nMeasure \nC2 0.768065 0.589923 \nCorrelation 0.801878 0.643009 \nRSS 0.815327 0.664758 ",
"text/html": "\n\n
\n \n \n | \n absolute_error | \n relative_error | \n relative_error_lenient | \n relative_error_strict | \n normalized_absolute_error | \n squared_error | \n root_mean_squared_error | \n root_relative_squared_error | \n correlation | \n squared_correlation | \n
\n \n | Measure | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n
\n \n \n \n | C2 | \n 0.175348 | \n 0.153675 | \n 0.152205 | \n 0.197185 | \n 1.209023 | \n 0.049460 | \n 0.222395 | \n 0.183108 | \n 0.768065 | \n 0.589923 | \n
\n \n | Correlation | \n 0.167494 | \n 0.143461 | \n 0.143185 | \n 0.187964 | \n 1.154868 | \n 0.049314 | \n 0.222068 | \n 0.181535 | \n 0.801878 | \n 0.643009 | \n
\n \n | RSS | \n 0.138080 | \n 0.121742 | \n 0.120687 | \n 0.150155 | \n 0.952062 | \n 0.032734 | \n 0.180926 | \n 0.151750 | \n 0.815327 | \n 0.664758 | \n
\n \n
\n
"
},
"metadata": {}
}
],
"source": [
"display(pd.concat([c2_regression_metrics, corr_regression_metrics, rss_regression_metrics]))"
]
},
{
"source": [
"## Hyperparameters tuning\n",
"\n",
"This one gonna take a while..."
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import KFold\n",
"from sklearn.model_selection import GridSearchCV\n",
"from rulekit.params import Measures\n",
"\n",
"\n",
"# define models and parameters\n",
"model = RuleRegressor()\n",
"min_rule_covered = [5]#range(3, 15)\n",
"measures_choice = [Measures.C2, Measures.RSS, Measures.WeightedLaplace, Measures.Correlation]\n",
"\n",
"# define grid search\n",
"grid = {\n",
" 'min_rule_covered': min_rule_covered, \n",
" 'induction_measure': measures_choice, \n",
" 'pruning_measure': measures_choice, \n",
" 'voting_measure': measures_choice\n",
"}\n",
"cv = KFold(n_splits=3)\n",
"grid_search = GridSearchCV(estimator=model, param_grid=grid, cv=cv, scoring='neg_root_mean_squared_error')\n",
"grid_result = grid_search.fit(x_train, y_train)\n",
"\n",
"# summarize results\n",
"print(\"Best RMSE: %f using %s\" % (grid_result.best_score_, grid_result.best_params_))"
]
},
{
"source": [
"## Prediction using the model selected from the tuning"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"reg = grid_result.best_estimator_"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"ruleset = reg.model\n",
"ruleset_stats = get_ruleset_stats('', ruleset)"
]
},
{
"source": [
"Generated rules"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"IF PD = (-inf, 0.50) AND AS038 = (-inf, 2.35) AND MM31 = <0.21, 0.22) AND BA13 = <1075.50, inf) THEN MM116_pred = {0.40} [0.40,0.40]\nIF MM116 = <0.35, 0.45) AND DMM116 = <-0.05, inf) AND MM31 = (-inf, 0.24) THEN MM116_pred = {0.40} [0.38,0.42]\nIF MM116 = (-inf, 0.45) AND MM31 = <0.18, 0.24) THEN MM116_pred = {0.40} [0.38,0.42]\nIF MM31 = <0.24, 0.25) AND BA13 = (-inf, 1076.50) THEN MM116_pred = {0.50} [0.50,0.50]\nIF MM116 = (-inf, 0.45) AND DMM116 = <-0.05, inf) AND AS038 = (-inf, 2.45) AND MM31 = <0.19, 0.25) AND PG072 = (-inf, 2.05) THEN MM116_pred = {0.40} [0.38,0.42]\nIF MM116 = (-inf, 0.45) AND DMM116 = (-inf, 0.05) THEN MM116_pred = {0.40} [0.37,0.43]\nIF MM116 = <0.35, inf) AND MM31 = (-inf, 0.23) THEN MM116_pred = {0.40} [0.39,0.41]\nIF MM116 = <0.35, inf) AND DMM116 = <-0.05, 0.05) AND MM31 = (-inf, 0.24) AND BA13 = (-inf, 1077.50) THEN MM116_pred = {0.40} [0.37,0.43]\nIF MM116 = <0.35, inf) AND DMM116 = <-0.05, inf) AND AS038 = <2.05, inf) AND MM31 = (-inf, 0.24) AND BA13 = (-inf, 1077.50) THEN MM116_pred = {0.40} [0.37,0.43]\nIF MM116 = <0.35, 0.70) AND DMM116 = <-0.05, 0.05) AND MM31 = (-inf, 0.24) THEN MM116_pred = {0.40} [0.36,0.44]\nIF PD = (-inf, 0.50) AND MM116 = <0.35, inf) AND DMM116 = <-0.05, 0.05) AND MM31 = (-inf, 0.24) THEN MM116_pred = {0.40} [0.36,0.44]\nIF MM116 = <0.55, inf) AND DMM116 = (-inf, 0.05) THEN MM116_pred = {0.90} [0.69,1.11]\nIF MM116 = <0.45, inf) AND MM31 = <0.23, 0.27) AND PG072 = <1.65, inf) AND BA13 = (-inf, 1075.50) THEN MM116_pred = {0.50} [0.49,0.51]\nIF PD = (-inf, 0.50) AND MM116 = <0.45, 0.55) AND DMM116 = <-0.05, inf) AND MM31 = <0.23, inf) AND PG072 = <1.65, inf) THEN MM116_pred = {0.50} [0.47,0.53]\nIF MM116 = <0.45, 0.55) AND DMM116 = <-0.05, inf) AND PG072 = <1.65, inf) THEN MM116_pred = {0.50} [0.47,0.53]\nIF PD = (-inf, 0.50) AND MM116 = <0.45, 0.55) AND AS038 = (-inf, 2.45) AND PG072 = <1.65, inf) THEN MM116_pred = {0.50} [0.47,0.53]\nIF MM116 = <0.55, 0.65) AND DMM116 = <0.05, inf) AND MM31 = (-inf, 0.26) AND PG072 = (-inf, 1.85) THEN MM116_pred = {0.60} [0.60,0.60]\nIF MM116 = <0.55, 0.95) AND DMM116 = <0.05, inf) AND MM31 = (-inf, 0.27) AND BA13 = <1075.50, inf) THEN MM116_pred = {0.70} [0.59,0.81]\nIF MM116 = (-inf, 1.05) AND DMM116 = (-inf, 0.15) AND AS038 = (-inf, 2.45) AND MM31 = (-inf, 0.27) AND PG072 = (-inf, 2.05) THEN MM116_pred = {0.40} [0.30,0.50]\nIF PD = (-inf, 0.50) AND DMM116 = <0.05, inf) AND AS038 = <2.35, inf) AND MM31 = <0.27, 0.28) THEN MM116_pred = {0.60} [0.60,0.60]\nIF PD = (-inf, 0.50) AND MM116 = <0.55, 0.75) AND DMM116 = <-0.05, inf) AND MM31 = <0.27, 0.30) THEN MM116_pred = {0.60} [0.57,0.63]\nIF MM116 = <0.55, 0.85) AND MM31 = <0.27, 0.30) THEN MM116_pred = {0.60} [0.51,0.69]\nIF MM116 = <0.45, 0.55) AND DMM116 = <-0.15, inf) AND MM31 = (-inf, 0.30) AND PG072 = <1.65, inf) THEN MM116_pred = {0.50} [0.47,0.53]\nIF DMM116 = (-inf, 0.15) AND AS038 = (-inf, 2.55) AND MM31 = <0.19, 0.30) AND PG072 = <1.55, inf) THEN MM116_pred = {0.50} [0.37,0.63]\nIF MM116 = (-inf, 0.95) AND DMM116 = <-0.30, inf) AND AS038 = <2.25, 2.45) AND MM31 = <0.28, 0.31) AND PG072 = <1.75, 1.95) AND BA13 = (-inf, 1077.50) THEN MM116_pred = {0.60} [0.50,0.70]\nIF MM116 = <0.45, 1.10) AND DMM116 = <-0.15, inf) AND AS038 = <2.15, 2.45) AND MM31 = (-inf, 0.31) AND BA13 = <1072.50, 1077.50) THEN MM116_pred = {0.50} [0.40,0.60]\nIF PD = (-inf, 0.50) AND MM116 = <0.45, 0.95) AND MM31 = <0.30, inf) AND BA13 = (-inf, 1076.50) THEN MM116_pred = {0.80} [0.68,0.92]\nIF MM116 = <0.35, 0.65) AND AS038 = (-inf, 2.45) AND MM31 = <0.29, inf) AND BA13 = (-inf, 1076.50) THEN MM116_pred = {0.60} [0.56,0.64]\nIF MM116 = <0.65, inf) AND DMM116 = <0.05, inf) AND AS038 = <2.15, inf) AND MM31 = <0.30, 0.32) AND PG072 = (-inf, 1.95) AND BA13 = <1074.50, inf) THEN MM116_pred = {1.20} [1.00,1.40]\nIF MM116 = <0.45, inf) AND DMM116 = <-0.15, inf) AND MM31 = <0.32, inf) THEN MM116_pred = {0.90} [0.69,1.11]\n"
]
}
],
"source": [
"for rule in ruleset.rules:\n",
" print(rule)"
]
},
{
"source": [
"Ruleset evaluation"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": " minimum_covered maximum_uncovered_fraction ignore_missing \\\nMeasure \n 5.0 0.0 False \n\n pruning_enabled max_growing_condition time_total_s time_growing_s \\\nMeasure \n True 0.0 105.966684 71.493354 \n\n time_pruning_s rules_count conditions_per_rule \\\nMeasure \n 34.399725 30 4.8 \n\n induced_conditions_per_rule avg_rule_coverage avg_rule_precision \\\nMeasure \n 25.266667 0.145382 0.910635 \n\n avg_rule_quality pvalue FDR_pvalue FWER_pvalue \\\nMeasure \n 0.724619 0.004967 0.004967 0.004968 \n\n fraction_significant fraction_FDR_significant \\\nMeasure \n 0.966667 0.966667 \n\n fraction_FWER_significant \nMeasure \n 0.966667 ",
"text/html": "\n\n
\n \n \n | \n minimum_covered | \n maximum_uncovered_fraction | \n ignore_missing | \n pruning_enabled | \n max_growing_condition | \n time_total_s | \n time_growing_s | \n time_pruning_s | \n rules_count | \n conditions_per_rule | \n induced_conditions_per_rule | \n avg_rule_coverage | \n avg_rule_precision | \n avg_rule_quality | \n pvalue | \n FDR_pvalue | \n FWER_pvalue | \n fraction_significant | \n fraction_FDR_significant | \n fraction_FWER_significant | \n
\n \n | Measure | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n
\n \n \n \n | \n 5.0 | \n 0.0 | \n False | \n True | \n 0.0 | \n 105.966684 | \n 71.493354 | \n 34.399725 | \n 30 | \n 4.8 | \n 25.266667 | \n 0.145382 | \n 0.910635 | \n 0.724619 | \n 0.004967 | \n 0.004967 | \n 0.004968 | \n 0.966667 | \n 0.966667 | \n 0.966667 | \n
\n \n
\n
"
},
"metadata": {}
}
],
"source": [
"display(ruleset_stats)"
]
},
{
"source": [
"### Validate model on test dataset"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": " absolute_error relative_error relative_error_lenient \\\nMeasure \n 0.175348 0.153675 0.152205 \n\n relative_error_strict normalized_absolute_error squared_error \\\nMeasure \n 0.197185 1.209023 0.04946 \n\n root_mean_squared_error root_relative_squared_error correlation \\\nMeasure \n 0.222395 0.183108 0.768065 \n\n squared_correlation \nMeasure \n 0.589923 ",
"text/html": "\n\n
\n \n \n | \n absolute_error | \n relative_error | \n relative_error_lenient | \n relative_error_strict | \n normalized_absolute_error | \n squared_error | \n root_mean_squared_error | \n root_relative_squared_error | \n correlation | \n squared_correlation | \n
\n \n | Measure | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n | \n
\n \n \n \n | \n 0.175348 | \n 0.153675 | \n 0.152205 | \n 0.197185 | \n 1.209023 | \n 0.04946 | \n 0.222395 | \n 0.183108 | \n 0.768065 | \n 0.589923 | \n
\n \n
\n
"
},
"metadata": {}
}
],
"source": [
"predictions = reg.predict(x_test)\n",
"regression_metrics = get_regression_metrics('', predictions, y_test)\n",
"display(regression_metrics.iloc[0])"
]
}
]
}