# Copyright (c) 2012-2024, OpenGeoSys Community (http://www.opengeosys.org)
# Distributed under a Modified BSD License.
# See accompanying file LICENSE.txt or
# http://www.opengeosys.org/project/license
#
from __future__ import annotations
import warnings
from collections.abc import Callable, Iterator, Sequence
from copy import copy as shallowcopy
from copy import deepcopy
from pathlib import Path
from typing import Any, Literal, cast, overload
import meshio
import numpy as np
import pyvista as pv
from lxml import etree as ET
from matplotlib import pyplot as plt
from scipy.interpolate import (
LinearNDInterpolator,
NearestNDInterpolator,
interp1d,
)
from tqdm import tqdm
from typeguard import typechecked
from ogstools import plot
from ogstools.variables import Variable, normalize_vars, u_reg
from .data_dict import DataDict
from .mesh import Mesh
from .xdmf_reader import XDMFReader
[docs]
class MeshSeries(Sequence[Mesh]):
"""
A wrapper around pyvista and meshio for reading of pvd and xdmf timeseries.
"""
[docs]
def __init__(self, filepath: str | Path | None = None) -> None:
"""
Initialize a MeshSeries object
:param filepath: Path to the PVD or XDMF file.
:returns: A MeshSeries object
"""
self._spatial_factor = 1.0
self._time_factor = 1.0
self._epsilon = 1.0e-6
self._mesh_cache: dict[float, Mesh] = {}
self._mesh_func_opt: Callable[[Mesh], Mesh] | None = None
# list of slices to be able to have nested slices with xdmf
# (but only the first slice will be efficient)
self._time_indices: list[slice | Any] = [slice(None)]
self._timevalues: np.ndarray
"original data timevalues - do not change."
if filepath is None:
self.filepath = self._data_type = None
return
self.filepath = Path(filepath)
self._data_type = self.filepath.suffix
match self._data_type:
case ".pvd":
self._pvd_reader = pv.PVDReader(self.filepath)
self.timestep_files = [
str(self.filepath.parent / dataset.path)
for dataset in self._pvd_reader.datasets
]
self._timevalues = np.asarray(self._pvd_reader.time_values)
case ".xdmf" | ".xmf":
self._data_type = ".xdmf"
self._xdmf_reader = XDMFReader(self.filepath)
self._timevalues = np.asarray(
[
float(element.attrib["Value"])
for collection_i in self._xdmf_reader.collection
for element in collection_i
if element.tag == "Time"
]
)
case ".vtu":
self._vtu_reader = pv.XMLUnstructuredGridReader(self.filepath)
self._timevalues = np.zeros(1)
case suffix:
msg = (
"Can only read 'pvd', 'xdmf', 'xmf'(from Paraview) or "
f"'vtu' files, but not '{suffix}'"
)
raise TypeError(msg)
[docs]
@classmethod
def from_data(
cls, meshes: Sequence[pv.DataSet], timevalues: np.ndarray
) -> MeshSeries:
"Create a MeshSeries from a list of meshes and timevalues."
new_ms = cls()
new_ms._timevalues = deepcopy(timevalues) # pylint: disable=W0212
new_ms._mesh_cache.update(
dict(zip(new_ms._timevalues, deepcopy(meshes), strict=True))
)
return new_ms
[docs]
def extend(self, mesh_series: MeshSeries) -> None:
"""
Extends self with mesh_series.
If the last element of the mesh series is within epsilon
to the first element of mesh_series to extend, the duplicate element is removed
"""
ms1_list = list(self)
ms2_list = list(mesh_series)
ms1_timevalues = self.timevalues
ms2_timevalues = mesh_series.timevalues
offset = 0.0
delta = ms2_timevalues[0] - ms1_timevalues[-1]
offset = 0.0 if delta >= 0 else ms1_timevalues[-1]
if ((delta < 0) and (ms2_timevalues[0] == 0.0)) or (
np.abs(delta) < self._epsilon
):
ms1_timevalues = ms1_timevalues[:-1]
ms1_list = ms1_list[:-1]
ms2_timevalues = ms2_timevalues + offset
self._timevalues = np.append(ms1_timevalues, ms2_timevalues, axis=0)
self._mesh_cache.update(
dict(
zip(
np.append(ms1_timevalues, ms2_timevalues, axis=0),
ms1_list + ms2_list,
strict=True,
)
)
)
[docs]
@classmethod
def resample(
cls, original: MeshSeries, timevalues: np.ndarray
) -> MeshSeries:
"Return a new MeshSeries interpolated to the given timevalues."
interp_meshes = [original.read_interp(tv) for tv in timevalues]
return cls.from_data(interp_meshes, timevalues)
def __deepcopy__(self, memo: dict) -> MeshSeries:
# Deep copy is the default when using self.copy()
# For shallow copy: self.copy(deep=False)
cls = self.__class__
self_copy = cls.__new__(cls)
memo[id(self)] = self_copy
for key, value in self.__dict__.items():
if key != "_pvd_reader" and key != "_xdmf_reader":
if isinstance(value, pv.UnstructuredGrid):
# For PyVista objects use their own copy method
setattr(self_copy, key, value.copy(deep=True))
else:
# For everything that is neither reader nor PyVista object
# use the deepcopy
setattr(self_copy, key, deepcopy(value, memo))
else:
# Shallow copy of reader is needed, because timesteps are
# stored in reader, deep copy doesn't work for _pvd_reader
# and _xdmf_reader
setattr(self_copy, key, shallowcopy(value))
return self_copy
[docs]
def copy(self, deep: bool = True) -> MeshSeries:
"""
Create a copy of MeshSeries object.
Deep copy is the default.
:param deep: switch to choose between deep (default) and shallow
(self.copy(deep=False)) copy.
:returns: Copy of self.
"""
return deepcopy(self) if deep else shallowcopy(self)
@overload
def __getitem__(self, index: int) -> Mesh: ...
@overload
def __getitem__(self, index: slice | Sequence) -> MeshSeries: ...
@overload
def __getitem__(self, index: str | Variable) -> np.ndarray: ...
[docs]
def __getitem__(self, index: Any) -> Any:
if isinstance(index, int):
return self.mesh(index)
if isinstance(index, str):
return self.values(index)
if isinstance(index, slice | Sequence):
ms_copy = self.copy(deep=False)
if ms_copy._time_indices == [slice(None)]:
ms_copy._time_indices = [index]
else:
ms_copy._time_indices += [index]
return ms_copy
raise ValueError
def __len__(self) -> int:
return len(self.timesteps)
def __iter__(self) -> Iterator[Mesh]:
for i in np.arange(len(self.timevalues), dtype=int):
yield self.mesh(i)
def __str__(self) -> str:
if self._data_type == ".vtu":
reader = self._vtu_reader
elif self._data_type == ".pvd":
reader = self._pvd_reader
elif self._data_type == ".xdmf":
reader = self._xdmf_reader
else:
reader = "None"
return (
f"MeshSeries:\n"
f"filepath: {self.filepath}\n"
f"data_type: {self._data_type}\n"
f"timevalues: {self.timevalues[0]} to {self.timevalues[-1]} in {len(self.timevalues)} steps\n"
f"reader: {reader}\n"
f"rawdata_file: {self.rawdata_file()}\n"
)
# deliberately typing as Sequence and not as zip because typing as zip
# leads to a weird cross-referencing error from sphinx side with no easy
# apparent fix
[docs]
def items(self) -> Sequence[tuple[float, Mesh]]:
"Returns zipped tuples of timevalues and meshes."
return zip(self.timevalues, self, strict=True) # type: ignore[return-value]
[docs]
def aggregate_over_time(
self, variable: Variable | str, func: Callable
) -> Mesh:
"""Aggregate data over all timesteps using a specified function.
:param variable: The mesh variable to be aggregated.
:param func: The aggregation function to apply. E.g. np.min,
np.max, np.mean, np.median, np.sum, np.std, np.var
:returns: A mesh with aggregated data according to the given function.
"""
# TODO: add function to create an empty mesh from a given on
# custom field_data may need to be preserved
mesh = self.mesh(0).copy(deep=True)
mesh.clear_point_data()
mesh.clear_cell_data()
if isinstance(variable, Variable):
output_name = f"{variable.output_name}_{func.__name__}"
else:
output_name = f"{variable}_{func.__name__}"
mesh[output_name] = func(self.values(variable), axis=0)
return mesh
[docs]
def clear_cache(self) -> None:
self._mesh_cache.clear()
[docs]
def closest_timestep(self, timevalue: float) -> int:
"""Return the corresponding timestep from a timevalue."""
return int(np.argmin(np.abs(self.timevalues - timevalue)))
[docs]
def closest_timevalue(self, timevalue: float) -> float:
"""Return the closest timevalue to a timevalue."""
return self.timevalues[self.closest_timestep(timevalue)]
[docs]
def ip_tesselated(self) -> MeshSeries:
"Create a new MeshSeries from integration point tessellation."
ip_mesh = self.mesh(0).to_ip_mesh()
ip_pt_cloud = self.mesh(0).to_ip_point_cloud()
ordering = ip_mesh.find_containing_cell(ip_pt_cloud.points)
ip_meshes = []
for i in np.arange(len(self.timevalues), dtype=int):
ip_data = {
key: self.mesh(i).field_data[key][np.argsort(ordering)]
for key in ip_mesh.cell_data
}
ip_mesh.cell_data.update(ip_data)
ip_meshes += [ip_mesh.copy()]
return MeshSeries.from_data(ip_meshes, self.timevalues)
[docs]
def mesh(self, timestep: int, lazy_eval: bool = True) -> Mesh:
"""Returns the mesh at the given timestep."""
timevalue = self.timevalues[timestep]
if not np.any(self.timevalues == timevalue):
msg = f"Value {timevalue} not found in the array."
raise ValueError(msg)
data_timestep = np.argmax(
self._timevalues * self._time_factor == timevalue
)
if timevalue in self._mesh_cache:
mesh = self._mesh_cache[timevalue]
else:
match self._data_type:
case ".pvd":
pv_mesh = self._read_pvd(data_timestep)
case ".xdmf":
pv_mesh = self._read_xdmf(data_timestep)
case ".vtu":
pv_mesh = self._vtu_reader.read()
case _:
msg = f"Unexpected datatype {self._data_type}."
raise TypeError(msg)
mesh = Mesh(self.mesh_func(pv_mesh))
if lazy_eval:
self._mesh_cache[timevalue] = mesh
if self._data_type == ".pvd":
mesh.filepath = Path(self.timestep_files[data_timestep])
else:
mesh.filepath = self.filepath
return mesh
[docs]
def rawdata_file(self) -> Path | None:
"""
Checks, if working with the raw data is possible. For example,
OGS Simulation results with XDMF support efficient raw data access via
`h5py <https://docs.h5py.org/en/stable/quick.html#quick>`_
:returns: The location of the file containing the raw data. If it does not
support efficient read (e.g., no efficient slicing), it returns None.
"""
if self._data_type == ".xdmf" and self._xdmf_reader.has_fast_access():
return self._xdmf_reader.rawdata_path() # single h5 file
return None
[docs]
def read_interp(self, timevalue: float, lazy_eval: bool = True) -> Mesh:
"""Return the temporal interpolated mesh for a given timevalue."""
t_vals = self.timevalues
ts1 = int(t_vals.searchsorted(timevalue, "right") - 1)
ts2 = min(ts1 + 1, len(t_vals) - 1)
if np.isclose(timevalue, t_vals[ts1]):
return self.mesh(ts1, lazy_eval)
mesh1 = self.mesh(ts1, lazy_eval)
mesh2 = self.mesh(ts2, lazy_eval)
mesh = mesh1.copy(deep=True)
for key in mesh1.point_data:
if np.all(mesh1.point_data[key] == mesh2.point_data[key]):
continue
dt = t_vals[ts2] - t_vals[ts1]
slope = (mesh2.point_data[key] - mesh1.point_data[key]) / dt
mesh.point_data[key] = mesh1.point_data[key] + slope * (
timevalue - t_vals[ts1]
)
return mesh
@property
def timevalues(self) -> np.ndarray:
"Return the timevalues."
vals = self._timevalues
for index in self._time_indices:
vals = vals[index]
return vals * self._time_factor
@property
def timesteps(self) -> list:
"""
Return the OGS simulation timesteps of the timeseries data.
Not to be confused with timevalues which returns a list of
times usually given in time units.
"""
# TODO: read time steps from fn string if available
return np.arange(len(self.timevalues), dtype=int)
def _xdmf_values(self, variable_name: str) -> np.ndarray:
dataitems = self._xdmf_reader.data_items[variable_name]
# pv filters produces these arrays, which we can use for slicing
# to also reflect the previous use of self.transform here
mask_map = {
"vtkOriginalPointIds": self.mesh(0).point_data,
"vtkOriginalCellIds": self.mesh(0).cell_data,
}
for mask, data in mask_map.items():
if variable_name in data and mask in data:
result = dataitems[self._time_indices[0], self.mesh(0)[mask]]
break
else:
result = dataitems[self._time_indices[0]]
for index in self._time_indices[1:]:
result = result[index]
if self._mesh_func_opt is not None and not any(
mask in data for mask, data in mask_map.items()
):
# if transform function doesn't produce the mask arrays we have to
# map data xdmf data to the entire list of meshes and apply the
# function on each mesh individually.
ms_copy = self.copy(deep=True)
ms_copy._mesh_func_opt = None # pylint: disable=protected-access
ms_copy.clear_cache()
raw_meshes = list(ms_copy)
for mesh, data in zip(raw_meshes, result, strict=True):
mesh[variable_name] = data
meshes = list(map(self.mesh_func, raw_meshes))
result = np.asarray([mesh[variable_name] for mesh in meshes])
return result
[docs]
def values(self, variable: str | Variable) -> np.ndarray:
"""
Get the data in the MeshSeries for all timesteps.
Adheres to time slicing via `__get_item__` and an applied pyvista filter
via `transform` if the applied filter produced 'vtkOriginalPointIds' or
'vtkOriginalCellIds' (e.g. `clip(..., crinkle=True)`,
`extract_cells(...)`, `threshold(...)`.)
:param variable: Variable to read/process from the MeshSeries.
:returns: A numpy array of shape (n_timesteps, n_points/c_cells).
If given an argument of type Variable is given, its
transform function is applied on the data.
"""
if isinstance(variable, Variable):
if variable.mesh_dependent:
return np.asarray([variable.transform(mesh) for mesh in self])
variable_name = variable.data_name
else:
variable_name = variable
all_cached = self._is_all_cached
if (
self._data_type == ".xdmf"
and variable_name in self._xdmf_reader.data_items
and not all_cached
):
result = self._xdmf_values(variable_name)
else:
result = np.asarray(
[mesh[variable_name] for mesh in tqdm(self, disable=all_cached)]
)
if isinstance(variable, Variable):
return variable.transform(result)
return result
@property
def _is_all_cached(self) -> bool:
"Check if all meshes in this meshseries are cached"
return np.isin(
self.timevalues, np.fromiter(self._mesh_cache.keys(), float)
).all()
@property
def point_data(self) -> DataDict:
"Useful for reading or setting point_data for the entire meshseries."
return DataDict(self, lambda mesh: mesh.point_data)
@property
def cell_data(self) -> DataDict:
"Useful for reading or setting cell_data for the entire meshseries."
return DataDict(self, lambda mesh: mesh.cell_data)
@property
def field_data(self) -> DataDict:
"Useful for reading or setting field_data for the entire meshseries."
return DataDict(self, lambda mesh: mesh.field_data)
def _read_pvd(self, timestep: int) -> pv.UnstructuredGrid:
self._pvd_reader.set_active_time_point(timestep)
return self._pvd_reader.read()[0]
def _read_xdmf(self, timestep: int) -> pv.UnstructuredGrid:
points, cells = self._xdmf_reader.read_points_cells()
_, point_data, cell_data, field_data = self._xdmf_reader.read_data(
timestep
)
meshio_mesh = meshio.Mesh(
points, cells, point_data, cell_data, field_data
)
# pv.from_meshio does not copy field_data (fix in pyvista?)
pv_mesh = pv.from_meshio(meshio_mesh)
pv_mesh.field_data.update(field_data)
return pv_mesh
def _time_of_extremum(
self,
variable: Variable | str,
np_func: Callable,
prefix: Literal["min", "max"],
) -> Mesh:
"""Returns a Mesh with the time of a given variable extremum as data.
The data is named as `f'{prefix}_{variable.output_name}_time'`."""
mesh = self.mesh(0).copy(deep=True)
variable = Variable.find(variable, mesh)
mesh.clear_point_data()
mesh.clear_cell_data()
output_name = f"{prefix}_{variable.output_name}_time"
mesh[output_name] = self.timevalues[
np_func(self.values(variable), axis=0)
]
return mesh
[docs]
def time_of_min(self, variable: Variable | str) -> Mesh:
"Returns a Mesh with the time of the variable minimum as data."
return self._time_of_extremum(variable, np.argmin, "min")
[docs]
def time_of_max(self, variable: Variable | str) -> Mesh:
"Returns a Mesh with the time of the variable maximum as data."
return self._time_of_extremum(variable, np.argmax, "max")
[docs]
def aggregate_over_domain(
self, variable: Variable | str, func: Callable
) -> np.ndarray:
"""Aggregate data over domain per timestep using a specified function.
:param variable: The mesh variable to be aggregated.
:param func: The aggregation function to apply. E.g. np.min,
np.max, np.mean, np.median, np.sum, np.std, np.var
:returns: A numpy array with aggregated data.
"""
return func(self.values(variable), axis=1)
[docs]
def plot_domain_aggregate(
self,
variable: Variable | str,
func: Callable,
ax: plt.Axes | None = None,
**kwargs: Any,
) -> plt.Figure | None:
"""
Plot the transient aggregated data over the domain per timestep.
:param variable: The mesh variable to be aggregated.
:param func: The aggregation function to apply. E.g. np.min,
np.max, np.mean, np.median, np.sum, np.std, np.var
:param ax: matplotlib axis to use for plotting
:returns: A matplotlib Figure or None if plotting on existing axis.
Keyword arguments get passed to `matplotlib.pyplot.plot`
"""
variable = Variable.find(variable, self.mesh(0))
values = self.aggregate_over_domain(variable.magnitude, func)
x_values = self.timevalues
x_label = f"time t / {plot.setup.time_unit}"
if ax is None:
fig, ax = plt.subplots()
else:
fig = None
if "label" in kwargs:
label = kwargs.pop("label")
ylabel = variable.get_label() + " " + func.__name__
else:
label = func.__name__
ylabel = variable.get_label()
ax.plot(x_values, values, label=label, **kwargs)
ax.set_axisbelow(True)
ax.grid(which="major", color="lightgrey", linestyle="-")
ax.grid(which="minor", color="0.95", linestyle="--")
ax.set_xlabel(x_label)
ax.set_ylabel(ylabel)
ax.legend()
ax.label_outer()
ax.minorticks_on()
return fig
[docs]
def probe(
self,
points: np.ndarray,
data_name: str,
interp_method: Literal["nearest", "linear"] = "linear",
) -> np.ndarray:
"""
Probe the MeshSeries at observation points.
:param points: The observation points to sample at.
:param data_name: Name of the data to sample.
:param interp_method: Interpolation method, defaults to `linear`
:returns: `numpy` array of interpolated data at observation points
with the following shape:
- multiple points: (n_timesteps, n_points, [n_components])
- single points: (n_timesteps, [n_components])
"""
pts = np.asarray(points).reshape((-1, 3))
values = np.swapaxes(self.values(data_name), 0, 1)
geom = self.mesh(0).points
if values.shape[0] != geom.shape[0]:
# assume cell_data
geom = self.mesh(0).cell_centers().points
# remove flat dimensions for interpolation
flat_axis = np.argwhere(np.all(np.isclose(geom, geom[0]), axis=0))
geom = np.delete(geom, flat_axis, 1)
pts = np.delete(pts, flat_axis, 1)
dim = int(np.max([cell.dimension for cell in self.mesh(0).cell]))
match dim > 1, interp_method:
case True, "nearest":
result = np.swapaxes(
NearestNDInterpolator(geom, values)(pts), 0, 1
)
case True, "linear":
result = np.swapaxes(
LinearNDInterpolator(geom, values, np.nan)(pts), 0, 1
)
case False, kind:
result = interp1d(geom[:, 0], values.T, kind=kind)(
np.squeeze(pts, 1)
)
case _, _:
msg = (
"No interpolation method implemented for mesh of "
f"{dim=} and {interp_method=}"
)
raise TypeError(msg)
if np.shape(points)[0] != 1 and np.shape(result)[1] == 1:
result = np.squeeze(result, axis=1)
return result
[docs]
def plot_probe(
self,
points: np.ndarray,
variable: Variable | str,
variable_abscissa: Variable | str | None = None,
labels: list[str] | None = None,
interp_method: Literal["nearest", "linear"] = "linear",
colors: list | None = None,
linestyles: list | None = None,
ax: plt.Axes | None = None,
fill_between: bool = False,
**kwargs: Any,
) -> plt.Figure | None:
"""
Plot the transient variable on the observation points in the MeshSeries.
:param points: The points to sample at.
:param variable: The variable to be sampled.
:param labels: The labels for each observation point.
:param interp_method: Choose the interpolation method, defaults to
`linear` for xdmf MeshSeries and
`probefilter` for pvd MeshSeries.
:param interp_backend: Interpolation backend for PVD MeshSeries.
Keyword Arguments get passed to `matplotlib.pyplot.plot`
"""
points = np.asarray(points).reshape((-1, 3))
variable = Variable.find(variable, self.mesh(0))
values = variable.magnitude.transform(
self.probe(points, variable.data_name, interp_method)
)
if values.shape[0] == 1:
values = values.ravel()
if variable_abscissa is None:
x_values = self.timevalues
x_label = f"time / {plot.setup.time_unit}"
else:
variable_abscissa = Variable.find(variable_abscissa, self.mesh(0))
x_values = variable_abscissa.magnitude.transform(
self.probe(points, variable_abscissa.data_name, interp_method)
)
x_unit_str = (
f" / {variable_abscissa.get_output_unit}"
if variable_abscissa.get_output_unit
else ""
)
x_label = (
variable_abscissa.output_name.replace("_", " ") + x_unit_str
)
if ax is None:
fig, ax = plt.subplots()
else:
fig = None
if points.shape[0] > 1:
ax.set_prop_cycle(
plot.utils.get_style_cycler(len(points), colors, linestyles)
)
if fill_between:
ax.fill_between(
x_values,
np.min(values, axis=-1),
np.max(values, axis=-1),
label=labels,
**kwargs,
)
else:
ax.plot(x_values, values, label=labels, **kwargs)
if labels is not None:
ax.legend(
facecolor="white", framealpha=1, prop={"family": "monospace"}
)
ax.set_axisbelow(True)
ax.grid(which="major", color="lightgrey", linestyle="-")
ax.grid(which="minor", color="0.95", linestyle="--")
ax.set_xlabel(x_label)
ax.set_ylabel(variable.magnitude.get_label(plot.setup.label_split))
ax.label_outer()
ax.minorticks_on()
return fig
[docs]
def plot_time_slice(
self,
x: Literal["x", "y", "z", "time"],
y: Literal["x", "y", "z", "time"],
variable: str | Variable,
time_logscale: bool = False,
fig: plt.Figure | None = None,
ax: plt.Axes | None = None,
cbar: bool = True,
**kwargs: Any,
) -> plt.Figure | None:
"""
Create a heatmap for a variable over time and space.
:param x: What to display on the x-axis (x, y, z or time)
:param y: What to display on the y-axis (x, y, z or time)
:param variable: The variable to be visualized.
:param time_logscale: Should log-scaling be applied to the time-axis?
:param fig: matplotlib figure to use for plotting.
:param ax: matplotlib axis to use for plotting.
:param cbar: If True, adds a colorbar.
Keyword Arguments:
- cb_labelsize: colorbar labelsize
- cb_loc: colorbar location ('left' or 'right')
- cb_pad: colorbar padding
- cmap: colormap
- vmin: minimum value for colorbar
- vmax: maximum value for colorbar
- num_levels: number of levels for colorbar
- figsize: figure size
- dpi: resolution
"""
if ax is None and fig is None:
fig, ax = plt.subplots(
figsize=kwargs.get("figsize", [18, 14]),
dpi=kwargs.get("dpi", 100),
)
optional_return_figure = fig
elif ax is None or fig is None:
msg = "Please provide fig and ax together or not at all."
raise ValueError(msg)
else:
optional_return_figure = None
if "time" not in [x, y]:
msg = "One of x_var and y_var has to be 'time'."
raise KeyError(msg)
if x not in "xyz" and y not in "xyz":
msg = "One of x_var and y_var has to be a spatial coordinate."
raise KeyError(msg)
var_z = Variable.find(variable, self.mesh(0))
var_x, var_y = normalize_vars(x, y, self.mesh(0))
if time_logscale:
def log10time(vals: np.ndarray) -> np.ndarray:
log10vals = np.log10(vals, where=vals != 0)
if log10vals[0] == 0:
log10vals[0] = log10vals[1] - (log10vals[2] - log10vals[1])
return log10vals
time_var = var_x if var_x.data_name == "time" else var_y
get_time = time_var.func
time_var.func = lambda ms, _: log10time(get_time(ms, _))
time_var.output_name = f"log10 {time_var.output_name}"
x_vals = var_x.transform(self)
y_vals = var_y.transform(self)
values = self.values(var_z)
if values.shape == (len(x_vals), len(y_vals)):
values = values.T
if "levels" in kwargs:
levels = np.asarray(kwargs.pop("levels"))
else:
vmin, vmax = (plot.setup.vmin, plot.setup.vmax)
levels = plot.levels.compute_levels(
kwargs.get("vmin", np.nanmin(values) if vmin is None else vmin),
kwargs.get("vmax", np.nanmax(values) if vmax is None else vmax),
kwargs.get("num_levels", plot.setup.num_levels),
)
cmap, norm = plot.utils.get_cmap_norm(levels, var_z, **kwargs)
ax.pcolormesh(x_vals, y_vals, values, cmap=cmap, norm=norm)
fontsize = kwargs.get("fontsize", plot.setup.fontsize)
plot.utils.label_ax(fig, ax, var_x, var_y, fontsize)
ax.tick_params(axis="both", labelsize=fontsize, length=fontsize * 0.5)
if cbar:
plot.contourplots.add_colorbars(fig, ax, var_z, levels, **kwargs)
plot.utils.update_font_sizes(fig.axes, fontsize)
return optional_return_figure
@property
def mesh_func(self) -> Callable[[Mesh], Mesh]:
"""Returns stored transformation function or identity if not given."""
if self._mesh_func_opt is None:
return lambda mesh: mesh.scale(self._spatial_factor)
return lambda mesh: Mesh(
self._mesh_func_opt(mesh).scale( # type: ignore[misc]
self._spatial_factor
)
)
[docs]
def scale(
self,
spatial: float | tuple[str, str] = 1.0,
time: float | tuple[str, str] = 1.0,
) -> MeshSeries:
"""Scale the spatial coordinates and timevalues.
Useful to convert to other units, e.g. "m" to "km" or "s" to "a".
If given as tuple of strings, the latter units will also be set in
ot.plot.setup.spatial_unit and ot.plot.setup.time_unit for plotting.
:param spatial: Float factor or a tuple of str (from_unit, to_unit).
:param time: Float factor or a tuple of str (from_unit, to_unit).
"""
Qty = u_reg.Quantity
if isinstance(spatial, float):
spatial_factor = spatial
else:
spatial_factor = Qty(Qty(spatial[0]), spatial[1]).magnitude
plot.setup.spatial_unit = spatial[1]
if isinstance(time, float):
time_factor = time
else:
time_factor = Qty(Qty(time[0]), time[1]).magnitude
plot.setup.time_unit = time[1]
new_ms = self.copy()
new_ms._spatial_factor *= spatial_factor
new_ms._time_factor *= time_factor
scaled_cache = {
timevalue * time_factor: Mesh(mesh.scale(spatial_factor))
for timevalue, mesh in new_ms._mesh_cache.items()
}
new_ms.clear_cache()
new_ms._mesh_cache = scaled_cache
return new_ms
def _rename_vtufiles(self, new_pvd_fn: Path, fns: list[Path]) -> list:
fns_new: list[Path] = []
assert self.filepath is not None
for filename in fns:
filepathparts_at_timestep = list(filename.parts)
filepathparts_at_timestep[-1] = filepathparts_at_timestep[
-1
].replace(
self.filepath.name.split(".")[0],
new_pvd_fn.name.split(".")[0],
)
fns_new.append(Path(*filepathparts_at_timestep))
return fns_new
def _save_vtu(
self, new_pvd_fn: Path, fns: list[Path], ascii: bool = False
) -> None:
for i, timestep in enumerate(self.timesteps):
if ".vtu" in fns[i].name:
pv.save_meshio(
Path(new_pvd_fn.parent, fns[i].name),
self.mesh(i),
binary=not ascii,
)
elif ".xdmf" in fns[i].name:
newname = fns[i].name.replace(
".xdmf", f"_ts_{timestep}_t_{self.timevalues[i]}.vtu"
)
pv.save_meshio(Path(new_pvd_fn.parent, newname), self.mesh(i))
else:
s = "File type not supported."
raise RuntimeError(s)
def _save_pvd(self, new_pvd_fn: Path, fns: list[Path]) -> None:
root = ET.Element("VTKFile")
root.attrib["type"] = "Collection"
root.attrib["version"] = "0.1"
root.attrib["byte_order"] = "LittleEndian"
root.attrib["compressor"] = "vtkZLibDataCompressor"
collection = ET.SubElement(root, "Collection")
for i, timestep in enumerate(self.timevalues):
timestepselement = ET.SubElement(collection, "DataSet")
timestepselement.attrib["timestep"] = str(timestep)
timestepselement.attrib["group"] = ""
timestepselement.attrib["part"] = "0"
if ".xdmf" in fns[i].name:
newname = fns[i].name.replace(
".xdmf", f"_ts_{self.timesteps[i]}_t_{timestep}.vtu"
)
timestepselement.attrib["file"] = newname
elif ".vtu" in fns[i].name:
timestepselement.attrib["file"] = fns[i].name
else:
s = "File type not supported."
raise RuntimeError(s)
tree = ET.ElementTree(root)
tree.write(
new_pvd_fn,
encoding="ISO-8859-1",
xml_declaration=True,
pretty_print=True,
)
def _check_path(self, filename: Path | None) -> Path:
if not isinstance(filename, Path):
s = "filename is empty"
raise RuntimeError(s)
assert isinstance(filename, Path)
return cast(Path, filename)
[docs]
def save(
self, filename: str, deep: bool = True, ascii: bool = False
) -> None:
"""
Save mesh series to disk.
:param filename: Filename to save the series to. Extension specifies
the file type. Currently only PVD is supported.
:param deep: Specifies whether VTU/H5 files should be written.
:param ascii: Specifies if ascii or binary format should be used,
defaults to binary (False) - True for ascii.
"""
fn = Path(filename)
fns = [
self._check_path(self.mesh(t).filepath)
for t in np.arange(len(self.timevalues), dtype=int)
]
if ".pvd" in fn.name:
if deep is True:
fns = self._rename_vtufiles(fn, fns)
self._save_vtu(fn, fns, ascii=ascii)
self._save_pvd(fn, fns)
else:
s = "Currently the save method is implemented for PVD/VTU only."
raise RuntimeError(s)
[docs]
def remove_array(
self, name: str, data_type: str = "field", skip_last: bool = False
) -> None:
"""
Removes an array from all time slices of the mesh series.
:param name: Array name
:param data_type: Data type of the array. Could be either
field, cell or point
:param skip_last: Skips the last time slice (e.g. for restart purposes).
"""
depr_msg = (
"Please use `del meshseries.field_data[key]` or "
"`del meshseries[:-1].field_data[key]` if you want to keep the "
"data in the last timestep"
)
warnings.warn(depr_msg, DeprecationWarning, stacklevel=1)
for i, mesh in enumerate(self):
if ((skip_last) is False) or (i < len(self) - 1):
if data_type == "field":
mesh.field_data.remove(name)
elif data_type == "cell":
mesh.cell_data.remove(name)
elif data_type == "point":
mesh.point_data.remove(name)
else:
msg = "array type unknown"
raise RuntimeError(msg)