Coverage for rulekit/_helpers.py: 85%

260 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-07 11:26 +0000

1"""Contains helper functions and classes 

2""" 

3import io 

4import json 

5from typing import Any 

6from typing import Callable 

7from typing import Optional 

8from typing import Union 

9 

10import numpy as np 

11import pandas as pd 

12from jpype import JArray 

13from jpype import java 

14from jpype import JClass 

15from jpype import JObject 

16from jpype.pickle import JPickler 

17from jpype.pickle import JUnpickler 

18 

19from rulekit._logging import _RuleKitJavaLoggerConfig 

20from rulekit._problem_types import ProblemType 

21from rulekit.exceptions import RuleKitMisconfigurationException 

22from rulekit.main import RuleKit 

23from rulekit.params import Measures 

24from rulekit.rules import BaseRule 

25 

26 

27def get_rule_generator(expert: bool = False) -> Any: 

28 """Factory for Java RuleGenerator class object 

29 

30 Args: 

31 expert (bool, optional): Whether expert induction is enables. 

32 Defaults to False. 

33 

34 Returns: 

35 Any: RuleGenerator instance 

36 """ 

37 RuleGenerator = JClass( # pylint: disable=invalid-name 

38 "adaa.analytics.rules.logic.rulegenerator.RuleGenerator" 

39 ) 

40 rule_generator = RuleGenerator(expert) 

41 return rule_generator 

42 

43 

44class RuleGeneratorConfigurator: 

45 """Class for configuring rule induction parameters""" 

46 

47 _MEASURES_PARAMETERS: list[str] = [ 

48 "induction_measure", 

49 "pruning_measure", 

50 "voting_measure", 

51 ] 

52 _USER_DEFINED_MEASURE_VALUE: str = "UserDefined" 

53 

54 def __init__(self, rule_generator): 

55 self.rule_generator = rule_generator 

56 self.LogRank = None # pylint: disable=invalid-name 

57 

58 def configure(self, **params: dict[str, Any]) -> Any: 

59 """Configures RuleGenerator instance with given induction parameters 

60 

61 Returns: 

62 Any: configured RuleGenerator instance 

63 """ 

64 self._configure_rule_generator(**params) 

65 self._validate_rule_generator_parameters(**params) 

66 self._configure_java_logging() 

67 return self.rule_generator 

68 

69 def _configure_expert_parameter(self, param_name: str, param_value: Any): 

70 if param_value is None: 

71 return 

72 rules_list = java.util.ArrayList() 

73 if isinstance(param_value, list) and len(param_value) > 0: 

74 if isinstance(param_value[0], str): 

75 for index, rule in enumerate(param_value): 

76 rule_name = f"{param_name[:-1]}-{index}" 

77 rules_list.add( 

78 JObject([rule_name, rule], JArray("java.lang.String", 1)) 

79 ) 

80 elif isinstance(param_value[0], BaseRule): 

81 for index, rule in enumerate(param_value): 

82 rule_name = f"{param_name[:-1]}-{index}" 

83 rules_list.add( 

84 JObject([rule_name, str(rule)], JArray("java.lang.String", 1)) 

85 ) 

86 elif isinstance(param_value[0], tuple): 

87 for index, rule in enumerate(param_value): 

88 rules_list.add( 

89 JObject([rule[0], rule[1]], JArray("java.lang.String", 1)) 

90 ) 

91 self.rule_generator.setListParameter(param_name, rules_list) 

92 

93 def _configure_simple_parameter(self, param_name: str, param_value: Any): 

94 if param_value is not None: 

95 if isinstance(param_value, bool): 

96 param_value = (str(param_value)).lower() 

97 elif isinstance(param_value, tuple): 

98 param_value = " ".join(list(map(str, param_value))) 

99 elif not isinstance(param_value, str): 

100 param_value = str(param_value) 

101 self.rule_generator.setParameter(param_name, param_value) 

102 

103 def _configure_measure_parameter( 

104 self, param_name: str, param_value: Union[str, Measures] 

105 ): 

106 if param_value is not None: 

107 if isinstance(param_value, Measures): 

108 self.rule_generator.setParameter(param_name, param_value.value) 

109 if isinstance(param_value, Callable): 

