diff --git a/src/Plotter.py b/src/Plotter.py index c9055befddaf2a526378850a681c37e187803df6..41374917984e7cf606500be9bd7b88bef39debd8 100644 --- a/src/Plotter.py +++ b/src/Plotter.py @@ -10,11 +10,32 @@ class Plotter: self.ds = dataset self.df = dataset.get_dataframe() - def customize_plot(self, fig, ax, styling_params): + def customize_plot(self, fig, ax, styling_params) -> None: + """ customize_plot + + Args: + fig (plt.figure.Figure), + ax (plt.axes.Axes), + styling_params (dict) + + + Returns: + None + """ if styling_params.get("title"): ax.set_title(styling_params["title"]) - def distribution_plot(self, target): + def distribution_plot(self, target) -> None: + """ plot a distribution plot. + + Args: + target (str, must be present as a column in the dataset), + styling_params (dict) + + + Returns: + None + """ grouped_data = self.df.groupby(target).size() plt.barh(grouped_data.index, grouped_data.values) print( @@ -28,7 +49,18 @@ class Plotter: def plot_categorical_bar_chart( self, category1, category2, styling_params={} - ): + ) -> None: + """ plot a categorical bar chart. + + Args: + category1 (str, must be present as a column in the dataset), + category2 (str, must be present as a column in the dataset), + styling_params (dict) + + + Returns: + None + """ ct = pd.crosstab(self.df[category1], self.df[category2]) # Calculate percentages by row ct_percent = ct.apply(lambda r: r / r.sum() * 100, axis=0) @@ -36,14 +68,39 @@ class Plotter: self.customize_plot(fig, ax, styling_params) ct_percent.plot(kind="bar", ax=ax) - def plot_categorical_boxplot(self, target, category, styling_params={}): + def plot_categorical_boxplot( + self, target, category, styling_params={} + ) -> None: + """ plot a categorical boxplot. + + Args: + target (str, must be present as a column in the dataset), + category (str, must be present as a column in the dataset), + styling_params (dict) + + + Returns: + None + """ + fig, ax = plt.subplots() self.customize_plot(fig, ax, styling_params) sns.boxplot(x=category, y=target, data=self.df, palette="rainbow") def plot_categorical_histplot( self, target, category, styling_params={}, bins=30 - ): + ) -> None: + """ plot a categorical hisplot. + + Args: + target (str, must be present as a column in the dataset), + category (str, must be present as a column in the dataset), + styling_params (dict) + + + Returns: + None + """ uniques = self.ds.get_unique_column_values(category) fig, ax = plt.subplots() self.customize_plot(fig, ax, styling_params) @@ -57,7 +114,18 @@ class Plotter: alpha=0.5, ) - def plot_scatterplot(self, target1, target2, styling_params={}): + def plot_scatterplot(self, target1, target2, styling_params={}) -> None: + """ plot a scatterplot. + + Args: + target1 (str, must be present as a column in the dataset), + target2 (str, must be present as a column in the dataset), + styling_params (dict) + + + Returns: + None + """ fig, ax = plt.subplots() self.customize_plot(fig, ax, styling_params) - ax.scatter(self.df[target1], self.df[target2]) + ax.scatter(self.df[target1], self.df[target2]) \ No newline at end of file