Coverage for rulekit/params.py: 92%

88 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-07 11:26 +0000

1"""Contains constants and classes for specyfing models parameters 

2""" 

3from enum import Enum 

4from typing import Callable 

5from typing import Optional 

6from typing import Tuple 

7from typing import Union 

8 

9from jpype import JImplements 

10from jpype import JOverride 

11from jpype.types import JDouble 

12from pydantic import BaseModel # pylint: disable=no-name-in-module 

13 

14MAX_INT: int = 2147483647 # max integer value in Java 

15 

16_UserDefinedMeasure = Callable[[float, float, float, float], float] 

17 

18 

19def _user_defined_measure_factory(measure_function: _UserDefinedMeasure): 

20 from adaa.analytics.rules.logic.quality import \ 

21 IUserMeasure # pylint: disable=import-outside-toplevel,import-error 

22 

23 @JImplements(IUserMeasure) 

24 class _UserMeasure: # pylint: disable=invalid-name,missing-function-docstring 

25 

26 @JOverride 

27 def getResult(self, p: JDouble, n: JDouble, P: JDouble, N: JDouble) -> float: 

28 return measure_function(float(p), float(n), float(P), float(N)) 

29 

30 return _UserMeasure() 

31 

32 

33class Measures(Enum): 

34 # pylint: disable=invalid-name 

35 """Enum for different measures used during induction, pruning and voting. 

36 

37 You can ream more about each measure and its implementation 

38 #41-rule-quality>`_ . 

39 `here <https://github.com/adaa-polsl/RuleKit/wiki/4-Quality-and-evaluation 

40 """ 

41 Accuracy = "Accuracy" 

42 BinaryEntropy = "BinaryEntropy" 

43 C1 = "C1" 

44 C2 = "C2" 

45 CFoil = "CFoil" 

46 CN2Significnce = "CN2Significnce" 

47 Correlation = "Correlation" 

48 Coverage = "Coverage" 

49 FBayesianConfirmation = "FBayesianConfirmation" 

50 FMeasure = "FMeasure" 

51 FullCoverage = "FullCoverage" 

52 GeoRSS = "GeoRSS" 

53 GMeasure = "GMeasure" 

54 InformationGain = "InformationGain" 

55 JMeasure = "JMeasure" 

56 Kappa = "Kappa" 

57 Klosgen = "Klosgen" 

58 Laplace = "Laplace" 

59 Lift = "Lift" 

60 LogicalSufficiency = "LogicalSufficiency" 

61 LogRank = "LogRank" 

62 MEstimate = "MEstimate" 

63 MutualSupport = "MutualSupport" 

64 Novelty = "Novelty" 

65 OddsRatio = "OddsRatio" 

66 OneWaySupport = "OneWaySupport" 

67 PawlakDependencyFactor = "PawlakDependencyFactor" 

68 Precision = "Precision" 

69 Q2 = "Q2" 

70 RelativeRisk = "RelativeRisk" 

71 Ripper = "Ripper" 

72 RSS = "RSS" 

73 RuleInterest = "RuleInterest" 

74 SBayesian = "SBayesian" 

75 Sensitivity = "Sensitivity" 

76 Specificity = "Specificity" 

77 TwoWaySupport = "TwoWaySupport" 

78 WeightedLaplace = "WeightedLaplace" 

79 WeightedRelativeAccuracy = "WeightedRelativeAccuracy" 

80 YAILS = "YAILS" 

81 

82 

83DEFAULT_PARAMS_VALUE = { 

84 "minsupp_new": 0.05, 

85 "induction_measure": Measures.Correlation, 

86 "pruning_measure": Measures.Correlation, 

87 "voting_measure": Measures.Correlation, 

88 "max_growing": 0.0, 

89 "enable_pruning": True, 

90 "ignore_missing": False, 

91 "max_uncovered_fraction": 0.0, 

92 "select_best_candidate": False, 

93 "complementary_conditions": False, 

94 "control_apriori_precision": True, 

95 "max_rule_count": 0, 

96 "approximate_induction": False, 

97 "approximate_bins_count": 100, 

98 "mean_based_regression": True, 

99 "extend_using_preferred": False, 

100 "extend_using_automatic": False, 

101 "induce_using_preferred": False, 

102 "induce_using_automatic": False, 

103 "consider_other_classes": False, 

104 "preferred_conditions_per_rule": MAX_INT, 

105 "preferred_attributes_per_rule": MAX_INT, 

106 # Contrast sets 

107 "minsupp_all": (0.8, 0.5, 0.2, 0.1), 

108 "max_neg2pos": 0.5, 

109 "max_passes_count": 5, 

110 "penalty_strength": 0.5, 

111 "penalty_saturation": 0.2, 

112} 

