Source code for environmentaltools.graphics.utils

import matplotlib.pyplot as plt
import shutil


[docs] def enable_latex_rendering(): """Enable LaTeX rendering for matplotlib plots.""" if shutil.which("latex") is None: raise RuntimeError( "LaTeX is required to render text in figures. Supported distributions:\n" " - MiKTeX (https://miktex.org/download) - Windows\n" " - TeX Live (https://tug.org/texlive/) - Cross-platform\n" " - TinyTeX (https://yihui.org/tinytex/) - Lightweight" ) plt.rcParams.update({ "text.usetex": True, "font.family": "serif", "text.latex.preamble": r'\usepackage{amsmath}' }) return
[docs] def show(file_name: str = None, res: int = 600): """Save figure to file or display on screen. Saves the current matplotlib figure to a file with specified resolution or displays it interactively. Args: file_name (str, optional): Path to save the plot. If None, displays interactively. If "to_axes", returns without action. Defaults to None. res (int): Resolution in DPI for saved figure. Defaults to 600. Returns: None: Either displays the figure or saves it to file. """ if not file_name: plt.show() elif file_name == ("to_axes"): pass else: if "png" or "pdf" in file_name: plt.savefig(f"{file_name}", dpi=res, bbox_inches="tight") else: plt.savefig(f"{file_name}" + ".png", dpi=res, bbox_inches="tight") plt.close() return
[docs] def handle_axis( ax, row_plots: int = 1, col_plots: int = 1, dim: int = 2, figsize: tuple = (5, 5), projection=None, kwargs: dict = {}, ): """Create matplotlib axis for figure if needed. Creates a new matplotlib axis/figure if none provided, supporting 2D, 3D, and projection-based plots with customizable layout. Args: ax (matplotlib.axes.Axes): Existing axis for the plot, or None to create new. row_plots (int): Number of subplot rows. Defaults to 1. col_plots (int): Number of subplot columns. Defaults to 1. dim (int): Number of dimensions (2 for 2D, 3 for 3D plots). Defaults to 2. figsize (tuple): Figure size as (width, height) in inches. Defaults to (5, 5). projection: Projection type for the axis (e.g., 'polar', CRS objects). Defaults to None. kwargs (dict): Additional keyword arguments passed to plt.subplots. Defaults to {}. Returns: tuple: (fig, ax) where fig is the Figure object (or None if ax was provided) and ax is the Axes object(s). If multiple subplots, ax is flattened array. """ fig = None if not ax: if dim == 2: fig, ax = plt.subplots(row_plots, col_plots, figsize=figsize, **kwargs) elif not projection is None: ax = plt.axes( projection=projection ) # project using coordinate reference system (CRS) of street map else: fig = plt.figure() ax = fig.add_subplot(111, projection="3d") if row_plots + col_plots > 2: ax = ax.flatten() return fig, ax
[docs] def labels(variable): """Gives the labels and units for plots Args: * variable (string): name of the variable Returns: * units (string): label with the name of the variable and the units """ units = { "calms": r"$\delta$ (h)", "depth": "Depth (m)", "dm": r"$\theta_m$ (deg)", "dv": r"$\theta_v$ (deg)", "DirM": r"$\theta_m$ (deg)", "DirU": r"$\theta_U$ (deg)", "DirV": r"$\theta_v$ (deg)", "Dmd": r"$\theta_m$ (deg)", "Dmv": r"$\theta_v$ (deg)", "dur": r"d (hr)", "dur_calms": r"$\Delta_0$ (hr)", "dur_storm": r"$d_0$ (hr)", "eta": r"$\eta$ (m)", "hs": r"$H_{s}$ (m)", "Hm0": r"$H_{m0}$ (m)", "Hs": r"$H_{s}$ (m)", "lat": "Latitude (deg)", "lon": "Longitude (deg)", "ma": r"$M_{ast}$ (m)", "mm": r"$M_{met}$ (m)", "pr": r"P (mm/day)", "Q": r"Q (m$^3$/s)", "Qd": r"$Q_d$ (m$^3$/s)", "S": r"S (psu)", "slr": r"$\Delta\eta$ (m)", "storm": r"d (h)", "surge": r"$\eta_s$ (m)", "swh": r"$H_{s}$ (m)", "t": "t (s)", "Tm0": r"$T_{m0}$ (s)", "tp": r"$T_p$ (s)", "Tp": r"$T_p$ (s)", "U": r"U (m/s)", "V": r"V (m/s)", "VelV": "[m/s]", "vv": r"$V_v$ (m/s)", "Vv": r"$V_v$ (m/s)", "Wd": r"$W_d$ (deg)", "Wv": r"$W_v$ (m/s)", "x": "x (m)", "y": "y (m)", "z": "z (m)", "None": "None", } if isinstance(variable, str): if not variable in units.keys(): labels_ = "" else: labels_ = units[variable] elif isinstance(variable, list): labels_ = list() for var_ in variable: if not var_ in units.keys(): labels_.append("") else: labels_.append(units[var_]) else: raise ValueError return labels_