Coverage for tests/utils.py: 90%
323 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
1import io
2import os
3import re
4from io import StringIO
5from typing import Union
6from xml.etree import ElementTree
8import pandas as pd
9from sklearn import metrics
11from rulekit._helpers import ExampleSetFactory
12from rulekit._problem_types import ProblemType
13from rulekit.arff import read_arff
15dir_path = os.path.dirname(os.path.realpath(__file__))
17TEST_CONFIG_PATH = f"{dir_path}/resources/config"
18REPORTS_IN_DIRECTORY_PATH = f"{dir_path}/resources/reports"
19DATA_IN_DIRECTORY_PATH = f"{dir_path}/resources/"
20REPORTS_OUT_DIRECTORY_PATH = f"{dir_path}/test_out"
22REPORTS_SECTIONS_HEADERS = {"RULES": "RULES"}
24DATASETS_PATH = "/resources/data"
25EXPERIMENTS_PATH = "/resources/config"
28def _fix_missing_values(column) -> None:
29 for i, value in enumerate(column.values):
30 if value == b"?":
31 column.values[i] = None
34class ExampleSetWrapper:
36 def __init__(self, values, labels, problem_type: ProblemType):
37 self.values = values
38 self.labels = labels
39 self.example_set = ExampleSetFactory(problem_type).make(values, labels)
41 def get_data(self) -> tuple:
42 return self.values, self.labels
45def load_arff_to_example_set(
46 path: str, label_attribute: str, problem_type: ProblemType
47) -> ExampleSetWrapper:
48 with open(path, "r") as file:
49 content = file.read().replace('"', "'")
50 arff_file = io.StringIO(content)
51 arff_data_frame = read_arff(arff_file)
53 attributes_names = []
54 for column_name in arff_data_frame.columns:
55 if column_name != label_attribute:
56 attributes_names.append(column_name)
58 values = arff_data_frame[attributes_names]
59 labels = arff_data_frame[label_attribute]
61 for column in values:
62 _fix_missing_values(values[column])
63 _fix_missing_values(labels)
64 return ExampleSetWrapper(values, labels, problem_type)
67def get_dataset_path(name: str) -> str:
68 return f"{DATASETS_PATH}/{name}"
71class Knowledge:
73 def __init__(self):
74 self.expert_rules = []
75 self.expert_preferred_conditions = []
76 self.expert_forbidden_conditions = []
79class TestReport:
81 def __init__(self, file_name: str):
82 self.file_name = file_name
83 self.rules = None
86class TestCase:
88 def __init__(self, problem_type: ProblemType):
89 self.param_config: dict[str, object] = None
90 self._reference_report: TestReport = None
91 self._example_set: ExampleSetWrapper = None
92 self.induction_params: dict = None
93 self.knowledge: Knowledge = None
94 self.problem_type: ProblemType = problem_type
96 self.name: str = None
97 self.data_set_file_path: str = None
98 self.label_attribute: str = None
99 self.survival_time: str = None
100 self.report_file_path: str = None
101 self.using_existing_report_file: bool = False
103 @property
104 def example_set(self) -> ExampleSetWrapper:
105 if self._example_set is None:
106 self._example_set = load_arff_to_example_set(
107 self.data_set_file_path, self.label_attribute, self.problem_type
108 )
109 return self._example_set
111 @property
112 def reference_report(self) -> TestReport:
113 if self._reference_report is None:
114 reader = TestReportReader(f"{self.report_file_path}.txt")
115 self._reference_report = reader.read()
116 reader.close()
117 return self._reference_report
120class DataSetConfig:
122 def __init__(self):
123 self.name: str = None
124 self.label_attribute: str = None
125 self.train_file_name: str = None
126 self.survival_time: str = None
129class TestConfig:
131 def __init__(self):
132 self.name: str = None
133 self.parameter_configs: dict[str, dict[str, object]] = None
134 self.datasets: list[DataSetConfig] = None
137class TestConfigParser:
138 TEST_KEY = "test"
139 NAME_KEY = "name"
140 IN_FILE_KEY = "in_file"
141 TRAINING_KEY = "training"
142 TRAIN_KEY = "train"
143 LABEL_KEY = "label"
144 DATASET_KEY = "dataset"
145 DATASETS_KEY = "datasets"
146 PARAM_KEY = "param"
147 PARAMETERS_SET_KEY = "parameter_sets"
148 PARAMETERS_KEY = "parameter_set"
149 ENTRY_KEY = "entry"
150 SURVIVAL_TIME_ROLE = "survival_time"
152 EXPERTS_RULES_PARAMETERS_NAMES = [
153 "expert_rules",
154 "expert_preferred_conditions",
155 "expert_forbidden_conditions",
156 ]
158 def __init__(self):
159 self.tests_configs: dict[str, TestConfig] = {}
160 self.root: ElementTree = None
162 def _parse_survival_time(self, element) -> Union[str, None]:
163 survival_time_element = element.find(TestConfigParser.SURVIVAL_TIME_ROLE)
164 if element.find(TestConfigParser.SURVIVAL_TIME_ROLE) is not None:
165 return survival_time_element.text
166 else:
167 return None
169 def _parse_experts_rules_parameters(self, elements) -> list[tuple]:
170 expert_rules = []
171 for element in elements:
172 rule_name: str = element.attrib["name"]
173 rule_content: str = element.text
174 # RuleKit originally used XML for specifying parameters and uses special xml characters
175 rule_content = rule_content.replace("<", "<").replace(">", ">")
176 expert_rules.append((rule_name, rule_content))
177 return expert_rules if len(expert_rules) > 0 else None
179 def _check_ambigous_data_sets_names(self, data_sets_configs: list[DataSetConfig]):
180 dictionary = {}
181 for element in data_sets_configs:
182 dictionary[element.name] = None
183 if len(dictionary.keys()) < len(data_sets_configs):
184 raise ValueError("Datasets are ambigous")
186 def _parse_data_set(self, element) -> DataSetConfig:
187 data_set_config = DataSetConfig()
188 data_set_config.label_attribute = element.find(TestConfigParser.LABEL_KEY).text
189 train_element = element.find(TestConfigParser.TRAINING_KEY)
190 train_element = train_element.find(TestConfigParser.TRAIN_KEY)
191 data_set_config.train_file_name = train_element.find(
192 TestConfigParser.IN_FILE_KEY
193 ).text
194 data_set_config.name = element.attrib.get("name", None)
195 if data_set_config.name is None:
196 file_name = os.path.basename(data_set_config.train_file_name)
197 data_set_config.name = file_name.split(".")[0]
198 data_set_config.survival_time = self._parse_survival_time(element)
199 return data_set_config
201 def _parse_data_sets(self, element) -> list[DataSetConfig]:
202 data_set_configs = []
203 node = element.find(TestConfigParser.DATASETS_KEY)
204 for element in node.findall(TestConfigParser.DATASET_KEY):
205 data_set_configs.append(self._parse_data_set(element))
206 return data_set_configs
208 def parse_test_parameters(self, element) -> dict[str, object]:
209 params = {}
210 for param_node in element.findall(TestConfigParser.PARAM_KEY):
211 name: str = param_node.attrib["name"]
212 if name in TestConfigParser.EXPERTS_RULES_PARAMETERS_NAMES:
213 value = self._parse_experts_rules_parameters(
214 param_node.findall(TestConfigParser.ENTRY_KEY)
215 )
216 else:
217 value = param_node.text
218 params[name] = value
219 return params
221 def _parse_test_parameters_sets(self, element) -> dict[str, dict[str, object]]:
222 parameters_sets = {}
223 params_sets_node = element.findall(TestConfigParser.PARAMETERS_SET_KEY)[0]
224 for param_set in params_sets_node.findall(TestConfigParser.PARAMETERS_KEY):
225 name: str = param_set.attrib["name"]
226 parameters_sets[name] = self.parse_test_parameters(param_set)
227 return parameters_sets
229 def _parse_test(self, element) -> TestConfig:
230 test_config = TestConfig()
231 test_config.parameter_configs = self._parse_test_parameters_sets(element)
232 test_config.datasets = self._parse_data_sets(element)
233 test_config.name = element.attrib["name"]
234 return test_config
236 def parse(self, file_path: str) -> dict[str, TestConfig]:
237 self.tests_configs = {}
238 self.root = ElementTree.parse(file_path).getroot()
239 if self.root.tag == "test":
240 test_elements = [self.root]
241 else:
242 test_elements = self.root.findall(TestConfigParser.TEST_KEY)
243 for test_element in test_elements:
244 test_config = self._parse_test(test_element)
245 self.tests_configs[test_config.name] = test_config
246 return self.tests_configs
249class TestCaseFactory:
251 def _make_test_case(
252 self,
253 test_case_name: str,
254 params: dict[str, object],
255 data_set_config: DataSetConfig,
256 problem_type: ProblemType,
257 ) -> TestCase:
258 test_case = TestCase(problem_type)
259 self._fix_params_typing(params)
260 self._fix_deprecated_params(params)
261 test_case.induction_params = params
262 test_case.data_set_file_path = (
263 f"{DATA_IN_DIRECTORY_PATH}/" f"{data_set_config.train_file_name}"
264 )
265 test_case.label_attribute = data_set_config.label_attribute
266 test_case.name = test_case_name
267 test_case.param_config = params
268 return test_case
270 def _fix_deprecated_params(self, params: dict[str, object]):
271 deprecated_minsupp_new_name = "min_rule_covered"
272 if deprecated_minsupp_new_name in params:
273 params["minsupp_new"] = params.pop(deprecated_minsupp_new_name)
275 def _fix_params_typing(self, params: dict):
276 for key, value in params.items():
277 if value == "false":
278 params[key] = False
279 continue
280 if value == "true":
281 params[key] = True
282 continue
283 if not "measure" in key:
284 params[key] = int(float(value))
286 def make(
287 self,
288 tests_configs: dict[str, TestConfig],
289 report_dir_path: str,
290 problem_type: ProblemType,
291 ) -> list[TestCase]:
292 test_cases = []
293 for key in tests_configs.keys():
294 test_config = tests_configs[key]
295 for config_name in test_config.parameter_configs.keys():
296 for data_set_config in test_config.datasets:
297 params = test_config.parameter_configs[config_name]
298 test_case_name = f"{key}.{config_name}.{data_set_config.name}"
299 test_config.parameter_configs[config_name].pop("use_expert", None)
300 expert_rules = test_config.parameter_configs[config_name].pop(
301 "expert_rules", None
302 )
303 preferred_conditions = test_config.parameter_configs[
304 config_name
305 ].pop("expert_preferred_conditions", None)
306 forbidden_conditions = test_config.parameter_configs[
307 config_name
308 ].pop("expert_forbidden_conditions", None)
309 test_case = self._make_test_case(
310 test_case_name,
311 test_config.parameter_configs[config_name],
312 data_set_config,
313 problem_type,
314 )
315 if "use_report" in params:
316 report_file_name = params["use_report"]
317 test_case.using_existing_report_file = True
318 else:
319 report_file_name = test_case_name
320 report_path = f"{report_dir_path}/{report_file_name}"
321 test_case.report_file_path = report_path
322 test_case.survival_time = data_set_config.survival_time
323 if (
324 expert_rules is not None
325 or preferred_conditions is not None
326 or forbidden_conditions is not None
327 ):
328 test_case.knowledge = Knowledge()
329 if expert_rules is not None:
330 test_case.knowledge.expert_rules = expert_rules
331 if forbidden_conditions is not None:
332 test_case.knowledge.expert_forbidden_conditions = (
333 forbidden_conditions
334 )
335 if preferred_conditions is not None:
336 test_case.knowledge.expert_preferred_conditions = (
337 preferred_conditions
338 )
339 test_cases.append(test_case)
340 return test_cases
343def get_rule_string(rule) -> str:
344 return re.sub(r"(\\[[^\\]]*\\]$)|(\\([^\\)]*\\)$)", "", str(rule))
347class TestReportReader:
349 def __init__(self, file_name: str):
350 self.file_name = file_name
351 self._file = open(file_name, encoding="utf-8", mode="r")
353 def _read_rules(self, test_report: TestReport):
354 rules = []
355 for line in self._file:
356 if len(line) == 0:
357 break
358 else:
359 rules.append(line)
360 test_report.rules = rules
362 def read(self) -> TestReport:
363 test_report = TestReport(self.file_name)
364 for line in self._file:
365 line = line.upper()
366 line = re.sub(r"\t", "", line)
367 line = line.replace("\n", "")
368 if line == REPORTS_SECTIONS_HEADERS["RULES"]:
369 self._read_rules(test_report)
370 elif line == "":
371 continue
372 else:
373 raise ValueError(
374 f"Invalid report file format for file: {self.file_name}"
375 )
376 return test_report
378 def close(self):
379 self._file.close()
382class TestReportWriter:
384 def __init__(self, file_name: str):
385 self._file = open(file_name, encoding="utf-8", mode="w")
386 if not os.path.exists(REPORTS_OUT_DIRECTORY_PATH):
387 os.makedirs(REPORTS_OUT_DIRECTORY_PATH)
389 def write(self, rule_set):
390 self._file.write("\n")
391 self._file.write(f'{REPORTS_SECTIONS_HEADERS["RULES"]}\n')
392 for rule in rule_set.rules:
393 self._file.write(f"\t{get_rule_string(rule)}")
395 def close(self):
396 self._file.close()
399def get_test_cases(class_name: str) -> list[TestCase]:
400 if not os.path.exists(DATA_IN_DIRECTORY_PATH):
401 raise RuntimeError(
402 """\n
403Test resources directory dosen't exist. Check if 'tests/resources/' directory exist.
405If you're running tests for the first time you need to download resources folder from RuleKit repository by running:
406 python tests/resources.py download
407 """
408 )
409 problem_type: ProblemType = _get_problem_type_from_test_case_class_name(class_name)
410 configs = TestConfigParser().parse(f"{TEST_CONFIG_PATH}/{class_name}.xml")
411 return TestCaseFactory().make(
412 configs, f"{REPORTS_IN_DIRECTORY_PATH}/{class_name}/", problem_type
413 )
416def _get_problem_type_from_test_case_class_name(class_name: str) -> ProblemType:
417 class_name = class_name.lower()
418 if "regression" in class_name:
419 return ProblemType.REGRESSION
420 elif "survival" in class_name:
421 return ProblemType.SURVIVAL
422 elif "classification" in class_name:
423 return ProblemType.CLASSIFICATION
424 raise Exception(f"Unknown problem type for test case class name: {class_name}")
427def assert_rules_are_equals(expected: list[str], actual: list[str]):
428 def sanitize_rule_string(rule_string: str) -> str:
429 return re.sub(r"(\t)|(\n)|(\[[^\]]*\]$)", "", rule_string)
431 expected = list(map(sanitize_rule_string, expected))
432 actual = list(map(sanitize_rule_string, actual))
434 if len(expected) != len(actual):
435 raise AssertionError(
436 "Rulesets have different number of rules, actual: "
437 f'{len(actual)}, expected: {len(expected)}'
438 )
439 dictionary = {}
440 for rule in expected:
441 dictionary[rule] = 0
442 for rule in actual:
443 key = rule
444 if key in dictionary:
445 dictionary[key] = dictionary[key] + 1
446 else:
447 raise AssertionError(
448 "Actual ruleset contains rules not present in expected ruleset"
449 )
450 for value in dictionary.values():
451 if value == 0:
452 raise AssertionError("Ruleset are not equal, some rules are missing")
453 elif value > 1:
454 raise AssertionError("Somes rules were duplicated")
457def assert_accuracy_is_greater(prediction, expected, threshold: float):
458 labels = expected.to_numpy().astype(str)
459 acc = metrics.accuracy_score(labels, prediction)
460 if acc <= threshold:
461 raise AssertionError(f"Accuracy should be greater than {threshold} (was {acc})")
464def assert_score_is_greater(prediction, expected, threshold: float):
465 if isinstance(prediction[0], int):
466 labels = expected.to_numpy().astype(int)
467 elif isinstance(prediction[0], float):
468 labels = expected.to_numpy().astype(float)
469 else:
470 raise ValueError(
471 f"Invalid prediction type: {str(type(prediction[0]))}. "
472 + "Supported types are: 1 dimensional numpy array or pandas Series object."
473 )
474 explained_variance_score = metrics.explained_variance_score(labels, prediction)
476 if explained_variance_score <= threshold:
477 raise AssertionError(f"Score should be greater than {threshold}")