113 

114_QualityMeasure = Union[Measures, _UserDefinedMeasure] 

115 

116 

117class ModelsParams(BaseModel): 

118 """Model for validating models hyperparameters""" 

119 

120 minsupp_new: Optional[float] = DEFAULT_PARAMS_VALUE["minsupp_new"] 

121 induction_measure: Optional[_QualityMeasure] = DEFAULT_PARAMS_VALUE[ 

122 "induction_measure" 

123 ] 

124 pruning_measure: Optional[_QualityMeasure] = DEFAULT_PARAMS_VALUE["pruning_measure"] 

125 voting_measure: Optional[_QualityMeasure] = DEFAULT_PARAMS_VALUE["voting_measure"] 

126 max_growing: Optional[float] = DEFAULT_PARAMS_VALUE["max_growing"] 

127 enable_pruning: Optional[bool] = DEFAULT_PARAMS_VALUE["enable_pruning"] 

128 ignore_missing: Optional[bool] = DEFAULT_PARAMS_VALUE["ignore_missing"] 

129 max_uncovered_fraction: Optional[float] = DEFAULT_PARAMS_VALUE[ 

130 "max_uncovered_fraction" 

131 ] 

132 select_best_candidate: Optional[bool] = DEFAULT_PARAMS_VALUE[ 

133 "select_best_candidate" 

134 ] 

135 complementary_conditions: Optional[bool] = DEFAULT_PARAMS_VALUE[ 

136 "complementary_conditions" 

137 ] 

138 max_rule_count: int = DEFAULT_PARAMS_VALUE["max_rule_count"] 

139 

140 

141class ExpertModelParams(ModelsParams): 

142 """Model for validating expert models hyperparameters""" 

143 

144 extend_using_preferred: Optional[bool] = DEFAULT_PARAMS_VALUE[ 

145 "extend_using_preferred" 

146 ] 

147 extend_using_automatic: Optional[bool] = DEFAULT_PARAMS_VALUE[ 

148 "extend_using_automatic" 

149 ] 

150 induce_using_preferred: Optional[bool] = DEFAULT_PARAMS_VALUE[ 

151 "induce_using_preferred" 

152 ] 

153 induce_using_automatic: Optional[bool] = DEFAULT_PARAMS_VALUE[ 

154 "induce_using_automatic" 

155 ] 

156 consider_other_classes: Optional[bool] = DEFAULT_PARAMS_VALUE[ 

157 "consider_other_classes" 

158 ] 

159 preferred_conditions_per_rule: Optional[int] = DEFAULT_PARAMS_VALUE[ 

160 "preferred_conditions_per_rule" 

161 ] 

162 preferred_attributes_per_rule: Optional[int] = DEFAULT_PARAMS_VALUE[ 

163 "preferred_attributes_per_rule" 

164 ] 

165 

166 

167class ContrastSetModelParams(ModelsParams): 

168 """Model for validating contrast set models hyperparameters""" 

169 

170 minsupp_all: Tuple[float, float, float, float] = DEFAULT_PARAMS_VALUE["minsupp_all"] 

171 max_neg2pos: Optional[float] = DEFAULT_PARAMS_VALUE["max_neg2pos"] 

172 max_passes_count: Optional[int] = DEFAULT_PARAMS_VALUE["max_passes_count"] 

173 penalty_strength: Optional[float] = DEFAULT_PARAMS_VALUE["penalty_strength"] 

174 penalty_saturation: Optional[float] = DEFAULT_PARAMS_VALUE["penalty_saturation"]