Source code for ogstools.variables.vector

# SPDX-FileCopyrightText: Copyright (c) OpenGeoSys Community (opengeosys.org)
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Callable
from dataclasses import dataclass
from typing import ClassVar, Literal, TypeAlias, TypeVar

import numpy as np
from pint.facets.plain import PlainQuantity

from ogstools.variables.variable import Scalar, Variable

from .tensor_math import _split_quantity, _to_quantity

ValType: TypeAlias = PlainQuantity | np.ndarray

T = TypeVar("T")


[docs] def vector_norm(values: ValType) -> ValType: ":returns: The norm of the vector." vals, unit = _split_quantity(values) result = np.linalg.norm(vals, axis=-1) return _to_quantity(result, unit)
[docs] class Vector(Variable): """Represent a vector variable. Vector variables should contain either 2 (2D) or 3 (3D) components. Vector components can be accesses with brackets e.g. displacement[0] """
[docs] def __getitem__(self, index: int | Literal["x", "y", "z"]) -> Scalar: """ Get a scalar variable as a specific component of the vector variable. :param index: The index of the component. :returns: A scalar variable as a vector component. """ int_index = index if isinstance(index, int) else "xyz".index(index) return Scalar.from_variable( self, output_name=self.output_name + f"_{index}", symbol=f"{{{self.symbol}}}_{index}", func=lambda x: self.func(x)[..., int_index], bilinear_cmap=True, )
@property def magnitude(self) -> Scalar: ":returns: A scalar variable as the magnitude of the vector." return Scalar.from_variable( self, output_name=self.output_name + "_magnitude", symbol=f"||{{{self.symbol}}}||", func=lambda x: vector_norm(self.func(x)), )
[docs] @dataclass class BHE_Vector(Variable): """ ========= =========================== BHE type available Vector components ========= =========================== 1U in, out, grout1, grout2 2U in1, in2, out1, out2, grout1, grout2, grout3, grout4 1P in, grout CXC in, out, grout CXA in, out, grout ========= =========================== """ BHE_COMPONENTS: ClassVar[dict[str, list[str]]] = { "1U": ["in", "out", "grout1", "grout2"], "2U": ["in1", "in2", "out1", "out2", "grout1", "grout2", "grout3", "grout4"], "CXC": ["in", "out", "grout"], "CXA": ["in", "out", "grout"], "1P": ["in", "grout"], } # fmt: skip
[docs] def __getitem__(self, index: int | str | tuple) -> Scalar: """ Get a scalar variable as a specific component of the vector variable. :param index: The index of the component. :returns: A scalar variable as a vector component. """ if isinstance(index, tuple) and len(index) > 2: msg = "Expected at most two indices: (BHE number, component)" raise IndexError(msg) suffix = f"{index[0]}" if isinstance(index, tuple) else "" comp_index = index[1] if isinstance(index, tuple) else index def get_component( comp_index: int | str | list[int] | list[str], ) -> Callable: def component_selector(x: T) -> T: data: np.ndarray = self.func(x) len_data = data.shape[-1] for _, components in BHE_Vector.BHE_COMPONENTS.items(): if len_data == len(components): if isinstance(comp_index, list): if all(isinstance(i, int) for i in comp_index): return data[..., comp_index] if all(isinstance(i, str) for i in comp_index): component_index_list = [] for comp in comp_index: assert isinstance( comp, str ) # Type assertion to make mypy happy component_index_list.append( components.index(comp) ) return data[..., component_index_list] msg = f"Unknown str index list {comp_index}" raise ValueError(msg) if isinstance(comp_index, str): if comp_index in components: return data[..., components.index(comp_index)] msg = f"Unknown str index {comp_index}" raise ValueError(msg) if isinstance(comp_index, int): return data[..., comp_index] msg = f"Unknown BHE type with BHE vector length {len_data}" raise ValueError(msg) return component_selector return Scalar.from_variable( self, data_name=self.data_name + suffix, output_name=self.output_name + suffix + f"_{comp_index}", symbol=f"{{{self.symbol}}}_{comp_index}", func=get_component(comp_index), )
@property def magnitude(self) -> Scalar: ":returns: A scalar variable as the magnitude of the vector." msg = """You tried to get the magnitude of a BHE temperature vector, which most likely is unintended. Please access the different components via indexing: e.g. ot.variables.temperature_BHE["T_in"].\n""" + str( BHE_Vector.__doc__ ) raise TypeError(msg)
[docs] @dataclass class VectorList(Variable): """Represent a list of vector variables."""
[docs] def __getitem__(self, index: int) -> Vector: ":returns: A vector variable as a component of the vectorlist variable." return Vector.from_variable( self, output_name=self.output_name + f"_{index}", symbol=f"{{{self.symbol}}}_{index}", func=lambda x: np.take(self.func(x), index, axis=-1), )