import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from .Dataset import Dataset
import logging


class Plotter:
    def __init__(self, dataset: Dataset):
        if type(dataset) != Dataset:
            logging.error("dataset parameter is not of type Dataset")
            raise ValueError(f"{dataset} is not of type Dataset")

        self.ds = dataset
        self.df = dataset.get_dataframe()

    def customize_plot(self, fig, ax, styling_params) -> None:
        """customize_plot

        Args:
            fig (plt.figure.Figure),
            ax (plt.axes.Axes),
            styling_params (dict)


        Returns:
            None
        """
        if styling_params.get("title"):
            ax.set_title(styling_params["title"])

    def distribution_plot(self, target, styling_params={}) -> None:
        """plot a distribution plot.

        Args:
            target (str, must be present as a column in the dataset),
            styling_params (dict)


        Returns:
            None
        """

        # implementing sensible logging and error catching
        if type(target) != str:
            logging.error("parameter target should be a string.")
            raise ValueError("parameter target should be a string.")

        if not (target in self.df.columns):
            logging.error("parameter target cannot be found in the dataset.")
            raise ValueError(
                "parameter target cannot be found in the dataset."
            )

        if type(styling_params) != dict:
            logging.error("parameter styling params should be a dict.")
            raise ValueError("parameter styling params should be a dict.")

        # plotting the plot

        grouped_data = self.df.groupby(target).size()
        plt.barh(grouped_data.index, grouped_data.values)
        print(
            str(grouped_data),
            str(grouped_data.index),
            str(grouped_data.values),
        )
        plt.xlabel("Size")
        plt.ylabel(target)
        plt.title(f"Distribution of {target}")

    def plot_categorical_bar_chart(
        self, category1, category2, styling_params={}
    ) -> None:
        """plot a categorical bar chart.

        Args:
            category1 (str, must be present as a column in the dataset),
            category2 (str, must be present as a column in the dataset),
            styling_params (dict)


        Returns:
            None
        """
        # implementing sensible logging and error catching
        if type(category1) != str:
            logging.error("parameter category1 should be a string.")
            raise ValueError("parameter category1 should be a string.")

        if not (category1 in self.df.columns):
            logging.error(
                "parameter category1 cannot be found in the dataset."
            )
            raise ValueError(
                "parameter category1 cannot be found in the dataset."
            )

        if type(category2) != str:
            logging.error("parameter category2 should be a string.")
            raise ValueError("parameter category2 should be a string.")

        if not (category2 in self.df.columns):
            logging.error(
                "parameter category2 cannot be found in the dataset."
            )
            raise ValueError(
                "parameter category2 cannot be found in the dataset."
            )

        if type(styling_params) != dict:
            logging.error("parameter styling params should be a dict.")
            raise ValueError("parameter styling params should be a dict.")

        # plotting the plot
        ct = pd.crosstab(self.df[category1], self.df[category2])
        # Calculate percentages by row
        ct_percent = ct.apply(lambda r: r / r.sum() * 100, axis=0)
        fig, ax = plt.subplots()
        self.customize_plot(fig, ax, styling_params)
        ct_percent.plot(kind="bar", ax=ax)

    def plot_categorical_boxplot(
        self, target, category, styling_params={}
    ) -> None:
        """plot a categorical boxplot.

        Args:
            target (str, must be present as a column in the dataset),
            category (str, must be present as a column in the dataset),
            styling_params (dict)


        Returns:
            None
        """

        fig, ax = plt.subplots()
        self.customize_plot(fig, ax, styling_params)
        sns.boxplot(x=category, y=target, data=self.df, palette="rainbow")

    def plot_categorical_histplot(
        self, target, category, styling_params={}, bins=30
    ) -> None:
        """plot a categorical hisplot.

        Args:
            target (str, must be present as a column in the dataset),
            category (str, must be present as a column in the dataset),
            styling_params (dict)


        Returns:
            None
        """
        # implementing sensible logging and error catching
        if type(target) != str:
            logging.error("parameter target should be a string.")
            raise ValueError("parameter target should be a string.")

        if not (target in self.df.columns):
            logging.error("parameter target cannot be found in the dataset.")
            raise ValueError(
                "parameter target cannot be found in the dataset."
            )

        if type(category) != str:
            logging.error("parameter category should be a string.")
            raise ValueError("parameter category should be a string.")

        if not (category in self.df.columns):
            logging.error("parameter category cannot be found in the dataset.")
            raise ValueError(
                "parameter category cannot be found in the dataset."
            )

        if type(styling_params) != dict:
            logging.error("parameter styling params should be a dict.")
            raise ValueError("parameter styling params should be a dict.")

        # plotting the plot

        uniques = self.ds.get_unique_column_values(category)
        fig, ax = plt.subplots()
        self.customize_plot(fig, ax, styling_params)
        for val in uniques:
            anx_score = self.df[self.df[category] == val][target]
            anx_score_weights = np.ones(len(anx_score)) / len(anx_score)
            ax.hist(
                anx_score,
                weights=anx_score_weights,
                bins=bins,
                alpha=0.5,
            )

    def plot_scatterplot(self, target1, target2, styling_params={}) -> None:
        """plot a scatterplot.

        Args:
            target1 (str, must be present as a column in the dataset),
            target2 (str, must be present as a column in the dataset),
            styling_params (dict)


        Returns:
            None
        """

        # implementing sensible logging and error catching
        if type(target1) != str:
            logging.error("parameter target1 should be a string.")
            raise ValueError("parameter target1 should be a string.")

        if not (target1 in self.df.columns):
            logging.error("parameter target1 cannot be found in the dataset.")
            raise ValueError(
                "parameter target1 cannot be found in the dataset."
            )

        if type(target2) != str:
            logging.error("parameter target2 should be a string.")
            raise ValueError("parameter target2 should be a string.")

        if not (target2 in self.df.columns):
            logging.error("parameter target2 cannot be found in the dataset.")
            raise ValueError(
                "parameter target2 cannot be found in the dataset."
            )

        if type(styling_params) != dict:
            logging.error("parameter styling params should be a dict.")
            raise ValueError("parameter styling params should be a dict.")

        # plotting the plot
        fig, ax = plt.subplots()
        self.customize_plot(fig, ax, styling_params)
        ax.scatter(self.df[target1], self.df[target2])

    def distribution_plot(self, target: str):
        """
        distribution_plot _summary_

        Args:
            target (str): _description_

        Returns:
            None
        """
        grouped_data = self.df.groupby(target).size()
        sorted_data = grouped_data.sort_values(ascending=True)
        plt.barh(sorted_data.index, sorted_data.values, data=sorted_data)
        print(grouped_data.sort_values(ascending=False))
        plt.xlabel("Size")
        plt.ylabel(target)
        plt.title(f"Distribution of {target}")