Coverage for tests/utils.py: 90%

323 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-07 11:26 +0000

1import io 

2import os 

3import re 

4from io import StringIO 

5from typing import Union 

6from xml.etree import ElementTree 

7 

8import pandas as pd 

9from sklearn import metrics 

10 

11from rulekit._helpers import ExampleSetFactory 

12from rulekit._problem_types import ProblemType 

13from rulekit.arff import read_arff 

14 

15dir_path = os.path.dirname(os.path.realpath(__file__)) 

16 

17TEST_CONFIG_PATH = f"{dir_path}/resources/config" 

18REPORTS_IN_DIRECTORY_PATH = f"{dir_path}/resources/reports" 

19DATA_IN_DIRECTORY_PATH = f"{dir_path}/resources/" 

20REPORTS_OUT_DIRECTORY_PATH = f"{dir_path}/test_out" 

21 

22REPORTS_SECTIONS_HEADERS = {"RULES": "RULES"} 

23 

24DATASETS_PATH = "/resources/data" 

25EXPERIMENTS_PATH = "/resources/config" 

26 

27 

28def _fix_missing_values(column) -> None: 

29 for i, value in enumerate(column.values): 

30 if value == b"?": 

31 column.values[i] = None 

32 

33 

34class ExampleSetWrapper: 

35 

36 def __init__(self, values, labels, problem_type: ProblemType): 

37 self.values = values 

38 self.labels = labels 

39 self.example_set = ExampleSetFactory(problem_type).make(values, labels) 

40 

41 def get_data(self) -> tuple: 

42 return self.values, self.labels 

43 

44 

45def load_arff_to_example_set( 

46 path: str, label_attribute: str, problem_type: ProblemType 

47) -> ExampleSetWrapper: 

48 with open(path, "r") as file: 

49 content = file.read().replace('"', "'") 

50 arff_file = io.StringIO(content) 

51 arff_data_frame = read_arff(arff_file) 

52 

53 attributes_names = [] 

54 for column_name in arff_data_frame.columns: 

55 if column_name != label_attribute: 

56 attributes_names.append(column_name) 

57 

58 values = arff_data_frame[attributes_names] 

59 labels = arff_data_frame[label_attribute] 

60 

61 for column in values: 

62 _fix_missing_values(values[column]) 

63 _fix_missing_values(labels) 

64 return ExampleSetWrapper(values, labels, problem_type) 

65 

66 

67def get_dataset_path(name: str) -> str: 

68 return f"{DATASETS_PATH}/{name}" 

69 

70 

71class Knowledge: 

72 

73 def __init__(self): 

74 self.expert_rules = [] 

75 self.expert_preferred_conditions = [] 

76 self.expert_forbidden_conditions = [] 

77 

78 

79class TestReport: 

80 

81 def __init__(self, file_name: str): 

82 self.file_name = file_name 

83 self.rules = None 

84 

85 

86class TestCase: 

87 

88 def __init__(self, problem_type: ProblemType): 

89 self.param_config: dict[str, object] = None 

90 self._reference_report: TestReport = None 

91 self._example_set: ExampleSetWrapper = None 

92 self.induction_params: dict = None 

93 self.knowledge: Knowledge = None 

94 self.problem_type: ProblemType = problem_type 

95 

96 self.name: str = None 

97 self.data_set_file_path: str = None 

98 self.label_attribute: str = None 

99 self.survival_time: str = None 

100 self.report_file_path: str = None 

101 self.using_existing_report_file: bool = False 

102 

103 @property 

104 def example_set(self) -> ExampleSetWrapper: 

105 if self._example_set is None: 

106 self._example_set = load_arff_to_example_set( 

107 self.data_set_file_path, self.label_attribute, self.problem_type 

108 ) 

109 return self._example_set 

110 

111 @property 

112 def reference_report(self) -> TestReport: 

113 if self._reference_report is None: 

114 reader = TestReportReader(f"{self.report_file_path}.txt") 

115 self._reference_report = reader.read() 

116 reader.close() 

117 return self._reference_report 

118 

119 

120class DataSetConfig: 

121 

122 def __init__(self): 

123 self.name: str = None 

124 self.label_attribute: str = None 

125 self.train_file_name: str = None 

126 self.survival_time: str = None 

127 

128 

129class TestConfig: 

130 

131 def __init__(self): 

132 self.name: str = None 

133 self.parameter_configs: dict[str, dict[str, object]] = None 

134 self.datasets: list[DataSetConfig] = None 

135 

136 

137class TestConfigParser: 

138 TEST_KEY = "test" 

139 NAME_KEY = "name" 

140 IN_FILE_KEY = "in_file" 

