Source code for ogstools.mesh.utils

# SPDX-FileCopyrightText: Copyright (c) OpenGeoSys Community (opengeosys.org)
# SPDX-License-Identifier: BSD-3-Clause

import shutil
import subprocess
from pathlib import Path
from tempfile import mkdtemp

import numpy as np
import pyvista as pv

from ogstools._find_ogs import cli

from .file_io import save


[docs] def node_reordering( mesh: pv.UnstructuredGrid, method: int = 1 ) -> pv.UnstructuredGrid: """Reorders mesh nodes to make a mesh compatible with OGS6. :param mesh: mesh whose nodes are to be reordered. :param method: 0: Reversing order of nodes for all elements.\n 1: Reversing order of nodes unless it's perceived correct by OGS6 standards. This is the default selection.\n 2: Fixing node ordering issues between VTK and OGS6 (only applies to prism-elements).\n 3: Re-ordering of mesh node vector such that all base nodes are sorted before all nonlinear nodes. """ tmp_file = Path(mkdtemp(prefix="node_reordering")) / "mesh.vtu" save(mesh, tmp_file) cli().NodeReordering(i=str(tmp_file), o=str(tmp_file), m=method) return pv.XMLUnstructuredGridReader(tmp_file).read()
[docs] def validate( mesh: pv.UnstructuredGrid | Path | str, strict: bool = False ) -> bool: """Check conformity of mesh with OGS. :param mesh: pyvista mesh or path to the mesh file. :param strict: If True, raise a UserWarning if checkMesh returns an error. """ if isinstance(mesh, pv.DataSet): mesh_file = str(Path(mkdtemp(prefix="validate")) / "mesh.vtu") save(mesh, mesh_file) else: mesh_file = str(mesh) # ToDo Either checkMesh must return status of mesh (not of itself) OR # cli() can handle stdout if shutil.which("checkMesh") is None: return True ret = subprocess.run( ["checkMesh", mesh_file, "-v"], stdout=subprocess.PIPE, check=False ) msg = ret.stdout.decode("utf-8") is_valid = "No errors found." in msg if not is_valid: print(msg) if strict and not is_valid: msg = "Provided mesh is not compliant with OGS." raise UserWarning(msg) return is_valid
[docs] def check_datatypes( mesh: pv.UnstructuredGrid, strict: bool = False, meshname: str = "" ) -> bool: mat_ids = mesh.cell_data.get("MaterialIDs", np.int32(0)) elem_ids = mesh.cell_data.get("bulk_element_ids", np.uint64(0)) node_ids = mesh.point_data.get("bulk_node_ids", np.uint64(0)) type_map = { mesh.points.dtype: ("Point coordinates", np.double), mat_ids.dtype: ("'MaterialIDs'", np.int32), elem_ids.dtype: ("'bulk_element_ids'", np.uint64), node_ids.dtype: ("'bulk_node_ids'", np.uint64), } for datatype, (name, ref_type) in type_map.items(): if datatype != ref_type: msg = ( f"{name} datatype needs to be {ref_type} for OGS, " f"but instead it is {datatype}. " ) if meshname != "": msg += f"Error raised by mesh with {meshname=}" if strict: raise TypeError(msg) return False return True
[docs] def reindex_material_ids(mesh: pv.UnstructuredGrid) -> None: unique_mat_ids = np.unique(mesh["MaterialIDs"]) id_map = dict( zip(*np.unique(unique_mat_ids, return_inverse=True), strict=True) ) mesh["MaterialIDs"] = np.int32(list(map(id_map.get, mesh["MaterialIDs"]))) return
[docs] def remove_data(mesh: pv.UnstructuredGrid, datanames: list[str]) -> None: for dataname in datanames: mesh.point_data.pop(dataname, None) mesh.cell_data.pop(dataname, None) mesh.field_data.pop(dataname, None)
[docs] def axis_ids_2D(mesh: pv.DataSet) -> tuple[int, int]: "Return the two axes, in which the mesh (predominantly) lives in." from ogstools.plot.utils import get_projection tri = pv.Triangle( [mesh.points[0], mesh.points[mesh.n_points // 2], mesh.points[-1]] ) axis_1, axis_2, _, _ = get_projection(tri) len1, len2 = (len(np.unique(mesh.points[:, ax])) for ax in [axis_1, axis_2]) if len1 == len2: if axis_2 > axis_1: return axis_1, axis_2 return axis_2, axis_1 if len1 <= len2: return axis_1, axis_2 return axis_2, axis_1
[docs] def reshape_obs_points( points: np.ndarray | list, mesh: pv.UnstructuredGrid | None = None ) -> np.ndarray: points = np.asarray(points) pts = points.reshape((-1, points.shape[-1])) # Add missing columns to comply with pyvista expectations if pts.shape[1] == 3: pts_pyvista = pts elif mesh is None: pts_pyvista = np.hstack( (pts, np.zeros((pts.shape[0], 3 - pts.shape[1]))) ) else: # Detect and handle flat dimensions geom = mesh.points flat_axis = np.argwhere(np.all(np.isclose(geom, geom[0]), axis=0)) flat_axis = flat_axis.flatten() if pts.shape[1] + len(flat_axis) < 3: err_msg = ( "Number of flat axis and number of coordinates" " in provided points doesn't add up to 3." " Please ensure that the provided points match" " the plane of the mesh." ) raise RuntimeError(err_msg) pts_pyvista = np.empty((pts.shape[0], 3)) pts_id = 0 for col_id in range(3): if col_id in flat_axis: pts_pyvista[:, col_id] = ( np.ones((pts.shape[0],)) * geom[0, col_id] ) else: pts_pyvista[:, col_id] = pts[:, pts_id] pts_id = pts_id + 1 return pts_pyvista