|
| 1 | +import json |
| 2 | +from abc import abstractmethod |
| 3 | + |
| 4 | +from ads.feature_store.common.utils.utility import none_type_safe_json_loads |
| 5 | +from plotly.graph_objs import Figure |
| 6 | +from typing import List |
| 7 | +import plotly |
| 8 | +import plotly.graph_objects as go |
| 9 | +from plotly.subplots import make_subplots |
| 10 | + |
| 11 | + |
| 12 | +class FeatureStat: |
| 13 | + @abstractmethod |
| 14 | + def add_to_figure(self, fig: Figure, xaxis: int, yaxis: int): |
| 15 | + pass |
| 16 | + |
| 17 | + @classmethod |
| 18 | + @abstractmethod |
| 19 | + def from_json(cls, json_dict: dict): |
| 20 | + pass |
| 21 | + |
| 22 | + @staticmethod |
| 23 | + def get_x_y_str_axes(xaxis: int, yaxis: int) -> (): |
| 24 | + return ( |
| 25 | + ("xaxis" + str(xaxis + 1)), |
| 26 | + ("yaxis" + str(yaxis + 1)), |
| 27 | + ("x" + str(xaxis + 1)), |
| 28 | + ("y" + str(yaxis + 1)), |
| 29 | + ) |
| 30 | + |
| 31 | + |
| 32 | +class FrequencyDistribution(FeatureStat): |
| 33 | + CONST_FREQUENCY = "frequency" |
| 34 | + CONST_BINS = "bins" |
| 35 | + CONST_FREQUENCY_DISTRIBUTION_TITLE = "Frequency Distribution" |
| 36 | + |
| 37 | + def __init__(self, frequency: List, bins: List): |
| 38 | + self.frequency = frequency |
| 39 | + self.bins = bins |
| 40 | + |
| 41 | + @classmethod |
| 42 | + def from_json(cls, json_dict: dict) -> "FrequencyDistribution": |
| 43 | + if json_dict is not None: |
| 44 | + return FrequencyDistribution( |
| 45 | + frequency=json_dict.get(FrequencyDistribution.CONST_FREQUENCY), |
| 46 | + bins=json_dict.get(FrequencyDistribution.CONST_BINS), |
| 47 | + ) |
| 48 | + else: |
| 49 | + return None |
| 50 | + |
| 51 | + def add_to_figure(self, fig: Figure, xaxis: int, yaxis: int): |
| 52 | + xaxis_str, yaxis_str, x_str, y_str = self.get_x_y_str_axes(xaxis, yaxis) |
| 53 | + if ( |
| 54 | + type(self.frequency) == list |
| 55 | + and type(self.bins) == list |
| 56 | + and 0 < len(self.frequency) == len(self.bins) > 0 |
| 57 | + ): |
| 58 | + fig.add_bar( |
| 59 | + x=self.bins, y=self.frequency, xaxis=x_str, yaxis=y_str, name="" |
| 60 | + ) |
| 61 | + fig.layout.annotations[xaxis].text = self.CONST_FREQUENCY_DISTRIBUTION_TITLE |
| 62 | + fig.layout[xaxis_str]["title"] = "Bins" |
| 63 | + fig.layout[yaxis_str]["title"] = "Frequency" |
| 64 | + |
| 65 | + |
| 66 | +class ProbabilityDistribution(FeatureStat): |
| 67 | + CONST_DENSITY = "density" |
| 68 | + CONST_BINS = "bins" |
| 69 | + CONST_PROBABILITY_DISTRIBUTION_TITLE = "Probability Distribution" |
| 70 | + |
| 71 | + def __init__(self, density: List, bins: List): |
| 72 | + self.density = density |
| 73 | + self.bins = bins |
| 74 | + |
| 75 | + @classmethod |
| 76 | + def from_json(cls, json_dict: dict): |
| 77 | + if json_dict is not None: |
| 78 | + return cls( |
| 79 | + density=json_dict.get(ProbabilityDistribution.CONST_DENSITY), |
| 80 | + bins=json_dict.get(ProbabilityDistribution.CONST_BINS), |
| 81 | + ) |
| 82 | + else: |
| 83 | + return None |
| 84 | + |
| 85 | + def add_to_figure(self, fig: Figure, xaxis: int, yaxis: int): |
| 86 | + xaxis_str, yaxis_str, x_str, y_str = self.get_x_y_str_axes(xaxis, yaxis) |
| 87 | + if ( |
| 88 | + type(self.density) == list |
| 89 | + and type(self.bins) == list |
| 90 | + and 0 < len(self.density) == len(self.bins) > 0 |
| 91 | + ): |
| 92 | + fig.add_bar( |
| 93 | + x=self.bins, |
| 94 | + y=self.density, |
| 95 | + xaxis=x_str, |
| 96 | + yaxis=y_str, |
| 97 | + name="", |
| 98 | + ) |
| 99 | + fig.layout.annotations[xaxis].text = self.CONST_PROBABILITY_DISTRIBUTION_TITLE |
| 100 | + fig.layout[xaxis_str]["title"] = "Bins" |
| 101 | + fig.layout[yaxis_str]["title"] = "Density" |
| 102 | + |
| 103 | + return go.Bar(x=self.bins, y=self.density) |
| 104 | + |
| 105 | + |
| 106 | +class TopKFrequentElements(FeatureStat): |
| 107 | + CONST_VALUE = "value" |
| 108 | + CONST_TOP_K_FREQUENT_TITLE = "Top K Frequent Elements" |
| 109 | + |
| 110 | + class TopKFrequentElement: |
| 111 | + CONST_VALUE = "value" |
| 112 | + CONST_ESTIMATE = "estimate" |
| 113 | + CONST_LOWER_BOUND = "lower_bound" |
| 114 | + CONST_UPPER_BOUND = "upper_bound" |
| 115 | + |
| 116 | + def __init__( |
| 117 | + self, value: str, estimate: int, lower_bound: int, upper_bound: int |
| 118 | + ): |
| 119 | + self.value = value |
| 120 | + self.estimate = estimate |
| 121 | + self.lower_bound = lower_bound |
| 122 | + self.upper_bound = upper_bound |
| 123 | + |
| 124 | + @classmethod |
| 125 | + def from_json(cls, json_dict: dict): |
| 126 | + if json_dict is not None: |
| 127 | + return cls( |
| 128 | + value=json_dict.get(cls.CONST_VALUE), |
| 129 | + estimate=json_dict.get(cls.CONST_ESTIMATE), |
| 130 | + lower_bound=json_dict.get(cls.CONST_LOWER_BOUND), |
| 131 | + upper_bound=json_dict.get(cls.CONST_UPPER_BOUND), |
| 132 | + ) |
| 133 | + |
| 134 | + else: |
| 135 | + return None |
| 136 | + |
| 137 | + def __init__(self, elements: List[TopKFrequentElement]): |
| 138 | + self.elements = elements |
| 139 | + |
| 140 | + @classmethod |
| 141 | + def from_json(cls, json_dict: dict): |
| 142 | + if json_dict is not None and json_dict.get(cls.CONST_VALUE) is not None: |
| 143 | + elements = json_dict.get(cls.CONST_VALUE) |
| 144 | + return cls( |
| 145 | + [cls.TopKFrequentElement.from_json(element) for element in elements] |
| 146 | + ) |
| 147 | + else: |
| 148 | + return None |
| 149 | + |
| 150 | + def add_to_figure(self, fig: Figure, xaxis: int, yaxis: int): |
| 151 | + xaxis_str, yaxis_str, x_str, y_str = self.get_x_y_str_axes(xaxis, yaxis) |
| 152 | + if type(self.elements) == list and len(self.elements) > 0: |
| 153 | + x_axis = [element.value for element in self.elements] |
| 154 | + y_axis = [element.estimate for element in self.elements] |
| 155 | + fig.add_bar(x=x_axis, y=y_axis, xaxis=x_str, yaxis=y_str, name="") |
| 156 | + fig.layout.annotations[xaxis].text = self.CONST_TOP_K_FREQUENT_TITLE |
| 157 | + fig.layout[yaxis_str]["title"] = "Count" |
| 158 | + fig.layout[xaxis_str]["title"] = "Element" |
| 159 | + |
| 160 | + |
| 161 | +class FeatureStatistics: |
| 162 | + CONST_FREQUENCY_DISTRIBUTION = "FrequencyDistribution" |
| 163 | + CONST_TITLE_FORMAT = "<b>{}</b>" |
| 164 | + CONST_PLOT_FORMAT = "{}_plot" |
| 165 | + CONST_PROBABILITY_DISTRIBUTION = "ProbabilityDistribution" |
| 166 | + CONST_TOP_K_FREQUENT = "TopKFrequentElements" |
| 167 | + |
| 168 | + def __init__( |
| 169 | + self, |
| 170 | + feature_name: str, |
| 171 | + top_k_frequent_elements: TopKFrequentElements, |
| 172 | + frequency_distribution: FrequencyDistribution, |
| 173 | + probability_distribution: ProbabilityDistribution, |
| 174 | + ): |
| 175 | + self.feature_name: str = feature_name |
| 176 | + self.top_k_frequent_elements = top_k_frequent_elements |
| 177 | + self.frequency_distribution = frequency_distribution |
| 178 | + self.probability_distribution = probability_distribution |
| 179 | + |
| 180 | + @classmethod |
| 181 | + def from_json(cls, feature_name: str, json_dict: dict) -> "FeatureStatistics": |
| 182 | + if json_dict is not None: |
| 183 | + return cls( |
| 184 | + feature_name, |
| 185 | + TopKFrequentElements.from_json(json_dict.get(cls.CONST_TOP_K_FREQUENT)), |
| 186 | + FrequencyDistribution.from_json( |
| 187 | + json_dict.get(cls.CONST_FREQUENCY_DISTRIBUTION) |
| 188 | + ), |
| 189 | + ProbabilityDistribution.from_json( |
| 190 | + json_dict.get(cls.CONST_PROBABILITY_DISTRIBUTION) |
| 191 | + ), |
| 192 | + ) |
| 193 | + else: |
| 194 | + return None |
| 195 | + |
| 196 | + @property |
| 197 | + def __stat_count__(self): |
| 198 | + graph_count = 0 |
| 199 | + if self.top_k_frequent_elements is not None: |
| 200 | + graph_count += 1 |
| 201 | + if self.probability_distribution is not None: |
| 202 | + graph_count += 1 |
| 203 | + if self.frequency_distribution is not None: |
| 204 | + graph_count += 1 |
| 205 | + return graph_count |
| 206 | + |
| 207 | + @property |
| 208 | + def __feature_stat_objects__(self) -> List[FeatureStat]: |
| 209 | + return [ |
| 210 | + stat |
| 211 | + for stat in [ |
| 212 | + self.top_k_frequent_elements, |
| 213 | + self.frequency_distribution, |
| 214 | + self.probability_distribution, |
| 215 | + ] |
| 216 | + if stat is not None |
| 217 | + ] |
| 218 | + |
| 219 | + def to_viz(self): |
| 220 | + graph_count = len(self.__feature_stat_objects__) |
| 221 | + if graph_count > 0: |
| 222 | + fig = make_subplots(cols=graph_count, column_titles=["title"] * graph_count) |
| 223 | + index = 0 |
| 224 | + for stat in [ |
| 225 | + stat for stat in self.__feature_stat_objects__ if stat is not None |
| 226 | + ]: |
| 227 | + stat.add_to_figure(fig, index, index) |
| 228 | + index += 1 |
| 229 | + fig.layout.title = self.CONST_TITLE_FORMAT.format(self.feature_name) |
| 230 | + fig.update_layout(title_font_size=20) |
| 231 | + fig.update_layout(title_x=0.5) |
| 232 | + fig.update_layout(showlegend=False) |
| 233 | + plotly.offline.iplot( |
| 234 | + fig, |
| 235 | + filename=self.CONST_PLOT_FORMAT.format(self.feature_name), |
| 236 | + ) |
0 commit comments