import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import seaborn as sns
from .Dataset import Dataset

class Plotter:
    def __init__(self, dataset: Dataset):
        self.ds = dataset
        self.df = dataset.get_dataframe()

    
    def customize_plot(self, fig, ax, styling_params):
        if styling_params.get('title'):
            ax.set_title(styling_params["title"])

    
    def plot_categorical_bar_chart(self, category1, category2, styling_params = {}):
        ct = pd.crosstab(self.df[category1], self.df[category2])
        # Calculate percentages by row
        ct_percent = ct.apply(lambda r: r/r.sum() * 100, axis=0)                
        fig, ax = plt.subplots()
        self.customize_plot(fig, ax, styling_params)
        ct_percent.plot(kind='bar', ax=ax)


    def plot_categorical_boxplot(self, target, category, styling_params = {}):
        fig, ax = plt.subplots()
        self.customize_plot(fig, ax, styling_params)
        sns.boxplot(x=category,y=target,data=self.df, palette='rainbow')


    def plot_categorical_histplot(self, target, category, styling_params = {}, bins= 30):
        uniques = self.ds.get_unique_column_values(category)
        fig, ax = plt.subplots()
        self.customize_plot(fig, ax, styling_params)
        for val in uniques:
            anx_score = self.df[self.df[category] == val][target]
            anx_score_weights = np.ones(len(anx_score)) / len(anx_score)
            ax.hist(
                anx_score,
                weights=anx_score_weights,
                bins = bins,
                alpha=0.5,
            )