141 TRAINING_KEY = "training" 

142 TRAIN_KEY = "train" 

143 LABEL_KEY = "label" 

144 DATASET_KEY = "dataset" 

145 DATASETS_KEY = "datasets" 

146 PARAM_KEY = "param" 

147 PARAMETERS_SET_KEY = "parameter_sets" 

148 PARAMETERS_KEY = "parameter_set" 

149 ENTRY_KEY = "entry" 

150 SURVIVAL_TIME_ROLE = "survival_time" 

151 

152 EXPERTS_RULES_PARAMETERS_NAMES = [ 

153 "expert_rules", 

154 "expert_preferred_conditions", 

155 "expert_forbidden_conditions", 

156 ] 

157 

158 def __init__(self): 

159 self.tests_configs: dict[str, TestConfig] = {} 

160 self.root: ElementTree = None 

161 

162 def _parse_survival_time(self, element) -> Union[str, None]: 

163 survival_time_element = element.find(TestConfigParser.SURVIVAL_TIME_ROLE) 

164 if element.find(TestConfigParser.SURVIVAL_TIME_ROLE) is not None: 

165 return survival_time_element.text 

166 else: 

167 return None 

168 

169 def _parse_experts_rules_parameters(self, elements) -> list[tuple]: 

170 expert_rules = [] 

171 for element in elements: 

172 rule_name: str = element.attrib["name"] 

173 rule_content: str = element.text 

174 # RuleKit originally used XML for specifying parameters and uses special xml characters 

175 rule_content = rule_content.replace("&lt;", "<").replace("&gt;", ">") 

176 expert_rules.append((rule_name, rule_content)) 

177 return expert_rules if len(expert_rules) > 0 else None 

178 

179 def _check_ambigous_data_sets_names(self, data_sets_configs: list[DataSetConfig]): 

180 dictionary = {} 

181 for element in data_sets_configs: 

182 dictionary[element.name] = None 

183 if len(dictionary.keys()) < len(data_sets_configs): 

184 raise ValueError("Datasets are ambigous") 

185 

186 def _parse_data_set(self, element) -> DataSetConfig: 

187 data_set_config = DataSetConfig() 

188 data_set_config.label_attribute = element.find(TestConfigParser.LABEL_KEY).text 

189 train_element = element.find(TestConfigParser.TRAINING_KEY) 

190 train_element = train_element.find(TestConfigParser.TRAIN_KEY) 

191 data_set_config.train_file_name = train_element.find( 

192 TestConfigParser.IN_FILE_KEY 

193 ).text 

194 data_set_config.name = element.attrib.get("name", None) 

195 if data_set_config.name is None: 

196 file_name = os.path.basename(data_set_config.train_file_name) 

197 data_set_config.name = file_name.split(".")[0] 

198 data_set_config.survival_time = self._parse_survival_time(element) 

199 return data_set_config 

200 

201 def _parse_data_sets(self, element) -> list[DataSetConfig]: 

202 data_set_configs = [] 

203 node = element.find(TestConfigParser.DATASETS_KEY) 

204 for element in node.findall(TestConfigParser.DATASET_KEY): 

205 data_set_configs.append(self._parse_data_set(element)) 

206 return data_set_configs 

207 

208 def parse_test_parameters(self, element) -> dict[str, object]: 

209 params = {} 

210 for param_node in element.findall(TestConfigParser.PARAM_KEY): 

211 name: str = param_node.attrib["name"] 

212 if name in TestConfigParser.EXPERTS_RULES_PARAMETERS_NAMES: 

213 value = self._parse_experts_rules_parameters( 

214 param_node.findall(TestConfigParser.ENTRY_KEY) 

215 ) 

216 else: 

217 value = param_node.text 

218 params[name] = value 

219 return params 

220 

221 def _parse_test_parameters_sets(self, element) -> dict[str, dict[str, object]]: 

222 parameters_sets = {} 

223 params_sets_node = element.findall(TestConfigParser.PARAMETERS_SET_KEY)[0] 

224 for param_set in params_sets_node.findall(TestConfigParser.PARAMETERS_KEY): 

225 name: str = param_set.attrib["name"] 

226 parameters_sets[name] = self.parse_test_parameters(param_set) 

227 return parameters_sets 

228 

229 def _parse_test(self, element) -> TestConfig: 

230 test_config = TestConfig() 

231 test_config.parameter_configs = self._parse_test_parameters_sets(element) 

232 test_config.datasets = self._parse_data_sets(element) 

233 test_config.name = element.attrib["name"] 

234 return test_config 

235 

236 def parse(self, file_path: str) -> dict[str, TestConfig]: 

237 self.tests_configs = {} 

