Coverage for rulekit/_operator.py: 91%
121 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
1"""Contains base classes for rule induction operators
2"""
3from __future__ import annotations
5from abc import ABC
6from abc import abstractmethod
7from typing import Any
8from typing import Optional
9from typing import Union
11import numpy as np
12import pandas as pd
13from pydantic import BaseModel
15from rulekit._helpers import ExampleSetFactory
16from rulekit._helpers import get_rule_generator
17from rulekit._helpers import ModelSerializer
18from rulekit._helpers import PredictionResultMapper
19from rulekit._helpers import RuleGeneratorConfigurator
20from rulekit._problem_types import ProblemType
21from rulekit.events import _command_listener_factory
22from rulekit.events import RuleInductionProgressListener
23from rulekit.main import RuleKit
24from rulekit.rules import BaseRule
25from rulekit.rules import RuleSet
27Data = Union[np.ndarray, pd.DataFrame, list]
30class BaseOperator(ABC):
31 """Base class for rule induction operator"""
33 __params_class__: type = None
35 def __init__(self, **kwargs):
36 self._initialize_rulekit()
37 self._params = None
38 self._rule_generator = None
39 self.set_params(**kwargs)
40 self.model: RuleSet[BaseRule] = None
42 def _initialize_rulekit(self):
43 if not RuleKit.initialized:
44 RuleKit.init()
46 def _map_result(self, predicted_example_set) -> np.ndarray:
47 return PredictionResultMapper.map(predicted_example_set)
49 def _validate_contrast_attribute(
50 self, example_set, contrast_attribute: Optional[str]
51 ) -> None:
52 if contrast_attribute is None:
53 return
54 contrast_attribute_instance = example_set.getAttributes().get(
55 contrast_attribute
56 )
57 if contrast_attribute_instance.isNumerical():
58 raise ValueError(
59 "Contrast set attributes must be a nominal attribute while "
60 + f'"{contrast_attribute}" is a numerical one.'
61 )
63 def fit( # pylint: disable=missing-function-docstring
64 self,
65 values: Data,
66 labels: Data,
67 survival_time_attribute: str = None,
68 contrast_attribute: str = None,
69 ) -> BaseOperator:
70 example_set = ExampleSetFactory(self._get_problem_type()).make(
71 values,
72 labels,
73 survival_time_attribute=survival_time_attribute,
74 contrast_attribute=contrast_attribute,
75 )
76 self._validate_contrast_attribute(example_set, contrast_attribute)
78 java_model = self._rule_generator.learn(example_set)
79 self.model = RuleSet[BaseRule](java_model)
80 return self.model
82 def predict(
83 self, values: Data
84 ) -> np.ndarray: # pylint: disable=missing-function-docstring
85 if self.model is None:
86 raise ValueError('"fit" method must be called before calling this method')
87 example_set = ExampleSetFactory(self._get_problem_type()).make(values)
88 return self.model._java_object.apply( # pylint: disable=protected-access
89 example_set
90 )
92 def get_params(
93 self, deep: bool = True # pylint: disable=unused-argument
94 ) -> dict[str, Any]:
95 """
96 Parameters
97 ----------
98 deep : :class:`rulekit.operator.Data`
99 Parameter for scikit-learn compatibility. Not used.
101 Returns
102 -------
103 hyperparameters : np.ndarray
104 Dictionary containing model hyperparameters.
105 """
106 return self._params
108 def set_params(self, **kwargs) -> object:
109 """Set models hyperparameters. Parameters are the same as in constructor."""
110 self._rule_generator = self._get_rule_generator()
111 params: BaseModel = self.__params_class__( # pylint: disable=not-callable
112 **kwargs
113 )
114 params_dict: dict = params.model_dump()
115 self._params = {
116 key: value for key, value in params_dict.items() if value is not None
117 }
118 configurator = RuleGeneratorConfigurator(self._rule_generator)
119 self._rule_generator = configurator.configure(**params_dict)
120 return self
122 def get_metadata_routing(self) -> None:
123 """
124 .. warning:: Scikit-learn metadata routing is not supported yet.
126 Raises:
127 NotImplementedError: _description_
128 """
129 raise NotImplementedError("Scikit-learn metadata routing is not supported yet.")
131 def get_coverage_matrix(self, values: Data) -> np.ndarray:
132 """Calculates coverage matrix for ruleset.
134 Parameters
135 ----------
136 values : :class:`rulekit.operator.Data`
137 dataset
139 Returns
140 -------
141 coverage_matrix : np.ndarray
142 Each row of the matrix represent single example from dataset and every
143 column represent on rule from rule set. Value 1 in the matrix cell means
144 that rule covered certain example, value 0 means that it doesn't.
145 """
146 if self.model is None:
147 raise ValueError('"fit" method must be called before calling this method')
148 example_set = ExampleSetFactory(self._get_problem_type()).make(values)
149 covering_info = self.model.covering(example_set)
150 if isinstance(values, (pd.Series, pd.DataFrame)):
151 values = values.to_numpy()
152 result = []
153 for i in range(len(values)):
154 row_result: list[int] = [
155 0 if item is None or i not in item else 1 for item in covering_info
156 ]
157 result.append(np.array(row_result))
158 return np.array(result)
160 def add_event_listener(self, listener: RuleInductionProgressListener):
161 """Add event listener object to the operator which allows to monitor
162 rule induction progress.
164 Example:
165 >>> from rulekit.events import RuleInductionProgressListener
166 >>> from rulekit.classification import RuleClassifier
167 >>>
168 >>> class MyEventListener(RuleInductionProgressListener):
169 >>> def on_new_rule(self, rule):
170 >>> print('Do something with new rule', rule)
171 >>>
172 >>> operator = RuleClassifier()
173 >>> operator.add_event_listener(MyEventListener())
175 Args:
176 listener (RuleInductionProgressListener): listener object
177 """
178 command_listener = _command_listener_factory(listener)
179 self._rule_generator.addOperatorListener(command_listener)
181 def __getstate__(self) -> dict:
182 state = self.__dict__.copy()
183 state.pop("_rule_generator")
184 return {"_params": self._params, "model": ModelSerializer.serialize(self.model)}
186 def __setstate__(self, state: dict):
187 self.model = ModelSerializer.deserialize(state["model"])
188 self._rule_generator = get_rule_generator()
189 self.set_params(**state["_params"])
191 def _get_rule_generator(self) -> RuleGeneratorConfigurator:
192 return get_rule_generator()
194 @abstractmethod
195 def _get_problem_type(self) -> ProblemType:
196 pass
199class ExpertKnowledgeOperator(BaseOperator, ABC):
200 """Base class for expert rule induction operator"""
202 def fit( # pylint: disable=missing-function-docstring,too-many-arguments
203 self,
204 values: Data,
205 labels: Data,
206 survival_time_attribute: str = None,
207 contrast_attribute: str = None,
208 expert_rules: list[Union[str, BaseRule]] = None,
209 expert_preferred_conditions: list[Union[str, BaseRule]] = None,
210 expert_forbidden_conditions: list[Union[str, BaseRule]] = None,
211 ) -> ExpertKnowledgeOperator:
212 example_set = ExampleSetFactory(self._get_problem_type()).make(
213 values,
214 labels,
215 survival_time_attribute=survival_time_attribute,
216 contrast_attribute=contrast_attribute,
217 )
218 self._validate_contrast_attribute(example_set, contrast_attribute)
219 self._configure_expert_parameters(
220 expert_rules,
221 expert_preferred_conditions,
222 expert_forbidden_conditions,
223 )
224 java_model = self._rule_generator.learn(example_set)
225 self.model = RuleSet(java_model)
226 return self.model
228 def _get_rule_generator(self) -> RuleGeneratorConfigurator:
229 return get_rule_generator(expert=True)
231 def _configure_expert_parameters(
232 self,
233 expert_rules: Optional[list[Union[str, BaseRule]]] = None,
234 expert_preferred_conditions: Optional[list[Union[str, BaseRule]]] = None,
235 expert_forbidden_conditions: Optional[list[Union[str, BaseRule]]] = None,
236 ) -> None:
237 if expert_rules is None:
238 expert_rules = []
239 if expert_preferred_conditions is None:
240 expert_preferred_conditions = []
241 if expert_forbidden_conditions is None:
242 expert_forbidden_conditions = []
244 configurator = RuleGeneratorConfigurator(self._rule_generator)
245 configurator._configure_simple_parameter( # pylint: disable=protected-access
246 "use_expert", True
247 )
248 configurator._configure_expert_parameter( # pylint: disable=protected-access
249 "expert_preferred_conditions",
250 self._sanitize_expert_parameter(expert_preferred_conditions),
251 )
252 configurator._configure_expert_parameter( # pylint: disable=protected-access
253 "expert_forbidden_conditions",
254 self._sanitize_expert_parameter(expert_forbidden_conditions),
255 )
256 configurator._configure_expert_parameter( # pylint: disable=protected-access
257 "expert_rules", self._sanitize_expert_parameter(expert_rules)
258 )
260 def _sanitize_expert_parameter(
261 self, expert_parameter: Optional[list[tuple[str, str]]]
262 ) -> list[tuple[str, str]]:
263 if expert_parameter is None:
264 return None
265 sanitized_parameter: list[tuple[str, str]] = []
266 for item in expert_parameter:
267 item_id, item_value = item
268 sanitized_parameter.append((item_id, item_value))
269 return sanitized_parameter