Source code for ogstools.plot.lineplots

# SPDX-FileCopyrightText: Copyright (c) OpenGeoSys Community (opengeosys.org)
# SPDX-License-Identifier: BSD-3-Clause


from collections.abc import Sequence
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pyvista as pv
from matplotlib.figure import Figure

from ogstools.plot import setup, utils
from ogstools.variables import Variable, _normalize_vars


def _format_ax(
    ax: plt.Axes,
    x_var: Variable,
    y_var: Variable,
    show_grid: bool,
) -> None:
    if ax.get_xlabel() == "":
        ax.set_xlabel(x_var.get_label(setup.label_split))
    if ax.get_ylabel() == "":
        ax.set_ylabel(y_var.get_label(setup.label_split))

    if show_grid:
        ax.grid(which="major", color="lightgrey", linestyle="-")
        ax.grid(which="minor", color="0.95", linestyle="--")
        ax.minorticks_on()


def _separate_by_empty_cells(
    mesh: pv.DataSet, *arrays: list[np.ndarray]
) -> None:
    if "vtkGhostType" not in mesh.cell_data:
        return
    mask = (
        mesh.ctp().point_data.get("vtkGhostType", np.zeros(mesh.n_points)) != 0
    )
    if not all(len(mask) == len(arr) for arr in arrays):
        return
    for array in arrays:
        array[mask] = np.nan


