diff --git a/src/test_plotter.py b/src/test_plotter.py index b1b16f4ab25c700f822622bdc9c15fe049cd8320..e2caf4dd1b57fef7802018e25e7205234d01b7c2 100644 --- a/src/test_plotter.py +++ b/src/test_plotter.py @@ -4,12 +4,49 @@ from Dataset import Dataset from Plotter import Plotter import pandas as pd +import pytest + this_file_dir = Path(__file__).parent +@pytest.fixture +def the_plotter() -> Dataset: + dataset = Dataset(str(this_file_dir / "../data/GamingStudy_data.csv")) + plotter = Plotter(dataset) + return plotter + + def test_load_plotter(): """Tests that the Plotter class can be loaded.""" dataset = Dataset(str(this_file_dir / "../data/GamingStudy_data.csv")) plotter = Plotter(dataset) assert type(plotter.df) == pd.DataFrame assert type(plotter.ds) == Dataset + + +def test_catch_colname_not_in_df(the_plotter: Plotter): + """Tests that functions that take colname correctly + catch colnames not in dataset.""" + with pytest.raises(KeyError): + the_plotter.distribution_plot("GAAAD_T") + + +def test_catch_target_not_string(the_plotter: Plotter): + """Tests that functions that take target correctly + catch non strings.""" + with pytest.raises(ValueError): + the_plotter.distribution_plot(True) + + +@pytest.mark.parametrize( + "param", + [True, None, "notdict", 6.4, pd, (1,), Dataset], +) +def test_catch_styling_params_not_dict(the_plotter: Plotter, param): + """Tests that functions that take styling_params correctly + catch non dictionaries.""" + with pytest.raises(ValueError): + the_plotter.distribution_plot("GAD_T", param) + the_plotter.plot_categorical_bar_chart("GAD_T", "GAD_T", param) + the_plotter.plot_categorical_histplot("GAD_T", "GAD_T", param) + the_plotter.plot_scatterplot("GAD_T", "GAD_T", param)