# SPDX-FileCopyrightText: Copyright (c) OpenGeoSys Community (opengeosys.org)
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
from collections.abc import Callable, Iterator, Sequence
from copy import copy as shallowcopy
from copy import deepcopy
from pathlib import Path
from typing import Any, Literal, cast, overload
from warnings import warn
import numpy as np
import pyvista as pv
from matplotlib import pyplot as plt
from typeguard import typechecked
from ogstools import plot
from ogstools._internal import copy_function_signature, deprecated
from ogstools.core.storage import StorageBase
from ogstools.mesh import read
from ogstools.mesh.utils import reshape_obs_points
from ogstools.plot.lineplots import line
from ogstools.variables import Variable, _normalize_vars, u_reg
from .data_dict import DataDict
[docs]
class MeshSeries(Sequence[pv.UnstructuredGrid], StorageBase):
"""
A wrapper around pyvista and meshio for reading of pvd and xdmf timeseries.
"""
__hash__ = None
[docs]
def __init__(
self,
filepath: str | Path | None = None,
spatial_unit: str | Sequence[str] = "m",
time_unit: str | Sequence[str] = "s",
id: str | None = None,
) -> None:
"""
Initialize a MeshSeries object
:param filepath: Path to the PVD or XDMF file.
:param spatial_unit: Unit/s of the mesh points. See note.
:param time_unit: Unit/s of the timevalues. See note.
:returns: A MeshSeries object
:note:
If given as a single string, the data is read in SI units i.e. in
seconds and meters and converted to the given units. If given as a
tuple, the first str corresponds to the data_unit, the second to the
output_unit. E.g.: ``ot.MeshSeries(filepath, "km", ("a", "d"))``
would read in the spatial data in meters and convert to kilometers
and read in the timevalues in years and convert to days.
"""
# TODO: if filepath extension = Path(filepath).suffix to also support xdmf
super().__init__("MeshSeries", file_ext="pvd", id=id)
self._time_factor = 1.0
self._epsilon = 1.0e-6
self._mesh_cache: dict[float, pv.UnstructuredGrid] = {}
self._mesh_func_opt: (
Callable[[pv.UnstructuredGrid], pv.UnstructuredGrid] | None
) = None
# list of slices to be able to have nested slices with xdmf
# (but only the first slice will be efficient)
self._time_indices: list[slice | Any] = [slice(None)]
self._timevalues: np.ndarray
"original data timevalues - do not change."
base_spatial, spatial_unit = (
("m", spatial_unit)
if isinstance(spatial_unit, str)
else spatial_unit
)
base_time, time_unit = (
("s", time_unit) if isinstance(time_unit, str) else time_unit
)
self.spatial_unit = u_reg.Quantity(1, base_spatial)
self.time_unit = u_reg.Quantity(1, base_time)
if (self.spatial_unit != u_reg.Quantity(1, spatial_unit) or
self.time_unit != u_reg.Quantity(1, time_unit)): # fmt: skip
self.scale(spatial_unit, time_unit)
if filepath is None:
self.filepath = self._data_type = None
return
self.filepath = Path(filepath)
self._data_type = self.filepath.suffix
match self._data_type:
case ".pvd":
self._pvd_reader = pv.PVDReader(self.filepath)
self.timestep_files = [
str(self.filepath.parent / dataset.path)
for dataset in self._pvd_reader.datasets
]
self._timevalues = np.asarray(self._pvd_reader.time_values)
self.skip_pvtu2vtu = False
case ".xdmf" | ".xmf":
from .xdmf_reader import XDMFReader
self._data_type = ".xdmf"
self._xdmf_reader = XDMFReader(self.filepath)
self._timevalues = np.asarray(
[
float(element.attrib["Value"])
for collection_i in self._xdmf_reader.collection
for element in collection_i
if element.tag == "Time"
]
)
case ".vtu":
self._vtu_reader = pv.XMLUnstructuredGridReader(self.filepath)
self._timevalues = np.zeros(1)
case suffix:
msg = (
"Can only read 'pvd', 'xdmf', 'xmf'(from Paraview) or "
f"'vtu' files, but not '{suffix}'"
)
raise TypeError(msg)
self.dim = self.mesh(0).GetMaxSpatialDimension()
[docs]
@classmethod
def from_data(
cls,
meshes: Sequence[pv.UnstructuredGrid],
timevalues: np.ndarray,
spatial_unit: str | Sequence[str] = "m",
time_unit: str | Sequence[str] = "s",
) -> MeshSeries:
"Create a MeshSeries from a list of meshes and timevalues."
new_ms = cls(spatial_unit=spatial_unit, time_unit=time_unit)
new_ms._timevalues = deepcopy(timevalues) # pylint: disable=W0212
new_ms._mesh_cache.update(
dict(zip(new_ms.timevalues, deepcopy(meshes), strict=True))
)
new_ms.dim = meshes[0].GetMaxSpatialDimension()
return new_ms
[docs]
@classmethod
def from_id(cls, meshseries_id: str) -> MeshSeries:
"""
Load MeshSeries from the user storage path using its ID.
StorageBase.Userpath must be set.
:param meshseries_id: The unique ID of the MeshSeries to load.
:returns: A MeshSeries instance restored from disk.
"""
# Try .pvd first, then .xdmf
pvd_file = (
StorageBase.saving_path()
/ "MeshSeries"
/ f"{meshseries_id}"
/ "ms.pvd"
)
xdmf_file = (
StorageBase.saving_path()
/ "MeshSeries"
/ f"{meshseries_id}"
/ "ms.xdmf"
)
if pvd_file.exists():
meshseries = cls(filepath=pvd_file)
elif xdmf_file.exists():
meshseries = cls(filepath=xdmf_file)
else:
msg = f"No MeshSeries found at {pvd_file} or {xdmf_file}"
raise FileNotFoundError(msg)
meshseries._id = meshseries_id
return meshseries
[docs]
def extend(self, mesh_series: MeshSeries) -> None:
"""
Extends self with mesh_series.
If the last element of the mesh series is within epsilon
to the first element of mesh_series to extend, the duplicate element is removed
"""
ms1_list = list(self)
ms2_list = list(mesh_series)
ms1_timevalues = self.timevalues
ms2_timevalues = mesh_series.timevalues
if hasattr(self, "timestep_files"):
ms1_timestep_files = self.timestep_files
else:
ms1_timestep_files = [""] * len(ms1_list)
if hasattr(mesh_series, "timestep_files"):
ms2_timestep_files = mesh_series.timestep_files
else:
ms2_timestep_files = [""] * len(ms2_list)
offset = 0.0
delta = ms2_timevalues[0] - ms1_timevalues[-1]
offset = 0.0 if delta >= 0 else ms1_timevalues[-1]
if ((delta < 0) and (ms2_timevalues[0] == 0.0)) or (
np.abs(delta) < self._epsilon
):
ms1_timevalues = ms1_timevalues[:-1]
ms1_list = ms1_list[:-1]
ms1_timestep_files = ms1_timestep_files[:-1]
ms2_timevalues = ms2_timevalues + offset
self._timevalues = np.append(ms1_timevalues, ms2_timevalues, axis=0)
self._mesh_cache.update(
dict(
zip(
np.append(ms1_timevalues, ms2_timevalues, axis=0),
ms1_list + ms2_list,
strict=True,
)
)
)
if hasattr(self, "timestep_files") or hasattr(
mesh_series, "timestep_files"
):
self.timestep_files = ms1_timestep_files + ms2_timestep_files
assert len(self.timestep_files) == len(
self._timevalues
), "Timestep files and timevalues do not match."
[docs]
def resample_temporal(self, timevalues: np.ndarray) -> MeshSeries:
"Return a new MeshSeries interpolated to the given timevalues."
interp_meshes = [self.mesh_interp(tv) for tv in timevalues]
return MeshSeries.from_data(
interp_meshes,
timevalues,
(self.spatial_unit, self.spatial_unit),
(self.time_unit, self.time_unit),
)
def _extract_probe(
self,
pts_or_mesh: np.ndarray | pv.DataSet,
data_name: str | Variable | list[str | Variable] | None = None,
interp_method: Literal["nearest", "linear"] = "linear",
) -> MeshSeries:
if isinstance(pts_or_mesh, pv.DataSet):
points = pts_or_mesh.points
mesh = pts_or_mesh
else:
points = np.asarray(pts_or_mesh, dtype=float)
mesh = pv.PolyData(np.asarray(points))
if data_name is None:
variables = list(set().union(self[0].point_data, self[0].cell_data))
elif isinstance(data_name, list):
variables = data_name
else:
variables = [data_name]
meshes = [mesh.copy() for _ in self.timevalues]
probe_ms = MeshSeries.from_data(
meshes,
self.timevalues,
(str(self.spatial_unit), str(self.spatial_unit)),
(str(self.time_unit), str(self.time_unit)),
)
point_data_keys = [key for key in variables if key in self.point_data]
cell_data_keys = [key for key in variables if key in self.cell_data]
for keys, data in [
(point_data_keys, probe_ms.point_data),
(cell_data_keys, probe_ms.cell_data),
]:
if len(keys) == 0:
continue
values = self.probe_values(points, keys, interp_method)
for var, vals in zip(keys, values, strict=True):
name = var.output_name if isinstance(var, Variable) else var
data[name] = vals
return probe_ms
[docs]
def probe(
self,
points: np.ndarray,
data_name: str | Variable | list[str | Variable] | None = None,
interp_method: Literal["nearest", "linear"] = "linear",
) -> MeshSeries:
"""Create a new MeshSeries by probing points on an existing MeshSeries.
:param points: The points at which to probe.
:param data_name: Data to extract. If None, use all point data.
:param interp_method: The interpolation method to use.
:returns: A MeshSeries (Pointcloud) containing the probed data.
"""
return self._extract_probe(points, data_name, interp_method)
[docs]
def interpolate(
self,
mesh: pv.DataSet,
data_name: str | Variable | list[str | Variable] | None = None,
) -> MeshSeries:
"""Create a new MeshSeries by spatial interpolation.
:param mesh: The mesh on which to interpolate.
:param data_name: Data to extract. If None, use all point data.
:returns: A spatially interpolated MeshSeries.
"""
return self._extract_probe(mesh, data_name, "linear")
def __eq__(self, other: object) -> bool:
if not isinstance(other, MeshSeries):
return NotImplemented
# TODO:: Field data is currently not copied, to not break copy/eq - contract -> False
return MeshSeries.compare(self, other, field_data=False)
def __deepcopy__(self, memo: dict) -> MeshSeries:
# Deep copy is the default when using self.copy()
# For shallow copy: self.copy(deep=False)
self_copy = self.__class__(self.filepath)
memo[id(self)] = self_copy
for key, value in self.__dict__.items():
if key == "_next_target" or key == "id":
pass
elif key != "_pvd_reader" and key != "_xdmf_reader":
if isinstance(value, pv.UnstructuredGrid):
# For PyVista objects use their own copy method
setattr(self_copy, key, value.copy(deep=True))
else:
# For everything that is neither reader nor PyVista object
# use the deepcopy
setattr(self_copy, key, deepcopy(value, memo))
else:
# Shallow copy of reader is needed, because timesteps are
# stored in reader, deep copy doesn't work for _pvd_reader
# and _xdmf_reader
setattr(self_copy, key, shallowcopy(value))
return self_copy
@overload
def __getitem__(self, index: int) -> pv.UnstructuredGrid: ...
@overload
def __getitem__(self, index: slice | Sequence) -> MeshSeries: ...
@overload
def __getitem__(self, index: str) -> np.ndarray: ...
[docs]
def __getitem__(self, index: Any) -> Any:
if isinstance(
index,
int | np.unsignedinteger | np.signedinteger,
):
return self.mesh(index)
if isinstance(index, str): # type: ignore[unreachable]
return self.values(index)
if isinstance(index, slice | Sequence):
ms_copy = self.copy(deep=False)
if ms_copy._time_indices == [slice(None)]:
ms_copy._time_indices = [index]
else:
ms_copy._time_indices += [index]
return ms_copy
msg = (
"Index type not supported.\n "
"It has to be one of following: "
"int, np.int, str, slice or sequence."
)
raise ValueError(msg)
def __len__(self) -> int:
return len(self.timesteps)
def __iter__(self) -> Iterator[pv.UnstructuredGrid]:
for i in np.arange(len(self.timevalues), dtype=int):
yield self.mesh(i)
def __str__(self) -> str:
if self._data_type == ".vtu":
reader = self._vtu_reader
elif self._data_type == ".pvd":
reader = self._pvd_reader
elif self._data_type == ".xdmf":
reader = self._xdmf_reader
else:
reader = "None"
return (
f"MeshSeries:\n"
f"filepath: {self._format_path(self.filepath)}\n"
f"data_type: {self._data_type}\n"
f"timevalues: {self.timevalues[0]} to {self.timevalues[-1]} in {len(self.timevalues)} steps\n"
f"reader: {reader}\n"
f"rawdata_file: {self._format_path(self.rawdata_file())}\n"
)
# deliberately typing as Sequence and not as zip because typing as zip
# leads to a weird cross-referencing error from sphinx side with no easy
# apparent fix
[docs]
def items(self) -> Sequence[tuple[float, pv.UnstructuredGrid]]:
"Returns zipped tuples of timevalues and meshes."
return zip(self.timevalues, self, strict=True) # type: ignore[return-value]
[docs]
def aggregate_temporal(
self, variable: Variable | str, func: Callable
) -> pv.UnstructuredGrid:
"""Aggregate data over all timesteps using a specified function.
:param variable: The mesh variable to be aggregated.
:param func: The aggregation function to apply. E.g. np.min,
np.max, np.mean, np.median, np.sum, np.std, np.var
:returns: A mesh with aggregated data according to the given function.
"""
# TODO: add function to create an empty mesh from a given on
# custom field_data may need to be preserved
mesh = self.mesh(0).copy(deep=True)
mesh.clear_point_data()
mesh.clear_cell_data()
if isinstance(variable, Variable):
output_name = f"{variable.output_name}_{func.__name__}"
else:
output_name = f"{variable}_{func.__name__}"
mesh[output_name] = func(self.values(variable), axis=0)
return mesh
[docs]
def clear_cache(self) -> None:
self._mesh_cache.clear()
[docs]
def closest_timestep(self, timevalue: float) -> int:
"""Return the corresponding timestep from a timevalue."""
return int(np.argmin(np.abs(self.timevalues - timevalue)))
[docs]
def closest_timevalue(self, timevalue: float) -> float:
"""Return the closest timevalue to a timevalue."""
return self.timevalues[self.closest_timestep(timevalue)]
[docs]
def ip_tesselated(self) -> MeshSeries:
"Create a new MeshSeries from integration point tessellation."
from ogstools.mesh.ip_mesh import to_ip_mesh, to_ip_point_cloud
ip_mesh = to_ip_mesh(self.mesh(0))
ip_pt_cloud = to_ip_point_cloud(self.mesh(0))
ordering = ip_mesh.find_containing_cell(ip_pt_cloud.points)
ip_meshes = []
for i in np.arange(len(self.timevalues), dtype=int):
ip_data = {
key: self.mesh(i).field_data[key][np.argsort(ordering)]
for key in ip_mesh.cell_data
}
ip_mesh.cell_data.update(ip_data)
ip_meshes += [ip_mesh.copy()]
return MeshSeries.from_data(ip_meshes, self.timevalues)
[docs]
def mesh(
self, timestep: int, lazy_eval: bool = True
) -> pv.UnstructuredGrid:
"""Returns the mesh at the given timestep."""
timevalue = self.timevalues[timestep]
if not np.any(self.timevalues == timevalue):
msg = f"Value {timevalue} not found in the array."
raise ValueError(msg)
data_timestep = np.argmax(
self._timevalues * self._time_factor == timevalue
)
if timevalue in self._mesh_cache:
mesh = self._mesh_cache[timevalue]
else:
match self._data_type:
case ".pvd":
suffix = self.timestep_files[data_timestep].split(".")[-1]
if suffix == "pvtu" and not self.skip_pvtu2vtu:
from ogstools._find_ogs import cli
from ogstools.core.storage import _date_temp_path
tmp_path = _date_temp_path("pvtu2vtu", "vtu")
tmp_path.parent.mkdir(parents=True)
cli().pvtu2vtu(
i=self.timestep_files[data_timestep], o=tmp_path
)
pv_mesh = read(tmp_path)
else:
pv_mesh = self._read_pvd(data_timestep)
case ".xdmf":
pv_mesh = self._read_xdmf(data_timestep)
case ".vtu":
pv_mesh = self._vtu_reader.read()
case _:
msg = f"Unexpected datatype {self._data_type}."
raise TypeError(msg)
mesh = self.mesh_func(pv_mesh)
if lazy_eval:
self._mesh_cache[timevalue] = mesh
filepath = (
Path(self.timestep_files[data_timestep])
if self._data_type == ".pvd"
else self.filepath
)
for attr, val in [
("filepath", filepath),
("spatial_unit", self.spatial_unit),
("time_unit", self.time_unit),
]:
if not hasattr(pv, "set_new_attribute") or hasattr(mesh, attr):
setattr(mesh, attr, val)
else:
pv.set_new_attribute(mesh, attr, val)
return mesh
[docs]
def rawdata_file(self) -> Path | None:
"""
Checks, if working with the raw data is possible. For example,
OGS Simulation results with XDMF support efficient raw data access via
`h5py <https://docs.h5py.org/en/stable/quick.html#quick>`_
:returns: The location of the file containing the raw data. If it does not
support efficient read (e.g., no efficient slicing), it returns None.
"""
if self._data_type == ".xdmf" and self._xdmf_reader.has_fast_access():
return self._xdmf_reader.rawdata_path() # single h5 file
return None
[docs]
def mesh_interp(
self, timevalue: float, lazy_eval: bool = True
) -> pv.UnstructuredGrid:
"""Return the temporal interpolated mesh for a given timevalue."""
t_vals = self.timevalues
ts1 = int(t_vals.searchsorted(timevalue, "right") - 1)
ts2 = min(ts1 + 1, len(t_vals) - 1)
if np.isclose(timevalue, t_vals[ts1]):
return self.mesh(ts1, lazy_eval)
mesh1 = self.mesh(ts1, lazy_eval)
mesh2 = self.mesh(ts2, lazy_eval)
mesh = mesh1.copy(deep=True)
# TODO interpolate cell_data and field_data as well
for key in mesh1.point_data:
if key not in mesh2.point_data:
msg = f"{key} not in timestep {ts2}."
warn(msg, RuntimeWarning, stacklevel=2)
continue
if np.all(mesh1.point_data[key] == mesh2.point_data[key]):
continue
dt = t_vals[ts2] - t_vals[ts1]
slope = (mesh2.point_data[key] - mesh1.point_data[key]) / dt
mesh.point_data[key] = mesh1.point_data[key] + slope * (
timevalue - t_vals[ts1]
)
return mesh
@property
def timevalues(self) -> np.ndarray:
"Return the timevalues."
vals = self._timevalues
for index in self._time_indices:
vals = vals[index]
return vals * self._time_factor
@property
def timesteps(self) -> list:
"""
Return the OGS simulation timesteps of the timeseries data.
Not to be confused with timevalues which returns a list of
times usually given in time units.
"""
# TODO: read time steps from fn string if available
return np.arange(len(self.timevalues), dtype=int)
def _xdmf_values(self, variable_name: str) -> np.ndarray:
dataitems = self._xdmf_reader.data_items[variable_name]
# pv filters produces these arrays, which we can use for slicing
# to also reflect the previous use of self.transform here
mask_map = {
"vtkOriginalPointIds": self.mesh(0).point_data,
"vtkOriginalCellIds": self.mesh(0).cell_data,
}
for mask, data in mask_map.items():
if variable_name in data and mask in data:
result = dataitems[self._time_indices[0], self.mesh(0)[mask]]
break
else:
result = dataitems[self._time_indices[0]]
for index in self._time_indices[1:]:
result = result[index]
if self._mesh_func_opt is not None and not any(
mask in data for mask, data in mask_map.items()
):
# if transform function doesn't produce the mask arrays we have to
# map data xdmf data to the entire list of meshes and apply the
# function on each mesh individually.
ms_copy = self.copy(deep=True)
ms_copy._mesh_func_opt = None # pylint: disable=protected-access
ms_copy.clear_cache()
raw_meshes = list(ms_copy)
for mesh, data in zip(raw_meshes, result, strict=True):
mesh[variable_name] = data
meshes = list(map(self.mesh_func, raw_meshes))
result = np.asarray([mesh[variable_name] for mesh in meshes])
return result
@overload
def values(self, variable: str | Variable) -> np.ndarray: ...
@overload
def values(self, variable: list[str | Variable]) -> list[np.ndarray]: ...
[docs]
def values(
self, variable: str | Variable | list[str | Variable]
) -> np.ndarray | list[np.ndarray]:
"""
Get the data in the MeshSeries for all timesteps.
Adheres to time slicing via `__getitem__` and an applied pyvista filter
via `transform` if the applied filter produced 'vtkOriginalPointIds' or
'vtkOriginalCellIds' (e.g. `clip(..., crinkle=True)`,
`extract_cells(...)`, `threshold(...)`.)
:param variable: Data to read/process from the MeshSeries. Can also
be a list of str or Variable.
:returns: A numpy array of shape (n_timesteps, n_points/c_cells).
If given an argument of type Variable is given, its
transform function is applied on the data. If a list of
str or Variable is given, a list of the individual values
is returned.
"""
if isinstance(variable, list):
return [self._values(var) for var in variable]
return self._values(variable)
def _values(self, variable: str | Variable) -> np.ndarray:
if isinstance(variable, Variable):
if variable.mesh_dependent:
return np.asarray([variable.transform(mesh) for mesh in self])
if (
variable.data_name != variable.output_name
and variable.output_name
in set().union(self.point_data, self.cell_data)
):
variable_name = variable.output_name
do_transform = False
else:
variable_name = variable.data_name
do_transform = True
else:
variable_name = variable
all_cached = self._is_all_cached
if (
self._data_type == ".xdmf"
and variable_name in self._xdmf_reader.data_items
and not all_cached
):
result = self._xdmf_values(variable_name)
else:
from tqdm import tqdm
result = np.asarray(
[
mesh[variable_name]
for mesh in (self if all_cached else tqdm(self))
]
)
if isinstance(variable, Variable) and do_transform:
return variable.transform(result)
return result
@property
def _is_all_cached(self) -> bool:
"Check if all meshes in this meshseries are cached"
return np.isin(
self.timevalues, np.fromiter(self._mesh_cache.keys(), float)
).all()
@property
def point_data(self) -> DataDict:
"Useful for reading or setting point_data for the entire meshseries."
return DataDict(self, lambda mesh: mesh.point_data, self[0].n_points)
@property
def cell_data(self) -> DataDict:
"Useful for reading or setting cell_data for the entire meshseries."
return DataDict(self, lambda mesh: mesh.cell_data, self[0].n_cells)
@property
def field_data(self) -> DataDict:
"Useful for reading or setting field_data for the entire meshseries."
return DataDict(self, lambda mesh: mesh.field_data, None)
def _read_pvd(self, timestep: int) -> pv.UnstructuredGrid:
self._pvd_reader.set_active_time_point(timestep)
return self._pvd_reader.read()[0]
def _read_xdmf(self, timestep: int) -> pv.UnstructuredGrid:
import meshio
points, cells = self._xdmf_reader.read_points_cells()
_, point_data, cell_data, field_data = self._xdmf_reader.read_data(
timestep
)
meshio_mesh = meshio.Mesh(
points, cells, point_data, cell_data, field_data
)
# pv.from_meshio does not copy field_data (fix in pyvista?)
pv_mesh = pv.from_meshio(meshio_mesh)
pv_mesh.field_data.update(field_data)
return pv_mesh
def _time_of_extremum(
self,
variable: Variable | str,
np_func: Callable,
prefix: Literal["min", "max"],
) -> pv.UnstructuredGrid:
"""Returns a Mesh with the time of a given variable extremum as data.
The data is named as `f'{prefix}_{variable.output_name}_time'`."""
mesh = self.mesh(0).copy(deep=True)
variable = Variable.find(variable, mesh)
mesh.clear_point_data()
mesh.clear_cell_data()
output_name = f"{prefix}_{variable.output_name}_time"
mesh[output_name] = self.timevalues[
np_func(self.values(variable), axis=0)
]
return mesh
[docs]
def time_of_min(self, variable: Variable | str) -> pv.UnstructuredGrid:
"Returns a Mesh with the time of the variable minimum as data."
return self._time_of_extremum(variable, np.argmin, "min")
[docs]
def time_of_max(self, variable: Variable | str) -> pv.UnstructuredGrid:
"Returns a Mesh with the time of the variable maximum as data."
return self._time_of_extremum(variable, np.argmax, "max")
[docs]
def aggregate_spatial(
self, variable: Variable | str, func: Callable
) -> np.ndarray:
"""Aggregate data over domain per timestep using a specified function.
:param variable: The mesh variable to be aggregated.
:param func: The aggregation function to apply. E.g. np.min,
np.max, np.mean, np.median, np.sum, np.std, np.var
:returns: A numpy array with aggregated data.
"""
return func(self.values(variable), axis=1)
def _flatten_vectors(self, data: list[np.ndarray]) -> list[np.ndarray]:
return [
vals[..., i] if vals.ndim > 2 else vals
for vals in data
for i in range(vals.shape[-1] if vals.ndim > 2 else 1)
]
def _restore_vectors(
self, data: list[np.ndarray], components: list[int]
) -> list[np.ndarray]:
original_list = []
idx = 0
for comp in components:
stacked_array = np.stack(data[idx : idx + comp], axis=-1)
original_list.append(np.squeeze(stacked_array))
idx += comp
return original_list
[docs]
def probe_values(
self,
points: np.ndarray | list,
data_name: str | Variable | list[str | Variable],
interp_method: Literal["nearest", "linear"] = "linear",
) -> np.ndarray | list[np.ndarray]:
"""
Return the sampled data of the MeshSeries at observation points.
Similar to :func:`~ogstools.MeshSeries.probe` but returns the
data directly instead of creating a new MeshSeries.
:param points: The observation points to sample at.
:param data_name: Data to sample. If provided as a Variable, the
output will transformed accordingly. Can also be
a list of str or Variable.
:param interp_method: Interpolation method, defaults to `linear`
:returns: `numpy` array/s of interpolated data at observation points
with the following shape:
- multiple points: (n_timesteps, n_points, [n_components])
- single points: (n_timesteps, [n_components])
If `data_name` is a list, a corresponding list of arrays is
returned.
"""
if (
isinstance(data_name, list)
and any(key in self.point_data for key in data_name)
and any(key in self.cell_data for key in data_name)
):
msg = "Cannot probe point and cell data together."
raise TypeError(msg)
pts = reshape_obs_points(points, self.mesh(0))
values = self.values(data_name)
if isinstance(data_name, list):
components = [
1 if arr.ndim == 2 else np.shape(arr)[-1] for arr in values
]
values = np.moveaxis(self._flatten_vectors(values), 0, -1)
values = np.swapaxes(values, 0, 1)
geom = self.mesh(0).points
if values.shape[0] != geom.shape[0]:
# assume cell_data
geom = self.mesh(0).cell_centers().points
# remove flat dimensions for interpolation
flat_axis = np.argwhere(np.all(np.isclose(geom, geom[0]), axis=0))
geom = np.delete(geom, flat_axis, 1)
pts = np.delete(pts, flat_axis, 1)
more_than_1d = self.dim > 1 or len(flat_axis) < 2
match more_than_1d, interp_method:
case True, "nearest":
from scipy.interpolate import NearestNDInterpolator
result = np.swapaxes(
NearestNDInterpolator(geom, values)(pts), 0, 1
)
case True, "linear":
from scipy.interpolate import LinearNDInterpolator
result = np.swapaxes(
LinearNDInterpolator(geom, values, np.nan)(pts), 0, 1
)
case False, kind:
from scipy.interpolate import interp1d
result = interp1d(geom[:, 0], values.T, kind=kind)(
np.squeeze(pts, 1)
)
case _, _:
msg = (
"No interpolation method implemented for mesh of "
f"{self.dim=} and {interp_method=}"
)
raise TypeError(msg)
if np.shape(points)[0] != 1 and np.shape(result)[1] == 1:
result = np.squeeze(result, axis=1)
if isinstance(data_name, list):
if more_than_1d:
result = np.moveaxis(result, -1, 0)
result = self._restore_vectors(result, components)
return result
@copy_function_signature(line)
def plot_line(self, *args: Any, **kwargs: Any) -> Any:
return line(self, *args, **kwargs)
# TODO: make us of ot.plot.heatmaps
[docs]
def plot_time_slice(
self,
x: Literal["x", "y", "z", "time"],
y: Literal["x", "y", "z", "time"],
variable: str | Variable,
time_logscale: bool = False,
fig: plt.Figure | None = None,
ax: plt.Axes | None = None,
cbar: bool = True,
**kwargs: Any,
) -> plt.Figure | None:
"""
Create a heatmap for a variable over time and space.
:param x: What to display on the x-axis (x, y, z or time)
:param y: What to display on the y-axis (x, y, z or time)
:param variable: The variable to be visualized.
:param time_logscale: Should log-scaling be applied to the time-axis?
:param fig: matplotlib figure to use for plotting.
:param ax: matplotlib axis to use for plotting.
:param cbar: If True, adds a colorbar.
Keyword Arguments:
- cb_labelsize: colorbar labelsize
- cb_loc: colorbar location ('left' or 'right')
- cb_pad: colorbar padding
- cmap: colormap
- vmin: minimum value for colorbar
- vmax: maximum value for colorbar
- num_levels: number of levels for colorbar
- figsize: figure size
- dpi: resolution
- log_scaled: logarithmic scaling
"""
if ax is None and fig is None:
fig, ax = plt.subplots(
figsize=kwargs.get("figsize", [18, 14]),
dpi=kwargs.get("dpi", 100),
)
optional_return_figure = fig
elif ax is None or fig is None:
msg = "Please provide fig and ax together or not at all."
raise ValueError(msg)
else:
optional_return_figure = None
if "time" not in [x, y]:
msg = "One of x_var and y_var has to be 'time'."
raise KeyError(msg)
if x not in "xyz" and y not in "xyz":
msg = "One of x_var and y_var has to be a spatial coordinate."
raise KeyError(msg)
var_z = Variable.find(variable, self.mesh(0))
var_x, var_y = _normalize_vars(x, y, self.mesh(0), ["time", "time"])
time_var = var_x if var_x.data_name == "time" else var_y
unit = self.time_unit
time_var.data_unit = time_var.output_unit = str(
unit if unit.magnitude != 1 else unit.units
)
if time_logscale:
def log10time(vals: np.ndarray) -> np.ndarray:
log10vals = np.log10(
vals, where=vals != 0, out=np.zeros_like(vals)
)
if log10vals[0] == 0:
log10vals[0] = log10vals[1] - (log10vals[2] - log10vals[1])
return log10vals
get_time = time_var.func
time_var.func = lambda ms: log10time(get_time(ms))
time_label = time_var.get_label()
time_var.get_label = ( # type: ignore[assignment]
lambda *_: f"log$_{{10}}$( {time_label} )"
)
x_vals = var_x.transform(self)
y_vals = var_y.transform(self)
values = self.values(var_z)
if values.shape == (len(x_vals), len(y_vals)):
values = values.T
# sort w.r.t spatial index if not already sorted
if x in "xyz" and not np.all(np.diff(x_vals) >= 0):
sorted_indices = np.argsort(x_vals)
x_vals = x_vals[sorted_indices]
values = values[:, sorted_indices]
elif y in "xyz" and not np.all(np.diff(y_vals) >= 0):
sorted_indices = np.argsort(y_vals)
y_vals = y_vals[sorted_indices]
values = values[sorted_indices]
if "levels" in kwargs:
levels = np.asarray(kwargs.pop("levels"))
else:
vmin, vmax = (plot.setup.vmin, plot.setup.vmax)
if (
kwargs.get("log_scaled", plot.setup.log_scaled)
and not var_z.is_mask()
):
values = np.log10(
values,
where=values > 0.0,
out=np.ones_like(values) * kwargs.get("vmin", np.nan),
)
levels = plot.levels.compute_levels(
kwargs.get("vmin", np.nanmin(values) if vmin is None else vmin),
kwargs.get("vmax", np.nanmax(values) if vmax is None else vmax),
kwargs.get("num_levels", plot.setup.num_levels),
)
cmap, norm = plot.utils.get_cmap_norm(levels, var_z, **kwargs)
ax.pcolormesh(x_vals, y_vals, values, cmap=cmap, norm=norm)
fontsize = kwargs.get("fontsize", plot.setup.fontsize)
x_label = var_x.get_label()
y_label = var_y.get_label()
plot.utils.label_ax(fig, ax, x_label, y_label, fontsize)
ax.tick_params(axis="both", labelsize=fontsize, length=fontsize * 0.5)
ax.margins(0, 0)
if cbar:
plot.contourplots.add_colorbars(fig, ax, var_z, levels, **kwargs)
plot.utils.update_font_sizes(fig.axes, fontsize)
return optional_return_figure
@property
def mesh_func(self) -> Callable[[pv.UnstructuredGrid], pv.UnstructuredGrid]:
"""Returns stored transformation function or identity if not given."""
if self._mesh_func_opt is None:
return lambda mesh: mesh
return lambda mesh: pv.UnstructuredGrid(
self._mesh_func_opt(mesh), # type: ignore[misc]
deep=True,
)
[docs]
def scale(
self,
spatial: int | float | str = 1.0,
time: int | float | str = 1.0,
) -> MeshSeries:
"""Scale the spatial coordinates and timevalues.
Useful to convert to other units, e.g. "m" to "km" or "s" to "a".
Converts from SI units (i.e. 'm' and 's') to the given arguments.
Does not create a copy, but modifies the calling object.
If you want to have a scaled version without changing the original do
`ms_scaled = ms.copy().scale(...)`
:param spatial: Float factor or str for target unit.
:param time: Float factor or str for target unit.
:returns: The scaled MeshSeries.
"""
Qty = u_reg.Quantity
if isinstance(spatial, str):
spatial_factor = Qty(self.spatial_unit, spatial).magnitude
spatial_unit = Qty(1, spatial)
else:
spatial_factor = spatial
spatial_unit = self.spatial_unit / spatial
if isinstance(time, str):
time_factor = Qty(self.time_unit, time).magnitude
time_unit = Qty(1, time)
else:
time_factor = time
time_unit = self.time_unit / time
if time_factor == 1.0:
scaled_cache = self._mesh_cache
else:
scaled_cache = {
timevalue * time_factor: self._mesh_cache.pop(timevalue)
for timevalue in list(self._mesh_cache.keys())
}
if spatial_factor != 1.0:
for mesh in scaled_cache.values():
# using transform would shorten this, but doing it explicitly
# allows us to use inplace=True which is a bit more efficient
mesh.scale(spatial_factor, inplace=True)
if (func := self._mesh_func_opt) is None:
self._mesh_func_opt = lambda mesh: mesh.scale(spatial_factor)
else:
self._mesh_func_opt = lambda mesh: func(mesh).scale( # type: ignore[misc]
spatial_factor
)
self._mesh_cache = scaled_cache
self.spatial_unit = spatial_unit
self.time_unit = time_unit
self._time_factor = self._time_factor * time_factor
return self
[docs]
@classmethod
@typechecked
def difference(
cls,
ms_a: MeshSeries,
ms_b: MeshSeries,
variable: Variable | str | None = None,
) -> MeshSeries:
"""
Compute difference of variables between the two MeshSeries instances
from which this method is called and a second MeshSeries
instance passed as method parameter.
Returns new instance of MeshSeries:
ms = ms_a - ms_b
:param ms_a: The mesh from which data is to be subtracted.
:param ms_b: The mesh whose data is to be subtracted.
:param variable: The variable of interest. If not given, all
point and cell_data will be processed raw.
:returns: MeshSeries containing the difference of variable` or
of all datasets between both MeshSeries.
"""
from ogstools.mesh.differences import difference_pairwise
if np.array_equal(ms_a.timevalues, ms_b.timevalues):
ms_b_ok = ms_b
else:
ms_b_ok = ms_b.resample_temporal(ms_a.timevalues)
msg = (
"Two instances of MeshSeries have different time values. "
"Direct difference cannot be computed. "
"ms_b will be interpolated to match timesteps of ms_a."
)
warn(msg, RuntimeWarning, stacklevel=2)
return MeshSeries.from_data(
difference_pairwise(ms_a, ms_b_ok, variable),
ms_a.timevalues,
(ms_a.spatial_unit, ms_a.spatial_unit),
(ms_a.time_unit, ms_a.time_unit),
)
[docs]
@staticmethod
def compare(
ms_a: MeshSeries,
ms_b: MeshSeries,
variable: Variable | str | None = None,
point_data: bool = True,
cell_data: bool = True,
field_data: bool = True,
fast_compare: bool = False,
atol: float = 0.0,
*,
strict: bool = False,
) -> bool:
"""
Method to compare two ``ot.MeshSeries`` objects.
Returns ``True`` if they match within the tolerances,
otherwise ``False``.
:param ms_a: The reference base MeshSeries for comparison.
:param ms_b: The MeshSeries to compare against the reference.
:param variable: The variable of interest.
If not given, all point, cell and field data will be processed.
:param point_data: Compare all point data if `variable` is None.
:param cell_data: Compare all cell data if `variable` is None.
:param field_data: Compare all field data if `variable` is None.
:param fast_compare: If ``True``, mesh topology is only verified for the first timestep,
otherwise, topology is checked at every timestep.
:param atol: Absolute tolerance.
:param strict: Raises an ``AssertionError``, if mismatch.
"""
from ogstools.mesh.differences import compare
# Testing timevalues
if ms_a.timevalues.shape != ms_b.timevalues.shape:
return False
if not np.allclose(
ms_a.timevalues,
ms_b.timevalues,
atol=atol,
rtol=0.0,
equal_nan=True,
):
if strict:
err_msg = "timevalues differs between MeshSeries."
raise AssertionError(err_msg)
return False
# Testing mesh-wise
for i, (mesh_a, mesh_b) in enumerate(zip(ms_a, ms_b, strict=True)):
check_topo = fast_compare or i == 0
if not compare(
mesh_a,
mesh_b,
variable=variable,
point_data=point_data,
cell_data=cell_data,
field_data=field_data,
check_topology=check_topo,
atol=atol,
strict=strict,
):
return False
return True
def _rename_vtufiles(self, new_pvd_fn: Path, fns: list[Path]) -> list:
fns_new: list[Path] = []
assert self.filepath is not None
for filename in fns:
filepathparts_at_timestep = list(filename.parts)
filepathparts_at_timestep[-1] = filepathparts_at_timestep[
-1
].replace(
self.filepath.name.split(".")[0],
new_pvd_fn.name.split(".")[0],
)
fns_new.append(Path(*filepathparts_at_timestep))
return fns_new
def _save_vtu(
self, new_pvd_fn: Path, fns: list[Path], ascii: bool = False
) -> None:
from ogstools.mesh import save
for i, timestep in enumerate(self.timesteps):
if ".vtu" in fns[i].name:
save(
self.mesh(i),
Path(new_pvd_fn.parent, fns[i].name),
binary=not ascii,
)
elif ".xdmf" in fns[i].name:
newname = fns[i].name.replace(
".xdmf", f"_ts_{timestep}_t_{self.timevalues[i]}.vtu"
)
save(self.mesh(i), Path(new_pvd_fn.parent, newname))
else:
s = "File type not supported."
raise RuntimeError(s)
def _save_pvd(self, new_pvd_fn: Path, fns: list[Path]) -> None:
from lxml import etree as ET
root = ET.Element("VTKFile")
root.attrib["type"] = "Collection"
root.attrib["version"] = "0.1"
root.attrib["byte_order"] = "LittleEndian"
root.attrib["compressor"] = "vtkZLibDataCompressor"
collection = ET.SubElement(root, "Collection")
for i, timestep in enumerate(self.timevalues):
timestepselement = ET.SubElement(collection, "DataSet")
timestepselement.attrib["timestep"] = str(timestep)
timestepselement.attrib["group"] = ""
timestepselement.attrib["part"] = "0"
if ".xdmf" in fns[i].name:
newname = fns[i].name.replace(
".xdmf", f"_ts_{self.timesteps[i]}_t_{timestep}.vtu"
)
timestepselement.attrib["file"] = newname
elif ".vtu" in fns[i].name:
timestepselement.attrib["file"] = fns[i].name
else:
s = "File type not supported."
raise RuntimeError(s)
tree = ET.ElementTree(root)
tree.write(
new_pvd_fn,
encoding="ISO-8859-1",
xml_declaration=True,
pretty_print=True,
)
def _check_path(self, filename: Path | None) -> Path:
if not isinstance(filename, Path):
s = "filename is empty"
raise RuntimeError(s)
assert isinstance(filename, Path)
return cast(Path, filename)
[docs]
def save(
self,
target: Path | str | None = None,
deep: bool = True,
ascii: bool = False,
overwrite: bool | None = None,
dry_run: bool = False,
archive: bool = False,
id: str | None = None,
) -> list[Path]:
user_defined = self._pre_save(target, overwrite, dry_run, id=id)
files = self._save_impl(deep=deep, ascii=ascii, dry_run=dry_run)
self._post_save(user_defined, archive, dry_run)
return files
def _save_impl(
self,
deep: bool = True,
ascii: bool = False,
dry_run: bool = False,
) -> list[Path]:
"""
Save mesh series to disk.
:param filename: Filename to save the series to. Extension specifies
the file type. Currently only PVD is supported.
:param deep: Specifies whether VTU/H5 files should be written.
:param ascii: Specifies if ascii or binary format should be used,
defaults to binary (False) - True for ascii.
"""
fn = self.next_target
fn.parent.mkdir(parents=True, exist_ok=True)
fns = [
self._check_path(self.mesh(t).filepath)
for t in np.arange(len(self.timevalues), dtype=int)
]
if ".pvd" in fn.name:
if deep is True:
fns = self._rename_vtufiles(fn, fns)
if dry_run:
# For dry_run, return paths where files would be written
fns_written = [fn.parent / f.name for f in fns]
return [fn] + fns_written
self._save_vtu(fn, fns, ascii=ascii)
# Update fns to actual written paths
fns = [fn.parent / f.name for f in fns]
if dry_run:
fns_written = [fn.parent / f.name for f in fns]
return [fn] + fns_written
self._save_pvd(fn, fns)
else:
s = "Currently the save method is implemented for PVD/VTU only."
raise RuntimeError(s)
# Return both the PVD file and all VTU files
return [fn] + fns
[docs]
@deprecated("""
Please use `del meshseries.field_data[key]` or, if you want to keep the
data in the last timestep: `del meshseries[:-1].field_data[key]`.
""")
def remove_array(
self, name: str, data_type: str = "field", skip_last: bool = False
) -> None:
"""
Removes an array from all time slices of the mesh series.
:param name: Array name
:param data_type: Data type of the array. Could be either
field, cell or point
:param skip_last: Skips the last time slice (e.g. for restart purposes).
"""
for i, mesh in enumerate(self):
if ((skip_last) is False) or (i < len(self) - 1):
if data_type == "field":
mesh.field_data.remove(name)
elif data_type == "cell":
mesh.cell_data.remove(name)
elif data_type == "point":
mesh.point_data.remove(name)
else:
msg = "array type unknown"
raise RuntimeError(msg)
def _propagate_target(self) -> None:
pass
def __repr__(self) -> str:
cls_name = self.__class__.__name__
if self._id and self.user_specified_target:
construct = f'{cls_name}.from_id("{self._id}")'
else:
filepath = self.active_target or self.next_target
filepath_str = str(filepath) if filepath is not None else None
construct = (
f"{cls_name}("
f"filepath={filepath_str!r}, "
f"spatial_unit={str(self.spatial_unit.units)!r}, "
f"time_unit={str(self.time_unit.units)!r}"
f")"
)
base_repr = super().__repr__()
return (
f"{construct}\n"
f"{base_repr}\n"
f"data_type={self._data_type!r}\n"
f"num_timesteps={len(self.timevalues)}\n"
f"dim={getattr(self, 'dim', None)!r}\n"
f"cached_meshes={len(self._mesh_cache)}\n"
)