"""
This test file tests the Dataset class in Dataset.py.
"""

from Dataset import Dataset
import pandas as pd
from pathlib import Path
import pytest
import numpy as np

this_file_dir = Path(__file__).parent


@pytest.fixture
def the_dataset() -> Dataset:
    dataset = Dataset(str(this_file_dir / "../data/GamingStudy_data.csv"))
    return dataset


def test_load_Dataset_class():
    """Tests if the dataset is successfully loaded."""
    dataset = Dataset(str(this_file_dir / "../data/GamingStudy_data.csv"))
    assert type(dataset) == Dataset
    assert type(dataset.dataframe) == pd.DataFrame


def test_incorrectly_load_Dataset_class():
    with pytest.raises(ValueError):
        dataset = Dataset(1234)  # not a string
    with pytest.raises(OSError):
        dataset = Dataset("aaa.bcd")  # doesn't end with .csv
    with pytest.raises(FileNotFoundError):
        dataset = Dataset(
            str(this_file_dir / "./data/GamingStudy_data.csv")
        )  # wrong file location


def test_get_dataframe(the_dataset: Dataset):
    """Tests Dataset.get_dataframe()."""
    assert type(the_dataset.get_dataframe()) == pd.DataFrame


def test_combined_anxiety_score(the_dataset: Dataset):
    """Tests Dataset.get_combined_anxiety_score()."""
    dataframe = the_dataset.get_dataframe()
    anxiety_scores = the_dataset.get_combined_anxiety_score(dataframe)
    assert anxiety_scores.dtype == float
    assert anxiety_scores.min() >= 0
    assert anxiety_scores.max() <= 1


def test_get_is_narcissist_col(the_dataset: Dataset):
    """Tests Dataset.get_is_narcissist_col()."""
    dataframe = the_dataset.get_dataframe()
    is_narcissist_row = the_dataset.get_is_narcissist_col(dataframe)
    assert is_narcissist_row.dtype == bool


def test_preprocessed_dataframe(the_dataset: Dataset):
    """Tests that the dataframe is preprocessed correctly."""
    dataframe = the_dataset.get_dataframe()
    columns_set = set(dataframe.columns)
    assert "League" not in columns_set
    assert "Anxiety_score" in columns_set
    assert "Is_narcissist" in columns_set


def test_get_sorted_columns(the_dataset: Dataset):
    """Tests Dataset.get_sorted_column()."""
    sorted_GAD1 = the_dataset.get_sorted_column("GAD_T")
    assert sorted_GAD1.iloc[0] <= sorted_GAD1.iloc[-1]


@pytest.mark.parametrize(
    "param",
    [np.array([1, 2, 3]), "123", 3, 0.1, [], np, pd, True, None],
)
def test_catch_non_dataframe(the_dataset: Dataset, param):
    """Tests that functions that take pd.DataFrame correctly
    catch incorrect input data types.
    """
    with pytest.raises(ValueError):
        the_dataset.preprocess_dataset(param)
        the_dataset.get_is_competitive_col(param)
        the_dataset._drop_unnecessary_columns(param)
        the_dataset.remove_nonaccepting_rows(param)
        the_dataset.preprocess_whyplay(param)
        the_dataset.get_combined_anxiety_score(param)
        the_dataset.get_is_narcissist_col(param)


@pytest.mark.parametrize(
    "param",
    ["true", "false", "True", "False", 1, 0, -1],
)
def test_catch_non_bool(the_dataset: Dataset, param):
    """Tests that functions that take bool or None correctly
    catch incorrect input data types."""
    dataframe = the_dataset.get_dataframe()
    columns_set = set(dataframe.columns)
    with pytest.raises(ValueError):
        the_dataset.get_category_counts("GAD_T", param)
        the_dataset.get_sorted_column("GAD_T", param)


@pytest.mark.parametrize(
    "param",
    [True, False, None],
)
def test_bool_or_none_params(the_dataset: Dataset, param):
    """Tests that functions that take bool or None correctly
    work as intended.
    """
    dataframe = the_dataset.get_dataframe()
    columns_set = set(dataframe.columns)
    the_dataset.get_category_counts("GAD_T", param)


def test_catch_colname_not_in_df(the_dataset: Dataset):
    """Tests that functions that take colname correctly
    catch colnames not in dataset."""
    with pytest.raises(KeyError):
        the_dataset.get_category_counts("GAAAD_T")
        the_dataset.get_sorted_column("GAAAD_T")