Skip to content
Snippets Groups Projects
Commit 030b099c authored by Sortofamudkip's avatar Sortofamudkip
Browse files

Merge branch '18-write-tests'

parents d632d964 cda0fa31
No related branches found
No related tags found
No related merge requests found
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import logging
class Dataset: class Dataset:
...@@ -11,8 +12,23 @@ class Dataset: ...@@ -11,8 +12,23 @@ class Dataset:
dataset_filename (str): the path of the dataset, dataset_filename (str): the path of the dataset,
relative to the location of the file calling this function. relative to the location of the file calling this function.
""" """
raw_dataframe = pd.read_csv(dataset_filename, encoding="windows-1254") if type(dataset_filename) != str:
logging.error("parameter `dataset_filename` is not a string")
raise ValueError(f"{dataset_filename} is not a string")
if not dataset_filename.endswith(".csv"):
logging.error("dataset filename should be CSV")
raise OSError(f"{dataset_filename} is not a CSV file.")
try:
raw_dataframe = pd.read_csv(
dataset_filename, encoding="windows-1254"
)
logging.info("Dataframe successfully loaded")
except FileNotFoundError as e:
logging.error("CSV file not found")
raise e
self.dataframe = self.preprocess_dataset(raw_dataframe) self.dataframe = self.preprocess_dataset(raw_dataframe)
logging.info("Dataset class successfully initialised")
def preprocess_dataset(self, raw_dataframe: pd.DataFrame) -> pd.DataFrame: def preprocess_dataset(self, raw_dataframe: pd.DataFrame) -> pd.DataFrame:
"""preprocess dataframe immediately after loading it. """preprocess dataframe immediately after loading it.
...@@ -25,6 +41,12 @@ class Dataset: ...@@ -25,6 +41,12 @@ class Dataset:
Returns: Returns:
pd.DataFrame: resulting preprocessed dataframe. pd.DataFrame: resulting preprocessed dataframe.
""" """
if type(raw_dataframe) != pd.DataFrame:
logging.error(
"parameter `raw_dataframe` is not a pandas DataFrame"
)
raise ValueError(f"{raw_dataframe} is not a pandas DataFrame")
dataframe = self._drop_unnecessary_columns( dataframe = self._drop_unnecessary_columns(
raw_dataframe raw_dataframe
) # for conveneince ) # for conveneince
...@@ -39,6 +61,10 @@ class Dataset: ...@@ -39,6 +61,10 @@ class Dataset:
return dataframe return dataframe
def get_is_competitive_col(self, dataframe: pd.DataFrame): def get_is_competitive_col(self, dataframe: pd.DataFrame):
if type(dataframe) != pd.DataFrame:
logging.error("parameter `dataframe` is not a pandas DataFrame")
raise ValueError(f"{dataframe} is not a pandas DataFrame")
is_competitive_col = np.zeros(shape=len(dataframe)) is_competitive_col = np.zeros(shape=len(dataframe))
is_competitive_col[ is_competitive_col[
(dataframe["whyplay"] == "improving") (dataframe["whyplay"] == "improving")
...@@ -64,6 +90,10 @@ class Dataset: ...@@ -64,6 +90,10 @@ class Dataset:
Returns: Returns:
pd.DataFrame: the dataframe. pd.DataFrame: the dataframe.
""" """
if type(dataframe) != pd.DataFrame:
logging.error("parameter `dataframe` is not a pandas DataFrame")
raise ValueError(f"{dataframe} is not a pandas DataFrame")
rows_to_drop = ( rows_to_drop = (
[ [
"League", "League",
...@@ -92,6 +122,10 @@ class Dataset: ...@@ -92,6 +122,10 @@ class Dataset:
Returns: Returns:
pd.DataFrame: the dataframe. pd.DataFrame: the dataframe.
""" """
if type(dataframe) != pd.DataFrame:
logging.error("parameter `dataframe` is not a pandas DataFrame")
raise ValueError(f"{dataframe} is not a pandas DataFrame")
# drop rows where users did not accept to having their data used # drop rows where users did not accept to having their data used
dataframe = dataframe.drop( dataframe = dataframe.drop(
dataframe[dataframe["accept"] != "Accept"].index, dataframe[dataframe["accept"] != "Accept"].index,
...@@ -108,6 +142,10 @@ class Dataset: ...@@ -108,6 +142,10 @@ class Dataset:
Returns: Returns:
pd.Series: the Is_competitive column. pd.Series: the Is_competitive column.
""" """
if type(dataframe) != pd.DataFrame:
logging.error("parameter `dataframe` is not a pandas DataFrame")
raise ValueError(f"{dataframe} is not a pandas DataFrame")
dataframe["whyplay"] = dataframe["whyplay"].str.lower() dataframe["whyplay"] = dataframe["whyplay"].str.lower()
most_common_whyplay_reasons = list( most_common_whyplay_reasons = list(
dataframe.groupby("whyplay") dataframe.groupby("whyplay")
...@@ -138,6 +176,10 @@ class Dataset: ...@@ -138,6 +176,10 @@ class Dataset:
Returns: Returns:
pd.Series: the anxiety score column. pd.Series: the anxiety score column.
""" """
if type(dataframe) != pd.DataFrame:
logging.error("parameter `dataframe` is not a pandas DataFrame")
raise ValueError(f"{dataframe} is not a pandas DataFrame")
gad_max = 21 gad_max = 21
gad_min = 0 gad_min = 0
gad_normalised = (dataframe["GAD_T"] - gad_min) / gad_max gad_normalised = (dataframe["GAD_T"] - gad_min) / gad_max
...@@ -161,6 +203,10 @@ class Dataset: ...@@ -161,6 +203,10 @@ class Dataset:
Returns: Returns:
pd.Series: the boolean narcissist column. pd.Series: the boolean narcissist column.
""" """
if type(dataframe) != pd.DataFrame:
logging.error("parameter `dataframe` is not a pandas DataFrame")
raise ValueError(f"{dataframe} is not a pandas DataFrame")
return np.where(dataframe["Narcissism"] <= 1.0, True, False) return np.where(dataframe["Narcissism"] <= 1.0, True, False)
def get_dataframe(self) -> pd.DataFrame: def get_dataframe(self) -> pd.DataFrame:
...@@ -183,6 +229,16 @@ class Dataset: ...@@ -183,6 +229,16 @@ class Dataset:
Returns: Returns:
pd.Series: The sorted column. pd.Series: The sorted column.
""" """
if type(colname) != str:
logging.error("parameter `colname` is not a string")
raise ValueError(f"{colname} is not a string")
if colname not in self.dataframe.columns:
logging.error("column requested not in dataframe")
raise KeyError(f"{colname} is not a column in dataframe")
if not (ascending is None or type(ascending) is bool):
logging.error("parameter `ascending` is not a bool or None")
raise ValueError(f"{ascending} is not a bool or None")
return self.dataframe[colname].sort_values(ascending=ascending) return self.dataframe[colname].sort_values(ascending=ascending)
def get_unique_column_values(self, colname: str): def get_unique_column_values(self, colname: str):
...@@ -195,6 +251,10 @@ class Dataset: ...@@ -195,6 +251,10 @@ class Dataset:
string array: an array of strings containing the unique values string array: an array of strings containing the unique values
present in the column present in the column
""" """
if type(colname) != str:
logging.error("parameter `colname` is not a string")
raise ValueError(f"{colname} is not a string")
return self.dataframe[colname].explode().unique() return self.dataframe[colname].explode().unique()
def get_category_counts( def get_category_counts(
...@@ -210,6 +270,17 @@ class Dataset: ...@@ -210,6 +270,17 @@ class Dataset:
Returns: Returns:
pd.Series: the counted categories. pd.Series: the counted categories.
""" """
if type(colname) != str:
logging.error("parameter `colname` is not a string")
raise ValueError(f"{colname} is not a string")
if colname not in self.dataframe.columns:
logging.error("column requested not in dataframe")
raise KeyError(f"{colname} is not a column in dataframe")
if not (ascending is None or type(ascending) is bool):
logging.error("parameter `ascending` is not a bool or None")
raise ValueError(f"{ascending} is not a bool or None")
grouped_size = self.dataframe.groupby(colname).size() grouped_size = self.dataframe.groupby(colname).size()
return ( return (
grouped_size grouped_size
......
...@@ -3,10 +3,15 @@ import matplotlib.pyplot as plt ...@@ -3,10 +3,15 @@ import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
import seaborn as sns import seaborn as sns
from .Dataset import Dataset from .Dataset import Dataset
import logging
class Plotter: class Plotter:
def __init__(self, dataset: Dataset): def __init__(self, dataset: Dataset):
if type(dataset) != Dataset:
logging.error("dataset parameter is not of type Dataset")
raise ValueError(f"{dataset} is not of type Dataset")
self.ds = dataset self.ds = dataset
self.df = dataset.get_dataframe() self.df = dataset.get_dataframe()
...@@ -25,6 +30,46 @@ class Plotter: ...@@ -25,6 +30,46 @@ class Plotter:
if styling_params.get("title"): if styling_params.get("title"):
ax.set_title(styling_params["title"]) ax.set_title(styling_params["title"])
def distribution_plot(self, target, styling_params={}) -> None:
"""plot a distribution plot.
Args:
target (str, must be present as a column in the dataset),
styling_params (dict)
Returns:
None
"""
# implementing sensible logging and error catching
if type(target) != str:
logging.error("parameter target should be a string.")
raise ValueError("parameter target should be a string.")
if not (target in self.df.columns):
logging.error("parameter target cannot be found in the dataset.")
raise ValueError(
"parameter target cannot be found in the dataset."
)
if type(styling_params) != dict:
logging.error("parameter styling params should be a dict.")
raise ValueError("parameter styling params should be a dict.")
# plotting the plot
grouped_data = self.df.groupby(target).size()
plt.barh(grouped_data.index, grouped_data.values)
print(
str(grouped_data),
str(grouped_data.index),
str(grouped_data.values),
)
plt.xlabel("Size")
plt.ylabel(target)
plt.title(f"Distribution of {target}")
def plot_categorical_bar_chart( def plot_categorical_bar_chart(
self, category1, category2, styling_params={} self, category1, category2, styling_params={}
) -> None: ) -> None:
...@@ -39,6 +84,36 @@ class Plotter: ...@@ -39,6 +84,36 @@ class Plotter:
Returns: Returns:
None None
""" """
# implementing sensible logging and error catching
if type(category1) != str:
logging.error("parameter category1 should be a string.")
raise ValueError("parameter category1 should be a string.")
if not (category1 in self.df.columns):
logging.error(
"parameter category1 cannot be found in the dataset."
)
raise ValueError(
"parameter category1 cannot be found in the dataset."
)
if type(category2) != str:
logging.error("parameter category2 should be a string.")
raise ValueError("parameter category2 should be a string.")
if not (category2 in self.df.columns):
logging.error(
"parameter category2 cannot be found in the dataset."
)
raise ValueError(
"parameter category2 cannot be found in the dataset."
)
if type(styling_params) != dict:
logging.error("parameter styling params should be a dict.")
raise ValueError("parameter styling params should be a dict.")
# plotting the plot
ct = pd.crosstab(self.df[category1], self.df[category2]) ct = pd.crosstab(self.df[category1], self.df[category2])
# Calculate percentages by row # Calculate percentages by row
ct_percent = ct.apply(lambda r: r / r.sum() * 100, axis=0) ct_percent = ct.apply(lambda r: r / r.sum() * 100, axis=0)
...@@ -79,6 +154,33 @@ class Plotter: ...@@ -79,6 +154,33 @@ class Plotter:
Returns: Returns:
None None
""" """
# implementing sensible logging and error catching
if type(target) != str:
logging.error("parameter target should be a string.")
raise ValueError("parameter target should be a string.")
if not (target in self.df.columns):
logging.error("parameter target cannot be found in the dataset.")
raise ValueError(
"parameter target cannot be found in the dataset."
)
if type(category) != str:
logging.error("parameter category should be a string.")
raise ValueError("parameter category should be a string.")
if not (category in self.df.columns):
logging.error("parameter category cannot be found in the dataset.")
raise ValueError(
"parameter category cannot be found in the dataset."
)
if type(styling_params) != dict:
logging.error("parameter styling params should be a dict.")
raise ValueError("parameter styling params should be a dict.")
# plotting the plot
uniques = self.ds.get_unique_column_values(category) uniques = self.ds.get_unique_column_values(category)
fig, ax = plt.subplots() fig, ax = plt.subplots()
self.customize_plot(fig, ax, styling_params) self.customize_plot(fig, ax, styling_params)
...@@ -104,6 +206,33 @@ class Plotter: ...@@ -104,6 +206,33 @@ class Plotter:
Returns: Returns:
None None
""" """
# implementing sensible logging and error catching
if type(target1) != str:
logging.error("parameter target1 should be a string.")
raise ValueError("parameter target1 should be a string.")
if not (target1 in self.df.columns):
logging.error("parameter target1 cannot be found in the dataset.")
raise ValueError(
"parameter target1 cannot be found in the dataset."
)
if type(target2) != str:
logging.error("parameter target2 should be a string.")
raise ValueError("parameter target2 should be a string.")
if not (target2 in self.df.columns):
logging.error("parameter target2 cannot be found in the dataset.")
raise ValueError(
"parameter target2 cannot be found in the dataset."
)
if type(styling_params) != dict:
logging.error("parameter styling params should be a dict.")
raise ValueError("parameter styling params should be a dict.")
# plotting the plot
fig, ax = plt.subplots() fig, ax = plt.subplots()
self.customize_plot(fig, ax, styling_params) self.customize_plot(fig, ax, styling_params)
ax.scatter(self.df[target1], self.df[target2]) ax.scatter(self.df[target1], self.df[target2])
...@@ -125,4 +254,3 @@ class Plotter: ...@@ -125,4 +254,3 @@ class Plotter:
plt.xlabel("Size") plt.xlabel("Size")
plt.ylabel(target) plt.ylabel(target)
plt.title(f"Distribution of {target}") plt.title(f"Distribution of {target}")
"""
This test file tests the Dataset class in Dataset.py.
"""
from Dataset import Dataset
import pandas as pd
from pathlib import Path
import pytest
import numpy as np
this_file_dir = Path(__file__).parent
@pytest.fixture
def the_dataset() -> Dataset:
dataset = Dataset(str(this_file_dir / "../data/GamingStudy_data.csv"))
return dataset
def test_load_Dataset_class():
"""Tests if the dataset is successfully loaded."""
dataset = Dataset(str(this_file_dir / "../data/GamingStudy_data.csv"))
assert type(dataset) == Dataset
assert type(dataset.dataframe) == pd.DataFrame
def test_incorrectly_load_Dataset_class():
with pytest.raises(ValueError):
dataset = Dataset(1234) # not a string
with pytest.raises(OSError):
dataset = Dataset("aaa.bcd") # doesn't end with .csv
with pytest.raises(FileNotFoundError):
dataset = Dataset(
str(this_file_dir / "./data/GamingStudy_data.csv")
) # wrong file location
def test_get_dataframe(the_dataset: Dataset):
"""Tests Dataset.get_dataframe()."""
assert type(the_dataset.get_dataframe()) == pd.DataFrame
def test_combined_anxiety_score(the_dataset: Dataset):
"""Tests Dataset.get_combined_anxiety_score()."""
dataframe = the_dataset.get_dataframe()
anxiety_scores = the_dataset.get_combined_anxiety_score(dataframe)
assert anxiety_scores.dtype == float
assert anxiety_scores.min() >= 0
assert anxiety_scores.max() <= 1
def test_get_is_narcissist_col(the_dataset: Dataset):
"""Tests Dataset.get_is_narcissist_col()."""
dataframe = the_dataset.get_dataframe()
is_narcissist_row = the_dataset.get_is_narcissist_col(dataframe)
assert is_narcissist_row.dtype == bool
def test_preprocessed_dataframe(the_dataset: Dataset):
"""Tests that the dataframe is preprocessed correctly."""
dataframe = the_dataset.get_dataframe()
columns_set = set(dataframe.columns)
assert "League" not in columns_set
assert "Anxiety_score" in columns_set
assert "Is_narcissist" in columns_set
def test_get_sorted_columns(the_dataset: Dataset):
"""Tests Dataset.get_sorted_column()."""
sorted_GAD1 = the_dataset.get_sorted_column("GAD_T")
assert sorted_GAD1.iloc[0] <= sorted_GAD1.iloc[-1]
@pytest.mark.parametrize(
"param",
[np.array([1, 2, 3]), "123", 3, 0.1, [], np, pd, True, None],
)
def test_catch_non_dataframe(the_dataset: Dataset, param):
"""Tests that functions that take pd.DataFrame correctly
catch incorrect input data types.
"""
with pytest.raises(ValueError):
the_dataset.preprocess_dataset(param)
the_dataset.get_is_competitive_col(param)
the_dataset._drop_unnecessary_columns(param)
the_dataset.remove_nonaccepting_rows(param)
the_dataset.preprocess_whyplay(param)
the_dataset.get_combined_anxiety_score(param)
the_dataset.get_is_narcissist_col(param)
@pytest.mark.parametrize(
"param",
["true", "false", "True", "False", 1, 0, -1],
)
def test_catch_non_bool(the_dataset: Dataset, param):
"""Tests that functions that take bool or None correctly
catch incorrect input data types."""
dataframe = the_dataset.get_dataframe()
columns_set = set(dataframe.columns)
with pytest.raises(ValueError):
the_dataset.get_category_counts("GAD_T", param)
the_dataset.get_sorted_column("GAD_T", param)
@pytest.mark.parametrize(
"param",
[True, False, None],
)
def test_bool_or_none_params(the_dataset: Dataset, param):
"""Tests that functions that take bool or None correctly
work as intended.
"""
dataframe = the_dataset.get_dataframe()
columns_set = set(dataframe.columns)
the_dataset.get_category_counts("GAD_T", param)
def test_catch_colname_not_in_df(the_dataset: Dataset):
"""Tests that functions that take colname correctly
catch colnames not in dataset."""
with pytest.raises(KeyError):
the_dataset.get_category_counts("GAAAD_T")
the_dataset.get_sorted_column("GAAAD_T")
from pathlib import Path
from Dataset import Dataset
from Plotter import Plotter
import pandas as pd
this_file_dir = Path(__file__).parent
def test_load_plotter():
"""Tests that the Plotter class can be loaded."""
dataset = Dataset(str(this_file_dir / "../data/GamingStudy_data.csv"))
plotter = Plotter(dataset)
assert type(plotter.df) == pd.DataFrame
assert type(plotter.ds) == Dataset
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment