Coverage for tests/test_serialization.py: 100%

101 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-07 11:26 +0000

1import os 

2import shutil 

3import sys 

4import unittest 

5 

6dir_path = os.path.dirname(os.path.realpath(__file__)) 

7sys.path.append(f'{dir_path}/..') 

8 

9import io 

10import pickle 

11import time 

12 

13import sklearn.tree as scikit 

14from sklearn import metrics 

15from sklearn.datasets import load_diabetes, load_iris 

16 

17from rulekit import classification 

18from rulekit.classification import ExpertRuleClassifier, RuleClassifier 

19from rulekit.main import RuleKit 

20from rulekit.regression import ExpertRuleRegressor, RuleRegressor 

21from rulekit.survival import ExpertSurvivalRules, SurvivalRules 

22 

23 

24class TestModelSerialization(unittest.TestCase): 

25 

26 TMP_DIR_PATH = f'{dir_path}/tmp' 

27 PICKLE_FILE_PATH = f'{TMP_DIR_PATH}/model.pickle' 

28 

29 @classmethod 

30 def setUpClass(cls): 

31 if not os.path.exists(TestModelSerialization.TMP_DIR_PATH): 

32 os.mkdir(TestModelSerialization.TMP_DIR_PATH) 

33 RuleKit.init() 

34 

35 @classmethod 

36 def tearDownClass(cls): 

37 shutil.rmtree(TestModelSerialization.TMP_DIR_PATH) 

38 

39 def serialize_model(self, model): 

40 with open(TestModelSerialization.PICKLE_FILE_PATH, 'wb') as handle: 

41 pickle.dump(model, handle, protocol=pickle.HIGHEST_PROTOCOL) 

42 

43 def deserialize_model(self) -> object: 

44 with open(TestModelSerialization.PICKLE_FILE_PATH, 'rb') as handle: 

45 return pickle.load(handle) 

46 

47 def test_classifier_serialization(self): 

48 x, y = load_iris(return_X_y=True) 

49 

50 model = RuleClassifier(minsupp_new=1) 

51 model.fit(x, y) 

52 prediction, metrics = model.predict(x, return_metrics=True) 

53 

54 self.serialize_model(model) 

55 deserialized_model = self.deserialize_model() 

56 deserialized_model_prediction, deserialized_model_metrics = deserialized_model.predict(x, return_metrics=True) 

57 

58 self.assertEqual(prediction.all(), deserialized_model_prediction.all(), 

59 'Deserialized model should predict same as original one') 

60 self.assertEqual(metrics, deserialized_model_metrics, 

61 'Deserialized model should return the same prediction metrics as original one') 

62 

63 self.serialize_model(deserialized_model) 

64 

65 deserialized_model = self.deserialize_model() 

66 deserialized_model_prediction, deserialized_model_metrics = deserialized_model.predict(x, return_metrics=True) 

67 

68 self.assertEqual(prediction.all(), deserialized_model_prediction.all(), 

69 'Model deserialized multiple time should predict same as original one') 

70 self.assertEqual(metrics, deserialized_model_metrics, 

71 'Model deserialized multiple time should return the same prediction metrics as original one') 

72 

73 

74 def test_expert_classifier_serialization(self): 

75 x, y = load_iris(return_X_y=True) 

76 

77 model = ExpertRuleClassifier(minsupp_new=1) 

78 model.fit(x, y) 

79 prediction = model.predict(x) 

80 

81 self.serialize_model(model) 

82 deserialized_model = self.deserialize_model() 

83 deserialized_model_prediction = deserialized_model.predict(x) 

84 

85 self.assertEqual(prediction.all(), deserialized_model_prediction.all(), 

86 'Deserialized model should predict same as original one') 

87 

88 def test_regressor_serialization(self): 

89 x, y = load_diabetes(return_X_y=True) 

90 

91 model = RuleRegressor(minsupp_new=10) 

92 model.fit(x, y) 

93 prediction = model.predict(x) 

94 

95 self.serialize_model(model) 

96 deserialized_model = self.deserialize_model() 

97 deserialized_model_prediction = deserialized_model.predict(x) 

98 

99 self.assertEqual(prediction.all(), deserialized_model_prediction.all(), 

100 'Deserialized model should predict same as original one') 

101 

102 

103 def test_expert_regressor_serialization(self): 

104 x, y = load_diabetes(return_X_y=True) 

105 

106 model = ExpertRuleRegressor(minsupp_new=10) 

107 model.fit(x, y) 

108 prediction = model.predict(x) 

109 

110 self.serialize_model(model) 

111 deserialized_model = self.deserialize_model() 

112 deserialized_model_prediction = deserialized_model.predict(x) 

113 

114 self.assertEqual(prediction.all(), deserialized_model_prediction.all(), 

115 'Deserialized model should predict same as original one') 

116 

117 def test_survival_serialization(self): 

118 x, y = load_iris(return_X_y=True) 

119 

120 model = SurvivalRules(minsupp_new=10, survival_time_attr='att1') 

121 model.fit(x, y) 

122 prediction = model.predict(x) 

123 

124 self.serialize_model(model) 

125 deserialized_model = self.deserialize_model() 

126 deserialized_model_prediction = deserialized_model.predict(x) 

127 

128 self.assertEqual(prediction.all(), deserialized_model_prediction.all(), 

129 'Deserialized model should predict same as original one') 

130 

131 def test_expert_survival_serialization(self): 

132 x, y = load_iris(return_X_y=True) 

133 

134 model = ExpertSurvivalRules(minsupp_new=10, survival_time_attr='att1') 

135 model.fit(x, y) 

136 prediction = model.predict(x) 

137 

138 self.serialize_model(model) 

139 deserialized_model = self.deserialize_model() 

140 deserialized_model_prediction = deserialized_model.predict(x) 

141 

142 self.assertEqual(prediction.all(), deserialized_model_prediction.all(), 

143 'Deserialized model should predict same as original one') 

144 

145 def test_multiple_serialization(self): 

146 x, y = load_iris(return_X_y=True) 

147 

148 model = RuleClassifier(minsupp_new=1) 

149 model.fit(x, y) 

150 prediction, metrics = model.predict(x, return_metrics=True) 

151 

152 self.serialize_model(model) 

153 self.serialize_model(model)