Coverage for rulekit/kaplan_meier.py: 100%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-07 11:26 +0000

1from typing import Optional 

2 

3import numpy as np 

4from jpype import JObject 

5 

6 

7class KaplanMeierEstimator: 

8 """Kaplan-Meier estimator of survival function. 

9 """ 

10 

11 def __init__(self, java_object: JObject) -> None: 

12 """ 

13 Args: 

14 java_object (JObject): \ 

15 `adaa.analytics.rules.logic.representation.KaplanMeierEstimator` \ 

16 object instance from Java 

17 """ 

18 self._java_object: JObject = java_object 

19 self._times: np.ndarray = np.array([ 

20 float(t) for t in self._java_object.getTimes() 

21 ]) 

22 self._probabilities: Optional[np.ndarray] = None 

23 self._events_count: Optional[np.ndarray] = None 

24 self._censored_count: Optional[np.ndarray] = None 

25 self._at_risk_count: Optional[np.ndarray] = None 

26 

27 @property 

28 def times(self) -> np.ndarray: 

29 """ 

30 Returns: 

31 np.ndarray: time points of the Kaplan-Meier estimator 

32 """ 

33 return self._times 

34 

35 @property 

36 def probabilities(self) -> np.ndarray: 

37 """ 

38 Returns: 

39 np.ndarray: survival probabilities for each time point 

40 """ 

41 if self._probabilities is None: 

42 self._probabilities = np.array([ 

43 self.get_probability_at(t) for t in self._times 

44 ]) 

45 return self._probabilities 

46 

47 @property 

48 def events_count(self) -> np.ndarray: 

49 """ 

50 Returns: 

51 np.ndarray: number of events for each time point 

52 """ 

53 if self._events_count is None: 

54 self._events_count = np.array([ 

55 self.get_events_count_at(t) for t in self._times 

56 ]) 

57 return self._events_count 

58 

59 @property 

60 def at_risk_count(self) -> np.ndarray: 

61 """ 

62 Returns: 

63 np.ndarray: risks for each time point 

64 """ 

65 if self._at_risk_count is None: 

66 self._at_risk_count = np.array([ 

67 self.get_risk_set_count_at(t) for t in self._times 

68 ]) 

69 return self._at_risk_count 

70 

71 def get_probability_at(self, time: float) -> float: 

72 """Gets survival probability at given time point. 

73 

74 Args: 

75 time (float): time point 

76 

77 Returns: 

78 float: survival probability at given time point 

79 """ 

80 return float(self._java_object.getProbabilityAt(time)) 

81 

82 def get_events_count_at(self, time: float) -> int: 

83 """Gets number of events at given time point. 

84 

85 Args: 

86 time (float): time point 

87 

88 Returns: 

89 int: number of events at given time point 

90 """ 

91 return int(self._java_object.getEventsCountAt(time)) 

92 

93 def get_risk_set_count_at(self, time: float) -> int: 

94 """Gets risk at given time. 

95 

96 Args: 

97 time (float): time point 

98 

99 Returns: 

100 int: risk at given time 

101 """ 

102 return int(self._java_object.getRiskSetCountAt(time))