Source code for src.plot

"""
Plot weight-training data.
"""

from src.utils.config import settings  # type: ignore
from typing import Any
import matplotlib.pyplot as plt  # type: ignore
import numpy as np  # type: ignore
import pandas as pd  # type: ignore
import seaborn as sns  # type: ignore
from src.crud.read import show_exercise  # type: ignore
from src.utils.get_exercises import get_available_exercises  # type: ignore
from src.utils.set_db_and_table import set_db_and_table  # type: ignore


[docs] def get_data(date: str, split: str) -> dict[str, pd.DataFrame] | dict: """Prepare pandas dataframes with training data for plotting. :param date: Date of the workout :type date: str :param split: Split of the workout :type split: str :return: Dataframes with training data :rtype: dict[str, pd.DataFrame] | dict """ _, table, training_catalogue = set_db_and_table(datatype="real") exercises = get_available_exercises(training_catalogue, split) return { ex: df for ex in exercises if not (df := pd.DataFrame(data=show_exercise(table, ex, date))).empty }
[docs] def compare_workouts(dfs_1: dict, dfs_2: dict) -> tuple[dict[Any, Any], dict[Any, Any]]: """Compare two workouts and return common exercises. :param dfs_1: Dataframes with training data :type dfs_1: dict :param dfs_2: Dataframes with training data :type dfs_2: dict :return: Dataframes with common exercises :rtype: tuple[dict[Any, Any], dict[Any, Any]] """ common_exercises = [] for k in sorted(set(dfs_1.keys()).intersection(set(dfs_2.keys()))): common_exercises.append(k) dfs_1_common = {k: v for k, v in dfs_1.items() if k in common_exercises} dfs_2_common = {k: v for k, v in dfs_2.items() if k in common_exercises} return dfs_1_common, dfs_2_common
# TODO: Move to datetime package
[docs] def validate_date_format(date: str) -> str: """Validate date format and return year. :param date: Date of the workout :type date: str """ if date[4] != "-": raise ValueError("Invalid date format. Should be YYYY-MM-DD.") return date[:4]
[docs] def create_barplots(dfs: dict, date: str) -> None: """Plot training data for specific date. :param dfs: Dataframes with training data :type dfs: dict :param date: Date of the workout :type date: str """ # TODO: highten legend transparency # TODO: set figure-level x- and y labels ("set_number" and "Repetitions") year = validate_date_format(date) keys = list(dfs.keys()) values = list(dfs.values()) sns.set_theme(style="white", context="talk") match len(dfs): case 3: print("3 exercises to plot") f, axes = plt.subplots(len(dfs), 1, figsize=(11, 9), sharex=True) # Ensure axes is an iterable array even if it's a single Axes object if not isinstance(axes, np.ndarray): axes = np.array([axes]) ax1, ax2, ax3 = axes case 4: print("4 exercises to plot") f, axes = plt.subplots( len(dfs), 1, figsize=(11, 9), sharex=True ) if not isinstance(axes, np.ndarray): axes = np.array([axes]) ax1, ax2, ax3, ax4 = axes sns.barplot( x=values[3]["set_number"], y=values[3]["reps"], hue=values[3]["weight"], palette="rocket", ax=ax4, ) ax4.axhline(0, color="k", clip_on=False) ax4.set_ylabel(keys[3]) ax4.bar_label(ax4.containers[0]) case _: print("Wrong number of exercises. Should be 3 or 4.") sns.barplot( x=values[0]["set_number"], y=values[0]["reps"], hue=values[0]["weight"], palette="rocket", ax=ax1, ) ax1.axhline(0, color="k", clip_on=False) ax1.set_ylabel(keys[0]) ax1.bar_label(ax1.containers[0]) sns.barplot( x=values[1]["set_number"], y=values[1]["reps"], hue=values[1]["weight"], palette="vlag", ax=ax2, ) ax2.axhline(0, color="k", clip_on=False) ax2.set_ylabel(keys[1]) ax2.bar_label(ax2.containers[0]) sns.barplot( x=values[2]["set_number"], y=values[2]["reps"], hue=values[2]["weight"], palette="deep", ax=ax3, ) ax3.axhline(0, color="k", clip_on=False) ax3.set_ylabel(keys[2]) ax3.bar_label(ax3.containers[0]) sns.despine(bottom=True) plt.setp(f.axes, yticks=[]) plt.tight_layout(h_pad=2) plt.title(f"Workout date: {date}") sns.move_legend(ax1, "upper right", bbox_to_anchor=(1, 1)) sns.move_legend(ax3, "center right", bbox_to_anchor=(1, 1)) plt.savefig(f"{settings['IMG_PATH']}{year}/workout_{date}.png")
[docs] def main() -> None: """Get data and create figure. """ dates = [ "2021-12-11", # "2022-03-14", "2022-05-28", ] split = "legs" dfs_first_leg_workout = get_data(dates[0], split) dfs_last_leg_workout = get_data(dates[-1], split) dfs_1_common, dfs_2_common = compare_workouts( dfs_first_leg_workout, dfs_last_leg_workout ) create_barplots(dfs_1_common, dates[0]) create_barplots(dfs_2_common, dates[-1]) plt.show()
if __name__ == "__main__": main()