Source code for src.model.plot_model

"""
Plot weight-training data with fit.
"""

import argparse
from datetime import datetime
from src.utils.config import settings  # type: ignore
import matplotlib.pyplot as plt  # type: ignore
import seaborn as sns  # type: ignore
from src.model.model import (  # type: ignore
    get_data,
    get_df,
    one_rep_max_estimator,
    calc_volume
)
from src.utils.set_db_and_table import set_db_and_table  # type: ignore


[docs] def create_1rm_plots(datatype: str, x: list, y: list, exercise: str) -> None: """Plot training data 1RM with fit. :param datatype: Data type: real or simulated :type datatype: str :param x: x-axis data :type x: list :param y: y-axis data :type y: list :param exercise: Exercise name :type exercise: str """ plt.figure(figsize=(8, 8)) x_prg1, x_prg2 = x[:10], x[10:] y_prg1, y_prg2 = y[:10], y[10:] # Only add confidence intervals if there are sufficient data points if len(x) < 5: sns.set_theme() ax = sns.scatterplot(x=x, y=y) ax.set_title(f"{exercise}") else: # ax = sns.regplot(x=x, y=y, ci=68, truncate=False) ax = sns.regplot( x=x_prg1, y=y_prg1, ci=68, truncate=False, label="Program 1: 4-split" ) sns.regplot(x=x_prg2, y=y_prg2, ci=68, truncate=True, label="Program 2: PPL") ax.set_title(f"{exercise}", fontsize=30) xticks = ax.get_xticks() xticks_dates = [datetime.fromtimestamp(x).strftime("%Y-%m-%d") for x in xticks] ax.set_xticklabels(xticks_dates) plt.ylim(0, max(y) + 5) plt.xticks(rotation=45) ax.set_ylabel("1 RM estimates [kg]", fontsize=20) ax.legend(loc="lower right", fontsize=20) plt.savefig( f"{settings['IMG_PATH']}all_years/one_rep_max/{datatype}_fitted_data_{exercise}_splines.png" ) plt.clf() # clear figure before next plot
[docs] def create_volume_plots(datatype: str, x: list, y: list, exercise: str) -> None: """Create volume plots. :param datatype: Data type: real or simulated :type datatype: str :param x: x-axis data :type x: list :param y: y-axis data :type y: list :param exercise: Exercise name :type exercise: str """ plt.figure(figsize=(8, 8)) # TODO: lookup dates and use get_program instead of list slicing x_prg1, x_prg2, x_prg3 = x[:10], x[10:20], x[20:] y_prg1, y_prg2, y_prg3 = y[:10], y[10:20], y[20:] ax = sns.regplot(x=x_prg1, y=y_prg1, ci=68, label="Program 1: 4-SPLIT") sns.regplot(x=x_prg2, y=y_prg2, ci=68, label="Program 2: PPL") sns.regplot(x=x_prg3, y=y_prg3, ci=68, label="Program 3: GVT") ax.set_title(f"{exercise}", fontsize=30) xticks = ax.get_xticks() xticks_dates = [datetime.fromtimestamp(x).strftime("%Y-%m-%d") for x in xticks] ax.set_xticklabels(xticks_dates) plt.ylim(0, max(y) + 5) plt.xticks(rotation=45) ax.set_ylabel("Volume [kg]", fontsize=20) ax.legend(loc="lower right", fontsize=20) plt.savefig( f"{settings['IMG_PATH']}all_years/volume/{datatype}_fitted_data_{exercise}_gvt.png" ) plt.clf() # clear figure before next plot
# TODO: unit test below function
[docs] def get_split(pgm: str) -> list[tuple[list, str]]: """Get split and key exercises. :param pgm: Program type: 1rm or gvt :type pgm: str :return: Split and key exercises :rtype: list[tuple[list, str]] """ splits_and_key_exercises_1rm = [ (["chest", "push"], "bb_bench_press"), (["back", "pull"], "seated_row"), (["legs"], "squat"), (["legs"], "deadlift"), ] splits_and_key_exercises_gvt = [ (["chest", "push", "chest_and_back"], "bb_bench_press"), (["legs", "legs_and_abs"], "squat"), ] split_selector = { "1rm": splits_and_key_exercises_1rm, "gvt": splits_and_key_exercises_gvt, } return split_selector[pgm]
[docs] def make_plots( pgm: str, split_selection: list[tuple[list, str]], table, datatype: str ) -> None: """Make plots. :param pgm: Program type :type pgm: str :param split_selection: Split and key exercises :type split_selection: list[tuple[list, str]] :param table: TinyDB table :type table: TinyDB.table :param datatype: Data type: real or simulated :type datatype: str :raises ValueError: If program type is not 1rm or gvt """ for splits, exercise in split_selection: df = get_df(table, splits, exercise) match pgm: case "1rm": df_plot = one_rep_max_estimator(df) x, y = get_data(df_plot) create_1rm_plots(datatype, x, y, exercise) case "gvt": df_plot = calc_volume(df) x, y = get_data(df_plot, y_col="volume") create_volume_plots(datatype, x, y, exercise) case _: raise ValueError
[docs] def main() -> None: """Get data and create figure. """ parser = argparse.ArgumentParser() parser.add_argument("--datatype", type=str, required=True) # real/simulated parser.add_argument("--pgm", type=str, required=True) # 1rm/gvt args = parser.parse_args() datatype = args.datatype pgm = args.pgm _, table, _ = set_db_and_table(datatype) split_selection = get_split(pgm) make_plots(pgm, split_selection, table, datatype)
if __name__ == "__main__": main()