110 self._configure_user_defined_measure_parameter(param_name, param_value) 

111 

112 def _configure_user_defined_measure_parameter( 

113 self, param_name: str, param_value: Any 

114 ): 

115 from rulekit.params import _user_defined_measure_factory 

116 

117 user_defined_measure = _user_defined_measure_factory(param_value) 

118 { 

119 "induction_measure": self.rule_generator.setUserMeasureInductionObject, 

120 "pruning_measure": self.rule_generator.setUserMeasurePurningObject, 

121 "voting_measure": self.rule_generator.setUserMeasureVotingObject, 

122 }[param_name](user_defined_measure) 

123 self.rule_generator.setParameter(param_name, self._USER_DEFINED_MEASURE_VALUE) 

124 

125 def _configure_rule_generator(self, **kwargs: dict[str, Any]): 

126 if any( 

127 [ 

128 kwargs.get(param_name) == Measures.LogRank 

129 for param_name in self._MEASURES_PARAMETERS 

130 ] 

131 ): 

132 self.LogRank = JClass("adaa.analytics.rules.logic.quality.LogRank") 

133 for measure_param_name in self._MEASURES_PARAMETERS: 

134 measure_param_value: Measures = kwargs.pop(measure_param_name, None) 

135 self._configure_measure_parameter(measure_param_name, measure_param_value) 

136 for param_name, param_value in kwargs.items(): 

137 self._configure_simple_parameter(param_name, param_value) 

138 

139 def _validate_rule_generator_parameters(self, **python_parameters: dict[str, Any]): 

140 """Validate whether operator parameters configuration is properly passed to the 

141 java RuleGenerator class. Otherwise, it raises an error. 

142 

143 Args: 

144 python_parameters (dict[str, Any]): parameters values configured for the 

145 operator in Python 

146 

147 Raises: 

148 ValueError: If failed to retrieve RuleGenerator parameters JSON 

149 RuleKitMisconfigurationException: If Java and Python parameters do not match 

150 """ 

151 

152 def are_params_equal( 

153 java_params: dict[str, Any], python_params: dict[str, Any] 

154 ): 

155 if java_params.keys() != python_params.keys(): 

156 return False 

157 for key in java_params.keys(): 

158 java_value = java_params.get(key) 

159 python_value = python_params.get(key) 

160 skip_check: bool = isinstance(python_value, Callable) 

161 if java_value == 'None': 

162 java_value = None 

163 if java_value is None and python_value is None: 

164 continue 

165 if java_value != python_value and not skip_check: 

166 return False 

167 return True 

168 

169 python_parameters = dict(python_parameters) 

170 for param_name, param_value in python_parameters.items(): 

171 # convert measures to strings values for comparison 

172 if isinstance(param_value, Measures): 

173 python_parameters[param_name] = param_value.value 

174 # convert booleans to lowercase strings for comparison 

175 elif isinstance(param_value, bool): 

176 python_parameters[param_name] = str(param_value).lower() 

177 # normalize tuples to strings for comparison 

178 elif isinstance(param_value, tuple): 

179 value = " ".join(list(map(str, param_value))) 

180 python_parameters[param_name] = value 

181 # convert numbers to strings for comparison 

182 elif isinstance(param_value, (int, float)): 

183 python_parameters[param_name] = str(param_value) 

184 java_params_json: str = str(self.rule_generator.getParamsAsJsonString()) 

185 try: 

186 java_params: dict[str, Any] = json.loads(java_params_json) 

187 except json.JSONDecodeError as error: 

188 raise ValueError( 

189 "Failed to decode RuleGenerator parameters JSON" 

190 ) from error 

191 # select only values that are used by Python wrapper 

192 java_params = { 

193 param_name: str(java_params.get(param_name)) 

194 for param_name in python_parameters.keys() 

195 } 

196 if not are_params_equal(java_params, python_parameters): 

197 raise RuleKitMisconfigurationException( 

198 java_parameters=java_params, python_parameters=python_parameters 

199 ) 

200 

201 def _configure_java_logging(self): 

202 logger_config: Optional[_RuleKitJavaLoggerConfig] = ( 

203 RuleKit.get_java_logger_config() 

204 ) 

205 if logger_config is None: 

206 return 

