Skip to content
Snippets Groups Projects
Plotter.py 3.63 KiB
Newer Older
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from .Dataset import Dataset

class Plotter:
    def __init__(self, dataset: Dataset):
        self.ds = dataset
        self.df = dataset.get_dataframe()
    def customize_plot(self, fig, ax, styling_params) -> None:

        Args:
            fig (plt.figure.Figure),
            ax (plt.axes.Axes),
            styling_params (dict)


        Returns:
            None
        """
        if styling_params.get("title"):
            ax.set_title(styling_params["title"])

    def plot_categorical_bar_chart(
        self, category1, category2, styling_params={}
    ) -> None:
        """plot a categorical bar chart.

        Args:
            category1 (str, must be present as a column in the dataset),
            category2 (str, must be present as a column in the dataset),
            styling_params (dict)


        Returns:
            None
        """
        ct = pd.crosstab(self.df[category1], self.df[category2])
        # Calculate percentages by row
        ct_percent = ct.apply(lambda r: r / r.sum() * 100, axis=0)
        fig, ax = plt.subplots()
        self.customize_plot(fig, ax, styling_params)
        ct_percent.plot(kind="bar", ax=ax)
    def plot_categorical_boxplot(
        self, target, category, styling_params={}
    ) -> None:
        """plot a categorical boxplot.

        Args:
            target (str, must be present as a column in the dataset),
            category (str, must be present as a column in the dataset),
            styling_params (dict)


        Returns:
            None
        """
        fig, ax = plt.subplots()
        self.customize_plot(fig, ax, styling_params)
        sns.boxplot(x=category, y=target, data=self.df, palette="rainbow")
    def plot_categorical_histplot(
        self, target, category, styling_params={}, bins=30
    ) -> None:
        """plot a categorical hisplot.

        Args:
            target (str, must be present as a column in the dataset),
            category (str, must be present as a column in the dataset),
            styling_params (dict)


        Returns:
            None
        """
        uniques = self.ds.get_unique_column_values(category)
        fig, ax = plt.subplots()
        self.customize_plot(fig, ax, styling_params)
        for val in uniques:
            anx_score = self.df[self.df[category] == val][target]
            anx_score_weights = np.ones(len(anx_score)) / len(anx_score)
            ax.hist(
                anx_score,
                weights=anx_score_weights,
    def plot_scatterplot(self, target1, target2, styling_params={}) -> None:
        """plot a scatterplot.

        Args:
            target1 (str, must be present as a column in the dataset),
            target2 (str, must be present as a column in the dataset),
            styling_params (dict)


        Returns:
            None
        """
        fig, ax = plt.subplots()
        self.customize_plot(fig, ax, styling_params)
        ax.scatter(self.df[target1], self.df[target2])

    def distribution_plot(self, target: str):
        """
        distribution_plot _summary_

        Args:
            target (str): _description_

        Returns:
            None
        """
        grouped_data = self.df.groupby(target).size()
        plt.barh(grouped_data.index, grouped_data.values)
        print(grouped_data.sort_values(ascending=False))
        # print(grouped_data.index)
        # print(grouped_data.values)
        plt.xlabel("Size")
        plt.ylabel(target)
        plt.title(f"Distribution of {target}")