Source code for ogstools.variables.vector
# Copyright (c) 2012-2024, OpenGeoSys Community (http://www.opengeosys.org)
# Distributed under a Modified BSD License.
# See accompanying file LICENSE.txt or
# http://www.opengeosys.org/project/license
#
from dataclasses import dataclass
from typing import Literal, TypeAlias
import numpy as np
from pint.facets.plain import PlainQuantity
from ogstools.variables.variable import Scalar, Variable
from .tensor_math import _split_quantity, _to_quantity
ValType: TypeAlias = PlainQuantity | np.ndarray
[docs]
def vector_norm(values: ValType) -> ValType:
":returns: The norm of the vector."
vals, unit = _split_quantity(values)
result = np.linalg.norm(vals, axis=-1)
return _to_quantity(result, unit)
[docs]
@dataclass
class Vector(Variable):
"""Represent a vector variable.
Vector variables should contain either 2 (2D) or 3 (3D) components.
Vector components can be accesses with brackets e.g. displacement[0]
"""
[docs]
def __getitem__(self, index: int | Literal["x", "y", "z"]) -> Scalar:
"""
Get a scalar variable as a specific component of the vector variable.
:param index: The index of the component.
:returns: A scalar variable as a vector component.
"""
int_index = index if isinstance(index, int) else "xyz".index(index)
return Scalar.from_variable(
self,
output_name=self.output_name + f"_{index}",
symbol=f"{{{self.symbol}}}_{index}",
func=lambda x: self.func(x)[..., int_index],
bilinear_cmap=True,
)
@property
def magnitude(self) -> Scalar:
":returns: A scalar variable as the magnitude of the vector."
return Scalar.from_variable(
self,
output_name=self.output_name + "_magnitude",
symbol=f"||{{{self.symbol}}}||",
func=lambda x: vector_norm(self.func(x)),
)
[docs]
@dataclass
class VectorList(Variable):
"""Represent a list of vector variables."""
[docs]
def __getitem__(self, index: int) -> Vector:
":returns: A vector variable as a component of the vectorlist variable."
return Vector.from_variable(
self,
output_name=self.output_name + f"_{index}",
symbol=f"{{{self.symbol}}}_{index}",
func=lambda x: np.take(self.func(x), index, axis=-1),
)