Source code for ogstools.meshplotlib.core

"""Meshplotlib core utilitites."""
import types
from typing import Optional as Opt
from typing import Union

import numpy as np
import pyvista as pv
from matplotlib import cm as mcm
from matplotlib import colormaps, rcParams
from matplotlib import colors as mcolors
from matplotlib import figure as mfigure
from matplotlib import pyplot as plt
from matplotlib import ticker as mticker
from matplotlib import transforms as mtransforms
from matplotlib.patches import Rectangle as Rect

from ogstools.propertylib import Property, Vector
from ogstools.propertylib.presets import _resolve_property

from . import plot_features as pf
from . import setup
from .levels import get_levels

# TODO: define default data_name for regions in setup


def _q_zero_line(property: Property, levels: np.ndarray):
    return property.bilinear_cmap or (
        property.data_name == "temperature" and levels[0] < 0 < levels[-1]
    )


[docs]def get_data(mesh: pv.UnstructuredGrid, property: Property) -> np.ndarray: """Get the data associated with a scalar or vector property from a mesh.""" if property.data_name in mesh.point_data: return mesh.point_data[property.data_name] if property.data_name in mesh.cell_data: return mesh.cell_data[property.data_name] msg = f"Property not found in mesh {mesh}." raise IndexError(msg)
[docs]def get_level_boundaries(levels: np.ndarray): return np.array( [ levels[0] - 0.5 * (levels[1] - levels[0]), *0.5 * (levels[:-1] + levels[1:]), levels[-1] + 0.5 * (levels[-1] - levels[-2]), ] )
[docs]def get_cmap_norm( levels: np.ndarray, property: Property ) -> tuple[mcolors.Colormap, mcolors.Normalize]: """Construct a discrete colormap and norm for the property field.""" vmin, vmax = (levels[0], levels[-1]) bilinear = property.bilinear_cmap and vmin <= 0.0 <= vmax cmap_str = setup.cmap_str(property) if property.is_mask(): conti_cmap = mcolors.ListedColormap(cmap_str) elif isinstance(cmap_str, list): conti_cmap = [colormaps[c] for c in cmap_str] else: conti_cmap = colormaps[cmap_str] if property.data_name == "temperature": cool_colors = conti_cmap[0](np.linspace(0, 0.75, 128 * (vmin < 0))) warm_colors = conti_cmap[1](np.linspace(0, 1, 128 * (vmax >= 0))) conti_cmap = mcolors.LinearSegmentedColormap.from_list( "temperature_cmap", np.vstack((cool_colors, warm_colors)) ) bilinear = vmin < 0 < vmax if bilinear: vmin, vmax = np.max(np.abs([vmin, vmax])) * np.array([-1.0, 1.0]) if property.categoric: vmin += 0.5 vmax += 0.5 conti_norm = mcolors.Normalize(vmin=vmin, vmax=vmax) mid_levels = np.append((levels[:-1] + levels[1:]) * 0.5, levels[-1]) colors = [conti_cmap(conti_norm(m_l)) for m_l in mid_levels] cmap = mcolors.ListedColormap(colors, name="custom") if setup.custom_cmap is not None: cmap = setup.custom_cmap boundaries = get_level_boundaries(levels) if property.categoric else levels norm = mcolors.BoundaryNorm( boundaries=boundaries, ncolors=len(boundaries), clip=True ) return cmap, norm
# to fix scientific offset position # https://github.com/matplotlib/matplotlib/issues/4476#issuecomment-105627334
[docs]def fix_scientific_offset_position(axis, func): axis._update_offset_text_position = types.MethodType(func, axis)
[docs]def y_update_offset_text_position(self, bboxes, bboxes2): # noqa: ARG001 x, y = self.offsetText.get_position() # y in axes coords, x in display coords self.offsetText.set_transform( mtransforms.blended_transform_factory( self.axes.transAxes, mtransforms.IdentityTransform() ) ) top = self.axes.bbox.ymax y = top + 2 * self.OFFSETTEXTPAD * self.figure.dpi / 72.0 self.offsetText.set_position((x, y))
[docs]def add_colorbars( fig: mfigure.Figure, ax: Union[plt.Axes, list[plt.Axes]], property: Property, levels: np.ndarray, pad: float = 0.05, ) -> None: """Add a colorbar to the matplotlib figure.""" cmap, norm = get_cmap_norm(levels, property) cm = mcm.ScalarMappable(norm=norm, cmap=cmap) categoric = property.categoric or (len(levels) == 2) if categoric: bounds = get_level_boundaries(levels) ticks = bounds[:-1] + 0.5 * np.diff(bounds) else: ticks = levels cb = fig.colorbar( cm, norm=norm, ax=ax, ticks=ticks, drawedges=True, location="right", spacing="uniform", pad=pad, format="%.3g" # fmt: skip ) if setup.invert_colorbar: cb.ax.invert_yaxis() if property.is_mask(): cb.ax.add_patch(Rect((0, 0.5), 1, -1, lw=0, fc="none", hatch="/")) if not categoric and setup.log_scaled: levels = 10**levels unit_str = ( f" / {property.get_output_unit()}" if property.get_output_unit() else "" ) cb.set_label( property.output_name.replace("_", " ") + unit_str, size=setup.rcParams_scaled["font.size"], ) cb.ax.tick_params( labelsize=setup.rcParams_scaled["font.size"], direction="out" ) mf = mticker.ScalarFormatter(useMathText=True, useOffset=True) mf.set_scientific(True) mf.set_powerlimits([-3, 3]) fix_scientific_offset_position(cb.ax.yaxis, y_update_offset_text_position) cb.ax.yaxis.set_offset_position("left") cb.ax.yaxis.set_major_formatter(mf) if _q_zero_line(property, levels): cb.ax.axhline( y=0, color="w", lw=2 * setup.rcParams_scaled["lines.linewidth"] ) if setup.log_scaled: cb.ax.set_yticklabels(10**ticks) if property.data_name == "MaterialIDs" and setup.material_names is not None: region_names = [] for mat_id in levels: if mat_id in setup.material_names: region_names += [setup.material_names[mat_id]] else: region_names += [mat_id] cb.ax.set_yticklabels(region_names) cb.ax.set_ylabel("") elif property.categoric: cb.ax.set_yticklabels(levels.astype(int))
[docs]def subplot( mesh: pv.UnstructuredGrid, property: Union[Property, str], ax: plt.Axes, levels: Opt[np.ndarray] = None, ) -> None: """ Plot the property field of a mesh on a matplotlib.axis. In 3D the mesh gets sliced according to slice_type and the origin in the PlotSetup in meshplotlib.setup. Custom levels and a colormap string can be provided. """ if isinstance(property, str): data_shape = mesh[property].shape property = _resolve_property(property, data_shape) if mesh.get_cell(0).dimension == 3: msg = "meshplotlib is for 2D meshes only, but found 3D elements." raise ValueError(msg) ax.axis("auto") if ( not property.is_mask() and property.mask in mesh.cell_data and len(mesh.cell_data[property.mask]) ): subplot(mesh, property.get_mask(), ax) mesh = mesh.ctp(True).threshold(value=[1, 1], scalars=property.mask) surf_tri = mesh.triangulate().extract_surface() # get projection mean_normal = np.abs(np.mean(mesh.extract_surface().cell_normals, axis=0)) projection = int(np.argmax(mean_normal)) x_id, y_id = np.delete([0, 1, 2], projection) # faces contains a padding indicating number of points per face which gets # removed with this reshaping and slicing to get the array of tri's x, y = setup.length.strip_units(surf_tri.points.T[[x_id, y_id]]) tri = surf_tri.faces.reshape((-1, 4))[:, 1:] values = property.magnitude.strip_units(get_data(surf_tri, property)) if setup.log_scaled: values_temp = np.where(values > 1e-14, values, 1e-14) values = np.log10(values_temp) p_min, p_max = np.nanmin(values), np.nanmax(values) if levels is None: num_levels = min(setup.num_levels, len(np.unique(values))) levels = get_levels(p_min, p_max, num_levels) cmap, norm = get_cmap_norm(levels, property) if ( property.data_name in mesh.cell_data and property.data_name not in mesh.point_data ): ax.tripcolor(x, y, tri, facecolors=values, cmap=cmap, norm=norm) if property.is_mask(): ax.tripcolor(x, y, tri, facecolors=values, mask=(values == 1), cmap=cmap, norm=norm, hatch="/") # fmt: skip else: ax.tricontourf(x, y, tri, values, levels=levels, cmap=cmap, norm=norm) if _q_zero_line(property, levels): ax.tricontour(x, y, tri, values, levels=[0], colors="w") surf = mesh.extract_surface() if setup.show_region_bounds and "MaterialIDs" in mesh.cell_data: pf.plot_layer_boundaries(ax, surf, projection) show_edges = setup.show_element_edges if isinstance(setup.show_element_edges, str): show_edges = setup.show_element_edges == property.data_name if show_edges: pf.plot_element_edges(ax, surf, projection) if isinstance(property, Vector): pf.plot_streamlines(ax, surf_tri, property, projection) ax.margins(0, 0) # otherwise it shrinks the plot content if abs(max(mean_normal) - 1) > 1e-6: sec_id = np.argmax(np.delete(mean_normal, projection)) sec_labels = [] for tick in ax.get_xticks(): origin = np.array(mesh.center) origin[sec_id] = min( max(tick, mesh.bounds[2 * sec_id] + 1e-6), mesh.bounds[2 * sec_id + 1] - 1e-6, ) sec_mesh = mesh.slice("xyz"[sec_id], origin) if sec_mesh.n_cells: sec_labels += [f"{sec_mesh.bounds[2 * projection]:.1f}"] else: sec_labels += [""] # TODO: use a function to make this short secax = ax.secondary_xaxis("top") secax.xaxis.set_major_locator(mticker.FixedLocator(ax.get_xticks())) secax.set_xticklabels(sec_labels) secax.set_xlabel(f'{"xyz"[projection]} / {setup.length.output_unit}') x_label = setup.x_label or f'{"xyz"[x_id]} / {setup.length.output_unit}' y_label = setup.y_label or f'{"xyz"[y_id]} / {setup.length.output_unit}' ax.set_xlabel(x_label) ax.set_ylabel(y_label)
def _get_rows_cols( meshes: Union[ list[pv.UnstructuredGrid], np.ndarray, pv.UnstructuredGrid, pv.MultiBlock, ] ) -> tuple[int, ...]: if isinstance(meshes, np.ndarray): if meshes.ndim in [1, 2]: return meshes.shape msg = "Input numpy array must be 1D or 2D." raise ValueError(msg) if isinstance(meshes, list): return (1, len(meshes)) if isinstance(meshes, pv.MultiBlock): return (1, meshes.n_blocks) return (1, 1) # TODO: fixed_figure_size -> ax aspect automatic def _fig_init(rows: int, cols: int, ax_aspect: float = 1.0) -> mfigure.Figure: nx_cb = 1 if setup.combined_colorbar else cols default_size = 8 cb_width = 4 y_label_width = 2 x_label_height = 1 figsize = setup.fig_scale * np.asarray( [ default_size * cols * ax_aspect + cb_width * nx_cb + y_label_width, default_size * rows + x_label_height, ] ) if figsize[0] / figsize[1] > setup.fig_aspect_limits[1]: figsize[0] = figsize[1] * setup.fig_aspect_limits[1] elif figsize[0] / figsize[1] < setup.fig_aspect_limits[0]: figsize[0] = figsize[1] * setup.fig_aspect_limits[0] fig, _ = plt.subplots( rows, cols, dpi=setup.dpi * setup.fig_scale, figsize=figsize, layout=setup.layout, sharex=True, sharey=True, ) fig.patch.set_alpha(1) return fig
[docs]def get_combined_levels( meshes: np.ndarray, property: Union[Property, str] ) -> np.ndarray: """ Calculate well spaced levels for the encompassing property range in meshes. """ if isinstance(property, str): data_shape = meshes[0][property].shape property = _resolve_property(property, data_shape) p_min, p_max = np.inf, -np.inf unique_vals = np.array([]) for mesh in np.ravel(meshes): values = property.magnitude.strip_units(get_data(mesh, property)) if setup.log_scaled: # TODO: can be improved values = np.log10(np.where(values > 1e-14, values, 1e-14)) p_min = min(p_min, np.nanmin(values)) if setup.p_min is None else p_min p_max = max(p_max, np.nanmax(values)) if setup.p_max is None else p_max unique_vals = np.unique( np.concatenate((unique_vals, np.unique(values))) ) p_min = setup.p_min if setup.p_min is not None else p_min p_max = setup.p_max if setup.p_max is not None else p_max if p_min == p_max: return np.array([p_min, p_max + 1e-12]) if ( all(val.is_integer() for val in unique_vals) and setup.p_min is None and setup.p_max is None ): return unique_vals[(p_min <= unique_vals) & (unique_vals <= p_max)] return get_levels(p_min, p_max, setup.num_levels)
def _plot_on_figure( fig: mfigure.Figure, meshes: Union[list[pv.UnstructuredGrid], np.ndarray, pv.UnstructuredGrid], property: Property, ) -> mfigure.Figure: """ Plot the property field of meshes on existing figure. :param meshes: Singular mesh of 2D numpy array of meshes :param property: the property field to be visualized on all meshes """ shape = _get_rows_cols(meshes) np_meshes = np.reshape(meshes, shape) np_axs = np.reshape(fig.axes, shape) if setup.combined_colorbar: combined_levels = get_combined_levels(np_meshes, property) for i in range(shape[0]): for j in range(shape[1]): _levels = ( combined_levels if setup.combined_colorbar else get_combined_levels(np_meshes[i, j], property) ) subplot(np_meshes[i, j], property, np_axs[i, j], _levels) np_axs[0, 0].set_title(setup.title_center, loc="center", y=1.02) np_axs[0, 0].set_title(setup.title_left, loc="left", y=1.02) np_axs[0, 0].set_title(setup.title_right, loc="right", y=1.02) # make extra space for the upper limit of the colorbar if setup.layout == "tight": plt.tight_layout(pad=1.4) if setup.combined_colorbar: cb_axs = np.ravel(fig.axes).tolist() add_colorbars( fig, cb_axs, property, combined_levels, pad=0.05 / shape[1] ) else: for i in range(shape[0]): for j in range(shape[1]): _levels = get_combined_levels(np_meshes[i, j], property) add_colorbars(fig, np_axs[i, j], property, _levels) return fig
[docs]def get_data_aspect(mesh: pv.DataSet) -> float: """ Calculate the data aspect ratio of a 2D mesh. """ mean_normal = np.abs(np.mean(mesh.extract_surface().cell_normals, axis=0)) projection = int(np.argmax(mean_normal)) x_id, y_id = 2 * np.delete([0, 1, 2], projection) lims = mesh.bounds return abs(lims[x_id + 1] - lims[x_id]) / abs(lims[y_id + 1] - lims[y_id])
# TODO: add as arguments: cmap, limits # TODO: num_levels should be min_levels
[docs]def plot( meshes: Union[list[pv.UnstructuredGrid], np.ndarray, pv.UnstructuredGrid], property: Union[Property, str], ) -> mfigure.Figure: """ Plot the property field of meshes with default settings. The resulting figure adheres to the configurations in meshplotlib.setup. For 2D, the whole domain, for 3D a set of slices is displayed. :param meshes: Singular mesh of 2D numpy array of meshes :param property: The property field to be visualized on all meshes """ rcParams.update(setup.rcParams_scaled) shape = _get_rows_cols(meshes) _meshes = np.reshape(meshes, shape).flatten() if isinstance(property, str): data_shape = _meshes[0][property].shape property = _resolve_property(property, data_shape) data_aspects = np.asarray([get_data_aspect(mesh) for mesh in _meshes]) data_aspects[ (data_aspects > setup.ax_aspect_limits[0]) & (data_aspects < setup.ax_aspect_limits[1]) ] = 1.0 clamped_aspects = np.clip(data_aspects, *setup.ax_aspect_limits) ax_aspects = data_aspects / clamped_aspects _fig = _fig_init( rows=shape[0], cols=shape[1], ax_aspect=np.mean(ax_aspects) ) n_axs = shape[0] * shape[1] # setting the aspect twice is intended # the first time results in properly spaced ticks # the second time fixes any deviations from the set aspect due to # additional colorbar(s) or secondary axes. for ax, aspect in zip(_fig.axes[: n_axs + 1], clamped_aspects): ax.set_aspect(aspect) fig = _plot_on_figure(_fig, meshes, property) for ax, aspect in zip(fig.axes[: n_axs + 1], clamped_aspects): ax.set_aspect(aspect) return fig