[docs] def line( dataset: pv.DataSet | Sequence[pv.DataSet], var1: str | Variable | None = None, var2: str | Variable | None = None, ax: plt.Axes | None = None, sort: bool = True, outer_legend: bool | tuple[float, float] = False, **kwargs: Any, ) -> Figure | None: """Plot some data of a (1D) dataset. You can pass "x", "y" or "z" to either of x_var or y_var to specify which spatial dimension should be used for the corresponding axis. By passing "time" the timevalues will be use for this axis. You can also pass two data variables for a phase plot. if no value is given, automatic detection of spatial axis is tried. >>> line(ms, ot.variables.temperature) # temperature over time >>> line(ms, ot.variables.temperature, "time") # time over temperature >>> line(ms, "pressure", "temperature") # temperature over pressure >>> line(mesh, ot.variables.temperature) # temperature over x, y or z >>> line(mesh, "y", "temperature") # temperature over y >>> line(mesh, ot.variables.pressure, "y") # y over pressure >>> line(mesh) # z=const: y over x, y=const: z over x, x=const: z over y :param dataset: The mesh or meshseries which contains the data to plot. :param var1: Variable for the x-axis if var2 is given else for y-axis. :param var2: Variable for the y-axis if var1 is given. :param ax: The matplotlib axis to use for plotting, if None a new figure will be created. :param sort: Automatically sort the values along the dimension of the mesh with the largest extent (only for pointclouds). :outer_legend: Draw legend to the right next to the plot area. By default False (legend stays inside). User can pass a tuple of two floats (x, y), which will be passed to bbox_to_anchor parameter in matplotlib legend call. True will pass the default values (1.05, 1.0). Keyword Arguments: - figsize: figure size (default=[16, 10]) - dpi: resolution of the figure - color: color of the line - linewidth: width of the line - linestyle: style of the line - label: label in the legend - grid: if True, show grid - monospace: if True, the legend uses a monospace font - loc: location of the legend (default="upper right") - clip_on: If True, clip the output to stay within the Axes. (default=False) - all other kwargs get passed to matplotlib's plot function Note: Using loc="best" will take a long time, if you plot lines on top of a contourplot, as matplotlib is calculating the best position against all the underlying cells. """ ##### prepare figure/ax ################################################## if isinstance(var1, plt.Axes) or isinstance(var2, plt.Axes): msg = "Please provide ax as keyword argument only!" raise TypeError(msg) figsize = kwargs.pop("figsize", [16, 10]) dpi = kwargs.pop("dpi", None) ax_: plt.Axes ax_ = plt.subplots(figsize=figsize, dpi=dpi)[1] if ax is None else ax ##### process variables ################################################## is_meshseries = isinstance(dataset, Sequence) mesh: pv.DataSet = dataset[0] if is_meshseries else dataset default = ["time", "time"] if is_meshseries else ["x", "y", "z"] x_var, y_var = _normalize_vars(var1, var2, mesh, default) pure_spatial = y_var.data_name in "xyz" and x_var.data_name in "xyz" # prefer point data over cell data x_cell_data = (x_var.data_name in mesh.cell_data) and ( x_var.data_name not in mesh.point_data ) y_cell_data = (y_var.data_name in mesh.cell_data) and ( y_var.data_name not in mesh.point_data ) ##### kwargs processing ################################################## if is_meshseries and "color" not in kwargs: color = kwargs.pop("colors", "tab10") colorlist = utils.colors_from_cmap(color, len(dataset)) ax_.set_prop_cycle(color=colorlist) else: kwargs.setdefault("color", y_var.color) ax_.set_prop_cycle(linestyle=["-", "--", ":", "-."]) lw_scale = 4 if pure_spatial else 2.5 kwargs.setdefault("linewidth", kwargs.pop("lw", None) or setup.linewidth) kwargs.setdefault("clip_on", True) kwargs["linewidth"] *= lw_scale labels = kwargs.pop("labels", kwargs.pop("label", None)) if labels: kwargs["label"] = labels outer_bool = outer_legend is True or isinstance(outer_legend, tuple) if outer_bool: loc = "upper left" else: loc = kwargs.pop("loc", "upper right" if pure_spatial else "best") fontsize = kwargs.pop("fontsize", setup.fontsize) prop = {"size": fontsize} if kwargs.pop("monospace", False): prop["family"] = "monospace" show_grid = kwargs.pop("grid", True) and not pure_spatial _format_ax(ax_, x_var, y_var, show_grid) ##### prepare data for plotting ########################################## x = x_var.transform(dataset) y = y_var.transform(dataset) if "vtkGhostType" in mesh.cell_data: x = x.astype(float) y = y.astype(float) _separate_by_empty_cells(mesh, x, y) # transposing to get individual lines int the plot in the case of plotting # linesamples for multiple timesteps or timeseries of multiple points if len(x.shape) < len(y.shape) and x.shape[0] != y.shape[0]: y = y.T if len(x.shape) > len(y.shape) and x.shape[0] != y.shape[0]: x = x.T def sorted_ids( mesh: pv.DataSet, use_cells: bool = False ) -> slice | np.ndarray: if is_meshseries or not sort: return slice(None) sort_idx = np.argmax(np.abs(np.diff(np.reshape(mesh.bounds, (3, 2))))) mesh_ = mesh.cell_centers() if use_cells else mesh return np.argsort(mesh_.points[:, sort_idx]) ##### plotting ########################################################### cell_types = np.unique( getattr(mesh, "celltypes", {cell.type for cell in mesh.cell}) ) only_points = cell_types in [{0}, {1}] surf: pv.PolyData = mesh.extract_surface(algorithm="dataset_surface") strip: pv.PolyData = surf.strip() if is_meshseries or only_points or strip.n_cells <= 1: x_sort_ids = sorted_ids(mesh=mesh, use_cells=x_cell_data) if x_cell_data == y_cell_data: # pure cell or point data ax_.plot(x[x_sort_ids], y[x_sort_ids], **kwargs) elif x_cell_data or y_cell_data: if mesh.n_cells != mesh.n_points - 1: msg = "Line Plot of CellData vs. PointData for cells with inner points currently not supported!" raise ValueError(msg) # mixed point data and cell data - special case y_sort_ids = sorted_ids(mesh=mesh, use_cells=y_cell_data) def prepare_data(data: np.ndarray, use_cells: bool) -> np.ndarray: if use_cells: # repeat the cell data to map it to the start and end point of the cell return np.repeat(data, 2) # only repeat inner points return np.concatenate( [[data[0]], np.repeat(data[1:-1], 2), [data[-1]]] ) x_plot_vals = prepare_data(x[x_sort_ids], x_cell_data) y_plot_vals = prepare_data(y[y_sort_ids], y_cell_data) ax_.plot(x_plot_vals, y_plot_vals, **kwargs) else: kwargs.setdefault("linestyle", kwargs.pop("ls", "-")) orig_ids = np.arange(mesh.n_points, dtype=np.int32) if x_cell_data or y_cell_data: msg = "Plotting CellData for interrupted lines currently not supported! Convert CellData to PointData to use this function." raise ValueError(msg) for cell_id, linestrip in enumerate(strip.cell): sort_ids = strip.cell_data.get("vtkOriginalPointIds", orig_ids)[ linestrip.point_ids ] label = kwargs.get("label") if cell_id == 0 else None ax_.plot(x[sort_ids], y[sort_ids], **{**kwargs, "label": label}) ##### leged and final formatting ######################################### if labels is not None: leg_prop: dict[str, Any] = {"loc": loc} if outer_legend is True: outer_legend = (1.05, 1.0) if isinstance(outer_legend, tuple): leg_prop["bbox_to_anchor"] = outer_legend leg_prop["borderaxespad"] = 0.0 ax_.legend(prop=prop, **leg_prop) utils.update_font_sizes(axes=ax_, fontsize=fontsize) if not pure_spatial: ax_.figure.tight_layout() return ax_.figure if ax is None else None