From 8ec765559d0603064a3f58b1325dee1bd29953aa Mon Sep 17 00:00:00 2001
From: Sortofamudkip <wishyutp0328@gmail.com>
Date: Thu, 13 Jul 2023 00:23:05 +0200
Subject: [PATCH] bool param checks

---
 src/test_dataset.py | 43 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 43 insertions(+)

diff --git a/src/test_dataset.py b/src/test_dataset.py
index 6005194..7219d72 100644
--- a/src/test_dataset.py
+++ b/src/test_dataset.py
@@ -6,6 +6,7 @@ from Dataset import Dataset
 import pandas as pd
 from pathlib import Path
 import pytest
+import numpy as np
 
 this_file_dir = Path(__file__).parent
 
@@ -68,3 +69,45 @@ def test_get_sorted_columns(the_dataset: Dataset):
     """Tests Dataset.get_sorted_column()."""
     sorted_GAD1 = the_dataset.get_sorted_column("GAD_T")
     assert sorted_GAD1.iloc[0] <= sorted_GAD1.iloc[-1]
+
+
+@pytest.mark.parametrize(
+    "param",
+    [np.array([1, 2, 3]), "123", 3, 0.1, [], np, pd, True, None],
+)
+def test_catch_non_dataframe(the_dataset: Dataset, param):
+    """Tests that functions that take pd.DataFrame correctly
+    catch incorrect input data types.
+    """
+    with pytest.raises(ValueError):
+        the_dataset.preprocess_dataset(param)
+        the_dataset.get_is_competitive_col(param)
+        the_dataset._drop_unnecessary_columns(param)
+        the_dataset.remove_nonaccepting_rows(param)
+        the_dataset.preprocess_whyplay(param)
+        the_dataset.get_combined_anxiety_score(param)
+        the_dataset.get_is_narcissist_col(param)
+
+
+@pytest.mark.parametrize(
+    "param",
+    ["true", "false", "True", "False"],
+)
+def test_catch_non_bool(the_dataset: Dataset, param):
+    """Tests that the dataframe is preprocessed correctly."""
+    dataframe = the_dataset.get_dataframe()
+    columns_set = set(dataframe.columns)
+    with pytest.raises(ValueError):
+        the_dataset.get_category_counts("GAD_T", param)
+        the_dataset.get_sorted_column("GAD_T", param)
+
+
+@pytest.mark.parametrize(
+    "param",
+    [True, False, None],
+)
+def test_bool_or_none_params(the_dataset: Dataset, param):
+    """Tests that the dataframe is preprocessed correctly."""
+    dataframe = the_dataset.get_dataframe()
+    columns_set = set(dataframe.columns)
+    the_dataset.get_category_counts("GAD_T", param)
-- 
GitLab