Coverage for tests/test_sklearn_metrics.py: 96%

25 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-07 11:26 +0000

1import unittest 

2 

3import sklearn.tree as scikit 

4from sklearn import metrics 

5from sklearn.datasets import load_iris 

6 

7from rulekit import classification 

8from rulekit.main import RuleKit 

9 

10 

11class TestMetrics(unittest.TestCase): 

12 

13 @classmethod 

14 def setUpClass(cls): 

15 RuleKit.init() 

16 

17 def test_classification_accuracy_on_iris(self): 

18 scikit_clf = scikit.DecisionTreeClassifier() 

19 rulekit_clf = classification.RuleClassifier() 

20 x, y = load_iris(return_X_y=True) 

21 

22 scikit_clf.fit(x, y) 

23 rulekit_clf.fit(x, y) 

24 scikit_prediction = scikit_clf.predict(x) 

25 rulekit_prediction = rulekit_clf.predict(x) 

26 

27 scikit_accuracy = metrics.accuracy_score(y, scikit_prediction) 

28 rulekit_accuracy = metrics.accuracy_score(y, rulekit_prediction) 

29 

30 assert abs(scikit_accuracy - 

31 rulekit_accuracy) < 0.04, 'RuleKit model should perform similar to scikit model' 

32 confusion_matrix = metrics.confusion_matrix(y, rulekit_prediction) 

33 self.assertIsNotNone( 

34 confusion_matrix, 'Confusion matrix should be calculated') 

35 

36 

37if __name__ == '__main__': 

38 unittest.main()