Coverage for tests/test_serialization.py: 100%
101 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
1import os
2import shutil
3import sys
4import unittest
6dir_path = os.path.dirname(os.path.realpath(__file__))
7sys.path.append(f'{dir_path}/..')
9import io
10import pickle
11import time
13import sklearn.tree as scikit
14from sklearn import metrics
15from sklearn.datasets import load_diabetes, load_iris
17from rulekit import classification
18from rulekit.classification import ExpertRuleClassifier, RuleClassifier
19from rulekit.main import RuleKit
20from rulekit.regression import ExpertRuleRegressor, RuleRegressor
21from rulekit.survival import ExpertSurvivalRules, SurvivalRules
24class TestModelSerialization(unittest.TestCase):
26 TMP_DIR_PATH = f'{dir_path}/tmp'
27 PICKLE_FILE_PATH = f'{TMP_DIR_PATH}/model.pickle'
29 @classmethod
30 def setUpClass(cls):
31 if not os.path.exists(TestModelSerialization.TMP_DIR_PATH):
32 os.mkdir(TestModelSerialization.TMP_DIR_PATH)
33 RuleKit.init()
35 @classmethod
36 def tearDownClass(cls):
37 shutil.rmtree(TestModelSerialization.TMP_DIR_PATH)
39 def serialize_model(self, model):
40 with open(TestModelSerialization.PICKLE_FILE_PATH, 'wb') as handle:
41 pickle.dump(model, handle, protocol=pickle.HIGHEST_PROTOCOL)
43 def deserialize_model(self) -> object:
44 with open(TestModelSerialization.PICKLE_FILE_PATH, 'rb') as handle:
45 return pickle.load(handle)
47 def test_classifier_serialization(self):
48 x, y = load_iris(return_X_y=True)
50 model = RuleClassifier(minsupp_new=1)
51 model.fit(x, y)
52 prediction, metrics = model.predict(x, return_metrics=True)
54 self.serialize_model(model)
55 deserialized_model = self.deserialize_model()
56 deserialized_model_prediction, deserialized_model_metrics = deserialized_model.predict(x, return_metrics=True)
58 self.assertEqual(prediction.all(), deserialized_model_prediction.all(),
59 'Deserialized model should predict same as original one')
60 self.assertEqual(metrics, deserialized_model_metrics,
61 'Deserialized model should return the same prediction metrics as original one')
63 self.serialize_model(deserialized_model)
65 deserialized_model = self.deserialize_model()
66 deserialized_model_prediction, deserialized_model_metrics = deserialized_model.predict(x, return_metrics=True)
68 self.assertEqual(prediction.all(), deserialized_model_prediction.all(),
69 'Model deserialized multiple time should predict same as original one')
70 self.assertEqual(metrics, deserialized_model_metrics,
71 'Model deserialized multiple time should return the same prediction metrics as original one')
74 def test_expert_classifier_serialization(self):
75 x, y = load_iris(return_X_y=True)
77 model = ExpertRuleClassifier(minsupp_new=1)
78 model.fit(x, y)
79 prediction = model.predict(x)
81 self.serialize_model(model)
82 deserialized_model = self.deserialize_model()
83 deserialized_model_prediction = deserialized_model.predict(x)
85 self.assertEqual(prediction.all(), deserialized_model_prediction.all(),
86 'Deserialized model should predict same as original one')
88 def test_regressor_serialization(self):
89 x, y = load_diabetes(return_X_y=True)
91 model = RuleRegressor(minsupp_new=10)
92 model.fit(x, y)
93 prediction = model.predict(x)
95 self.serialize_model(model)
96 deserialized_model = self.deserialize_model()
97 deserialized_model_prediction = deserialized_model.predict(x)
99 self.assertEqual(prediction.all(), deserialized_model_prediction.all(),
100 'Deserialized model should predict same as original one')
103 def test_expert_regressor_serialization(self):
104 x, y = load_diabetes(return_X_y=True)
106 model = ExpertRuleRegressor(minsupp_new=10)
107 model.fit(x, y)
108 prediction = model.predict(x)
110 self.serialize_model(model)
111 deserialized_model = self.deserialize_model()
112 deserialized_model_prediction = deserialized_model.predict(x)
114 self.assertEqual(prediction.all(), deserialized_model_prediction.all(),
115 'Deserialized model should predict same as original one')
117 def test_survival_serialization(self):
118 x, y = load_iris(return_X_y=True)
120 model = SurvivalRules(minsupp_new=10, survival_time_attr='att1')
121 model.fit(x, y)
122 prediction = model.predict(x)
124 self.serialize_model(model)
125 deserialized_model = self.deserialize_model()
126 deserialized_model_prediction = deserialized_model.predict(x)
128 self.assertEqual(prediction.all(), deserialized_model_prediction.all(),
129 'Deserialized model should predict same as original one')
131 def test_expert_survival_serialization(self):
132 x, y = load_iris(return_X_y=True)
134 model = ExpertSurvivalRules(minsupp_new=10, survival_time_attr='att1')
135 model.fit(x, y)
136 prediction = model.predict(x)
138 self.serialize_model(model)
139 deserialized_model = self.deserialize_model()
140 deserialized_model_prediction = deserialized_model.predict(x)
142 self.assertEqual(prediction.all(), deserialized_model_prediction.all(),
143 'Deserialized model should predict same as original one')
145 def test_multiple_serialization(self):
146 x, y = load_iris(return_X_y=True)
148 model = RuleClassifier(minsupp_new=1)
149 model.fit(x, y)
150 prediction, metrics = model.predict(x, return_metrics=True)
152 self.serialize_model(model)
153 self.serialize_model(model)