Coverage for rulekit/_operator.py: 91%

121 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-07 11:26 +0000

1"""Contains base classes for rule induction operators 

2""" 

3from __future__ import annotations 

4 

5from abc import ABC 

6from abc import abstractmethod 

7from typing import Any 

8from typing import Optional 

9from typing import Union 

10 

11import numpy as np 

12import pandas as pd 

13from pydantic import BaseModel 

14 

15from rulekit._helpers import ExampleSetFactory 

16from rulekit._helpers import get_rule_generator 

17from rulekit._helpers import ModelSerializer 

18from rulekit._helpers import PredictionResultMapper 

19from rulekit._helpers import RuleGeneratorConfigurator 

20from rulekit._problem_types import ProblemType 

21from rulekit.events import _command_listener_factory 

22from rulekit.events import RuleInductionProgressListener 

23from rulekit.main import RuleKit 

24from rulekit.rules import BaseRule 

25from rulekit.rules import RuleSet 

26 

27Data = Union[np.ndarray, pd.DataFrame, list] 

28 

29 

30class BaseOperator(ABC): 

31 """Base class for rule induction operator""" 

32 

33 __params_class__: type = None 

34 

35 def __init__(self, **kwargs): 

36 self._initialize_rulekit() 

37 self._params = None 

38 self._rule_generator = None 

39 self.set_params(**kwargs) 

40 self.model: RuleSet[BaseRule] = None 

41 

42 def _initialize_rulekit(self): 

43 if not RuleKit.initialized: 

44 RuleKit.init() 

45 

46 def _map_result(self, predicted_example_set) -> np.ndarray: 

47 return PredictionResultMapper.map(predicted_example_set) 

48 

49 def _validate_contrast_attribute( 

50 self, example_set, contrast_attribute: Optional[str] 

51 ) -> None: 

52 if contrast_attribute is None: 

53 return 

54 contrast_attribute_instance = example_set.getAttributes().get( 

55 contrast_attribute 

56 ) 

57 if contrast_attribute_instance.isNumerical(): 

58 raise ValueError( 

59 "Contrast set attributes must be a nominal attribute while " 

60 + f'"{contrast_attribute}" is a numerical one.' 

61 ) 

62 

63 def fit( # pylint: disable=missing-function-docstring 

64 self, 

65 values: Data, 

66 labels: Data, 

67 survival_time_attribute: str = None, 

68 contrast_attribute: str = None, 

69 ) -> BaseOperator: 

70 example_set = ExampleSetFactory(self._get_problem_type()).make( 

71 values, 

72 labels, 

73 survival_time_attribute=survival_time_attribute, 

74 contrast_attribute=contrast_attribute, 

75 ) 

76 self._validate_contrast_attribute(example_set, contrast_attribute) 

77 

78 java_model = self._rule_generator.learn(example_set) 

79 self.model = RuleSet[BaseRule](java_model) 

80 return self.model 

81 

82 def predict( 

83 self, values: Data 

84 ) -> np.ndarray: # pylint: disable=missing-function-docstring 

85 if self.model is None: 

86 raise ValueError('"fit" method must be called before calling this method') 

87 example_set = ExampleSetFactory(self._get_problem_type()).make(values) 

88 return self.model._java_object.apply( # pylint: disable=protected-access 

89 example_set 

90 ) 

91 

92 def get_params( 

93 self, deep: bool = True # pylint: disable=unused-argument 

94 ) -> dict[str, Any]: 

95 """ 

96 Parameters 

97 ---------- 

98 deep : :class:`rulekit.operator.Data` 

99 Parameter for scikit-learn compatibility. Not used. 

100 

101 Returns 

102 ------- 

103 hyperparameters : np.ndarray 

104 Dictionary containing model hyperparameters. 

105 """ 

106 return self._params 

107 

108 def set_params(self, **kwargs) -> object: 

109 """Set models hyperparameters. Parameters are the same as in constructor.""" 

110 self._rule_generator = self._get_rule_generator() 

111 params: BaseModel = self.__params_class__( # pylint: disable=not-callable 

112 **kwargs 

113 ) 

114 params_dict: dict = params.model_dump() 

115 self._params = { 

116 key: value for key, value in params_dict.items() if value is not None 

117 } 

118 configurator = RuleGeneratorConfigurator(self._rule_generator) 

119 self._rule_generator = configurator.configure(**params_dict) 

120 return self 

121 

122 def get_metadata_routing(self) -> None: 

123 """ 

124 .. warning:: Scikit-learn metadata routing is not supported yet. 

125 

126 Raises: 

127 NotImplementedError: _description_ 

128 """ 

129 raise NotImplementedError("Scikit-learn metadata routing is not supported yet.") 

130 

131 def get_coverage_matrix(self, values: Data) -> np.ndarray: 

132 """Calculates coverage matrix for ruleset. 

133 

134 Parameters 

135 ---------- 

136 values : :class:`rulekit.operator.Data` 

137 dataset 

138 

139 Returns 

140 ------- 

141 coverage_matrix : np.ndarray 

142 Each row of the matrix represent single example from dataset and every 

143 column represent on rule from rule set. Value 1 in the matrix cell means 

144 that rule covered certain example, value 0 means that it doesn't. 

145 """ 

146 if self.model is None: 

