Source code for ogstools.plot.utils

# Copyright (c) 2012-2024, OpenGeoSys Community (http://www.opengeosys.org)
#            Distributed under a Modified BSD License.
#            See accompanying file LICENSE.txt or
#            http://www.opengeosys.org/project/license
#


from itertools import cycle, islice
from math import nextafter
from pathlib import Path
from typing import Any
from warnings import warn

import matplotlib.pyplot as plt
import numpy as np
import pyvista as pv
from cycler import Cycler
from matplotlib import colormaps
from matplotlib import colors as mcolors
from matplotlib.animation import FFMpegWriter, FuncAnimation, ImageMagickWriter
from typeguard import typechecked

from ogstools.plot.levels import level_boundaries
from ogstools.variables import Variable

from .shared import setup


[docs] def get_style_cycler( min_number_of_styles: int, colors: list | None | None = None, linestyles: list | None = None, ) -> Cycler: if colors is None: colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] if linestyles is None: linestyles = ["-", "--", ":", "-."] styles_len = min(len(colors), len(linestyles)) c_cycler = plt.cycler(color=colors) ls_cycler = plt.cycler(linestyle=linestyles) if min_number_of_styles <= styles_len: style_cycler = c_cycler[:styles_len] + ls_cycler[:styles_len] else: style_cycler = ls_cycler * c_cycler return style_cycler
[docs] def justified_labels(points: np.ndarray) -> list[str]: "Formats an array of points to a list of aligned str." def fmt(val: float) -> str: return f"{val:.2f}".rstrip("0").rstrip(".") col_lens = np.max( [[len(fmt(coord)) for coord in point] for point in points], axis=0 ) dim = points.shape[1] return [ ", ".join(fmt(point[i]).rjust(col_lens[i]) for i in range(dim)) for point in points ]
[docs] @typechecked def label_spatial_axes( fig: plt.Figure | None, axes: plt.Axes | np.ndarray, x_var: Variable, y_var: Variable, ) -> None: """ Add labels to x and y axis. If given an array of axes, only the outer axes will be labeled. """ if isinstance(axes, np.ndarray): ax: plt.Axes if fig is not None: for ax in np.ravel(axes): label_ax(fig, ax, x_var, y_var) else: for ax in axes[-1, :]: ax.set_xlabel(x_var.get_label()) for ax in axes[:, 0]: ax.set_ylabel(y_var.get_label()) else: axes.set_xlabel(x_var.get_label()) axes.set_ylabel(y_var.get_label())
[docs] def label_ax( fig: plt.Figure, ax: plt.Axes, var_x: Variable, var_y: Variable, fontsize: float | None = None, ) -> None: """Labels the x- and y-Axes according to the given Variables. Accounts for shared axes and if that's the case, only the first axes in a row or column will be labeled.""" sharex = ax.get_shared_x_axes().joined(ax, fig.axes[0]) sharey = ax.get_shared_y_axes().joined(ax, fig.axes[0]) is_first_in_row = ax.get_position().xmin == min( [ax_.get_position().xmin for ax_ in fig.axes] ) is_first_in_col = ax.get_position().ymin == min( [ax_.get_position().ymin for ax_ in fig.axes] ) fontsize = setup.fontsize if fontsize is None else fontsize if not sharex or (sharex and is_first_in_col): ax.set_xlabel(var_x.get_label(), fontsize=fontsize) if not sharey or (sharey and is_first_in_row): ax.set_ylabel(var_y.get_label(), fontsize=fontsize)
[docs] def update_font_sizes( axes: plt.Axes | np.ndarray | list[plt.Axes], fontsize: float | None = None, ) -> None: """ Update font sizes of labels and texts. This also scales the ticks accordingly. :param ax: matplotlib axes which should be updated :param fontsize: font size for the labels and ticks """ if fontsize is None: fontsize = setup.fontsize ax: plt.Axes scale = fontsize / setup.fontsize tick_pad = scale * setup.tick_pad tick_len = scale * setup.tick_length min_tick_len = tick_len * 2.0 / 3.5 # matplotlib default for ax in np.ravel(np.asarray(axes)): tick_labels = ax.get_xticklabels() + ax.get_yticklabels() labels = [ax.title, ax.xaxis.label, ax.yaxis.label] offset_text = ax.yaxis.get_offset_text() ax.tick_params(axis="both", which="both", labelsize=fontsize) for item in tick_labels + labels + [offset_text]: item.set_fontsize(fontsize) ax.tick_params("both", which="major", pad=tick_pad, length=tick_len) ax.tick_params("both", which="minor", pad=tick_pad, length=min_tick_len) return
[docs] def get_data_aspect(mesh: pv.UnstructuredGrid) -> float: """ Calculate the data aspect ratio of a 2D mesh. """ mean_normal = np.abs(np.mean(mesh.extract_surface().cell_normals, axis=0)) projection = int(np.argmax(mean_normal)) x_id, y_id = 2 * np.delete([0, 1, 2], projection) lims = mesh.bounds return abs(lims[x_id + 1] - lims[x_id]) / abs(lims[y_id + 1] - lims[y_id])
[docs] def get_rows_cols( meshes: ( list[pv.UnstructuredGrid] | np.ndarray | pv.UnstructuredGrid | pv.MultiBlock ), ) -> tuple[int, ...]: if isinstance(meshes, np.ndarray): if meshes.ndim in [1, 2]: return meshes.shape msg = "Input numpy array must be 1D or 2D." raise ValueError(msg) if isinstance(meshes, list): return (1, len(meshes)) if isinstance(meshes, pv.MultiBlock): return (1, meshes.n_blocks) return (1, 1)
[docs] def get_projection( mesh: pv.UnstructuredGrid, ) -> tuple[int, int, int, np.ndarray]: """ Identify which projection is used: XY, XZ or YZ. :param mesh: singular mesh :returns: x_id, y_id, projection, mean_normal """ mean_normal = np.abs(np.mean(mesh.extract_surface().cell_normals, axis=0)) projection = int(np.argmax(mean_normal)) x_id, y_id = np.delete([0, 1, 2], projection) return x_id, y_id, projection, mean_normal
[docs] def save_animation(anim: FuncAnimation, filename: str, fps: int) -> None: """ Save a FuncAnimation with some codec presets. :param anim: the FuncAnimation to be saved :param filename: the name of the resulting file :param fps: the number of frames per second """ print("Start saving animation...") extension = Path(filename).suffix msg = "" match extension, FFMpegWriter.isAvailable(), ImageMagickWriter.isAvailable(): case _, False, False: msg = """Neither .mp4 nor .gif writers are installed.\n Try installing ffmpeg and/or ImageMagick respectively to enable them.""" case ".mp4", True, _: codec_args = ( "-crf 28 -preset ultrafast -pix_fmt yuv420p " "-vf pad=ceil(iw/2)*2:ceil(ih/2)*2" ).split(" ") writer = FFMpegWriter( fps=fps, codec="libx265", extra_args=codec_args ) case ".gif", _, True: warn( """\n Gif format may struggle with cache overflow errors, due to lossless compression.\n You can try avoiding it by lowering dpi with: ogstools.plot.setup.dpi=50 \n or you can export to mp4 format instead.""", RuntimeWarning, stacklevel=2, ) writer = ImageMagickWriter() case ".mp4", False, True: msg = """ffmpeg is not installed. Output to .mp4 not possible.\n Try .gif extension instead or install ffmpeg.""" case ".gif", True, False: msg = """ImageMagick is not installed.\n Output to .gif not possible.\n Try .mp4 extension or install ImageMagick.""" case extension, _, _: msg = f"""{extension} is not a supported file type.\n Only .mp4 and .gif are supported.\n Try installing ffmpeg for .mp4 or ImageMagic for .gif.""" if msg != "": raise RuntimeError(msg) try: anim.save(filename, writer=writer) print("Successful!") except Exception as err: msg = f"\nSaving Animation failed with the following error: {err}" raise RuntimeError(msg) from err
[docs] def get_cmap_norm( levels: np.ndarray, variable: Variable, **kwargs: Any ) -> tuple[mcolors.Colormap, mcolors.Normalize]: """Construct a discrete colormap and norm for the variable field.""" vmin, vmax = (levels[0], levels[-1]) if variable.categoric: vmin += 0.5 vmax += 0.5 if "cmap" in kwargs: continuous_cmap = colormaps[kwargs.get("cmap")] elif isinstance(variable.cmap, str): continuous_cmap = colormaps[variable.cmap] else: continuous_cmap = variable.cmap conti_norm: mcolors.TwoSlopeNorm | mcolors.Normalize if variable.bilinear_cmap: if vmin < 0.0 < vmax: vcenter = 0.0 vmin, vmax = np.max(np.abs([vmin, vmax])) * np.array([-1.0, 1.0]) conti_norm = mcolors.TwoSlopeNorm(vcenter, vmin, vmax) else: # only use one half of the diverging colormap col_range = np.linspace(0.0, nextafter(0.5, -np.inf), 128) if vmax > 0.0: col_range += 0.5 continuous_cmap = mcolors.LinearSegmentedColormap.from_list( "half_cmap", continuous_cmap(col_range) ) conti_norm = mcolors.Normalize(vmin, vmax) else: conti_norm = mcolors.Normalize(vmin, vmax) mid_levels = np.append((levels[:-1] + levels[1:]) * 0.5, levels[-1]) colors = [continuous_cmap(conti_norm(m_l)) for m_l in mid_levels] if setup.custom_cmap is None: cmap = mcolors.ListedColormap(colors, name="custom") else: cmap = setup.custom_cmap boundaries = level_boundaries(levels) if variable.categoric else levels if vmax == nextafter(vmin, np.inf): cmap = mcolors.ListedColormap(["grey", "grey"], name="custom") if nextafter(levels[0], np.inf) == levels[-1]: return cmap, None norm = mcolors.BoundaryNorm( boundaries=boundaries, ncolors=len(boundaries) - 1, clip=False ) return cmap, norm
# TODO: use ColorType when matplotlib can be upgraded to 3.8
[docs] def contrast_color(color: Any) -> Any: """Return black or white - whichever has more contrast to color. https://stackoverflow.com/questions/3942878/how-to-decide-font-color-in-white-or-black-depending-on-background-color """ r, g, b = mcolors.to_rgb(color) # threshold lowered on purpose to prefer black coloring to only use white # when it is really necessary return "k" if (r * 0.299 + g * 0.587 + b * 0.114) > 0.2 else "w"
[docs] def colors_from_cmap(cmap: str | list, num: int) -> list[str]: "Convert a colormap to a list of colors." if isinstance(cmap, str): # Assuming it's a colormap name cmap_: plt.Colormap = plt.get_cmap(cmap) if cmap_.N <= 20: # for discrete colormaps return list(map(cmap_, range(num))) # for continuous colormaps return list(map(cmap_, np.linspace(0, 1, num))) # Assuming it's already a list of colors, repeat list entries until length # of num is reached return list(islice(cycle(cmap), num))
[docs] def padded( ax: plt.Axes, x: float, y: float, pad_x: bool = True ) -> tuple[float, float]: "Add a padding to x and y towards the axes center." x, y = ax.transLimits.transform((x, y)) if pad_x and (x <= 0.25 or x >= 0.75): x += (2 * (x <= 0.5) - 1) * 0.075 y += (2 * (y <= 0.5) - 1) * 0.075 # Unpacking this here helps type hinting. Direct return doesn't work. x, y = ax.transLimits.inverted().transform((x, y)) return x, y
[docs] def color_twin_axes(axes: list[plt.Axes], colors: list) -> None: for ax_temp, color_temp in zip(axes, colors, strict=False): ax_temp.tick_params(axis="y", which="both", colors=color_temp) ax_temp.yaxis.label.set_color(color_temp) # Axis spine color has to be applied on twin axis for both sides axes[1].spines["left"].set_color(colors[0]) axes[1].spines["right"].set_color(colors[1])