# Copyright (c) 2012-2024, OpenGeoSys Community (http://www.opengeosys.org)
# Distributed under a Modified BSD License.
# See accompanying file LICENSE.txt or
# http://www.opengeosys.org/project/license
#
from collections.abc import Callable
from dataclasses import dataclass
from typing import ClassVar, Literal, TypeAlias, TypeVar
import numpy as np
from pint.facets.plain import PlainQuantity
from ogstools.variables.variable import Scalar, Variable
from .tensor_math import _split_quantity, _to_quantity
ValType: TypeAlias = PlainQuantity | np.ndarray
T = TypeVar("T")
[docs]
def vector_norm(values: ValType) -> ValType:
":returns: The norm of the vector."
vals, unit = _split_quantity(values)
result = np.linalg.norm(vals, axis=-1)
return _to_quantity(result, unit)
[docs]
class Vector(Variable):
"""Represent a vector variable.
Vector variables should contain either 2 (2D) or 3 (3D) components.
Vector components can be accesses with brackets e.g. displacement[0]
"""
[docs]
def __getitem__(self, index: int | Literal["x", "y", "z"]) -> Scalar:
"""
Get a scalar variable as a specific component of the vector variable.
:param index: The index of the component.
:returns: A scalar variable as a vector component.
"""
int_index = index if isinstance(index, int) else "xyz".index(index)
return Scalar.from_variable(
self,
output_name=self.output_name + f"_{index}",
symbol=f"{{{self.symbol}}}_{index}",
func=lambda x: self.func(x)[..., int_index],
bilinear_cmap=True,
)
@property
def magnitude(self) -> Scalar:
":returns: A scalar variable as the magnitude of the vector."
return Scalar.from_variable(
self,
output_name=self.output_name + "_magnitude",
symbol=f"||{{{self.symbol}}}||",
func=lambda x: vector_norm(self.func(x)),
)
[docs]
@dataclass
class BHE_Vector(Variable):
"""
========= ===========================
BHE type available Vector components
========= ===========================
1U in, out, grout1, grout2
2U in1, in2, out1, out2, grout1, grout2, grout3, grout4
1P in, grout
CXC in, out, grout
CXA in, out, grout
========= ===========================
"""
BHE_COMPONENTS: ClassVar[dict[str, list[str]]] = {
"1U": ["in", "out", "grout1", "grout2"],
"2U": ["in1", "in2", "out1", "out2", "grout1", "grout2", "grout3", "grout4"],
"CXC": ["in", "out", "grout"],
"CXA": ["in", "out", "grout"],
"1P": ["in", "grout"],
} # fmt: skip
[docs]
def __getitem__(self, index: int | str | tuple) -> Scalar:
"""
Get a scalar variable as a specific component of the vector variable.
:param index: The index of the component.
:returns: A scalar variable as a vector component.
"""
if isinstance(index, tuple) and len(index) > 2:
msg = "Expected at most two indices: (BHE number, component)"
raise IndexError(msg)
suffix = f"{index[0]}" if isinstance(index, tuple) else ""
comp_index = index[1] if isinstance(index, tuple) else index
def get_component(
comp_index: int | str | list[int] | list[str],
) -> Callable:
def component_selector(x: T) -> T:
data: np.ndarray = self.func(x)
len_data = data.shape[-1]
for _, components in BHE_Vector.BHE_COMPONENTS.items():
if len_data == len(components):
if isinstance(comp_index, list):
if all(isinstance(i, int) for i in comp_index):
return data[..., comp_index]
if all(isinstance(i, str) for i in comp_index):
component_index_list = []
for comp in comp_index:
assert isinstance(
comp, str
) # Type assertion to make mypy happy
component_index_list.append(
components.index(comp)
)
return data[..., component_index_list]
msg = f"Unknown str index list {comp_index}"
raise ValueError(msg)
if isinstance(comp_index, str):
if comp_index in components:
return data[..., components.index(comp_index)]
msg = f"Unknown str index {comp_index}"
raise ValueError(msg)
if isinstance(comp_index, int):
return data[..., comp_index]
msg = f"Unknown BHE type with BHE vector length {len_data}"
raise ValueError(msg)
return component_selector
return Scalar.from_variable(
self,
data_name=self.data_name + suffix,
output_name=self.output_name + suffix + f"_{comp_index}",
symbol=f"{{{self.symbol}}}_{comp_index}",
func=get_component(comp_index),
)
@property
def magnitude(self) -> Scalar:
":returns: A scalar variable as the magnitude of the vector."
msg = """You tried to get the magnitude of a BHE temperature vector,
which most likely is unintended. Please access the different components
via indexing: e.g. ot.variables.temperature_BHE["T_in"].\n""" + str(
BHE_Vector.__doc__
)
raise TypeError(msg)
[docs]
@dataclass
class VectorList(Variable):
"""Represent a list of vector variables."""
[docs]
def __getitem__(self, index: int) -> Vector:
":returns: A vector variable as a component of the vectorlist variable."
return Vector.from_variable(
self,
output_name=self.output_name + f"_{index}",
symbol=f"{{{self.symbol}}}_{index}",
func=lambda x: np.take(self.func(x), index, axis=-1),
)