From 8ec765559d0603064a3f58b1325dee1bd29953aa Mon Sep 17 00:00:00 2001 From: Sortofamudkip <wishyutp0328@gmail.com> Date: Thu, 13 Jul 2023 00:23:05 +0200 Subject: [PATCH] bool param checks --- src/test_dataset.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/test_dataset.py b/src/test_dataset.py index 6005194..7219d72 100644 --- a/src/test_dataset.py +++ b/src/test_dataset.py @@ -6,6 +6,7 @@ from Dataset import Dataset import pandas as pd from pathlib import Path import pytest +import numpy as np this_file_dir = Path(__file__).parent @@ -68,3 +69,45 @@ def test_get_sorted_columns(the_dataset: Dataset): """Tests Dataset.get_sorted_column().""" sorted_GAD1 = the_dataset.get_sorted_column("GAD_T") assert sorted_GAD1.iloc[0] <= sorted_GAD1.iloc[-1] + + +@pytest.mark.parametrize( + "param", + [np.array([1, 2, 3]), "123", 3, 0.1, [], np, pd, True, None], +) +def test_catch_non_dataframe(the_dataset: Dataset, param): + """Tests that functions that take pd.DataFrame correctly + catch incorrect input data types. + """ + with pytest.raises(ValueError): + the_dataset.preprocess_dataset(param) + the_dataset.get_is_competitive_col(param) + the_dataset._drop_unnecessary_columns(param) + the_dataset.remove_nonaccepting_rows(param) + the_dataset.preprocess_whyplay(param) + the_dataset.get_combined_anxiety_score(param) + the_dataset.get_is_narcissist_col(param) + + +@pytest.mark.parametrize( + "param", + ["true", "false", "True", "False"], +) +def test_catch_non_bool(the_dataset: Dataset, param): + """Tests that the dataframe is preprocessed correctly.""" + dataframe = the_dataset.get_dataframe() + columns_set = set(dataframe.columns) + with pytest.raises(ValueError): + the_dataset.get_category_counts("GAD_T", param) + the_dataset.get_sorted_column("GAD_T", param) + + +@pytest.mark.parametrize( + "param", + [True, False, None], +) +def test_bool_or_none_params(the_dataset: Dataset, param): + """Tests that the dataframe is preprocessed correctly.""" + dataframe = the_dataset.get_dataframe() + columns_set = set(dataframe.columns) + the_dataset.get_category_counts("GAD_T", param) -- GitLab