Source code for species.util.box_util
"""
Utility functions for boxes.
"""
import warnings
from typing import Dict, Optional
import numpy as np
from typeguard import typechecked
from species.core import constants
from species.core.box import ObjectBox
from species.read.read_model import ReadModel
from species.util.core_util import print_section
[docs]
@typechecked
def update_objectbox(
objectbox: ObjectBox, model_param: Dict[str, float], model: Optional[str] = None
) -> ObjectBox:
"""
Function for updating the spectra and/or photometric fluxes in
an :class:`~species.core.box_types.ObjectBox`, for example by
applying a flux scaling and/or error inflation.
Parameters
----------
objectbox : species.core.box.ObjectBox
Box with the object's data, including the spectra
and/or photometric fluxes.
model_param : dict
Dictionary with the model parameters. Should contain the
value(s) of the flux scaling and/or the error inflation.
model : str, None
Name of the atmospheric model. Only required for inflating
the errors of spectra. Otherwise, the argument can be set
to ``None``. Not required when ``model='petitradtrans'``
because the error inflation is implemented differently
with :class:`~species.fit.retrieval.AtmosphericRetrieval`.
Returns
-------
species.core.box.ObjectBox
The input box which includes the spectra with the
scaled fluxes and/or inflated errors.
"""
print_section("Update ObjectBox")
if objectbox.flux is not None:
for key, value in objectbox.flux.items():
instr_name = key.split(".")[0]
if f"{key}_error" in model_param:
# Inflate the photometric uncertainty of a filter
# Scale relative to the uncertainty
infl_factor = model_param[f"{key}_error"]
var_add = infl_factor**2 * value[1] ** 2
elif f"{instr_name}_error" in model_param:
# Inflate photometric uncertainty of an instrument
# Scale relative to the uncertainty
infl_factor = model_param[f"{instr_name}_error"]
var_add = infl_factor**2 * value[1] ** 2
else:
# No inflation required
var_add = None
if var_add is not None:
message = (
f"Inflating the uncertainty of {key} by a "
+ f"factor {infl_factor:.2f} to "
+ f"{np.sqrt(var_add):.2e} (W m-2 um-1)..."
)
print(message, end="", flush=True)
value[1] = np.sqrt(value[1] ** 2 + var_add)
print(" [DONE]")
objectbox.flux[key] = value
if objectbox.spectrum is not None:
# Check if there are any spectra
for key, value in objectbox.spectrum.items():
# Get the spectrum (3 columns)
spec_tmp = value[0]
if f"scaling_{key}" in model_param:
# Scale the flux of the spectrum
scaling = model_param[f"scaling_{key}"]
print(
f"Scaling the flux of {key} by: {scaling:.2f}...",
end="",
flush=True,
)
spec_tmp[:, 1] *= model_param[f"scaling_{key}"]
print(" [DONE]")
if f"error_{key}" in model_param:
if model is None:
warnings.warn(
"The dictionary with model parameters "
f"contains the error inflation for '{key}' "
"but the argument of 'model' is set to "
"'None'. Inflation of the errors is "
"therefore not possible."
)
elif model == "petitradtrans":
# Increase the errors by a constant value
add_error = 10.0 ** model_param[f"error_{key}"]
log_msg = (
f"Inflating the uncertainties of {key} "
+ "by a constant value of "
+ f"{add_error:.2e} (W m-2 um-1)..."
)
print(log_msg, end="", flush=True)
spec_tmp[:, 2] += add_error
print(" [DONE]")
else:
# Calculate the model spectrum
wavel_range = (0.9 * spec_tmp[0, 0], 1.1 * spec_tmp[-1, 0])
readmodel = ReadModel(model, wavel_range=wavel_range)
model_box = readmodel.get_model(
model_param,
spec_res=value[3],
wavel_resample=spec_tmp[:, 0],
)
# Inflate the uncertainties relative to
# the fluxes of the model spectrum
infl_factor = model_param[f"error_{key}"]
log_msg = (
f"Inflating the uncertainties of {key} "
+ "by a factor {infl_factor:.2f}..."
)
print(log_msg, end="", flush=True)
spec_tmp[:, 2] = np.sqrt(
spec_tmp[:, 2] ** 2 + (infl_factor * model_box.flux) ** 2
)
print(" [DONE]")
if f"radvel_{key}" in model_param:
# Shift the wavelengths of the data by
# the radial velocity in opposite direction
wavel_shift = (
-1.0
* model_param[f"radvel_{key}"]
* 1e3
* spec_tmp[:, 0]
/ constants.LIGHT
)
mean_shift = np.mean(wavel_shift) * 1e3 # (nm)
print(
f"Mean wavelength shift (nm) for {key}: {mean_shift:.2f}...",
end="",
flush=True,
)
spec_tmp[:, 0] += wavel_shift
print(" [DONE]")
# Store the spectra with the scaled fluxes and/or errors
# The other three elements (i.e. the covariance matrix,
# the inverted covariance matrix, and the spectral
# resolution) remain unaffected
objectbox.spectrum[key] = (spec_tmp, value[1], value[2], value[3])
return objectbox