import numpy as np import matplotlib.pyplot as plt import pandas as pd import seaborn as sns from .Dataset import Dataset class Plotter: def __init__(self, dataset: Dataset): self.ds = dataset self.df = dataset.get_dataframe() def customize_plot(self, fig, ax, styling_params) -> None: """customize_plot Args: fig (plt.figure.Figure), ax (plt.axes.Axes), styling_params (dict) Returns: None """ if styling_params.get("title"): ax.set_title(styling_params["title"]) def plot_categorical_bar_chart( self, category1, category2, styling_params={} ) -> None: """plot a categorical bar chart. Args: category1 (str, must be present as a column in the dataset), category2 (str, must be present as a column in the dataset), styling_params (dict) Returns: None """ ct = pd.crosstab(self.df[category1], self.df[category2]) # Calculate percentages by row ct_percent = ct.apply(lambda r: r / r.sum() * 100, axis=0) fig, ax = plt.subplots() self.customize_plot(fig, ax, styling_params) ct_percent.plot(kind="bar", ax=ax) def plot_categorical_boxplot( self, target, category, styling_params={} ) -> None: """plot a categorical boxplot. Args: target (str, must be present as a column in the dataset), category (str, must be present as a column in the dataset), styling_params (dict) Returns: None """ fig, ax = plt.subplots() self.customize_plot(fig, ax, styling_params) sns.boxplot(x=category, y=target, data=self.df, palette="rainbow") def plot_categorical_histplot( self, target, category, styling_params={}, bins=30 ) -> None: """plot a categorical hisplot. Args: target (str, must be present as a column in the dataset), category (str, must be present as a column in the dataset), styling_params (dict) Returns: None """ uniques = self.ds.get_unique_column_values(category) fig, ax = plt.subplots() self.customize_plot(fig, ax, styling_params) for val in uniques: anx_score = self.df[self.df[category] == val][target] anx_score_weights = np.ones(len(anx_score)) / len(anx_score) ax.hist( anx_score, weights=anx_score_weights, bins=bins, alpha=0.5, ) def plot_scatterplot(self, target1, target2, styling_params={}) -> None: """plot a scatterplot. Args: target1 (str, must be present as a column in the dataset), target2 (str, must be present as a column in the dataset), styling_params (dict) Returns: None """ fig, ax = plt.subplots() self.customize_plot(fig, ax, styling_params) ax.scatter(self.df[target1], self.df[target2]) def distribution_plot(self, target: str): """ distribution_plot _summary_ Args: target (str): _description_ Returns: None """ grouped_data = self.df.groupby(target).size() sorted_data = grouped_data.sort_values(ascending=True) plt.barh(sorted_data.index, sorted_data.values, data=sorted_data) print(grouped_data.sort_values(ascending=False)) plt.xlabel("Size") plt.ylabel(target) plt.title(f"Distribution of {target}")