Source code for ogstools.meshseries._meshseries

# SPDX-FileCopyrightText: Copyright (c) OpenGeoSys Community (opengeosys.org)
# SPDX-License-Identifier: BSD-3-Clause


from __future__ import annotations

from collections.abc import Callable, Iterator, Sequence
from copy import copy as shallowcopy
from copy import deepcopy
from pathlib import Path
from typing import Any, Literal, cast, overload
from warnings import warn

import numpy as np
import pyvista as pv
from matplotlib import pyplot as plt
from typeguard import typechecked

from ogstools import plot
from ogstools._internal import copy_function_signature, deprecated
from ogstools.core.storage import StorageBase
from ogstools.mesh import read
from ogstools.mesh.utils import reshape_obs_points
from ogstools.plot.lineplots import line
from ogstools.variables import Variable, _normalize_vars, u_reg

from .data_dict import DataDict


[docs] class MeshSeries(Sequence[pv.UnstructuredGrid], StorageBase): """ A wrapper around pyvista and meshio for reading of pvd and xdmf timeseries. """ __hash__ = None
[docs] def __init__( self, filepath: str | Path | None = None, spatial_unit: str | Sequence[str] = "m", time_unit: str | Sequence[str] = "s", id: str | None = None, ) -> None: """ Initialize a MeshSeries object :param filepath: Path to the PVD or XDMF file. :param spatial_unit: Unit/s of the mesh points. See note. :param time_unit: Unit/s of the timevalues. See note. :returns: A MeshSeries object :note: If given as a single string, the data is read in SI units i.e. in seconds and meters and converted to the given units. If given as a tuple, the first str corresponds to the data_unit, the second to the output_unit. E.g.: ``ot.MeshSeries(filepath, "km", ("a", "d"))`` would read in the spatial data in meters and convert to kilometers and read in the timevalues in years and convert to days. """ # TODO: if filepath extension = Path(filepath).suffix to also support xdmf super().__init__("MeshSeries", file_ext="pvd", id=id) self._time_factor = 1.0 self._epsilon = 1.0e-6 self._mesh_cache: dict[float, pv.UnstructuredGrid] = {} self._mesh_func_opt: ( Callable[[pv.UnstructuredGrid], pv.UnstructuredGrid] | None ) = None # list of slices to be able to have nested slices with xdmf # (but only the first slice will be efficient) self._time_indices: list[slice | Any] = [slice(None)] self._timevalues: np.ndarray "original data timevalues - do not change." base_spatial, spatial_unit = ( ("m", spatial_unit) if isinstance(spatial_unit, str) else spatial_unit ) base_time, time_unit = ( ("s", time_unit) if isinstance(time_unit, str) else time_unit ) self.spatial_unit = u_reg.Quantity(1, base_spatial) self.time_unit = u_reg.Quantity(1, base_time) if (self.spatial_unit != u_reg.Quantity(1, spatial_unit) or self.time_unit != u_reg.Quantity(1, time_unit)): # fmt: skip self.scale(spatial_unit, time_unit) if filepath is None: self.filepath = self._data_type = None return self.filepath = Path(filepath) self._data_type = self.filepath.suffix match self._data_type: case ".pvd": self._pvd_reader = pv.PVDReader(self.filepath) self.timestep_files = [ str(self.filepath.parent / dataset.path) for dataset in self._pvd_reader.datasets ] self._timevalues = np.asarray(self._pvd_reader.time_values) self.skip_pvtu2vtu = False case ".xdmf" | ".xmf": from .xdmf_reader import XDMFReader self._data_type = ".xdmf" self._xdmf_reader = XDMFReader(self.filepath) self._timevalues = np.asarray( [ float(element.attrib["Value"]) for collection_i in self._xdmf_reader.collection for element in collection_i if element.tag == "Time" ] ) case ".vtu": self._vtu_reader = pv.XMLUnstructuredGridReader(self.filepath) self._timevalues = np.zeros(1) case suffix: msg = ( "Can only read 'pvd', 'xdmf', 'xmf'(from Paraview) or " f"'vtu' files, but not '{suffix}'" ) raise TypeError(msg) self.dim = self.mesh(0).GetMaxSpatialDimension()
[docs] @classmethod def from_data( cls, meshes: Sequence[pv.UnstructuredGrid], timevalues: np.ndarray, spatial_unit: str | Sequence[str] = "m", time_unit: str | Sequence[str] = "s", ) -> MeshSeries: "Create a MeshSeries from a list of meshes and timevalues." new_ms = cls(spatial_unit=spatial_unit, time_unit=time_unit) new_ms._timevalues = deepcopy(timevalues) # pylint: disable=W0212 new_ms._mesh_cache.update( dict(zip(new_ms.timevalues, deepcopy(meshes), strict=True)) ) new_ms.dim = meshes[0].GetMaxSpatialDimension() return new_ms
[docs] @classmethod def from_id(cls, meshseries_id: str) -> MeshSeries: """ Load MeshSeries from the user storage path using its ID. StorageBase.Userpath must be set. :param meshseries_id: The unique ID of the MeshSeries to load. :returns: A MeshSeries instance restored from disk. """ # Try .pvd first, then .xdmf pvd_file = ( StorageBase.saving_path() / "MeshSeries" / f"{meshseries_id}" / "ms.pvd" ) xdmf_file = ( StorageBase.saving_path() / "MeshSeries" / f"{meshseries_id}" / "ms.xdmf" ) if pvd_file.exists(): meshseries = cls(filepath=pvd_file) elif xdmf_file.exists(): meshseries = cls(filepath=xdmf_file) else: msg = f"No MeshSeries found at {pvd_file} or {xdmf_file}" raise FileNotFoundError(msg) meshseries._id = meshseries_id return meshseries
[docs] def extend(self, mesh_series: MeshSeries) -> None: """ Extends self with mesh_series. If the last element of the mesh series is within epsilon to the first element of mesh_series to extend, the duplicate element is removed """ ms1_list = list(self) ms2_list = list(mesh_series) ms1_timevalues = self.timevalues ms2_timevalues = mesh_series.timevalues if hasattr(self, "timestep_files"): ms1_timestep_files = self.timestep_files else: ms1_timestep_files = [""] * len(ms1_list) if hasattr(mesh_series, "timestep_files"): ms2_timestep_files = mesh_series.timestep_files else: ms2_timestep_files = [""] * len(ms2_list) offset = 0.0 delta = ms2_timevalues[0] - ms1_timevalues[-1] offset = 0.0 if delta >= 0 else ms1_timevalues[-1] if ((delta < 0) and (ms2_timevalues[0] == 0.0)) or ( np.abs(delta) < self._epsilon ): ms1_timevalues = ms1_timevalues[:-1] ms1_list = ms1_list[:-1] ms1_timestep_files = ms1_timestep_files[:-1] ms2_timevalues = ms2_timevalues + offset self._timevalues = np.append(ms1_timevalues, ms2_timevalues, axis=0) self._mesh_cache.update( dict( zip( np.append(ms1_timevalues, ms2_timevalues, axis=0), ms1_list + ms2_list, strict=True, ) ) ) if hasattr(self, "timestep_files") or hasattr( mesh_series, "timestep_files" ): self.timestep_files = ms1_timestep_files + ms2_timestep_files assert len(self.timestep_files) == len( self._timevalues ), "Timestep files and timevalues do not match."
[docs] def resample_temporal(self, timevalues: np.ndarray) -> MeshSeries: "Return a new MeshSeries interpolated to the given timevalues." interp_meshes = [self.mesh_interp(tv) for tv in timevalues] return MeshSeries.from_data( interp_meshes, timevalues, (self.spatial_unit, self.spatial_unit), (self.time_unit, self.time_unit), )
def _extract_probe( self, pts_or_mesh: np.ndarray | pv.DataSet, data_name: str | Variable | list[str | Variable] | None = None, interp_method: Literal["nearest", "linear"] = "linear", ) -> MeshSeries: if isinstance(pts_or_mesh, pv.DataSet): points = pts_or_mesh.points mesh = pts_or_mesh else: points = np.asarray(pts_or_mesh, dtype=float) mesh = pv.PolyData(np.asarray(points)) if data_name is None: variables = list(set().union(self[0].point_data, self[0].cell_data)) elif isinstance(data_name, list): variables = data_name else: variables = [data_name] meshes = [mesh.copy() for _ in self.timevalues] probe_ms = MeshSeries.from_data( meshes, self.timevalues, (str(self.spatial_unit), str(self.spatial_unit)), (str(self.time_unit), str(self.time_unit)), ) point_data_keys = [key for key in variables if key in self.point_data] cell_data_keys = [key for key in variables if key in self.cell_data] for keys, data in [ (point_data_keys, probe_ms.point_data), (cell_data_keys, probe_ms.cell_data), ]: if len(keys) == 0: continue values = self.probe_values(points, keys, interp_method) for var, vals in zip(keys, values, strict=True): name = var.output_name if isinstance(var, Variable) else var data[name] = vals return probe_ms
[docs] def probe( self, points: np.ndarray, data_name: str | Variable | list[str | Variable] | None = None, interp_method: Literal["nearest", "linear"] = "linear", ) -> MeshSeries: """Create a new MeshSeries by probing points on an existing MeshSeries. :param points: The points at which to probe. :param data_name: Data to extract. If None, use all point data. :param interp_method: The interpolation method to use. :returns: A MeshSeries (Pointcloud) containing the probed data. """ return self._extract_probe(points, data_name, interp_method)
[docs] def interpolate( self, mesh: pv.DataSet, data_name: str | Variable | list[str | Variable] | None = None, ) -> MeshSeries: """Create a new MeshSeries by spatial interpolation. :param mesh: The mesh on which to interpolate. :param data_name: Data to extract. If None, use all point data. :returns: A spatially interpolated MeshSeries. """ return self._extract_probe(mesh, data_name, "linear")
def __eq__(self, other: object) -> bool: if not isinstance(other, MeshSeries): return NotImplemented # TODO:: Field data is currently not copied, to not break copy/eq - contract -> False return MeshSeries.compare(self, other, field_data=False) def __deepcopy__(self, memo: dict) -> MeshSeries: # Deep copy is the default when using self.copy() # For shallow copy: self.copy(deep=False) self_copy = self.__class__(self.filepath) memo[id(self)] = self_copy for key, value in self.__dict__.items(): if key == "_next_target" or key == "id": pass elif key != "_pvd_reader" and key != "_xdmf_reader": if isinstance(value, pv.UnstructuredGrid): # For PyVista objects use their own copy method setattr(self_copy, key, value.copy(deep=True)) else: # For everything that is neither reader nor PyVista object # use the deepcopy setattr(self_copy, key, deepcopy(value, memo)) else: # Shallow copy of reader is needed, because timesteps are # stored in reader, deep copy doesn't work for _pvd_reader # and _xdmf_reader setattr(self_copy, key, shallowcopy(value)) return self_copy @overload def __getitem__(self, index: int) -> pv.UnstructuredGrid: ... @overload def __getitem__(self, index: slice | Sequence) -> MeshSeries: ... @overload def __getitem__(self, index: str) -> np.ndarray: ...
[docs] def __getitem__(self, index: Any) -> Any: if isinstance( index, int | np.unsignedinteger | np.signedinteger, ): return self.mesh(index) if isinstance(index, str): # type: ignore[unreachable] return self.values(index) if isinstance(index, slice | Sequence): ms_copy = self.copy(deep=False) if ms_copy._time_indices == [slice(None)]: ms_copy._time_indices = [index] else: ms_copy._time_indices += [index] return ms_copy msg = ( "Index type not supported.\n " "It has to be one of following: " "int, np.int, str, slice or sequence." ) raise ValueError(msg)
def __len__(self) -> int: return len(self.timesteps) def __iter__(self) -> Iterator[pv.UnstructuredGrid]: for i in np.arange(len(self.timevalues), dtype=int): yield self.mesh(i) def __str__(self) -> str: if self._data_type == ".vtu": reader = self._vtu_reader elif self._data_type == ".pvd": reader = self._pvd_reader elif self._data_type == ".xdmf": reader = self._xdmf_reader else: reader = "None" return ( f"MeshSeries:\n" f"filepath: {self._format_path(self.filepath)}\n" f"data_type: {self._data_type}\n" f"timevalues: {self.timevalues[0]} to {self.timevalues[-1]} in {len(self.timevalues)} steps\n" f"reader: {reader}\n" f"rawdata_file: {self._format_path(self.rawdata_file())}\n" ) # deliberately typing as Sequence and not as zip because typing as zip # leads to a weird cross-referencing error from sphinx side with no easy # apparent fix
[docs] def items(self) -> Sequence[tuple[float, pv.UnstructuredGrid]]: "Returns zipped tuples of timevalues and meshes." return zip(self.timevalues, self, strict=True) # type: ignore[return-value]
[docs] def aggregate_temporal( self, variable: Variable | str, func: Callable ) -> pv.UnstructuredGrid: """Aggregate data over all timesteps using a specified function. :param variable: The mesh variable to be aggregated. :param func: The aggregation function to apply. E.g. np.min, np.max, np.mean, np.median, np.sum, np.std, np.var :returns: A mesh with aggregated data according to the given function. """ # TODO: add function to create an empty mesh from a given on # custom field_data may need to be preserved mesh = self.mesh(0).copy(deep=True) mesh.clear_point_data() mesh.clear_cell_data() if isinstance(variable, Variable): output_name = f"{variable.output_name}_{func.__name__}" else: output_name = f"{variable}_{func.__name__}" mesh[output_name] = func(self.values(variable), axis=0) return mesh
[docs] def clear_cache(self) -> None: self._mesh_cache.clear()
[docs] def closest_timestep(self, timevalue: float) -> int: """Return the corresponding timestep from a timevalue.""" return int(np.argmin(np.abs(self.timevalues - timevalue)))
[docs] def closest_timevalue(self, timevalue: float) -> float: """Return the closest timevalue to a timevalue.""" return self.timevalues[self.closest_timestep(timevalue)]
[docs] def ip_tesselated(self) -> MeshSeries: "Create a new MeshSeries from integration point tessellation." from ogstools.mesh.ip_mesh import to_ip_mesh, to_ip_point_cloud ip_mesh = to_ip_mesh(self.mesh(0)) ip_pt_cloud = to_ip_point_cloud(self.mesh(0)) ordering = ip_mesh.find_containing_cell(ip_pt_cloud.points) ip_meshes = [] for i in np.arange(len(self.timevalues), dtype=int): ip_data = { key: self.mesh(i).field_data[key][np.argsort(ordering)] for key in ip_mesh.cell_data } ip_mesh.cell_data.update(ip_data) ip_meshes += [ip_mesh.copy()] return MeshSeries.from_data(ip_meshes, self.timevalues)
[docs] def mesh( self, timestep: int, lazy_eval: bool = True ) -> pv.UnstructuredGrid: """Returns the mesh at the given timestep.""" timevalue = self.timevalues[timestep] if not np.any(self.timevalues == timevalue): msg = f"Value {timevalue} not found in the array." raise ValueError(msg) data_timestep = np.argmax( self._timevalues * self._time_factor == timevalue ) if timevalue in self._mesh_cache: mesh = self._mesh_cache[timevalue] else: match self._data_type: case ".pvd": suffix = self.timestep_files[data_timestep].split(".")[-1] if suffix == "pvtu" and not self.skip_pvtu2vtu: from ogstools._find_ogs import cli from ogstools.core.storage import _date_temp_path tmp_path = _date_temp_path("pvtu2vtu", "vtu") tmp_path.parent.mkdir(parents=True) cli().pvtu2vtu( i=self.timestep_files[data_timestep], o=tmp_path ) pv_mesh = read(tmp_path) else: pv_mesh = self._read_pvd(data_timestep) case ".xdmf": pv_mesh = self._read_xdmf(data_timestep) case ".vtu": pv_mesh = self._vtu_reader.read() case _: msg = f"Unexpected datatype {self._data_type}." raise TypeError(msg) mesh = self.mesh_func(pv_mesh) if lazy_eval: self._mesh_cache[timevalue] = mesh filepath = ( Path(self.timestep_files[data_timestep]) if self._data_type == ".pvd" else self.filepath ) for attr, val in [ ("filepath", filepath), ("spatial_unit", self.spatial_unit), ("time_unit", self.time_unit), ]: if not hasattr(pv, "set_new_attribute") or hasattr(mesh, attr): setattr(mesh, attr, val) else: pv.set_new_attribute(mesh, attr, val) return mesh
[docs] def rawdata_file(self) -> Path | None: """ Checks, if working with the raw data is possible. For example, OGS Simulation results with XDMF support efficient raw data access via `h5py <https://docs.h5py.org/en/stable/quick.html#quick>`_ :returns: The location of the file containing the raw data. If it does not support efficient read (e.g., no efficient slicing), it returns None. """ if self._data_type == ".xdmf" and self._xdmf_reader.has_fast_access(): return self._xdmf_reader.rawdata_path() # single h5 file return None
[docs] def mesh_interp( self, timevalue: float, lazy_eval: bool = True ) -> pv.UnstructuredGrid: """Return the temporal interpolated mesh for a given timevalue.""" t_vals = self.timevalues ts1 = int(t_vals.searchsorted(timevalue, "right") - 1) ts2 = min(ts1 + 1, len(t_vals) - 1) if np.isclose(timevalue, t_vals[ts1]): return self.mesh(ts1, lazy_eval) mesh1 = self.mesh(ts1, lazy_eval) mesh2 = self.mesh(ts2, lazy_eval) mesh = mesh1.copy(deep=True) # TODO interpolate cell_data and field_data as well for key in mesh1.point_data: if key not in mesh2.point_data: msg = f"{key} not in timestep {ts2}." warn(msg, RuntimeWarning, stacklevel=2) continue if np.all(mesh1.point_data[key] == mesh2.point_data[key]): continue dt = t_vals[ts2] - t_vals[ts1] slope = (mesh2.point_data[key] - mesh1.point_data[key]) / dt mesh.point_data[key] = mesh1.point_data[key] + slope * ( timevalue - t_vals[ts1] ) return mesh
@property def timevalues(self) -> np.ndarray: "Return the timevalues." vals = self._timevalues for index in self._time_indices: vals = vals[index] return vals * self._time_factor @property def timesteps(self) -> list: """ Return the OGS simulation timesteps of the timeseries data. Not to be confused with timevalues which returns a list of times usually given in time units. """ # TODO: read time steps from fn string if available return np.arange(len(self.timevalues), dtype=int) def _xdmf_values(self, variable_name: str) -> np.ndarray: dataitems = self._xdmf_reader.data_items[variable_name] # pv filters produces these arrays, which we can use for slicing # to also reflect the previous use of self.transform here mask_map = { "vtkOriginalPointIds": self.mesh(0).point_data, "vtkOriginalCellIds": self.mesh(0).cell_data, } for mask, data in mask_map.items(): if variable_name in data and mask in data: result = dataitems[self._time_indices[0], self.mesh(0)[mask]] break else: result = dataitems[self._time_indices[0]] for index in self._time_indices[1:]: result = result[index] if self._mesh_func_opt is not None and not any( mask in data for mask, data in mask_map.items() ): # if transform function doesn't produce the mask arrays we have to # map data xdmf data to the entire list of meshes and apply the # function on each mesh individually. ms_copy = self.copy(deep=True) ms_copy._mesh_func_opt = None # pylint: disable=protected-access ms_copy.clear_cache() raw_meshes = list(ms_copy) for mesh, data in zip(raw_meshes, result, strict=True): mesh[variable_name] = data meshes = list(map(self.mesh_func, raw_meshes)) result = np.asarray([mesh[variable_name] for mesh in meshes]) return result @overload def values(self, variable: str | Variable) -> np.ndarray: ... @overload def values(self, variable: list[str | Variable]) -> list[np.ndarray]: ...
[docs] def values( self, variable: str | Variable | list[str | Variable] ) -> np.ndarray | list[np.ndarray]: """ Get the data in the MeshSeries for all timesteps. Adheres to time slicing via `__getitem__` and an applied pyvista filter via `transform` if the applied filter produced 'vtkOriginalPointIds' or 'vtkOriginalCellIds' (e.g. `clip(..., crinkle=True)`, `extract_cells(...)`, `threshold(...)`.) :param variable: Data to read/process from the MeshSeries. Can also be a list of str or Variable. :returns: A numpy array of shape (n_timesteps, n_points/c_cells). If given an argument of type Variable is given, its transform function is applied on the data. If a list of str or Variable is given, a list of the individual values is returned. """ if isinstance(variable, list): return [self._values(var) for var in variable] return self._values(variable)
def _values(self, variable: str | Variable) -> np.ndarray: if isinstance(variable, Variable): if variable.mesh_dependent: return np.asarray([variable.transform(mesh) for mesh in self]) if ( variable.data_name != variable.output_name and variable.output_name in set().union(self.point_data, self.cell_data) ): variable_name = variable.output_name do_transform = False else: variable_name = variable.data_name do_transform = True else: variable_name = variable all_cached = self._is_all_cached if ( self._data_type == ".xdmf" and variable_name in self._xdmf_reader.data_items and not all_cached ): result = self._xdmf_values(variable_name) else: from tqdm import tqdm result = np.asarray( [ mesh[variable_name] for mesh in (self if all_cached else tqdm(self)) ] ) if isinstance(variable, Variable) and do_transform: return variable.transform(result) return result @property def _is_all_cached(self) -> bool: "Check if all meshes in this meshseries are cached" return np.isin( self.timevalues, np.fromiter(self._mesh_cache.keys(), float) ).all() @property def point_data(self) -> DataDict: "Useful for reading or setting point_data for the entire meshseries." return DataDict(self, lambda mesh: mesh.point_data, self[0].n_points) @property def cell_data(self) -> DataDict: "Useful for reading or setting cell_data for the entire meshseries." return DataDict(self, lambda mesh: mesh.cell_data, self[0].n_cells) @property def field_data(self) -> DataDict: "Useful for reading or setting field_data for the entire meshseries." return DataDict(self, lambda mesh: mesh.field_data, None) def _read_pvd(self, timestep: int) -> pv.UnstructuredGrid: self._pvd_reader.set_active_time_point(timestep) return self._pvd_reader.read()[0] def _read_xdmf(self, timestep: int) -> pv.UnstructuredGrid: import meshio points, cells = self._xdmf_reader.read_points_cells() _, point_data, cell_data, field_data = self._xdmf_reader.read_data( timestep ) meshio_mesh = meshio.Mesh( points, cells, point_data, cell_data, field_data ) # pv.from_meshio does not copy field_data (fix in pyvista?) pv_mesh = pv.from_meshio(meshio_mesh) pv_mesh.field_data.update(field_data) return pv_mesh def _time_of_extremum( self, variable: Variable | str, np_func: Callable, prefix: Literal["min", "max"], ) -> pv.UnstructuredGrid: """Returns a Mesh with the time of a given variable extremum as data. The data is named as `f'{prefix}_{variable.output_name}_time'`.""" mesh = self.mesh(0).copy(deep=True) variable = Variable.find(variable, mesh) mesh.clear_point_data() mesh.clear_cell_data() output_name = f"{prefix}_{variable.output_name}_time" mesh[output_name] = self.timevalues[ np_func(self.values(variable), axis=0) ] return mesh
[docs] def time_of_min(self, variable: Variable | str) -> pv.UnstructuredGrid: "Returns a Mesh with the time of the variable minimum as data." return self._time_of_extremum(variable, np.argmin, "min")
[docs] def time_of_max(self, variable: Variable | str) -> pv.UnstructuredGrid: "Returns a Mesh with the time of the variable maximum as data." return self._time_of_extremum(variable, np.argmax, "max")
[docs] def aggregate_spatial( self, variable: Variable | str, func: Callable ) -> np.ndarray: """Aggregate data over domain per timestep using a specified function. :param variable: The mesh variable to be aggregated. :param func: The aggregation function to apply. E.g. np.min, np.max, np.mean, np.median, np.sum, np.std, np.var :returns: A numpy array with aggregated data. """ return func(self.values(variable), axis=1)
def _flatten_vectors(self, data: list[np.ndarray]) -> list[np.ndarray]: return [ vals[..., i] if vals.ndim > 2 else vals for vals in data for i in range(vals.shape[-1] if vals.ndim > 2 else 1) ] def _restore_vectors( self, data: list[np.ndarray], components: list[int] ) -> list[np.ndarray]: original_list = [] idx = 0 for comp in components: stacked_array = np.stack(data[idx : idx + comp], axis=-1) original_list.append(np.squeeze(stacked_array)) idx += comp return original_list
[docs] def probe_values( self, points: np.ndarray | list, data_name: str | Variable | list[str | Variable], interp_method: Literal["nearest", "linear"] = "linear", ) -> np.ndarray | list[np.ndarray]: """ Return the sampled data of the MeshSeries at observation points. Similar to :func:`~ogstools.MeshSeries.probe` but returns the data directly instead of creating a new MeshSeries. :param points: The observation points to sample at. :param data_name: Data to sample. If provided as a Variable, the output will transformed accordingly. Can also be a list of str or Variable. :param interp_method: Interpolation method, defaults to `linear` :returns: `numpy` array/s of interpolated data at observation points with the following shape: - multiple points: (n_timesteps, n_points, [n_components]) - single points: (n_timesteps, [n_components]) If `data_name` is a list, a corresponding list of arrays is returned. """ if ( isinstance(data_name, list) and any(key in self.point_data for key in data_name) and any(key in self.cell_data for key in data_name) ): msg = "Cannot probe point and cell data together." raise TypeError(msg) pts = reshape_obs_points(points, self.mesh(0)) values = self.values(data_name) if isinstance(data_name, list): components = [ 1 if arr.ndim == 2 else np.shape(arr)[-1] for arr in values ] values = np.moveaxis(self._flatten_vectors(values), 0, -1) values = np.swapaxes(values, 0, 1) geom = self.mesh(0).points if values.shape[0] != geom.shape[0]: # assume cell_data geom = self.mesh(0).cell_centers().points # remove flat dimensions for interpolation flat_axis = np.argwhere(np.all(np.isclose(geom, geom[0]), axis=0)) geom = np.delete(geom, flat_axis, 1) pts = np.delete(pts, flat_axis, 1) more_than_1d = self.dim > 1 or len(flat_axis) < 2 match more_than_1d, interp_method: case True, "nearest": from scipy.interpolate import NearestNDInterpolator result = np.swapaxes( NearestNDInterpolator(geom, values)(pts), 0, 1 ) case True, "linear": from scipy.interpolate import LinearNDInterpolator result = np.swapaxes( LinearNDInterpolator(geom, values, np.nan)(pts), 0, 1 ) case False, kind: from scipy.interpolate import interp1d result = interp1d(geom[:, 0], values.T, kind=kind)( np.squeeze(pts, 1) ) case _, _: msg = ( "No interpolation method implemented for mesh of " f"{self.dim=} and {interp_method=}" ) raise TypeError(msg) if np.shape(points)[0] != 1 and np.shape(result)[1] == 1: result = np.squeeze(result, axis=1) if isinstance(data_name, list): if more_than_1d: result = np.moveaxis(result, -1, 0) result = self._restore_vectors(result, components) return result
@copy_function_signature(line) def plot_line(self, *args: Any, **kwargs: Any) -> Any: return line(self, *args, **kwargs) # TODO: make us of ot.plot.heatmaps
[docs] def plot_time_slice( self, x: Literal["x", "y", "z", "time"], y: Literal["x", "y", "z", "time"], variable: str | Variable, time_logscale: bool = False, fig: plt.Figure | None = None, ax: plt.Axes | None = None, cbar: bool = True, **kwargs: Any, ) -> plt.Figure | None: """ Create a heatmap for a variable over time and space. :param x: What to display on the x-axis (x, y, z or time) :param y: What to display on the y-axis (x, y, z or time) :param variable: The variable to be visualized. :param time_logscale: Should log-scaling be applied to the time-axis? :param fig: matplotlib figure to use for plotting. :param ax: matplotlib axis to use for plotting. :param cbar: If True, adds a colorbar. Keyword Arguments: - cb_labelsize: colorbar labelsize - cb_loc: colorbar location ('left' or 'right') - cb_pad: colorbar padding - cmap: colormap - vmin: minimum value for colorbar - vmax: maximum value for colorbar - num_levels: number of levels for colorbar - figsize: figure size - dpi: resolution - log_scaled: logarithmic scaling """ if ax is None and fig is None: fig, ax = plt.subplots( figsize=kwargs.get("figsize", [18, 14]), dpi=kwargs.get("dpi", 100), ) optional_return_figure = fig elif ax is None or fig is None: msg = "Please provide fig and ax together or not at all." raise ValueError(msg) else: optional_return_figure = None if "time" not in [x, y]: msg = "One of x_var and y_var has to be 'time'." raise KeyError(msg) if x not in "xyz" and y not in "xyz": msg = "One of x_var and y_var has to be a spatial coordinate." raise KeyError(msg) var_z = Variable.find(variable, self.mesh(0)) var_x, var_y = _normalize_vars(x, y, self.mesh(0), ["time", "time"]) time_var = var_x if var_x.data_name == "time" else var_y unit = self.time_unit time_var.data_unit = time_var.output_unit = str( unit if unit.magnitude != 1 else unit.units ) if time_logscale: def log10time(vals: np.ndarray) -> np.ndarray: log10vals = np.log10( vals, where=vals != 0, out=np.zeros_like(vals) ) if log10vals[0] == 0: log10vals[0] = log10vals[1] - (log10vals[2] - log10vals[1]) return log10vals get_time = time_var.func time_var.func = lambda ms: log10time(get_time(ms)) time_label = time_var.get_label() time_var.get_label = ( # type: ignore[assignment] lambda *_: f"log$_{{10}}$( {time_label} )" ) x_vals = var_x.transform(self) y_vals = var_y.transform(self) values = self.values(var_z) if values.shape == (len(x_vals), len(y_vals)): values = values.T # sort w.r.t spatial index if not already sorted if x in "xyz" and not np.all(np.diff(x_vals) >= 0): sorted_indices = np.argsort(x_vals) x_vals = x_vals[sorted_indices] values = values[:, sorted_indices] elif y in "xyz" and not np.all(np.diff(y_vals) >= 0): sorted_indices = np.argsort(y_vals) y_vals = y_vals[sorted_indices] values = values[sorted_indices] if "levels" in kwargs: levels = np.asarray(kwargs.pop("levels")) else: vmin, vmax = (plot.setup.vmin, plot.setup.vmax) if ( kwargs.get("log_scaled", plot.setup.log_scaled) and not var_z.is_mask() ): values = np.log10( values, where=values > 0.0, out=np.ones_like(values) * kwargs.get("vmin", np.nan), ) levels = plot.levels.compute_levels( kwargs.get("vmin", np.nanmin(values) if vmin is None else vmin), kwargs.get("vmax", np.nanmax(values) if vmax is None else vmax), kwargs.get("num_levels", plot.setup.num_levels), ) cmap, norm = plot.utils.get_cmap_norm(levels, var_z, **kwargs) ax.pcolormesh(x_vals, y_vals, values, cmap=cmap, norm=norm) fontsize = kwargs.get("fontsize", plot.setup.fontsize) x_label = var_x.get_label() y_label = var_y.get_label() plot.utils.label_ax(fig, ax, x_label, y_label, fontsize) ax.tick_params(axis="both", labelsize=fontsize, length=fontsize * 0.5) ax.margins(0, 0) if cbar: plot.contourplots.add_colorbars(fig, ax, var_z, levels, **kwargs) plot.utils.update_font_sizes(fig.axes, fontsize) return optional_return_figure
@property def mesh_func(self) -> Callable[[pv.UnstructuredGrid], pv.UnstructuredGrid]: """Returns stored transformation function or identity if not given.""" if self._mesh_func_opt is None: return lambda mesh: mesh return lambda mesh: pv.UnstructuredGrid( self._mesh_func_opt(mesh), # type: ignore[misc] deep=True, )
[docs] def transform( self, mesh_func: Callable[ [pv.UnstructuredGrid], pv.UnstructuredGrid ] = lambda mesh: mesh, ) -> MeshSeries: """ Apply a transformation function to the underlying mesh. :param mesh_func: A function which expects to read a mesh and return a mesh. Useful for slicing / clipping / thresholding. :returns: A deep copy of this MeshSeries with transformed meshes. """ ms_copy = self.copy(deep=True) # pylint: disable=protected-access for cache_timevalue, cache_mesh in self._mesh_cache.items(): ms_copy._mesh_cache[cache_timevalue] = pv.UnstructuredGrid( mesh_func(cache_mesh), deep=True ) ms_copy._mesh_func_opt = lambda mesh: mesh_func(self.mesh_func(mesh)) ms_copy.dim = ms_copy.mesh(0).GetMaxSpatialDimension() return ms_copy
[docs] def scale( self, spatial: int | float | str = 1.0, time: int | float | str = 1.0, ) -> MeshSeries: """Scale the spatial coordinates and timevalues. Useful to convert to other units, e.g. "m" to "km" or "s" to "a". Converts from SI units (i.e. 'm' and 's') to the given arguments. Does not create a copy, but modifies the calling object. If you want to have a scaled version without changing the original do `ms_scaled = ms.copy().scale(...)` :param spatial: Float factor or str for target unit. :param time: Float factor or str for target unit. :returns: The scaled MeshSeries. """ Qty = u_reg.Quantity if isinstance(spatial, str): spatial_factor = Qty(self.spatial_unit, spatial).magnitude spatial_unit = Qty(1, spatial) else: spatial_factor = spatial spatial_unit = self.spatial_unit / spatial if isinstance(time, str): time_factor = Qty(self.time_unit, time).magnitude time_unit = Qty(1, time) else: time_factor = time time_unit = self.time_unit / time if time_factor == 1.0: scaled_cache = self._mesh_cache else: scaled_cache = { timevalue * time_factor: self._mesh_cache.pop(timevalue) for timevalue in list(self._mesh_cache.keys()) } if spatial_factor != 1.0: for mesh in scaled_cache.values(): # using transform would shorten this, but doing it explicitly # allows us to use inplace=True which is a bit more efficient mesh.scale(spatial_factor, inplace=True) if (func := self._mesh_func_opt) is None: self._mesh_func_opt = lambda mesh: mesh.scale(spatial_factor) else: self._mesh_func_opt = lambda mesh: func(mesh).scale( # type: ignore[misc] spatial_factor ) self._mesh_cache = scaled_cache self.spatial_unit = spatial_unit self.time_unit = time_unit self._time_factor = self._time_factor * time_factor return self
[docs] @classmethod @typechecked def difference( cls, ms_a: MeshSeries, ms_b: MeshSeries, variable: Variable | str | None = None, ) -> MeshSeries: """ Compute difference of variables between the two MeshSeries instances from which this method is called and a second MeshSeries instance passed as method parameter. Returns new instance of MeshSeries: ms = ms_a - ms_b :param ms_a: The mesh from which data is to be subtracted. :param ms_b: The mesh whose data is to be subtracted. :param variable: The variable of interest. If not given, all point and cell_data will be processed raw. :returns: MeshSeries containing the difference of variable` or of all datasets between both MeshSeries. """ from ogstools.mesh.differences import difference_pairwise if np.array_equal(ms_a.timevalues, ms_b.timevalues): ms_b_ok = ms_b else: ms_b_ok = ms_b.resample_temporal(ms_a.timevalues) msg = ( "Two instances of MeshSeries have different time values. " "Direct difference cannot be computed. " "ms_b will be interpolated to match timesteps of ms_a." ) warn(msg, RuntimeWarning, stacklevel=2) return MeshSeries.from_data( difference_pairwise(ms_a, ms_b_ok, variable), ms_a.timevalues, (ms_a.spatial_unit, ms_a.spatial_unit), (ms_a.time_unit, ms_a.time_unit), )
[docs] @staticmethod def compare( ms_a: MeshSeries, ms_b: MeshSeries, variable: Variable | str | None = None, point_data: bool = True, cell_data: bool = True, field_data: bool = True, fast_compare: bool = False, atol: float = 0.0, *, strict: bool = False, ) -> bool: """ Method to compare two ``ot.MeshSeries`` objects. Returns ``True`` if they match within the tolerances, otherwise ``False``. :param ms_a: The reference base MeshSeries for comparison. :param ms_b: The MeshSeries to compare against the reference. :param variable: The variable of interest. If not given, all point, cell and field data will be processed. :param point_data: Compare all point data if `variable` is None. :param cell_data: Compare all cell data if `variable` is None. :param field_data: Compare all field data if `variable` is None. :param fast_compare: If ``True``, mesh topology is only verified for the first timestep, otherwise, topology is checked at every timestep. :param atol: Absolute tolerance. :param strict: Raises an ``AssertionError``, if mismatch. """ from ogstools.mesh.differences import compare # Testing timevalues if ms_a.timevalues.shape != ms_b.timevalues.shape: return False if not np.allclose( ms_a.timevalues, ms_b.timevalues, atol=atol, rtol=0.0, equal_nan=True, ): if strict: err_msg = "timevalues differs between MeshSeries." raise AssertionError(err_msg) return False # Testing mesh-wise for i, (mesh_a, mesh_b) in enumerate(zip(ms_a, ms_b, strict=True)): check_topo = fast_compare or i == 0 if not compare( mesh_a, mesh_b, variable=variable, point_data=point_data, cell_data=cell_data, field_data=field_data, check_topology=check_topo, atol=atol, strict=strict, ): return False return True
[docs] @typechecked def extract( self, index: slice | int | np.ndarray | list, preference: Literal["points", "cells"] = "points", ) -> MeshSeries: """ Extract a subset of the domain by point or cell indices. :param index: Indices of points or cells to extract. :param preference: Selected entities. :returns: A MeshSeries with the selected domain subset. """ func: dict[ str, Callable[[pv.UnstructuredGrid], pv.UnstructuredGrid] ] = { "points": lambda mesh: mesh.extract_points( np.arange(mesh.n_points)[index], include_cells=False ), "cells": lambda mesh: mesh.extract_cells( np.arange(mesh.n_points)[index] ), } return self.transform(func[preference])
def _rename_vtufiles(self, new_pvd_fn: Path, fns: list[Path]) -> list: fns_new: list[Path] = [] assert self.filepath is not None for filename in fns: filepathparts_at_timestep = list(filename.parts) filepathparts_at_timestep[-1] = filepathparts_at_timestep[ -1 ].replace( self.filepath.name.split(".")[0], new_pvd_fn.name.split(".")[0], ) fns_new.append(Path(*filepathparts_at_timestep)) return fns_new def _save_vtu( self, new_pvd_fn: Path, fns: list[Path], ascii: bool = False ) -> None: from ogstools.mesh import save for i, timestep in enumerate(self.timesteps): if ".vtu" in fns[i].name: save( self.mesh(i), Path(new_pvd_fn.parent, fns[i].name), binary=not ascii, ) elif ".xdmf" in fns[i].name: newname = fns[i].name.replace( ".xdmf", f"_ts_{timestep}_t_{self.timevalues[i]}.vtu" ) save(self.mesh(i), Path(new_pvd_fn.parent, newname)) else: s = "File type not supported." raise RuntimeError(s) def _save_pvd(self, new_pvd_fn: Path, fns: list[Path]) -> None: from lxml import etree as ET root = ET.Element("VTKFile") root.attrib["type"] = "Collection" root.attrib["version"] = "0.1" root.attrib["byte_order"] = "LittleEndian" root.attrib["compressor"] = "vtkZLibDataCompressor" collection = ET.SubElement(root, "Collection") for i, timestep in enumerate(self.timevalues): timestepselement = ET.SubElement(collection, "DataSet") timestepselement.attrib["timestep"] = str(timestep) timestepselement.attrib["group"] = "" timestepselement.attrib["part"] = "0" if ".xdmf" in fns[i].name: newname = fns[i].name.replace( ".xdmf", f"_ts_{self.timesteps[i]}_t_{timestep}.vtu" ) timestepselement.attrib["file"] = newname elif ".vtu" in fns[i].name: timestepselement.attrib["file"] = fns[i].name else: s = "File type not supported." raise RuntimeError(s) tree = ET.ElementTree(root) tree.write( new_pvd_fn, encoding="ISO-8859-1", xml_declaration=True, pretty_print=True, ) def _check_path(self, filename: Path | None) -> Path: if not isinstance(filename, Path): s = "filename is empty" raise RuntimeError(s) assert isinstance(filename, Path) return cast(Path, filename)
[docs] def save( self, target: Path | str | None = None, deep: bool = True, ascii: bool = False, overwrite: bool | None = None, dry_run: bool = False, archive: bool = False, id: str | None = None, ) -> list[Path]: user_defined = self._pre_save(target, overwrite, dry_run, id=id) files = self._save_impl(deep=deep, ascii=ascii, dry_run=dry_run) self._post_save(user_defined, archive, dry_run) return files
def _save_impl( self, deep: bool = True, ascii: bool = False, dry_run: bool = False, ) -> list[Path]: """ Save mesh series to disk. :param filename: Filename to save the series to. Extension specifies the file type. Currently only PVD is supported. :param deep: Specifies whether VTU/H5 files should be written. :param ascii: Specifies if ascii or binary format should be used, defaults to binary (False) - True for ascii. """ fn = self.next_target fn.parent.mkdir(parents=True, exist_ok=True) fns = [ self._check_path(self.mesh(t).filepath) for t in np.arange(len(self.timevalues), dtype=int) ] if ".pvd" in fn.name: if deep is True: fns = self._rename_vtufiles(fn, fns) if dry_run: # For dry_run, return paths where files would be written fns_written = [fn.parent / f.name for f in fns] return [fn] + fns_written self._save_vtu(fn, fns, ascii=ascii) # Update fns to actual written paths fns = [fn.parent / f.name for f in fns] if dry_run: fns_written = [fn.parent / f.name for f in fns] return [fn] + fns_written self._save_pvd(fn, fns) else: s = "Currently the save method is implemented for PVD/VTU only." raise RuntimeError(s) # Return both the PVD file and all VTU files return [fn] + fns
[docs] @deprecated(""" Please use `del meshseries.field_data[key]` or, if you want to keep the data in the last timestep: `del meshseries[:-1].field_data[key]`. """) def remove_array( self, name: str, data_type: str = "field", skip_last: bool = False ) -> None: """ Removes an array from all time slices of the mesh series. :param name: Array name :param data_type: Data type of the array. Could be either field, cell or point :param skip_last: Skips the last time slice (e.g. for restart purposes). """ for i, mesh in enumerate(self): if ((skip_last) is False) or (i < len(self) - 1): if data_type == "field": mesh.field_data.remove(name) elif data_type == "cell": mesh.cell_data.remove(name) elif data_type == "point": mesh.point_data.remove(name) else: msg = "array type unknown" raise RuntimeError(msg)
def _propagate_target(self) -> None: pass def __repr__(self) -> str: cls_name = self.__class__.__name__ if self._id and self.user_specified_target: construct = f'{cls_name}.from_id("{self._id}")' else: filepath = self.active_target or self.next_target filepath_str = str(filepath) if filepath is not None else None construct = ( f"{cls_name}(" f"filepath={filepath_str!r}, " f"spatial_unit={str(self.spatial_unit.units)!r}, " f"time_unit={str(self.time_unit.units)!r}" f")" ) base_repr = super().__repr__() return ( f"{construct}\n" f"{base_repr}\n" f"data_type={self._data_type!r}\n" f"num_timesteps={len(self.timevalues)}\n" f"dim={getattr(self, 'dim', None)!r}\n" f"cached_meshes={len(self._mesh_cache)}\n" )