from pathlib import Path from Dataset import Dataset from Plotter import Plotter import pandas as pd import pytest this_file_dir = Path(__file__).parent @pytest.fixture def the_plotter() -> Dataset: dataset = Dataset(str(this_file_dir / "../data/GamingStudy_data.csv")) plotter = Plotter(dataset) return plotter def test_load_plotter(): """Tests that the Plotter class can be loaded.""" dataset = Dataset(str(this_file_dir / "../data/GamingStudy_data.csv")) plotter = Plotter(dataset) assert type(plotter.df) == pd.DataFrame assert type(plotter.ds) == Dataset def test_catch_colname_not_in_df(the_plotter: Plotter): """Tests that functions that take colname correctly catch colnames not in dataset.""" with pytest.raises(KeyError): the_plotter.distribution_plot("GAAAD_T") def test_catch_target_not_string(the_plotter: Plotter): """Tests that functions that take target correctly catch non strings.""" with pytest.raises(ValueError): the_plotter.distribution_plot(True) @pytest.mark.parametrize( "param", [True, None, "notdict", 6.4, pd, (1,), Dataset], ) def test_catch_styling_params_not_dict(the_plotter: Plotter, param): """Tests that functions that take styling_params correctly catch non dictionaries.""" with pytest.raises(ValueError): the_plotter.distribution_plot("GAD_T", param) the_plotter.plot_categorical_bar_chart("GAD_T", "GAD_T", param) the_plotter.plot_categorical_histplot("GAD_T", "GAD_T", param) the_plotter.plot_scatterplot("GAD_T", "GAD_T", param)