Coverage for rulekit/_helpers.py: 85%
260 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
1"""Contains helper functions and classes
2"""
3import io
4import json
5from typing import Any
6from typing import Callable
7from typing import Optional
8from typing import Union
10import numpy as np
11import pandas as pd
12from jpype import JArray
13from jpype import java
14from jpype import JClass
15from jpype import JObject
16from jpype.pickle import JPickler
17from jpype.pickle import JUnpickler
19from rulekit._logging import _RuleKitJavaLoggerConfig
20from rulekit._problem_types import ProblemType
21from rulekit.exceptions import RuleKitMisconfigurationException
22from rulekit.main import RuleKit
23from rulekit.params import Measures
24from rulekit.rules import BaseRule
27def get_rule_generator(expert: bool = False) -> Any:
28 """Factory for Java RuleGenerator class object
30 Args:
31 expert (bool, optional): Whether expert induction is enables.
32 Defaults to False.
34 Returns:
35 Any: RuleGenerator instance
36 """
37 RuleGenerator = JClass( # pylint: disable=invalid-name
38 "adaa.analytics.rules.logic.rulegenerator.RuleGenerator"
39 )
40 rule_generator = RuleGenerator(expert)
41 return rule_generator
44class RuleGeneratorConfigurator:
45 """Class for configuring rule induction parameters"""
47 _MEASURES_PARAMETERS: list[str] = [
48 "induction_measure",
49 "pruning_measure",
50 "voting_measure",
51 ]
52 _USER_DEFINED_MEASURE_VALUE: str = "UserDefined"
54 def __init__(self, rule_generator):
55 self.rule_generator = rule_generator
56 self.LogRank = None # pylint: disable=invalid-name
58 def configure(self, **params: dict[str, Any]) -> Any:
59 """Configures RuleGenerator instance with given induction parameters
61 Returns:
62 Any: configured RuleGenerator instance
63 """
64 self._configure_rule_generator(**params)
65 self._validate_rule_generator_parameters(**params)
66 self._configure_java_logging()
67 return self.rule_generator
69 def _configure_expert_parameter(self, param_name: str, param_value: Any):
70 if param_value is None:
71 return
72 rules_list = java.util.ArrayList()
73 if isinstance(param_value, list) and len(param_value) > 0:
74 if isinstance(param_value[0], str):
75 for index, rule in enumerate(param_value):
76 rule_name = f"{param_name[:-1]}-{index}"
77 rules_list.add(
78 JObject([rule_name, rule], JArray("java.lang.String", 1))
79 )
80 elif isinstance(param_value[0], BaseRule):
81 for index, rule in enumerate(param_value):
82 rule_name = f"{param_name[:-1]}-{index}"
83 rules_list.add(
84 JObject([rule_name, str(rule)], JArray("java.lang.String", 1))
85 )
86 elif isinstance(param_value[0], tuple):
87 for index, rule in enumerate(param_value):
88 rules_list.add(
89 JObject([rule[0], rule[1]], JArray("java.lang.String", 1))
90 )
91 self.rule_generator.setListParameter(param_name, rules_list)
93 def _configure_simple_parameter(self, param_name: str, param_value: Any):
94 if param_value is not None:
95 if isinstance(param_value, bool):
96 param_value = (str(param_value)).lower()
97 elif isinstance(param_value, tuple):
98 param_value = " ".join(list(map(str, param_value)))
99 elif not isinstance(param_value, str):
100 param_value = str(param_value)
101 self.rule_generator.setParameter(param_name, param_value)
103 def _configure_measure_parameter(
104 self, param_name: str, param_value: Union[str, Measures]
105 ):
106 if param_value is not None:
107 if isinstance(param_value, Measures):
108 self.rule_generator.setParameter(param_name, param_value.value)
109 if isinstance(param_value, Callable):
110 self._configure_user_defined_measure_parameter(param_name, param_value)
112 def _configure_user_defined_measure_parameter(
113 self, param_name: str, param_value: Any
114 ):
115 from rulekit.params import _user_defined_measure_factory
117 user_defined_measure = _user_defined_measure_factory(param_value)
118 {
119 "induction_measure": self.rule_generator.setUserMeasureInductionObject,
120 "pruning_measure": self.rule_generator.setUserMeasurePurningObject,
121 "voting_measure": self.rule_generator.setUserMeasureVotingObject,
122 }[param_name](user_defined_measure)
123 self.rule_generator.setParameter(param_name, self._USER_DEFINED_MEASURE_VALUE)
125 def _configure_rule_generator(self, **kwargs: dict[str, Any]):
126 if any(
127 [
128 kwargs.get(param_name) == Measures.LogRank
129 for param_name in self._MEASURES_PARAMETERS
130 ]
131 ):
132 self.LogRank = JClass("adaa.analytics.rules.logic.quality.LogRank")
133 for measure_param_name in self._MEASURES_PARAMETERS:
134 measure_param_value: Measures = kwargs.pop(measure_param_name, None)
135 self._configure_measure_parameter(measure_param_name, measure_param_value)
136 for param_name, param_value in kwargs.items():
137 self._configure_simple_parameter(param_name, param_value)
139 def _validate_rule_generator_parameters(self, **python_parameters: dict[str, Any]):
140 """Validate whether operator parameters configuration is properly passed to the
141 java RuleGenerator class. Otherwise, it raises an error.
143 Args:
144 python_parameters (dict[str, Any]): parameters values configured for the
145 operator in Python
147 Raises:
148 ValueError: If failed to retrieve RuleGenerator parameters JSON
149 RuleKitMisconfigurationException: If Java and Python parameters do not match
150 """
152 def are_params_equal(
153 java_params: dict[str, Any], python_params: dict[str, Any]
154 ):
155 if java_params.keys() != python_params.keys():
156 return False
157 for key in java_params.keys():
158 java_value = java_params.get(key)
159 python_value = python_params.get(key)
160 skip_check: bool = isinstance(python_value, Callable)
161 if java_value == 'None':
162 java_value = None
163 if java_value is None and python_value is None:
164 continue
165 if java_value != python_value and not skip_check:
166 return False
167 return True
169 python_parameters = dict(python_parameters)
170 for param_name, param_value in python_parameters.items():
171 # convert measures to strings values for comparison
172 if isinstance(param_value, Measures):
173 python_parameters[param_name] = param_value.value
174 # convert booleans to lowercase strings for comparison
175 elif isinstance(param_value, bool):
176 python_parameters[param_name] = str(param_value).lower()
177 # normalize tuples to strings for comparison
178 elif isinstance(param_value, tuple):
179 value = " ".join(list(map(str, param_value)))
180 python_parameters[param_name] = value
181 # convert numbers to strings for comparison
182 elif isinstance(param_value, (int, float)):
183 python_parameters[param_name] = str(param_value)
184 java_params_json: str = str(self.rule_generator.getParamsAsJsonString())
185 try:
186 java_params: dict[str, Any] = json.loads(java_params_json)
187 except json.JSONDecodeError as error:
188 raise ValueError(
189 "Failed to decode RuleGenerator parameters JSON"
190 ) from error
191 # select only values that are used by Python wrapper
192 java_params = {
193 param_name: str(java_params.get(param_name))
194 for param_name in python_parameters.keys()
195 }
196 if not are_params_equal(java_params, python_parameters):
197 raise RuleKitMisconfigurationException(
198 java_parameters=java_params, python_parameters=python_parameters
199 )
201 def _configure_java_logging(self):
202 logger_config: Optional[_RuleKitJavaLoggerConfig] = (
203 RuleKit.get_java_logger_config()
204 )
205 if logger_config is None:
206 return
207 self.rule_generator.configureLogger(
208 logger_config.log_file_path, logger_config.verbosity_level
209 )
212class ExampleSetFactory:
213 """Creates ExampleSet object from given data"""
215 DEFAULT_LABEL_ATTRIBUTE_NAME: str = "label"
216 AUTOMATIC_ATTRIBUTES_NAMES_PREFIX: str = "att"
218 def __init__(self, problem_type: ProblemType) -> None:
219 self._problem_type: ProblemType = problem_type
220 self._attributes_names: list[str] = None
221 self._label_name: str = None
222 self._survival_time_attribute: str = None
223 self._contrast_attribute: str = None
224 self._X: np.ndarray = None
225 self._y: np.ndarray = None
227 def make(
228 self,
229 X: Union[pd.DataFrame, np.ndarray],
230 y: Union[pd.Series, np.ndarray] = None,
231 survival_time_attribute: str = None,
232 contrast_attribute: str = None,
233 ) -> JObject:
234 """Creates ExampleSet object from given data
236 Args:
237 X (Union[pd.DataFrame, np.ndarray]): Data
238 y (Union[pd.Series, np.ndarray], optional): Labels. Defaults to None.
239 survival_time_attribute (str, optional): Name of survival time
240 attribute. Defaults to None.
241 contrast_attribute (str, optional): Name of contrast attribute.
242 Defaults to None.
244 Returns:
245 JObject: ExampleSet object
246 """
247 self._attributes_names = []
248 self._survival_time_attribute = survival_time_attribute
249 self._contrast_attribute = contrast_attribute
250 self._sanitize_X(X)
251 self._sanitize_y(y)
252 self._validate_X()
253 self._validate_y()
254 return self._create_example_set()
256 def _sanitize_y(self, y: Union[pd.Series, np.ndarray, list]):
257 if y is None:
258 return
259 elif isinstance(y, pd.Series):
260 self._label_name = y.name
261 self._attributes_names.append(self._label_name)
262 self._y = y.to_numpy()
263 elif isinstance(y, list):
264 self._label_name = self.DEFAULT_LABEL_ATTRIBUTE_NAME
265 self._attributes_names.append(self._label_name)
266 self._y = np.array(y)
267 elif isinstance(y, np.ndarray):
268 self._label_name = self.DEFAULT_LABEL_ATTRIBUTE_NAME
269 self._attributes_names.append(self._label_name)
270 self._y = y
271 else:
272 raise ValueError(
273 f"Invalid y type: {str(type(y))}. "
274 "Supported types are: 1 dimensional numpy array or pandas "
275 "Series object."
276 )
278 def _sanitize_X(
279 self,
280 X: Union[pd.DataFrame, np.ndarray],
281 ) -> tuple[np.ndarray, np.ndarray]:
282 if isinstance(X, pd.DataFrame):
283 self._attributes_names = X.columns.tolist()
284 # replace nan values with None
285 X = X.where(pd.notnull(X), None)
286 self._X = X.to_numpy()
287 elif isinstance(X, np.ndarray):
288 self._attributes_names = [
289 f"{self.AUTOMATIC_ATTRIBUTES_NAMES_PREFIX}{index + 1}"
290 for index in range(X.shape[1])
291 ]
292 self._X = X
293 else:
294 raise ValueError(
295 f"Invalid X type: {str(type(X))}. "
296 "Supported types are: 2 dimensional numpy array or pandas DataFrame "
297 "object."
298 )
300 def _validate_X(self):
301 if len(self._X.shape) != 2:
302 raise ValueError(
303 "X must be a 2 dimensional numpy array or pandas DataFrame object. "
304 + f"Its current shape is: {str(self._X.shape)}"
305 )
307 def _validate_y(self):
308 if self._y is not None and len(self._y.shape) != 1:
309 raise ValueError(
310 "y must be a 1 dimensional numpy array or pandas DataFrame object. "
311 + f"Its current shape is: {str(self._y.shape)}"
312 )
314 def _create_example_set(self) -> JObject:
315 data: JObject = self._prepare_data()
316 args: list = [
317 data,
318 self._attributes_names,
319 self._label_name,
320 self._survival_time_attribute,
321 self._contrast_attribute,
322 ]
323 DataTable = JClass( # pylint: disable=invalid-name
324 "adaa.analytics.rules.data.DataTable"
325 )
326 try:
327 table = DataTable(*args)
328 if self._y is not None:
329 ExampleSetFactory = JClass(
330 'adaa.analytics.rules.logic.representation.'
331 'exampleset.ExampleSetFactory'
332 )
333 factory = ExampleSetFactory(2)
334 example_set = factory.create(table)
335 return example_set
336 else:
337 return table
338 except Exception as error:
339 from rulekit.exceptions import RuleKitJavaException
341 RuleKitJavaException(error).print_java_stack_trace()
342 raise error
344 def _wrap_training_example_set(self, example_set: JObject) -> JObject:
345 # training dataset must be wrapped in additional classes for rule induction
346 # to work properly
347 ExampleSetFactory = JClass(
348 "adaa.analytics.rules.logic.representation.exampleset.ExampleSetFactory"
349 )
350 factory: JObject = ExampleSetFactory(self._problem_type.value)
351 return factory.create(example_set)
353 def _prepare_data(self) -> JObject:
354 if self._y is None:
355 data = self._X
356 else:
357 data = np.hstack((self._X.astype(object), self._y.reshape(-1, 1)))
358 java_data = JObject(data, JArray("java.lang.Object", 2))
359 return java_data
362class PredictionResultMapper:
363 """Maps prediction results to numpy array"""
365 PREDICTION_COLUMN_ROLE: str = "prediction"
366 CONFIDENCE_COLUMN_ROLE: str = "confidence"
368 @staticmethod
369 def map_confidence(predicted_example_set, label_unique_values: list) -> np.ndarray:
370 """Maps models confidence values to numpy array
372 Args:
373 predicted_example_set (_type_): predicted ExampleSet instance
374 label_unique_values (list): unique labels values
376 Returns:
377 np.ndarray: numpy array with mapped confidence values
378 """
379 confidence_matrix: list[list[float]] = []
380 for label_value in label_unique_values:
381 confidence_col: JObject = PredictionResultMapper._get_column_by_role(
382 predicted_example_set,
383 f"{PredictionResultMapper.CONFIDENCE_COLUMN_ROLE}_{label_value}",
384 )
386 confidence_values = [
387 float(predicted_example_set.getExample(i).getValue(confidence_col))
388 for i in range(predicted_example_set.size())
389 ]
391 confidence_matrix.append(confidence_values)
392 return np.array(confidence_matrix, dtype=float).T
394 @staticmethod
395 def map(predicted_example_set: JObject) -> np.ndarray:
396 """Maps models predictions to numpy array
398 Args:
399 predicted_example_set (_type_): ExampleSet with predictions
401 Returns:
402 np.ndarray: numpy array containing predictions
403 """
404 prediction_col: JObject = PredictionResultMapper._get_column_by_role(
405 predicted_example_set, PredictionResultMapper.PREDICTION_COLUMN_ROLE
406 )
407 if prediction_col.isNominal():
408 return PredictionResultMapper.map_to_nominal(predicted_example_set)
409 return PredictionResultMapper.map_to_numerical(predicted_example_set)
411 @staticmethod
412 def map_to_nominal(predicted_example_set: JObject) -> np.ndarray:
413 """Maps models predictions to nominal numpy array of strings
415 Args:
416 predicted_example_set (_type_): ExampleSet with predictions
418 Returns:
419 np.ndarray: numpy array containing predictions
420 """
421 prediction_col: JObject = PredictionResultMapper._get_column_by_role(
422 predicted_example_set, PredictionResultMapper.PREDICTION_COLUMN_ROLE
423 )
425 return np.array(
426 [
427 str(predicted_example_set.getExample(i).getNominalValue(prediction_col))
428 for i in range(predicted_example_set.size())
429 ],
430 dtype=str,
431 )
433 @staticmethod
434 def map_to_numerical(
435 predicted_example_set: JObject, remap: bool = True
436 ) -> np.ndarray:
437 """Maps models predictions to numerical numpy array
439 Args:
440 predicted_example_set (_type_): ExampleSet with predictions
442 Returns:
443 np.ndarray: numpy array containing predictions
444 """
445 prediction_col: JObject = PredictionResultMapper._get_column_by_role(
446 predicted_example_set, PredictionResultMapper.PREDICTION_COLUMN_ROLE
447 )
448 label_mapping = predicted_example_set.getAttributes().getLabel().getMapping()
449 if remap:
450 predictions: list = [
451 label_mapping.mapIndex(
452 int(predicted_example_set.getExample(i).getValue(prediction_col))
453 )
454 for i in range(predicted_example_set.size())
455 ]
456 predictions = list(map(lambda x: float(str(x)), predictions))
457 return np.array(predictions)
458 return np.array(
459 [
460 float(predicted_example_set.getExample(i).getValue(prediction_col))
461 for i in range(predicted_example_set.size())
462 ]
463 )
465 @staticmethod
466 def map_survival(predicted_example_set) -> np.ndarray:
467 """Maps survival models predictions to numpy array. Used as alternative to
468 `map` method used in survival analysis
470 Args:
471 predicted_example_set (_type_): ExampleSet with predictions
473 Returns:
474 np.ndarray: numpy array containing predictions
475 """
476 estimators = []
477 attribute = predicted_example_set.getAttributes().get("estimator")
478 example_set_iterator = predicted_example_set.iterator()
479 while example_set_iterator.hasNext():
480 example = example_set_iterator.next()
481 example_estimator = str(example.getValueAsString(attribute))
482 example_estimator = example_estimator.split(" ")
483 _, example_estimator[0] = example_estimator[0].split(":")
484 times = [
485 float(example_estimator[i])
486 for i in range(len(example_estimator) - 1)
487 if i % 2 == 0
488 ]
489 probabilities = [
490 float(example_estimator[i])
491 for i in range(len(example_estimator))
492 if i % 2 != 0
493 ]
494 estimator = {"times": times, "probabilities": probabilities}
495 estimators.append(estimator)
496 return np.array(estimators)
498 @staticmethod
499 def _get_column_by_role(predicted_example_set: JObject, role: str) -> JObject:
500 return predicted_example_set.getAttributes().getColumnByRole(role)
503class ModelSerializer:
504 """Class for serializing models"""
506 @staticmethod
507 def serialize(real_model: object) -> bytes:
508 """Serialize Java ruleset object.
510 Args:
511 real_model (object): Java ruleset object
512 """
513 in_memory_file = io.BytesIO()
514 JPickler(in_memory_file).dump(real_model)
515 serialized_bytes = in_memory_file.getvalue()
516 in_memory_file.close()
517 return serialized_bytes
519 @staticmethod
520 def deserialize(serialized_bytes: bytes) -> object:
521 """Deserialize Java ruleset object from bytes.
523 Args:
524 serialized_bytes (bytes): serialized bytes
526 Returns:
527 object: deserialized Java ruleset object
528 """
529 if not RuleKit.initialized:
530 RuleKit.init()
531 in_memory_file = io.BytesIO(serialized_bytes)
532 model = JUnpickler(in_memory_file).load()
533 in_memory_file.close()
534 return model