207 self.rule_generator.configureLogger( 

208 logger_config.log_file_path, logger_config.verbosity_level 

209 ) 

210 

211 

212class ExampleSetFactory: 

213 """Creates ExampleSet object from given data""" 

214 

215 DEFAULT_LABEL_ATTRIBUTE_NAME: str = "label" 

216 AUTOMATIC_ATTRIBUTES_NAMES_PREFIX: str = "att" 

217 

218 def __init__(self, problem_type: ProblemType) -> None: 

219 self._problem_type: ProblemType = problem_type 

220 self._attributes_names: list[str] = None 

221 self._label_name: str = None 

222 self._survival_time_attribute: str = None 

223 self._contrast_attribute: str = None 

224 self._X: np.ndarray = None 

225 self._y: np.ndarray = None 

226 

227 def make( 

228 self, 

229 X: Union[pd.DataFrame, np.ndarray], 

230 y: Union[pd.Series, np.ndarray] = None, 

231 survival_time_attribute: str = None, 

232 contrast_attribute: str = None, 

233 ) -> JObject: 

234 """Creates ExampleSet object from given data 

235 

236 Args: 

237 X (Union[pd.DataFrame, np.ndarray]): Data 

238 y (Union[pd.Series, np.ndarray], optional): Labels. Defaults to None. 

239 survival_time_attribute (str, optional): Name of survival time 

240 attribute. Defaults to None. 

241 contrast_attribute (str, optional): Name of contrast attribute. 

242 Defaults to None. 

243 

244 Returns: 

245 JObject: ExampleSet object 

246 """ 

247 self._attributes_names = [] 

248 self._survival_time_attribute = survival_time_attribute 

249 self._contrast_attribute = contrast_attribute 

250 self._sanitize_X(X) 

251 self._sanitize_y(y) 

252 self._validate_X() 

253 self._validate_y() 

254 return self._create_example_set() 

255 

256 def _sanitize_y(self, y: Union[pd.Series, np.ndarray, list]): 

257 if y is None: 

258 return 

259 elif isinstance(y, pd.Series): 

260 self._label_name = y.name 

261 self._attributes_names.append(self._label_name) 

262 self._y = y.to_numpy() 

263 elif isinstance(y, list): 

264 self._label_name = self.DEFAULT_LABEL_ATTRIBUTE_NAME 

265 self._attributes_names.append(self._label_name) 

266 self._y = np.array(y) 

267 elif isinstance(y, np.ndarray): 

268 self._label_name = self.DEFAULT_LABEL_ATTRIBUTE_NAME 

269 self._attributes_names.append(self._label_name) 

270 self._y = y 

271 else: 

272 raise ValueError( 

273 f"Invalid y type: {str(type(y))}. " 

274 "Supported types are: 1 dimensional numpy array or pandas " 

275 "Series object." 

276 ) 

277 

278 def _sanitize_X( 

279 self, 

280 X: Union[pd.DataFrame, np.ndarray], 

281 ) -> tuple[np.ndarray, np.ndarray]: 

282 if isinstance(X, pd.DataFrame): 

283 self._attributes_names = X.columns.tolist() 

284 # replace nan values with None 

285 X = X.where(pd.notnull(X), None) 

286 self._X = X.to_numpy() 

287 elif isinstance(X, np.ndarray): 

288 self._attributes_names = [ 

289 f"{self.AUTOMATIC_ATTRIBUTES_NAMES_PREFIX}{index + 1}" 

290 for index in range(X.shape[1]) 

291 ] 

292 self._X = X 

293 else: 

294 raise ValueError( 

295 f"Invalid X type: {str(type(X))}. " 

296 "Supported types are: 2 dimensional numpy array or pandas DataFrame " 

297 "object." 

298 ) 

299 

300 def _validate_X(self): 

301 if len(self._X.shape) != 2: 

302 raise ValueError( 

303 "X must be a 2 dimensional numpy array or pandas DataFrame object. " 

304 + f"Its current shape is: {str(self._X.shape)}" 

305 ) 

306 

307 def _validate_y(self): 

308 if self._y is not None and len(self._y.shape) != 1: 