238 self.root = ElementTree.parse(file_path).getroot() 

239 if self.root.tag == "test": 

240 test_elements = [self.root] 

241 else: 

242 test_elements = self.root.findall(TestConfigParser.TEST_KEY) 

243 for test_element in test_elements: 

244 test_config = self._parse_test(test_element) 

245 self.tests_configs[test_config.name] = test_config 

246 return self.tests_configs 

247 

248 

249class TestCaseFactory: 

250 

251 def _make_test_case( 

252 self, 

253 test_case_name: str, 

254 params: dict[str, object], 

255 data_set_config: DataSetConfig, 

256 problem_type: ProblemType, 

257 ) -> TestCase: 

258 test_case = TestCase(problem_type) 

259 self._fix_params_typing(params) 

260 self._fix_deprecated_params(params) 

261 test_case.induction_params = params 

262 test_case.data_set_file_path = ( 

263 f"{DATA_IN_DIRECTORY_PATH}/" f"{data_set_config.train_file_name}" 

264 ) 

265 test_case.label_attribute = data_set_config.label_attribute 

266 test_case.name = test_case_name 

267 test_case.param_config = params 

268 return test_case 

269 

270 def _fix_deprecated_params(self, params: dict[str, object]): 

271 deprecated_minsupp_new_name = "min_rule_covered" 

272 if deprecated_minsupp_new_name in params: 

273 params["minsupp_new"] = params.pop(deprecated_minsupp_new_name) 

274 

275 def _fix_params_typing(self, params: dict): 

276 for key, value in params.items(): 

277 if value == "false": 

278 params[key] = False 

279 continue 

280 if value == "true": 

281 params[key] = True 

282 continue 

283 if not "measure" in key: 

284 params[key] = int(float(value)) 

285 

286 def make( 

287 self, 

288 tests_configs: dict[str, TestConfig], 

289 report_dir_path: str, 

290 problem_type: ProblemType, 

291 ) -> list[TestCase]: 

292 test_cases = [] 

293 for key in tests_configs.keys(): 

294 test_config = tests_configs[key] 

295 for config_name in test_config.parameter_configs.keys(): 

296 for data_set_config in test_config.datasets: 

297 params = test_config.parameter_configs[config_name] 

298 test_case_name = f"{key}.{config_name}.{data_set_config.name}" 

299 test_config.parameter_configs[config_name].pop("use_expert", None) 

300 expert_rules = test_config.parameter_configs[config_name].pop( 

301 "expert_rules", None 

302 ) 

303 preferred_conditions = test_config.parameter_configs[ 

304 config_name 

305 ].pop("expert_preferred_conditions", None) 

306 forbidden_conditions = test_config.parameter_configs[ 

307 config_name 

308 ].pop("expert_forbidden_conditions", None) 

309 test_case = self._make_test_case( 

310 test_case_name, 

311 test_config.parameter_configs[config_name], 

312 data_set_config, 

313 problem_type, 

314 ) 

315 if "use_report" in params: 

316 report_file_name = params["use_report"] 

317 test_case.using_existing_report_file = True 

318 else: 

319 report_file_name = test_case_name 

320 report_path = f"{report_dir_path}/{report_file_name}" 

321 test_case.report_file_path = report_path 

322 test_case.survival_time = data_set_config.survival_time 

323 if ( 

324 expert_rules is not None 

325 or preferred_conditions is not None 

326 or forbidden_conditions is not None 

327 ): 

328 test_case.knowledge = Knowledge() 

329 if expert_rules is not None: 

330 test_case.knowledge.expert_rules = expert_rules 

331 if forbidden_conditions is not None: 

332 test_case.knowledge.expert_forbidden_conditions = ( 

333 forbidden_conditions 

334 ) 

335 if preferred_conditions is not None: 

336 test_case.knowledge.expert_preferred_conditions = ( 

337 preferred_conditions 

338 ) 

339 test_cases.append(test_case) 

340 return test_cases 

341 

342 

343def get_rule_string(rule) -> str: 

344 return re.sub(r"(\\[[^\\]]*\\]$)|(\\([^\\)]*\\)$)", "", str(rule)) 

345 

346 

347class TestReportReader: 

348 

349 def __init__(self, file_name: str): 

350 self.file_name = file_name 

351 self._file = open(file_name, encoding="utf-8", mode="r") 

352 

353 def _read_rules(self, test_report: TestReport): 

354 rules = [] 

355 for line in self._file: 

356 if len(line) == 0: 

357 break 

358 else: 

359 rules.append(line) 

360 test_report.rules = rules 

361 

362 def read(self) -> TestReport: 

363 test_report = TestReport(self.file_name) 