147 raise ValueError('"fit" method must be called before calling this method') 

148 example_set = ExampleSetFactory(self._get_problem_type()).make(values) 

149 covering_info = self.model.covering(example_set) 

150 if isinstance(values, (pd.Series, pd.DataFrame)): 

151 values = values.to_numpy() 

152 result = [] 

153 for i in range(len(values)): 

154 row_result: list[int] = [ 

155 0 if item is None or i not in item else 1 for item in covering_info 

156 ] 

157 result.append(np.array(row_result)) 

158 return np.array(result) 

159 

160 def add_event_listener(self, listener: RuleInductionProgressListener): 

161 """Add event listener object to the operator which allows to monitor 

162 rule induction progress. 

163 

164 Example: 

165 >>> from rulekit.events import RuleInductionProgressListener 

166 >>> from rulekit.classification import RuleClassifier 

167 >>> 

168 >>> class MyEventListener(RuleInductionProgressListener): 

169 >>> def on_new_rule(self, rule): 

170 >>> print('Do something with new rule', rule) 

171 >>> 

172 >>> operator = RuleClassifier() 

173 >>> operator.add_event_listener(MyEventListener()) 

174 

175 Args: 

176 listener (RuleInductionProgressListener): listener object 

177 """ 

178 command_listener = _command_listener_factory(listener) 

179 self._rule_generator.addOperatorListener(command_listener) 

180 

181 def __getstate__(self) -> dict: 

182 state = self.__dict__.copy() 

183 state.pop("_rule_generator") 

184 return {"_params": self._params, "model": ModelSerializer.serialize(self.model)} 

185 

186 def __setstate__(self, state: dict): 

187 self.model = ModelSerializer.deserialize(state["model"]) 

188 self._rule_generator = get_rule_generator() 

189 self.set_params(**state["_params"]) 

190 

191 def _get_rule_generator(self) -> RuleGeneratorConfigurator: 

192 return get_rule_generator() 

193 

194 @abstractmethod 

195 def _get_problem_type(self) -> ProblemType: 

196 pass 

197 

198 

199class ExpertKnowledgeOperator(BaseOperator, ABC): 

200 """Base class for expert rule induction operator""" 

201 

202 def fit( # pylint: disable=missing-function-docstring,too-many-arguments 

203 self, 

204 values: Data, 

205 labels: Data, 

206 survival_time_attribute: str = None, 

207 contrast_attribute: str = None, 

208 expert_rules: list[Union[str, BaseRule]] = None, 

209 expert_preferred_conditions: list[Union[str, BaseRule]] = None, 

210 expert_forbidden_conditions: list[Union[str, BaseRule]] = None, 

211 ) -> ExpertKnowledgeOperator: 

212 example_set = ExampleSetFactory(self._get_problem_type()).make( 

213 values, 

214 labels, 

215 survival_time_attribute=survival_time_attribute, 

216 contrast_attribute=contrast_attribute, 

217 ) 

218 self._validate_contrast_attribute(example_set, contrast_attribute) 

219 self._configure_expert_parameters( 

220 expert_rules, 

221 expert_preferred_conditions, 

222 expert_forbidden_conditions, 

223 ) 

224 java_model = self._rule_generator.learn(example_set) 

225 self.model = RuleSet(java_model) 

226 return self.model 

227 

228 def _get_rule_generator(self) -> RuleGeneratorConfigurator: 

229 return get_rule_generator(expert=True) 

230 

231 def _configure_expert_parameters( 

232 self, 

233 expert_rules: Optional[list[Union[str, BaseRule]]] = None, 

234 expert_preferred_conditions: Optional[list[Union[str, BaseRule]]] = None, 

235 expert_forbidden_conditions: Optional[list[Union[str, BaseRule]]] = None, 

236 ) -> None: 

237 if expert_rules is None: 

238 expert_rules = [] 

239 if expert_preferred_conditions is None: 

240 expert_preferred_conditions = [] 

241 if expert_forbidden_conditions is None: 

242 expert_forbidden_conditions = [] 

243 

244 configurator = RuleGeneratorConfigurator(self._rule_generator) 

245 configurator._configure_simple_parameter( # pylint: disable=protected-access 

246 "use_expert", True 

247 ) 

248 configurator._configure_expert_parameter( # pylint: disable=protected-access 

249 "expert_preferred_conditions", 

250 self._sanitize_expert_parameter(expert_preferred_conditions), 

251 ) 

252 configurator._configure_expert_parameter( # pylint: disable=protected-access 

253 "expert_forbidden_conditions", 

254 self._sanitize_expert_parameter(expert_forbidden_conditions), 

255 ) 

256 configurator._configure_expert_parameter( # pylint: disable=protected-access 

257 "expert_rules", self._sanitize_expert_parameter(expert_rules) 

258 ) 

259 

260 def _sanitize_expert_parameter( 

261 self, expert_parameter: Optional[list[tuple[str, str]]] 

262 ) -> list[tuple[str, str]]: 

263 if expert_parameter is None: 

264 return None 

265 sanitized_parameter: list[tuple[str, str]] = [] 

266 for item in expert_parameter: 

267 item_id, item_value = item 

268 sanitized_parameter.append((item_id, item_value)) 

269 return sanitized_parameter