309 raise ValueError( 

310 "y must be a 1 dimensional numpy array or pandas DataFrame object. " 

311 + f"Its current shape is: {str(self._y.shape)}" 

312 ) 

313 

314 def _create_example_set(self) -> JObject: 

315 data: JObject = self._prepare_data() 

316 args: list = [ 

317 data, 

318 self._attributes_names, 

319 self._label_name, 

320 self._survival_time_attribute, 

321 self._contrast_attribute, 

322 ] 

323 DataTable = JClass( # pylint: disable=invalid-name 

324 "adaa.analytics.rules.data.DataTable" 

325 ) 

326 try: 

327 table = DataTable(*args) 

328 if self._y is not None: 

329 ExampleSetFactory = JClass( 

330 'adaa.analytics.rules.logic.representation.' 

331 'exampleset.ExampleSetFactory' 

332 ) 

333 factory = ExampleSetFactory(2) 

334 example_set = factory.create(table) 

335 return example_set 

336 else: 

337 return table 

338 except Exception as error: 

339 from rulekit.exceptions import RuleKitJavaException 

340 

341 RuleKitJavaException(error).print_java_stack_trace() 

342 raise error 

343 

344 def _wrap_training_example_set(self, example_set: JObject) -> JObject: 

345 # training dataset must be wrapped in additional classes for rule induction 

346 # to work properly 

347 ExampleSetFactory = JClass( 

348 "adaa.analytics.rules.logic.representation.exampleset.ExampleSetFactory" 

349 ) 

350 factory: JObject = ExampleSetFactory(self._problem_type.value) 

351 return factory.create(example_set) 

352 

353 def _prepare_data(self) -> JObject: 

354 if self._y is None: 

355 data = self._X 

356 else: 

357 data = np.hstack((self._X.astype(object), self._y.reshape(-1, 1))) 

358 java_data = JObject(data, JArray("java.lang.Object", 2)) 

359 return java_data 

360 

361 

362class PredictionResultMapper: 

363 """Maps prediction results to numpy array""" 

364 

365 PREDICTION_COLUMN_ROLE: str = "prediction" 

366 CONFIDENCE_COLUMN_ROLE: str = "confidence" 

367 

368 @staticmethod 

369 def map_confidence(predicted_example_set, label_unique_values: list) -> np.ndarray: 

370 """Maps models confidence values to numpy array 

371 

372 Args: 

373 predicted_example_set (_type_): predicted ExampleSet instance 

374 label_unique_values (list): unique labels values 

375 

376 Returns: 

377 np.ndarray: numpy array with mapped confidence values 

378 """ 

379 confidence_matrix: list[list[float]] = [] 

380 for label_value in label_unique_values: 

381 confidence_col: JObject = PredictionResultMapper._get_column_by_role( 

382 predicted_example_set, 

383 f"{PredictionResultMapper.CONFIDENCE_COLUMN_ROLE}_{label_value}", 

384 ) 

385 

386 confidence_values = [ 

387 float(predicted_example_set.getExample(i).getValue(confidence_col)) 

388 for i in range(predicted_example_set.size()) 

389 ] 

390 

391 confidence_matrix.append(confidence_values) 

392 return np.array(confidence_matrix, dtype=float).T 

393 

394 @staticmethod 

395 def map(predicted_example_set: JObject) -> np.ndarray: 

396 """Maps models predictions to numpy array 

397 

398 Args: 

399 predicted_example_set (_type_): ExampleSet with predictions 

400 

401 Returns: 

402 np.ndarray: numpy array containing predictions 

403 """ 

404 prediction_col: JObject = PredictionResultMapper._get_column_by_role( 

405 predicted_example_set, PredictionResultMapper.PREDICTION_COLUMN_ROLE 

406 ) 

407 if prediction_col.isNominal(): 

408 return PredictionResultMapper.map_to_nominal(predicted_example_set) 

409 return PredictionResultMapper.map_to_numerical(predicted_example_set) 

410 

411 @staticmethod 

412 def map_to_nominal(predicted_example_set: JObject) -> np.ndarray: 

413 """Maps models predictions to nominal numpy array of strings 

414 

415 Args: 

416 predicted_example_set (_type_): ExampleSet with predictions 

417 

418 Returns: 

419 np.ndarray: numpy array containing predictions 

420 """ 

