{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Expert Rules" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook presents example usage of user-guided rule induction which follows the scheme introduced by the [GuideR](https://www.sciencedirect.com/science/article/abs/pii/S0950705119300802?dgcid=coauthor) algorithm (Sikora et al, 2019). \n", "Each problem (classification, regression, survival) in addition to the basic class has an expert class, i.e. RuleClassifier and ExpertRuleClassifier. Expert classes allow you to define set of initial rules, preferred conditions and forbidden conditions. \n", "This tutorial will show you how to define rules and conditions\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import RuleKit" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from rulekit import RuleKit\n", "from rulekit.classification import RuleClassifier\n", "from rulekit.params import Measures" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from scipy.io import arff\n", "import pandas as pd\n", "\n", "\n", "data_df = pd.DataFrame(arff.loadarff(\"seismic-bumps.arff\")[0])\n", "data_df['class'] = data_df['class'].astype(int)\n", "\n", "X = data_df.drop(['class'], axis=1)\n", "y = data_df['class']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define rules and conditions" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "expert_rules = [\n", " ('rule-0', 'IF [[gimpuls = <-inf, 750)]] THEN class = {0}'),\n", " ('rule-1', 'IF [[gimpuls = <750, inf)]] THEN class = {1}')\n", "]\n", "\n", "expert_preferred_conditions = [('preferred-condition-0', '1: IF [[seismic = {a}]] THEN class = {0}'), (\n", " 'preferred-attribute-0', '1: IF [[gimpuls = Any]] THEN class = {1}')]\n", "\n", "expert_forbidden_conditions = [('forb-attribute-0', '1: IF [[seismoacoustic = Any]] THEN class = {0}'), (\n", " 'forb-attribute-1', 'inf: IF [[ghazard = Any]] THEN class = {1}')]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Rule induction" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from rulekit.classification import ExpertRuleClassifier\n", "\n", "clf = ExpertRuleClassifier(\n", " minsupp_new=8,\n", " max_growing=0,\n", " extend_using_preferred=True,\n", " extend_using_automatic=True,\n", " induce_using_preferred=True,\n", " induce_using_automatic=True\n", ")\n", "clf.fit(\n", " X, y,\n", " expert_rules=expert_rules,\n", " expert_preferred_conditions=expert_preferred_conditions,\n", " expert_forbidden_conditions=expert_forbidden_conditions\n", ")\n", "ruleset = clf.model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "IF [seismic = {a}] AND gimpuls = (-inf, 521.50) AND genergy = (-inf, 32875) AND nbumps = (-inf, 0.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1252.50) AND nbumps = (-inf, 1.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1342.50) AND goimpuls = (-inf, 312) AND nbumps = (-inf, 1.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1427.50) AND nbumps = (-inf, 1.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1653.50) AND genergy = (-inf, 1006585) AND goimpuls = (-inf, 312) AND nbumps = (-inf, 1.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1752) AND nbumps = (-inf, 1.50) THEN class = {0}\n", "IF gimpuls = (-inf, 2733) AND goimpuls = (-inf, 312) AND nbumps = (-inf, 1.50) THEN class = {0}\n", "IF gimpuls = <2965, inf) AND genergy = <634250, inf) AND nbumps = (-inf, 1.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1331) AND nbumps = (-inf, 2.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1655.50) AND genergy = (-inf, 386010) AND nbumps = (-inf, 2.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1686) AND goimpuls = (-inf, 312) AND nbumps5 = (-inf, 0.50) AND nbumps = (-inf, 2.50) AND nbumps2 = (-inf, 1.50) THEN class = {0}\n", "IF gimpuls = (-inf, 2892) AND genergy = (-inf, 386010) AND goimpuls = (-inf, 312) AND nbumps = (-inf, 2.50) THEN class = {0}\n", "IF gimpuls = (-inf, 2068.50) AND goimpuls = (-inf, 312) AND genergy = (-inf, 1004565) AND nbumps = (-inf, 2.50) AND nbumps2 = (-inf, 1.50) THEN class = {0}\n", "IF gimpuls = (-inf, 2184.50) AND nbumps = (-inf, 2.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1253.50) AND nbumps3 = (-inf, 1.50) AND nbumps2 = (-inf, 2.50) THEN class = {0}\n", "IF gimpuls = (-inf, 901) AND goimpuls = (-inf, 96.50) AND senergy = (-inf, 3850) AND nbumps = (-inf, 3.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1253.50) AND nbumps3 = (-inf, 1.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1253.50) AND goimpuls = (-inf, 312) AND senergy = (-inf, 9600) AND nbumps3 = (-inf, 2.50) AND nbumps2 = (-inf, 1.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1253.50) AND nbumps3 = (-inf, 2.50) AND nbumps2 = (-inf, 1.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1253.50) AND goimpuls = (-inf, 312) AND senergy = (-inf, 8100) AND nbumps2 = (-inf, 2.50) THEN class = {0}\n", "IF ghazard = {a} AND goenergy = <-40.50, 68.50) AND maxenergy = (-inf, 5500) AND gimpuls = (-inf, 901) AND goimpuls = <-39.50, inf) AND senergy = <1150, inf) AND nbumps2 = <1.50, inf) THEN class = {0}\n", "IF goenergy = <-48.50, inf) AND gimpuls = (-inf, 695.50) AND maxenergy = <2500, inf) AND goimpuls = <-54.50, inf) AND genergy = <10915, inf) AND nbumps3 = (-inf, 3.50) AND senergy = <3950, inf) AND nbumps2 = (-inf, 1.50) AND nbumps = (-inf, 6.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1253.50) AND nbumps = (-inf, 4.50) AND nbumps2 = (-inf, 2.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1253.50) AND nbumps3 = (-inf, 2.50) AND nbumps = (-inf, 5.50) THEN class = {0}\n", "IF maxenergy = (-inf, 75000) AND gimpuls = (-inf, 901) AND genergy = (-inf, 378500) AND nbumps3 = (-inf, 3.50) AND nbumps4 = (-inf, 2.50) THEN class = {0}\n", "IF gimpuls = (-inf, 1139.50) AND goimpuls = (-inf, 312) AND senergy = (-inf, 85450) THEN class = {0}\n", "IF gimpuls = <1150.50, inf) AND goimpuls = <-35.50, inf) AND nbumps3 = (-inf, 2.50) AND nbumps2 = (-inf, 0.50) AND nbumps = <1.50, inf) THEN class = {0}\n", "IF goenergy = <-18.50, inf) AND gimpuls = <927, inf) AND genergy = (-inf, 508210) AND senergy = (-inf, 5750) AND nbumps2 = <1.50, inf) THEN class = {0}\n", "IF senergy = (-inf, 5750) THEN class = {0}\n", "IF gimpuls = (-inf, 2489.50) AND genergy = (-inf, 318735) AND nbumps3 = (-inf, 2.50) AND nbumps2 = (-inf, 2.50) THEN class = {0}\n", "IF goenergy = <-36.50, inf) AND goimpuls = (-inf, 6.50) AND genergy = <392530, inf) AND senergy = <6750, inf) AND nbumps2 = (-inf, 1.50) THEN class = {0}\n", "IF gimpuls = (-inf, 3881.50) AND nbumps = (-inf, 4.50) AND nbumps2 = (-inf, 2.50) THEN class = {0}\n", "IF [gimpuls = <1253.50, inf)] AND goenergy = <-40.50, 87) AND maxenergy = (-inf, 7500) AND genergy = <96260, 673155) AND seismic = {b} AND seismoacoustic = {a} AND senergy = (-inf, 10000) AND nbumps = (-inf, 3.50) THEN class = {1}\n", "IF goenergy = (-inf, 96) AND maxenergy = <1500, inf) AND gimpuls = <605, 1959) AND goimpuls = <-55, 95) AND genergy = <61250, 662435) AND senergy = (-inf, 36050) AND nbumps3 = <0.50, inf) AND nbumps2 = <0.50, inf) AND nbumps = (-inf, 6.50) THEN class = {1}\n", "IF goenergy = (-inf, 186) AND maxenergy = <1500, inf) AND gimpuls = <538.50, inf) AND genergy = <58310, 934630) AND goimpuls = <-55, inf) AND senergy = (-inf, 40650) AND nbumps2 = <0.50, inf) THEN class = {1}\n", "IF gimpuls = <521.50, inf) AND genergy = <58310, inf) AND goimpuls = <-71, inf) AND senergy = <650, inf) AND nbumps = <1.50, inf) AND nbumps2 = <0.50, inf) THEN class = {1}\n", "IF goenergy = (-inf, 97) AND gimpuls = <378, 2132) AND maxenergy = <2500, inf) AND genergy = <34880, 587745) AND goimpuls = (-inf, 95) AND senergy = <3150, 36050) AND nbumps3 = <0.50, inf) AND nbumps2 = <0.50, inf) AND nbumps = (-inf, 6.50) THEN class = {1}\n", "IF goenergy = (-inf, 135.50) AND gimpuls = <306, inf) AND genergy = <19245, inf) AND senergy = <550, inf) AND nbumps = <1.50, inf) THEN class = {1}\n", "IF goenergy = (-inf, -1.50) AND gimpuls = <153.50, 289) AND genergy = <17405, 37085) AND goimpuls = <-60.50, inf) AND senergy = (-inf, 40500) AND nbumps3 = (-inf, 3.50) AND nbumps = <1.50, inf) AND nbumps2 = <0.50, inf) THEN class = {1}\n", "IF goenergy = (-inf, 131.50) AND gimpuls = <1253.50, inf) AND genergy = <54930, 1062020) AND goimpuls = <-60.50, 109) AND shift = {W} AND senergy = (-inf, 36050) AND nbumps2 = (-inf, 2.50) THEN class = {1}\n", "IF gimpuls = <98.50, inf) AND senergy = <650, inf) AND nbumps2 = <0.50, inf) THEN class = {1}\n", "IF goenergy = <-78.50, inf) AND gimpuls = <66, inf) AND goimpuls = <-74.50, inf) AND genergy = <3065, inf) AND senergy = <550, inf) THEN class = {1}\n", "IF goenergy = (-inf, 176.50) AND gimpuls = <131, inf) AND genergy = <48545, inf) THEN class = {1}\n", "IF goenergy = <-4, inf) AND gimpuls = <396, 1445.50) AND genergy = <32795, 49585) AND goimpuls = <-19, inf) AND shift = {W} AND senergy = (-inf, 350) THEN class = {1}\n", "IF goenergy = <-37.50, inf) AND gimpuls = <537.50, 796) AND genergy = <16805, 32020) AND goimpuls = <-36.50, inf) AND senergy = (-inf, 250) THEN class = {1}\n", "IF goenergy = <-37.50, 181) AND gimpuls = <240, 470.50) AND genergy = <19670, 40735) AND goimpuls = <-42.50, inf) AND shift = {W} THEN class = {1}\n", "IF gimpuls = <54.50, inf) AND goimpuls = <-74.50, inf) AND genergy = <1510, inf) AND senergy = (-inf, 115450) THEN class = {1}\n" ] } ], "source": [ "for rule in ruleset.rules:\n", " print(rule)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare dataset" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from scipy.io import arff\n", "import pandas as pd\n", "\n", "data_df = pd.DataFrame(arff.loadarff(\"methane-train.arff\")[0])\n", "\n", "X = data_df.drop(['MM116_pred'], axis=1)\n", "y = data_df['MM116_pred']" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | MM31 | \n", "MM116 | \n", "AS038 | \n", "PG072 | \n", "PD | \n", "BA13 | \n", "DMM116 | \n", "
|---|---|---|---|---|---|---|---|
| 0 | \n", "0.46 | \n", "1.3 | \n", "2.4 | \n", "2.0 | \n", "1.0 | \n", "1076.0 | \n", "0.0 | \n", "
| 1 | \n", "0.46 | \n", "1.3 | \n", "2.2 | \n", "1.9 | \n", "1.0 | \n", "1076.0 | \n", "0.0 | \n", "
| 2 | \n", "0.49 | \n", "1.3 | \n", "2.2 | \n", "1.9 | \n", "1.0 | \n", "1076.0 | \n", "0.0 | \n", "
| 3 | \n", "0.50 | \n", "1.3 | \n", "2.3 | \n", "1.9 | \n", "1.0 | \n", "1076.0 | \n", "0.0 | \n", "
| 4 | \n", "0.54 | \n", "1.3 | \n", "2.3 | \n", "1.9 | \n", "1.0 | \n", "1076.0 | \n", "0.0 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 13363 | \n", "0.64 | \n", "1.2 | \n", "2.4 | \n", "1.8 | \n", "1.0 | \n", "1077.0 | \n", "0.0 | \n", "
| 13364 | \n", "0.59 | \n", "1.2 | \n", "2.4 | \n", "1.8 | \n", "1.0 | \n", "1077.0 | \n", "0.0 | \n", "
| 13365 | \n", "0.60 | \n", "1.1 | \n", "2.2 | \n", "1.8 | \n", "1.0 | \n", "1077.0 | \n", "-0.1 | \n", "
| 13366 | \n", "0.64 | \n", "1.1 | \n", "2.2 | \n", "1.8 | \n", "1.0 | \n", "1077.0 | \n", "0.0 | \n", "
| 13367 | \n", "0.65 | \n", "1.2 | \n", "2.2 | \n", "1.7 | \n", "0.0 | \n", "1077.0 | \n", "0.1 | \n", "
13368 rows × 7 columns
\n", "