Skip to content
Snippets Groups Projects
Dataset.py 2.71 KiB
Newer Older
import numpy as np
import pandas as pd


class Dataset:
    def __init__(self, dataset_filename: str) -> None:
        raw_dataframe = pd.read_csv(dataset_filename, encoding="windows-1254")
        self.dataframe = self.preprocess_dataset(raw_dataframe)
    def preprocess_dataset(self, raw_dataframe: pd.DataFrame) -> pd.DataFrame:
        """preprocess dataframe immediately after loading it.

        Args:
            raw_dataframe (pd.DataFrame):
                raw dataframe as read from pd.read_csv().

        Returns:
            pd.DataFrame: resulting preprocessed dataframe.
        """
        dataframe = raw_dataframe.drop(["League"], axis="columns")
        dataframe["Anxiety_score"] = self.get_combined_anxiety_score(dataframe)
        dataframe["Is_narcissist"] = self.get_is_narcissist_col(dataframe)
        # more preprocessing goes here
        return dataframe

    def get_combined_anxiety_score(self, dataframe: pd.DataFrame) -> pd.Series:
        """Get the combined axiety score, as a column.
        This score is based on the GAN, SPIN and SWL metrics.
        Each of the three columns are first normalised,
          then the mean is returned.

        Args:
            dataframe (pd.DataFrame): the dataframe.

        Returns:
            pd.Series: the anxiety score column.
        """
        gad_max = 21
        gad_min = 0
        gad_normalised = (dataframe["GAD_T"] - gad_min) / gad_max
        spin_max = 68
        spin_min = 0
        spin_normalised = (dataframe["SPIN_T"] - spin_min) / spin_max
        swl_max = 35
        swl_min = 5
        swl_flipped = 1 - (dataframe["SWL_T"] - swl_min) / swl_max
        combined = (gad_normalised + spin_normalised + swl_flipped) / 3
        return combined

    def get_is_narcissist_col(self, dataframe: pd.DataFrame):
        return np.where(dataframe["Narcissism"] <= 1.0, True, False)

    def get_dataframe(self) -> pd.DataFrame:
        """A getter function for the dataframe.

        Returns:
            pd.DataFrame: the dataset.
        """
        return self.dataframe

    def draw_histogram(self):
        raise NotImplementedError
        return self.dataframe.columns

    def get_plottable_columns(self) -> list:
        raise NotImplementedError
    def get_sorted_column(
        self, colname: str, ascending: bool = True
    ) -> pd.Series:
        """Returns a single column, sorted either ascending or descending.

        Args:
            colname (str): the column name (see get_dataset_columns()).
            ascending (bool, optional): Sorting order. Defaults to True.

        Returns:
            pd.Series: The sorted column.
        """
        return self.dataframe[colname].sort_values(ascending=ascending)