421 prediction_col: JObject = PredictionResultMapper._get_column_by_role( 

422 predicted_example_set, PredictionResultMapper.PREDICTION_COLUMN_ROLE 

423 ) 

424 

425 return np.array( 

426 [ 

427 str(predicted_example_set.getExample(i).getNominalValue(prediction_col)) 

428 for i in range(predicted_example_set.size()) 

429 ], 

430 dtype=str, 

431 ) 

432 

433 @staticmethod 

434 def map_to_numerical( 

435 predicted_example_set: JObject, remap: bool = True 

436 ) -> np.ndarray: 

437 """Maps models predictions to numerical numpy array 

438 

439 Args: 

440 predicted_example_set (_type_): ExampleSet with predictions 

441 

442 Returns: 

443 np.ndarray: numpy array containing predictions 

444 """ 

445 prediction_col: JObject = PredictionResultMapper._get_column_by_role( 

446 predicted_example_set, PredictionResultMapper.PREDICTION_COLUMN_ROLE 

447 ) 

448 label_mapping = predicted_example_set.getAttributes().getLabel().getMapping() 

449 if remap: 

450 predictions: list = [ 

451 label_mapping.mapIndex( 

452 int(predicted_example_set.getExample(i).getValue(prediction_col)) 

453 ) 

454 for i in range(predicted_example_set.size()) 

455 ] 

456 predictions = list(map(lambda x: float(str(x)), predictions)) 

457 return np.array(predictions) 

458 return np.array( 

459 [ 

460 float(predicted_example_set.getExample(i).getValue(prediction_col)) 

461 for i in range(predicted_example_set.size()) 

462 ] 

463 ) 

464 

465 @staticmethod 

466 def map_survival(predicted_example_set) -> np.ndarray: 

467 """Maps survival models predictions to numpy array. Used as alternative to 

468 `map` method used in survival analysis 

469 

470 Args: 

471 predicted_example_set (_type_): ExampleSet with predictions 

472 

473 Returns: 

474 np.ndarray: numpy array containing predictions 

475 """ 

476 estimators = [] 

477 attribute = predicted_example_set.getAttributes().get("estimator") 

478 example_set_iterator = predicted_example_set.iterator() 

479 while example_set_iterator.hasNext(): 

480 example = example_set_iterator.next() 

481 example_estimator = str(example.getValueAsString(attribute)) 

482 example_estimator = example_estimator.split(" ") 

483 _, example_estimator[0] = example_estimator[0].split(":") 

484 times = [ 

485 float(example_estimator[i]) 

486 for i in range(len(example_estimator) - 1) 

487 if i % 2 == 0 

488 ] 

489 probabilities = [ 

490 float(example_estimator[i]) 

491 for i in range(len(example_estimator)) 

492 if i % 2 != 0 

493 ] 

494 estimator = {"times": times, "probabilities": probabilities} 

495 estimators.append(estimator) 

496 return np.array(estimators) 

497 

498 @staticmethod 

499 def _get_column_by_role(predicted_example_set: JObject, role: str) -> JObject: 

500 return predicted_example_set.getAttributes().getColumnByRole(role) 

501 

502 

503class ModelSerializer: 

504 """Class for serializing models""" 

505 

506 @staticmethod 

507 def serialize(real_model: object) -> bytes: 

508 """Serialize Java ruleset object. 

509 

510 Args: 

511 real_model (object): Java ruleset object 

512 """ 

513 in_memory_file = io.BytesIO() 

514 JPickler(in_memory_file).dump(real_model) 

515 serialized_bytes = in_memory_file.getvalue() 

516 in_memory_file.close() 

517 return serialized_bytes 

518 

519 @staticmethod 

520 def deserialize(serialized_bytes: bytes) -> object: 

521 """Deserialize Java ruleset object from bytes. 

522 

523 Args: 

524 serialized_bytes (bytes): serialized bytes 

525 

526 Returns: 

527 object: deserialized Java ruleset object 

528 """ 

529 if not RuleKit.initialized: 

530 RuleKit.init() 

531 in_memory_file = io.BytesIO(serialized_bytes) 

532 model = JUnpickler(in_memory_file).load() 

533 in_memory_file.close() 

534 return model