364 for line in self._file: 

365 line = line.upper() 

366 line = re.sub(r"\t", "", line) 

367 line = line.replace("\n", "") 

368 if line == REPORTS_SECTIONS_HEADERS["RULES"]: 

369 self._read_rules(test_report) 

370 elif line == "": 

371 continue 

372 else: 

373 raise ValueError( 

374 f"Invalid report file format for file: {self.file_name}" 

375 ) 

376 return test_report 

377 

378 def close(self): 

379 self._file.close() 

380 

381 

382class TestReportWriter: 

383 

384 def __init__(self, file_name: str): 

385 self._file = open(file_name, encoding="utf-8", mode="w") 

386 if not os.path.exists(REPORTS_OUT_DIRECTORY_PATH): 

387 os.makedirs(REPORTS_OUT_DIRECTORY_PATH) 

388 

389 def write(self, rule_set): 

390 self._file.write("\n") 

391 self._file.write(f'{REPORTS_SECTIONS_HEADERS["RULES"]}\n') 

392 for rule in rule_set.rules: 

393 self._file.write(f"\t{get_rule_string(rule)}") 

394 

395 def close(self): 

396 self._file.close() 

397 

398 

399def get_test_cases(class_name: str) -> list[TestCase]: 

400 if not os.path.exists(DATA_IN_DIRECTORY_PATH): 

401 raise RuntimeError( 

402 """\n 

403Test resources directory dosen't exist. Check if 'tests/resources/' directory exist. 

404 

405If you're running tests for the first time you need to download resources folder from RuleKit repository by running: 

406 python tests/resources.py download 

407 """ 

408 ) 

409 problem_type: ProblemType = _get_problem_type_from_test_case_class_name(class_name) 

410 configs = TestConfigParser().parse(f"{TEST_CONFIG_PATH}/{class_name}.xml") 

411 return TestCaseFactory().make( 

412 configs, f"{REPORTS_IN_DIRECTORY_PATH}/{class_name}/", problem_type 

413 ) 

414 

415 

416def _get_problem_type_from_test_case_class_name(class_name: str) -> ProblemType: 

417 class_name = class_name.lower() 

418 if "regression" in class_name: 

419 return ProblemType.REGRESSION 

420 elif "survival" in class_name: 

421 return ProblemType.SURVIVAL 

422 elif "classification" in class_name: 

423 return ProblemType.CLASSIFICATION 

424 raise Exception(f"Unknown problem type for test case class name: {class_name}") 

425 

426 

427def assert_rules_are_equals(expected: list[str], actual: list[str]): 

428 def sanitize_rule_string(rule_string: str) -> str: 

429 return re.sub(r"(\t)|(\n)|(\[[^\]]*\]$)", "", rule_string) 

430 

431 expected = list(map(sanitize_rule_string, expected)) 

432 actual = list(map(sanitize_rule_string, actual)) 

433 

434 if len(expected) != len(actual): 

435 raise AssertionError( 

436 "Rulesets have different number of rules, actual: " 

437 f'{len(actual)}, expected: {len(expected)}' 

438 ) 

439 dictionary = {} 

440 for rule in expected: 

441 dictionary[rule] = 0 

442 for rule in actual: 

443 key = rule 

444 if key in dictionary: 

445 dictionary[key] = dictionary[key] + 1 

446 else: 

447 raise AssertionError( 

448 "Actual ruleset contains rules not present in expected ruleset" 

449 ) 

450 for value in dictionary.values(): 

451 if value == 0: 

452 raise AssertionError("Ruleset are not equal, some rules are missing") 

453 elif value > 1: 

454 raise AssertionError("Somes rules were duplicated") 

455 

456 

457def assert_accuracy_is_greater(prediction, expected, threshold: float): 

458 labels = expected.to_numpy().astype(str) 

459 acc = metrics.accuracy_score(labels, prediction) 

460 if acc <= threshold: 

461 raise AssertionError(f"Accuracy should be greater than {threshold} (was {acc})") 

462 

463 

464def assert_score_is_greater(prediction, expected, threshold: float): 

465 if isinstance(prediction[0], int): 

466 labels = expected.to_numpy().astype(int) 

467 elif isinstance(prediction[0], float): 

468 labels = expected.to_numpy().astype(float) 

469 else: 

470 raise ValueError( 

471 f"Invalid prediction type: {str(type(prediction[0]))}. " 

472 + "Supported types are: 1 dimensional numpy array or pandas Series object." 

473 ) 

474 explained_variance_score = metrics.explained_variance_score(labels, prediction) 

475 

476 if explained_variance_score <= threshold: 

477 raise AssertionError(f"Score should be greater than {threshold}")