Skip to content
Snippets Groups Projects
Plotter.py 8.37 KiB
Newer Older
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from .Dataset import Dataset
import logging
class Plotter:
    def __init__(self, dataset: Dataset):
        if type(dataset) != Dataset:
            logging.error("dataset parameter is not of type Dataset")
            raise ValueError(f"{dataset} is not of type Dataset")

        self.ds = dataset
        self.df = dataset.get_dataframe()
    def customize_plot(self, fig, ax, styling_params) -> None:

        Args:
            fig (plt.figure.Figure),
            ax (plt.axes.Axes),
            styling_params (dict)


        Returns:
            None
        """
        if styling_params.get("title"):
            ax.set_title(styling_params["title"])
    def distribution_plot(self, target, styling_params={}) -> None:
        """plot a distribution plot.

        Args:
            target (str, must be present as a column in the dataset),
            styling_params (dict)


        Returns:
            None
        """

        # implementing sensible logging and error catching
        if type(target) != str:
            logging.error("parameter target should be a string.")
            raise ValueError("parameter target should be a string.")

        if not (target in self.df.columns):
            logging.error("parameter target cannot be found in the dataset.")
            raise ValueError(
                "parameter target cannot be found in the dataset."
            )

        if type(styling_params) != dict:
            logging.error("parameter styling params should be a dict.")
            raise ValueError("parameter styling params should be a dict.")
        # plotting the plot
        grouped_data = self.df.groupby(target).size()
        plt.barh(grouped_data.index, grouped_data.values)
        print(
            str(grouped_data),
            str(grouped_data.index),
            str(grouped_data.values),
        )
        plt.xlabel("Size")
        plt.ylabel(target)
        plt.title(f"Distribution of {target}")

    def plot_categorical_bar_chart(
        self, category1, category2, styling_params={}
    ) -> None:
        """plot a categorical bar chart.

        Args:
            category1 (str, must be present as a column in the dataset),
            category2 (str, must be present as a column in the dataset),
            styling_params (dict)


        Returns:
            None
        """
        # implementing sensible logging and error catching
        if type(category1) != str:
            logging.error("parameter category1 should be a string.")
            raise ValueError("parameter category1 should be a string.")

        if not (category1 in self.df.columns):
            logging.error(
                "parameter category1 cannot be found in the dataset."
            )
            raise ValueError(
                "parameter category1 cannot be found in the dataset."
            )

        if type(category2) != str:
            logging.error("parameter category2 should be a string.")
            raise ValueError("parameter category2 should be a string.")

        if not (category2 in self.df.columns):
            logging.error(
                "parameter category2 cannot be found in the dataset."
            )
            raise ValueError(
                "parameter category2 cannot be found in the dataset."
            )

        if type(styling_params) != dict:
            logging.error("parameter styling params should be a dict.")
            raise ValueError("parameter styling params should be a dict.")

        # plotting the plot
        ct = pd.crosstab(self.df[category1], self.df[category2])
        # Calculate percentages by row
        ct_percent = ct.apply(lambda r: r / r.sum() * 100, axis=0)
        fig, ax = plt.subplots()
        self.customize_plot(fig, ax, styling_params)
        ct_percent.plot(kind="bar", ax=ax)
    def plot_categorical_boxplot(
        self, target, category, styling_params={}
    ) -> None:
        """plot a categorical boxplot.

        Args:
            target (str, must be present as a column in the dataset),
            category (str, must be present as a column in the dataset),
            styling_params (dict)


        Returns:
            None
        """
        fig, ax = plt.subplots()
        self.customize_plot(fig, ax, styling_params)
        sns.boxplot(x=category, y=target, data=self.df, palette="rainbow")
    def plot_categorical_histplot(
        self, target, category, styling_params={}, bins=30
    ) -> None:
        """plot a categorical hisplot.

        Args:
            target (str, must be present as a column in the dataset),
            category (str, must be present as a column in the dataset),
            styling_params (dict)


        Returns:
            None
        """
        # implementing sensible logging and error catching
        if type(target) != str:
            logging.error("parameter target should be a string.")
            raise ValueError("parameter target should be a string.")

        if not (target in self.df.columns):
            logging.error("parameter target cannot be found in the dataset.")
            raise ValueError(
                "parameter target cannot be found in the dataset."
            )

        if type(category) != str:
            logging.error("parameter category should be a string.")
            raise ValueError("parameter category should be a string.")

        if not (category in self.df.columns):
            logging.error("parameter category cannot be found in the dataset.")
            raise ValueError(
                "parameter category cannot be found in the dataset."
            )

        if type(styling_params) != dict:
            logging.error("parameter styling params should be a dict.")
            raise ValueError("parameter styling params should be a dict.")
        # plotting the plot
        uniques = self.ds.get_unique_column_values(category)
        fig, ax = plt.subplots()
        self.customize_plot(fig, ax, styling_params)
        for val in uniques:
            anx_score = self.df[self.df[category] == val][target]
            anx_score_weights = np.ones(len(anx_score)) / len(anx_score)
            ax.hist(
                anx_score,
                weights=anx_score_weights,
    def plot_scatterplot(self, target1, target2, styling_params={}) -> None:
        """plot a scatterplot.

        Args:
            target1 (str, must be present as a column in the dataset),
            target2 (str, must be present as a column in the dataset),
            styling_params (dict)


        Returns:
            None
        """

        # implementing sensible logging and error catching
        if type(target1) != str:
            logging.error("parameter target1 should be a string.")
            raise ValueError("parameter target1 should be a string.")

        if not (target1 in self.df.columns):
            logging.error("parameter target1 cannot be found in the dataset.")
            raise ValueError(
                "parameter target1 cannot be found in the dataset."
            )

        if type(target2) != str:
            logging.error("parameter target2 should be a string.")
            raise ValueError("parameter target2 should be a string.")

        if not (target2 in self.df.columns):
            logging.error("parameter target2 cannot be found in the dataset.")
            raise ValueError(
                "parameter target2 cannot be found in the dataset."
            )

        if type(styling_params) != dict:
            logging.error("parameter styling params should be a dict.")
            raise ValueError("parameter styling params should be a dict.")
        # plotting the plot
        fig, ax = plt.subplots()
        self.customize_plot(fig, ax, styling_params)
        ax.scatter(self.df[target1], self.df[target2])

    def distribution_plot(self, target: str):
        """
        distribution_plot _summary_

        Args:
            target (str): _description_

        Returns:
            None
        """
        grouped_data = self.df.groupby(target).size()
Alexander Shervud's avatar
Alexander Shervud committed
        sorted_data = grouped_data.sort_values(ascending=True)
        plt.barh(sorted_data.index, sorted_data.values, data=sorted_data)
        print(grouped_data.sort_values(ascending=False))
        plt.xlabel("Size")
        plt.ylabel(target)
        plt.title(f"Distribution of {target}")