Skip to content
Snippets Groups Projects
test_plotter.py 1.57 KiB
Newer Older
from pathlib import Path

from Dataset import Dataset
from Plotter import Plotter
import pandas as pd

Sortofamudkip's avatar
Sortofamudkip committed
import pytest

this_file_dir = Path(__file__).parent


Sortofamudkip's avatar
Sortofamudkip committed
@pytest.fixture
def the_plotter() -> Dataset:
    dataset = Dataset(str(this_file_dir / "../data/GamingStudy_data.csv"))
    plotter = Plotter(dataset)
    return plotter


def test_load_plotter():
    """Tests that the Plotter class can be loaded."""
    dataset = Dataset(str(this_file_dir / "../data/GamingStudy_data.csv"))
    plotter = Plotter(dataset)
    assert type(plotter.df) == pd.DataFrame
    assert type(plotter.ds) == Dataset
Sortofamudkip's avatar
Sortofamudkip committed


def test_catch_colname_not_in_df(the_plotter: Plotter):
    """Tests that functions that take colname correctly
    catch colnames not in dataset."""
    with pytest.raises(KeyError):
        the_plotter.distribution_plot("GAAAD_T")


def test_catch_target_not_string(the_plotter: Plotter):
    """Tests that functions that take target correctly
    catch non strings."""
    with pytest.raises(ValueError):
        the_plotter.distribution_plot(True)


@pytest.mark.parametrize(
    "param",
    [True, None, "notdict", 6.4, pd, (1,), Dataset],
)
def test_catch_styling_params_not_dict(the_plotter: Plotter, param):
    """Tests that functions that take styling_params correctly
    catch non dictionaries."""
    with pytest.raises(ValueError):
        the_plotter.distribution_plot("GAD_T", param)
        the_plotter.plot_categorical_bar_chart("GAD_T", "GAD_T", param)
        the_plotter.plot_categorical_histplot("GAD_T", "GAD_T", param)
        the_plotter.plot_scatterplot("GAD_T", "GAD_T", param)