Source code for taurex.util.math

"""Optimized Math functions used in taurex"""


import typing as t

import numpy as np
import numpy.typing as npt

from taurex.log import setup_log
from taurex.types import AnyValType

_log = setup_log(__name__)

try:
    from .math_numba import intepr_bilin_numba_II, interp_lin_numba

    numba_enabled = True
except ImportError:
    _log.warning("Numba not installed, using numpy")
    numba_enabled = False


[docs] def interp_exp_and_lin_numpy( x11: npt.NDArray[np.float64], x12: npt.NDArray[np.float64], x21: npt.NDArray[np.float64], x22: npt.NDArray[np.float64], temperature: float, temperature_min: float, temperature_max: float, pressure: float, pressure_min: float, pressure_max: float, ) -> npt.NDArray[np.float64]: """2D interpolation. Applies linear interpolation across pressure and e interpolation across temperature between pressure_min,pressure_max and temperature_min,temperature_max Parameters ---------- x11: array Array corresponding to pressure_min,temperature_min x12: array Array corresponding to pressure_min,temperature_max x21: array Array corresponding to pressure_max,temperature_min x22: array Array corresponding to pressure_max,temperature_max temperature: float Coordinate to exp interpolate to temperature_min: float Nearest known temperature coordinate where temperature_min < temperature temperature_max: float Nearest known temperature coordinate where temperature < temperature_max pressure: float Coordinate to linear interpolate to pressure_min: float Nearest known pressure coordinate where pressure_min < pressure pressure_max: float Nearest known pressure coordinate where pressure < pressure_max """ return ( (x11 * (pressure_max - pressure_min) - (pressure - pressure_min) * (x11 - x21)) * np.exp( temperature_max * (-temperature + temperature_min) * np.log( ( x11 * (pressure_max - pressure_min) - (pressure - pressure_min) * (x11 - x21) ) / ( x12 * (pressure_max - pressure_min) - (pressure - pressure_min) * (x12 - x22) ) ) / (temperature * (temperature_max - temperature_min)) ) / (pressure_max - pressure_min) )
[docs] def interp_exp_numpy( x11: npt.NDArray[np.float64], x12: npt.NDArray[np.float64], temperature, temperature_min, temperature_max, ) -> npt.NDArray[np.float64]: return x11 * np.exp( temperature_max * (-temperature + temperature_min) * np.log(x11 / x12) / (temperature * (temperature_max - temperature_min)) )
[docs] def interp_lin_numpy( x11: npt.NDArray[np.float64], x12: npt.NDArray[np.float64], pressure: float, pressure_min: float, pressure_max: float, ) -> npt.NDArray[np.float64]: return ( x11 * (pressure_max - pressure_min) - (pressure - pressure_min) * (x11 - x12) ) / (pressure_max - pressure_min)
[docs] def interp_bilin_numpy( x11: npt.NDArray[np.float64], x12: npt.NDArray[np.float64], x21: npt.NDArray[np.float64], x22: npt.NDArray[np.float64], temperature: float, temperature_min: float, temperature_max: float, pressure: float, pressure_min: float, pressure_max: float, ) -> npt.NDArray[np.float64]: pressure_diff = pressure_max - pressure_min temperature_diff = temperature_max - temperature_min pressure_scale = (pressure - pressure_min) / pressure_diff temperature_scale = (temperature - temperature_min) / temperature_diff return ( x11 - pressure_scale * (x11 - x21) - pressure_scale * temperature_scale * (x21 - x11 + x12 - x22) - temperature_scale * (x11 - x12) )
[docs] def intepr_bilin_numexpr( x11: npt.NDArray[np.float64], x12: npt.NDArray[np.float64], x21: npt.NDArray[np.float64], x22: npt.NDArray[np.float64], temperature: float, temperature_min: float, temperature_max: float, pressure: float, pressure_min: float, pressure_max: float, ) -> npt.NDArray[np.float64]: import numexpr as ne return ne.evaluate( "(x11*(pressure_max - pressure_min)*(temperature_max - temperature_min)" " - (pressure - pressure_min)*(temperature_max - temperature_min)*(x11 - x21)" " - (temperature - temperature_min)*(-(pressure - pressure_min)*(x11 - x21)" " + (pressure - pressure_min)*(x12 - x22) + (pressure_max - " "pressure_min)*(x11 - x12)))/((pressure_max - " "pressure_min)*(temperature_max - temperature_min))" )
[docs] def intepr_bilin_double( x11: npt.NDArray[np.float64], x12: npt.NDArray[np.float64], x21: npt.NDArray[np.float64], x22: npt.NDArray[np.float64], temperature: float, temperature_min: float, temperature_max: float, pressure: float, pressure_min: float, pressure_max: float, ) -> npt.NDArray[np.float64]: return interp_lin_only( interp_lin_only(x11, x12, temperature, temperature_min, temperature_max), interp_lin_only(x21, x22, temperature, temperature_min, temperature_max), pressure, pressure_min, pressure_max, )
[docs] def intepr_bilin_old( x11: npt.NDArray[np.float64], x12: npt.NDArray[np.float64], x21: npt.NDArray[np.float64], x22: npt.NDArray[np.float64], temperature: float, temperature_min: float, temperature_max: float, pressure: float, pressure_min: float, pressure_max: float, ) -> npt.NDArray[np.float64]: return ( x11 * (pressure_max - pressure_min) * (temperature_max - temperature_min) - (pressure - pressure_min) * (temperature_max - temperature_min) * (x11 - x21) - (temperature - temperature_min) * ( -(pressure - pressure_min) * (x11 - x21) + (pressure - pressure_min) * (x12 - x22) + (pressure_max - pressure_min) * (x11 - x12) ) ) / ((pressure_max - pressure_min) * (temperature_max - temperature_min))
[docs] def compute_rayleigh_cross_section( wngrid: npt.NDArray[np.float64], n: float, n_air: t.Optional[float] = 2.6867805e25, king: t.Optional[float] = 1.0, ) -> npt.NDArray[np.float64]: """Compute Rayleigh cross section.""" wlgrid = (10000 / wngrid) * 1e-6 n_factor = (n**2 - 1) / (n_air * (n**2 + 2)) sigma = 24.0 * (np.pi**3) * king * (n_factor**2) / (wlgrid**4) return sigma
[docs] def test_nan(val: t.Union[float, npt.ArrayLike]) -> bool: """Test if a value is nan.""" if hasattr(val, "__len__"): try: return np.isnan(val).any() except TypeError: # print(type(val)) return True else: return val != val
# Choose the best functions for the task if numba_enabled: interp_lin_only = interp_lin_numba intepr_bilin = intepr_bilin_numba_II else: interp_lin_only = interp_lin_numpy intepr_bilin = interp_bilin_numpy interp_exp_and_lin = interp_exp_and_lin_numpy interp_exp_only = interp_exp_numpy
[docs] class OnlineVariance: """USes the M2 algorithm to compute the variance in a streaming fashion""" def __init__(self) -> None: """Initialise the class.""" self.reset()
[docs] def reset(self) -> None: """Reset the class.""" self.count = 0.0 self.wcount = 0.0 self.wcount2 = 0.0 self.mean = None self.M2 = None
[docs] def update(self, value: AnyValType, weight: t.Optional[float] = 1.0): """Update the variance.""" self.count += 1 self.wcount += weight self.wcount2 += weight * weight if self.mean is None: self.mean = value * 0.0 self.M2 = value * 0.0 mean_old = self.mean try: self.mean = mean_old + (weight / self.wcount) * (value - mean_old) except ZeroDivisionError: self.mean = value * 0.0 self.M2 += weight * (value - mean_old) * (value - self.mean)
@property def variance(self) -> float: """Return the variance.""" if self.count < 2: return np.nan else: return self.M2 / self.wcount @property def sampleVariance(self) -> AnyValType: # noqa: N802 """Return the sample variance.""" if self.count < 2: return np.nan else: return self.M2 / (self.wcount - 1) # def combine_variance(self,averages, variances, counts): # good_idx = [idx for idx,a in enumerate(averages) if not test_nan(a)] # averages = [averages[idx] for idx in good_idx] # variances = [variances[idx] for idx in good_idx] # counts = [counts[idx] for idx in good_idx] # good_variance = None # if not test_nan(variances): # try: # good_variance = variances[np.where(~np.isnan(variances))[0][0]]*0.0 # except IndexError: # good_variance = None # #print(good_idx,'Good',good_variance) # variances = [v if not test_nan(v) else good_variance for v in variances] # #print('NEWAVERAGES',averages) # #print('NEW WEIGHTS',counts) # average = np.average(averages, weights=counts,axis=0) # #print('final average',average) # size = np.sum(counts) # counts = np.array(counts) * size/np.sum(counts) # if hasattr(average,'__len__'): # average = average[None,...] # for x in range(1,len(average.shape)): # counts = counts[:,None] # squares = 0.0 # if good_variance is not None: # squares = counts*np.nan_to_num(variances) # #print(counts,variances,squares) # squares = squares + counts*(average - averages)**2 # return average,np.sum(squares,axis=0)/size
[docs] def combine_variance( self, averages: npt.ArrayLike, variance: npt.ArrayLike, counts: npt.ArrayLike ) -> t.Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: """Combine different variance calculations together.""" average = None size = np.sum(counts) for avg, cnt in zip(averages, counts): if cnt == 0: continue # print('avg',avg) if avg is not None and avg is not np.nan: if average is None: average = avg * cnt else: average += avg * cnt average /= size # print('AVERGAE',average) counts = np.array(counts) * size / np.sum(counts) squares = None for avg, cnt, var in zip(averages, counts, variance): # print('COUNT ',cnt) if cnt == 0.0: continue if cnt > 0.0: if squares is None: squares = cnt * (average - avg) ** 2 else: squares += cnt * (average - avg) ** 2 if var is not np.nan: squares += cnt * var # squares = counts*variances # squares += counts*(average - averages)**2 return average, squares / size
[docs] def parallelVariance(self) -> AnyValType: # noqa: N802 """Compute the variance in parallel.""" from taurex import mpi variance = self.variance mean = self.mean if mean is None: mean = np.nan variances = mpi.allgather(variance) averages = mpi.allgather(mean) counts = mpi.allgather(self.wcount) all_counts = mpi.allgather(self.count) if sum(all_counts) < 2: return np.nan else: finalvariance = self.combine_variance(averages, variances, counts) return finalvariance[-1]