Source code for ogstools.plot.lineplots

from collections.abc import Sequence
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pyvista as pv

from ogstools.plot import setup, utils
from ogstools.variables import Variable, normalize_vars


def _format_ax(
    ax: plt.Axes,
    x_var: Variable,
    y_var: Variable,
    pure_spatial: bool,
    kwargs: dict,
) -> None:
    show_grid = kwargs.pop("grid", True) and not pure_spatial

    if ax.get_xlabel() == "":
        ax.set_xlabel(x_var.get_label(setup.label_split))
    if ax.get_ylabel() == "":
        ax.set_ylabel(y_var.get_label(setup.label_split))

    if show_grid:
        ax.grid(which="major", color="lightgrey", linestyle="-")
        ax.grid(which="minor", color="0.95", linestyle="--")
        ax.minorticks_on()

    if not pure_spatial:
        ax.figure.tight_layout()


[docs] def line( dataset: pv.DataSet | Sequence[pv.DataSet], var1: str | Variable | None = None, var2: str | Variable | None = None, ax: plt.Axes | None = None, sort: bool = True, **kwargs: Any, ) -> plt.Figure | None: """Plot some data of a (1D) mesh. You can pass "x", "y" or "z" to either of x_var or y_var to specify which spatial dimension should be used for the corresponding axis. You can also pass two data variables for a phase plot. if no value is given, automatic detection of spatial axis is tried. >>> line(mesh) # z=const: y over x, y=const: z over x, x=const: z over y >>> line(mesh, ot.variables.temperature) # temperature over x, y or z >>> line(mesh, "y", "temperature") # temperature over y >>> line(mesh, ot.variables.pressure, "y") # y over pressure >>> line(mesh, "pressure", "temperature") # temperature over pressure :param dataset: The mesh which contains the data to plot :param var1: Variable for the x-axis if var2 is given else for y-axis. :param var2: Variable for the y-axis if var1 is given. :param ax: The matplotlib axis to use for plotting, if None a new figure will be created. :param sort: Automatically sort the values along the dimension of the mesh with the largest extent Keyword Arguments: - figsize: figure size (default=[16, 10]) - color: color of the line - linewidth: width of the line - linestyle: style of the line - label: label in the legend - grid: if True, show grid - all other kwargs get passed to matplotlib's plot function """ if isinstance(var1, plt.Axes) or isinstance(var2, plt.Axes): msg = "Please provide ax as keyword argument only!" raise TypeError(msg) figsize = kwargs.pop("figsize", [16, 10]) ax_: plt.Axes ax_ = plt.subplots(figsize=figsize)[1] if ax is None else ax mesh = dataset[0] if isinstance(dataset, Sequence) else dataset x_var, y_var = normalize_vars(var1, var2, mesh) if isinstance(dataset, Sequence): color = kwargs.pop("colors", kwargs.pop("color", "tab10")) colorlist = utils.colors_from_cmap(color, len(dataset)) ax_.set_prop_cycle(color=colorlist) else: kwargs.setdefault("color", y_var.color) pure_spatial = y_var.data_name in "xyz" and x_var.data_name in "xyz" lw_scale = 4 if pure_spatial else 2.5 kwargs.setdefault("linewidth", kwargs.pop("lw", None) or setup.linewidth) kwargs["linewidth"] *= lw_scale labels = kwargs.pop("labels", kwargs.pop("label", None)) if sort and "time" not in [var1, var2]: sort_idx = np.argmax(np.abs(np.diff(np.reshape(mesh.bounds, (3, 2))))) sort_ids = np.argsort(mesh.points[:, sort_idx]) else: sort_ids = slice(None) x = x_var.transform(dataset)[..., sort_ids] y = y_var.transform(dataset)[..., sort_ids] if len(x.shape) < len(y.shape) and x.shape[0] != y.shape[0]: y = y.T if len(x.shape) > len(y.shape) and x.shape[0] != y.shape[0]: x = x.T if labels: kwargs["label"] = labels _format_ax(ax_, x_var, y_var, pure_spatial, kwargs) fontsize = kwargs.pop("fontsize", setup.fontsize) ax_.plot(x, y, **kwargs) if labels: ax_.legend(fontsize=fontsize) utils.update_font_sizes(axes=ax_, fontsize=fontsize) return ax_.figure if ax is None else None