Source code for mlmc.quantity.quantity_types

import abc
import copy
import numpy as np
from scipy import interpolate
from typing import List, Tuple
import mlmc.quantity.quantity


[docs] class QType(metaclass=abc.ABCMeta): """ Base class for quantity types. :param qtype: inner/contained QType or Python type """ def __init__(self, qtype): self._qtype = qtype
[docs] def size(self) -> int: """ Size of the type in flattened units. :return: int """ raise NotImplementedError
[docs] def base_qtype(self): """ Return the base scalar/bool type for nested types. :return: QType """ return self._qtype.base_qtype()
[docs] def replace_scalar(self, substitute_qtype): """ Find ScalarType and replace it with substitute_qtype. :param substitute_qtype: QType that replaces ScalarType :return: QType (new instance with scalar replaced) """ inner_qtype = self._qtype.replace_scalar(substitute_qtype) new_qtype = copy.deepcopy(self) new_qtype._qtype = inner_qtype return new_qtype
[docs] @staticmethod def keep_dims(chunk: np.ndarray) -> np.ndarray: """ Ensure chunk has shape [M, chunk size, 2]. For scalar quantities the input block can have shape (chunk size, 2). Sometimes we need to 'flatten' first few dimensions to achieve desired chunk shape. :param chunk: numpy array :return: numpy array with shape [M, chunk size, 2] :raises ValueError: if chunk.ndim < 2 """ # Keep dims [M, chunk size, 2] if len(chunk.shape) == 2: chunk = chunk[np.newaxis, :] elif len(chunk.shape) > 2: chunk = chunk.reshape((int(np.prod(chunk.shape[:-2])), chunk.shape[-2], chunk.shape[-1])) else: raise ValueError("Chunk shape not supported: need ndim >= 2") return chunk
def _make_getitem_op(self, chunk: np.ndarray, key): """ Extract a slice from chunk while preserving chunk dims. :param chunk: level chunk, numpy array with shape [M, chunk size, 2] :param key: index/slice used by parent QType :return: numpy array with shape [M', chunk size', 2] """ return QType.keep_dims(chunk[key])
[docs] def reshape(self, data: np.ndarray) -> np.ndarray: """ Default reshape (identity). :param data: numpy array :return: numpy array """ return data
[docs] class ScalarType(QType): """ Scalar quantity type (leaf type). """ def __init__(self, qtype=float): """ :param qtype: Python type or nested type used as underlying scalar type """ self._qtype = qtype
[docs] def base_qtype(self): """ :return: base scalar QType (self or underlying BoolType base) """ if isinstance(self._qtype, BoolType): return self._qtype.base_qtype() return self
[docs] def size(self) -> int: """ :return: int size of the scalar (defaults to 1 or uses `_qtype.size()` if present) """ if hasattr(self._qtype, "size"): return self._qtype.size() return 1
[docs] def replace_scalar(self, substitute_qtype): """ Replace ScalarType with substitute type. :param substitute_qtype: QType that replaces ScalarType :return: substitute_qtype """ return substitute_qtype
[docs] class BoolType(ScalarType): """ Boolean scalar type (inherits ScalarType). """ pass
[docs] class ArrayType(QType): """ Array quantity type. :param shape: int or tuple describing array shape :param qtype: contained QType for array elements """ def __init__(self, shape, qtype: QType): if isinstance(shape, int): shape = (shape,) self._shape = shape self._qtype = qtype
[docs] def size(self) -> int: """ :return: total flattened size (product of shape * inner qtype size) """ return int(np.prod(self._shape)) * int(self._qtype.size())
[docs] def get_key(self, key): """ ArrayType indexing. :param key: int, tuple of ints or slice objects :return: Tuple (QuantityType, offset) where offset is 0 for this implementation """ # Get new shape by applying indexing on an empty array of the target shape new_shape = np.empty(self._shape)[key].shape # If one selected item is considered to be a scalar QType if len(new_shape) == 1 and new_shape[0] == 1: new_shape = () # Result is also array if len(new_shape) > 0: q_type = ArrayType(new_shape, qtype=self._qtype) # Result is single array item else: q_type = self._qtype return q_type, 0
def _make_getitem_op(self, chunk: np.ndarray, key): """ Slice operation for ArrayType while restoring original shape. :param chunk: numpy array [M, chunk size, 2] :param key: slice or index to apply on the array-shaped leading dims :return: numpy array with preserved dims via QType.keep_dims """ assert self._shape is not None chunk = chunk.reshape((*self._shape, chunk.shape[-2], chunk.shape[-1])) return QType.keep_dims(chunk[key])
[docs] def reshape(self, data: np.ndarray) -> np.ndarray: """ Reshape flattened data to array shape. :param data: numpy array :return: reshaped numpy array """ if isinstance(self._qtype, ScalarType): return data.reshape(self._shape) else: # assume trailing dimension belongs to inner types total = np.prod(data.shape) leading = int(np.prod(self._shape)) return data.reshape((*self._shape, int(total // leading)))
[docs] class TimeSeriesType(QType): """ Time-series quantity type. :param times: iterable of time points :param qtype: QType for each time slice """ def __init__(self, times, qtype): if isinstance(times, np.ndarray): times = times.tolist() self._times = times self._qtype = qtype
[docs] def size(self) -> int: """ :return: total size = number of time points * inner qtype.size() """ return len(self._times) * int(self._qtype.size())
[docs] def get_key(self, key): """ Get a qtype and offset corresponding to a given time key. :param key: time value to locate :return: Tuple (q_type, offset) """ q_type = self._qtype try: position = self._times.index(key) except ValueError: # keep behavior similar to original: print available items print( "Item " + str(key) + " was not found in TimeSeries" + ". Available items: " + str(list(self._times)) ) # raise to make the error explicit raise return q_type, position * q_type.size()
[docs] @staticmethod def time_interpolation(quantity, value): """ Interpolate a time-series quantity to a single time value. :param quantity: Quantity instance with qtype being a TimeSeriesType :param value: float time value where to interpolate :return: Quantity object representing interpolated value """ def interp(y): split_indices = np.arange(1, len(quantity.qtype._times)) * quantity.qtype._qtype.size() y = np.split(y, split_indices, axis=-3) f = interpolate.interp1d(quantity.qtype._times, y, axis=0) return f(value) return mlmc.quantity.quantity.Quantity( quantity_type=quantity.qtype._qtype, input_quantities=[quantity], operation=interp )
[docs] class FieldType(QType): """ Field type composed of named entries each having the same base qtype. :param args: List of (name, QType) pairs """ def __init__(self, args: List[Tuple[str, QType]]): self._dict = dict(args) self._qtype = args[0][1] assert all(q_type.size() == self._qtype.size() for _, q_type in args)
[docs] def size(self) -> int: """ :return: total size = number of fields * inner qtype size """ return len(self._dict.keys()) * int(self._qtype.size())
[docs] def get_key(self, key): """ Access sub-field by name. :param key: field name :return: Tuple (q_type, offset) """ q_type = self._qtype try: position = list(self._dict.keys()).index(key) except ValueError: print( "Key " + str(key) + " was not found in FieldType" + ". Available keys: " + str(list(self._dict.keys())[:5]) + "..." ) raise return q_type, position * q_type.size()
[docs] class DictType(QType): """ Dictionary-like type of named QTypes which may differ in size. :param args: List of (name, QType) pairs """ def __init__(self, args: List[Tuple[str, QType]]): self._dict = dict(args) # keep ordered mapping semantics self._check_base_type() def _check_base_type(self): """ Ensure all contained qtypes share the same base_qtype. :raises TypeError: if base_qtypes differ """ qtypes = list(self._dict.values()) qtype_0_base_type = qtypes[0].base_qtype() for qtype in qtypes[1:]: if not isinstance(qtype.base_qtype(), type(qtype_0_base_type)): raise TypeError( "qtype {} has base QType {}, expecting {}. " "All QTypes must have same base QType, either ScalarType or BoolType".format( qtype, qtype.base_qtype(), qtype_0_base_type ) )
[docs] def base_qtype(self): """ :return: base_qtype of the first element """ return next(iter(self._dict.values())).base_qtype()
[docs] def size(self) -> int: """ :return: total flattened size (sum of sizes of contained qtypes) """ return int(sum(q_type.size() for _, q_type in self._dict.items()))
[docs] def get_qtypes(self): """ :return: iterable of contained qtypes """ return self._dict.values()
[docs] def replace_scalar(self, substitute_qtype): """ Replace scalar types recursively inside dict entries. :param substitute_qtype: QType that replaces ScalarType :return: new DictType instance """ dict_items = [] for key, qtype in self._dict.items(): new_qtype = qtype.replace_scalar(substitute_qtype) dict_items.append((key, new_qtype)) return DictType(dict_items)
[docs] def get_key(self, key): """ Return the QType and starting offset for a named key. :param key: name of entry :return: Tuple (q_type, start_offset) """ try: q_type = self._dict[key] except KeyError: print( "Key " + str(key) + " was not found in DictType" + ". Available keys: " + str(list(self._dict.keys())[:5]) + "..." ) raise start = 0 for k, qt in self._dict.items(): if k == key: break start += qt.size() return q_type, start