Coverage for rulekit/kaplan_meier.py: 100%
35 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-07 11:26 +0000
1from typing import Optional
3import numpy as np
4from jpype import JObject
7class KaplanMeierEstimator:
8 """Kaplan-Meier estimator of survival function.
9 """
11 def __init__(self, java_object: JObject) -> None:
12 """
13 Args:
14 java_object (JObject): \
15 `adaa.analytics.rules.logic.representation.KaplanMeierEstimator` \
16 object instance from Java
17 """
18 self._java_object: JObject = java_object
19 self._times: np.ndarray = np.array([
20 float(t) for t in self._java_object.getTimes()
21 ])
22 self._probabilities: Optional[np.ndarray] = None
23 self._events_count: Optional[np.ndarray] = None
24 self._censored_count: Optional[np.ndarray] = None
25 self._at_risk_count: Optional[np.ndarray] = None
27 @property
28 def times(self) -> np.ndarray:
29 """
30 Returns:
31 np.ndarray: time points of the Kaplan-Meier estimator
32 """
33 return self._times
35 @property
36 def probabilities(self) -> np.ndarray:
37 """
38 Returns:
39 np.ndarray: survival probabilities for each time point
40 """
41 if self._probabilities is None:
42 self._probabilities = np.array([
43 self.get_probability_at(t) for t in self._times
44 ])
45 return self._probabilities
47 @property
48 def events_count(self) -> np.ndarray:
49 """
50 Returns:
51 np.ndarray: number of events for each time point
52 """
53 if self._events_count is None:
54 self._events_count = np.array([
55 self.get_events_count_at(t) for t in self._times
56 ])
57 return self._events_count
59 @property
60 def at_risk_count(self) -> np.ndarray:
61 """
62 Returns:
63 np.ndarray: risks for each time point
64 """
65 if self._at_risk_count is None:
66 self._at_risk_count = np.array([
67 self.get_risk_set_count_at(t) for t in self._times
68 ])
69 return self._at_risk_count
71 def get_probability_at(self, time: float) -> float:
72 """Gets survival probability at given time point.
74 Args:
75 time (float): time point
77 Returns:
78 float: survival probability at given time point
79 """
80 return float(self._java_object.getProbabilityAt(time))
82 def get_events_count_at(self, time: float) -> int:
83 """Gets number of events at given time point.
85 Args:
86 time (float): time point
88 Returns:
89 int: number of events at given time point
90 """
91 return int(self._java_object.getEventsCountAt(time))
93 def get_risk_set_count_at(self, time: float) -> int:
94 """Gets risk at given time.
96 Args:
97 time (float): time point
99 Returns:
100 int: risk at given time
101 """
102 return int(self._java_object.getRiskSetCountAt(time))