From 0692666bcfd2e9319a01535ab22ec83d9b8be280 Mon Sep 17 00:00:00 2001
From: Sortofamudkip <wishyutp0328@gmail.com>
Date: Thu, 13 Jul 2023 14:23:58 +0200
Subject: [PATCH] plotting param tests

---
 src/test_plotter.py | 37 +++++++++++++++++++++++++++++++++++++
 1 file changed, 37 insertions(+)

diff --git a/src/test_plotter.py b/src/test_plotter.py
index b1b16f4..e2caf4d 100644
--- a/src/test_plotter.py
+++ b/src/test_plotter.py
@@ -4,12 +4,49 @@ from Dataset import Dataset
 from Plotter import Plotter
 import pandas as pd
 
+import pytest
+
 this_file_dir = Path(__file__).parent
 
 
+@pytest.fixture
+def the_plotter() -> Dataset:
+    dataset = Dataset(str(this_file_dir / "../data/GamingStudy_data.csv"))
+    plotter = Plotter(dataset)
+    return plotter
+
+
 def test_load_plotter():
     """Tests that the Plotter class can be loaded."""
     dataset = Dataset(str(this_file_dir / "../data/GamingStudy_data.csv"))
     plotter = Plotter(dataset)
     assert type(plotter.df) == pd.DataFrame
     assert type(plotter.ds) == Dataset
+
+
+def test_catch_colname_not_in_df(the_plotter: Plotter):
+    """Tests that functions that take colname correctly
+    catch colnames not in dataset."""
+    with pytest.raises(KeyError):
+        the_plotter.distribution_plot("GAAAD_T")
+
+
+def test_catch_target_not_string(the_plotter: Plotter):
+    """Tests that functions that take target correctly
+    catch non strings."""
+    with pytest.raises(ValueError):
+        the_plotter.distribution_plot(True)
+
+
+@pytest.mark.parametrize(
+    "param",
+    [True, None, "notdict", 6.4, pd, (1,), Dataset],
+)
+def test_catch_styling_params_not_dict(the_plotter: Plotter, param):
+    """Tests that functions that take styling_params correctly
+    catch non dictionaries."""
+    with pytest.raises(ValueError):
+        the_plotter.distribution_plot("GAD_T", param)
+        the_plotter.plot_categorical_bar_chart("GAD_T", "GAD_T", param)
+        the_plotter.plot_categorical_histplot("GAD_T", "GAD_T", param)
+        the_plotter.plot_scatterplot("GAD_T", "GAD_T", param)
-- 
GitLab