"""
Module for plotting MCMC results.
"""
import warnings
from numbers import Real
import corner
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from beartype import beartype
from beartype.typing import List, Optional, Tuple, Union
from matplotlib.ticker import ScalarFormatter
from scipy.stats import norm
from species.core import constants
from species.read.read_isochrone import ReadIsochrone
from species.read.read_model import ReadModel
from species.util.convert_util import logg_to_mass
from species.util.core_util import print_section
from species.util.plot_util import update_labels
from species.util.dust_util import (
interp_lognorm,
interp_powerlaw,
ism_extinction,
log_normal_distribution,
power_law_distribution,
)
from species.util.retrieval_util import (
atomic_masses,
calc_metal_ratio,
get_line_species,
mass_frac_dict,
mean_molecular_weight,
)
[docs]
@beartype
def plot_posterior(
tag: str,
title: Optional[str] = None,
offset: Optional[Tuple[Real, Real]] = None,
title_fmt: Union[str, List[str]] = ".2f",
limits: Optional[List[Tuple[Real, Real]]] = None,
max_prob: bool = False,
vmr: bool = False,
inc_luminosity: bool = False,
inc_mass: bool = False,
inc_log_mass: bool = False,
inc_pt_param: bool = False,
inc_loglike: bool = False,
inc_abund: bool = True,
output: Optional[str] = None,
object_type: str = "planet",
param_inc: Optional[List[str]] = None,
show_priors: bool = False,
show_grid: bool = True,
kwargs_corner: Optional[dict] = None,
) -> mpl.figure.Figure:
"""
Function to plot the posterior distribution of the
model parameters. For model grids, colored ticks show
the grid points, which can be used to check the width
of the posterior distributions relative to the spacing
between the grid points.
Parameters
----------
tag : str
Database tag with the samples.
title : str, None
Plot title. No title is shown if the arguments
is set to ``None``.
offset : tuple(float, float), None
Offset of the x- and y-axis label. Default values
are used if the arguments is set to ``None``.
title_fmt : str, list(str)
Format of the titles above the 1D distributions. Either a
single string, which will be used for all parameters, or a
list with the title format for each parameter separately
(in the order as shown in the corner plot).
limits : list(tuple(float, float), ), None
Axis limits of all parameters. Automatically set if the
argument is set to ``None``.
max_prob : bool
Plot the position of the sample with the maximum likelihood.
vmr : bool
Plot the volume mixing ratios (i.e. number fractions)
instead of the mass fractions of the retrieved species with
:class:`~species.fit.retrieval.AtmosphericRetrieval`.
inc_luminosity : bool
Include the log10 of the luminosity in the posterior plot
as calculated from the effective temperature and radius.
inc_mass : bool
Include the mass in the posterior plot as calculated
from the surface gravity and radius.
inc_log_mass : bool
Include the logarithm of the mass, :math:`\\log_{10}{M}`, in
the posterior plot, as calculated from the surface gravity
and radius.
inc_pt_param : bool
Include the parameters of the pressure-temperature profile.
Only used if the ``tag`` contains samples obtained with
:class:`~species.fit.retrieval.AtmosphericRetrieval`.
inc_loglike : bool
Include the log-likelihood, :math:`\\ln{L}`, as additional
parameter in the corner plot.
inc_abund : bool
Include the abundances when retrieving free abundances with
:class:`~species.fit.retrieval.AtmosphericRetrieval`.
output : str, None
Output filename for the plot. The plot is shown in an
interface window if the argument is set to ``None``.
object_type : str
Object type ('planet' or 'star'). With 'planet', the radius
and mass are expressed in Jupiter units. With 'star', the
radius and mass are expressed in solar units.
param_inc : list(str), None
List with subset of parameters that will be included in the
posterior plot. This parameter can also be used to change the
order of the parameters in the posterior plot. All parameters
will be included if the argument is set to ``None``.
show_priors : bool
Plot the normal priors in the diagonal panels together with the
1D marginalized posterior distributions (default: False). This
will only show the priors that had a normal distribution, so
those that were set with the ``normal_prior`` parameter in
:class:`~species.fit.fit_model.FitModel` and
:class:`~species.fit.retrieval.AtmosphericRetrieval.setup_retrieval`.
show_grid : bool
Show with lines for the grid points of the atmospheric model
on the 1D and 2D marginalized posteriors (default: True).
This parameter has only an effect for results obtained with
:class:`~species.fit.fit_model.FitModel`.
kwargs_corner : dict, None
Dictionary with keyword arguments that can be used to adjust the
parameters of the `corner() function
<https://corner.readthedocs.io/en/latest/api/>`_ of ``corner.py``.
Returns
-------
matplotlib.figure.Figure
The ``Figure`` object that can be used for further
customization of the plot.
"""
from species.data.database import Database
species_db = Database()
box = species_db.get_samples(tag)
samples = box.samples
print_section("Plot posterior distributions")
print(f"Database tag: {tag}")
print(f"Object type: {object_type}")
print(f"Manual parameters: {param_inc}\n")
if "model_type" in box.attributes:
print((f"Model type: {box.attributes['model_type']}"))
elif "spec_type" in box.attributes:
print((f"Model type: {box.attributes['spec_type']}"))
if "model_name" in box.attributes:
print((f"Model name: {box.attributes['model_name']}"))
elif "spec_name" in box.attributes:
print((f"Model type: {box.attributes['spec_name']}"))
if "sampler" in box.attributes:
print((f"Sampler: {box.attributes['sampler']}"))
print(f"\nShow priors: {show_priors}")
print(f"Show grid: {show_grid}")
# Create empty dictionary if needed
if kwargs_corner is None:
kwargs_corner = {}
plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"
# index_sel = [0, 1, 8, 9, 14]
# samples = samples[:, index_sel]
#
# for i in range(13, 9, -1):
# del box.parameters[i]
#
# del box.parameters[2]
# del box.parameters[2]
# del box.parameters[2]
# del box.parameters[2]
# del box.parameters[2]
# del box.parameters[2]
ndim = len(box.parameters)
if not inc_pt_param and box.model_name == "petitradtrans":
pt_param = [
"tint",
"t1",
"t2",
"t3",
"alpha",
"log_delta",
"T_bottom",
"PTslope_1",
"PTslope_2",
"PTslope_3",
"PTslope_4",
"PTslope_5",
"PTslope_6",
]
index_del = []
item_del = []
for i in range(100):
pt_item = f"t{i}"
if pt_item in box.parameters:
param_index = np.argwhere(np.array(box.parameters) == pt_item)[0]
index_del.append(param_index)
item_del.append(pt_item)
else:
break
for item in pt_param:
if item in box.parameters and item not in item_del:
param_index = np.argwhere(np.array(box.parameters) == item)[0]
index_del.append(param_index)
item_del.append(item)
samples = np.delete(samples, index_del, axis=1)
ndim -= len(index_del)
for item in item_del:
box.parameters.remove(item)
if box.model_name == "petitradtrans":
n_line_species = box.attributes["n_line_species"]
line_species = []
for i in range(n_line_species):
line_species.append(box.attributes[f"line_species{i}"])
if "abund_nodes" not in box.attributes:
box.attributes["abund_nodes"] = "None"
if box.model_name == "petitradtrans" and box.attributes["chemistry"] == "free":
if box.attributes["abund_nodes"] == "None":
box.parameters.append("c_h_ratio")
box.parameters.append("o_h_ratio")
box.parameters.append("c_o_ratio")
ndim += 3
abund_index = {}
for line_item in line_species:
abund_index[line_item] = box.parameters.index(line_item)
c_h_ratio = np.zeros(samples.shape[0])
o_h_ratio = np.zeros(samples.shape[0])
c_o_ratio = np.zeros(samples.shape[0])
for i, sample_item in enumerate(samples):
abund_dict = {}
for line_item in line_species:
abund_dict[line_item] = sample_item[abund_index[line_item]]
(
c_h_ratio[i],
o_h_ratio[i],
c_o_ratio[i],
) = calc_metal_ratio(
abund_dict,
line_species,
)
if (
vmr
and box.model_name == "petitradtrans"
and box.attributes["chemistry"] == "free"
):
print("Changing mass fractions to number fractions...", end="", flush=True)
# Get all available line species
all_line_species = get_line_species()
# Get the atomic and molecular masses
masses = atomic_masses()
# Create array for the updated samples
updated_samples = np.zeros(samples.shape)
for i, samples_item in enumerate(samples):
# Initiate a dictionary for the log10 mass fraction of the metals
log_x_abund = {}
for param_item in box.parameters:
if param_item in all_line_species:
# Get the index of the parameter
param_index = box.parameters.index(param_item)
# Store log10 mass fraction in the dictionary
log_x_abund[param_item] = samples_item[param_index]
# Create a dictionary with all mass fractions, including H2 and He
x_abund = mass_frac_dict(log_x_abund, line_species)
# Calculate the mean molecular weight from the input mass fractions
mmw = mean_molecular_weight(x_abund)
for param_item in box.parameters:
if param_item in all_line_species:
# Get the index of the parameter
param_index = box.parameters.index(param_item)
# Overwrite the sample with the log10 number fraction
samples_item[param_index] = np.log10(
10.0 ** samples_item[param_index] * mmw / masses[param_item]
)
# Store the updated sample to the array
updated_samples[i,] = samples_item
# Overwrite the samples in the SamplesBox
box.samples = updated_samples
print("\nMedian parameters:")
for param_key, param_value in box.median_sample.items():
if isinstance(param_value, Real):
if -0.1 < param_value < 0.1:
print(f" - {param_key} = {param_value:.2e}")
else:
print(f" - {param_key} = {param_value:.2f}")
if "gauss_mean" in box.parameters:
param_index = np.argwhere(np.array(box.parameters) == "gauss_mean")[0]
samples[:, param_index] *= 1e3 # (um) -> (nm)
if "gauss_sigma" in box.parameters:
param_index = np.argwhere(np.array(box.parameters) == "gauss_sigma")[0]
samples[:, param_index] *= 1e3 # (um) -> (nm)
if box.prob_sample is not None:
print("\nSample with the maximum likelihood:")
for param_key, param_value in box.prob_sample.items():
if isinstance(param_value, Real):
if -0.1 < param_value < 0.1:
print(f" - {param_key} = {param_value:.2e}")
else:
print(f" - {param_key} = {param_value:.2f}")
for param_item in box.parameters:
if param_item[0:11] == "wavelength_":
param_index = box.parameters.index(param_item)
# (um) -> (nm)
box.samples[:, param_index] *= 1e3
# Add [C/H], [O/H], and C/O if free abundances were retrieved
if box.attributes["abund_nodes"] == "None":
for param_item in box.parameters:
if param_item.split("_")[0] == "H2O":
samples = np.column_stack((samples, c_h_ratio, o_h_ratio, c_o_ratio))
break
# Include the derived bolometric luminosity
if inc_luminosity:
if "teff" in box.parameters and "radius" in box.parameters:
teff_index = np.argwhere(np.array(box.parameters) == "teff")[0]
radius_index = np.argwhere(np.array(box.parameters) == "radius")[0]
lum_atm = (
4.0
* np.pi
* (samples[..., radius_index] * constants.R_JUP) ** 2
* constants.SIGMA_SB
* samples[..., teff_index] ** 4.0
/ constants.L_SUN
)
n_disk = 0
if "disk_teff" in box.parameters and "disk_radius" in box.parameters:
n_disk = 1
else:
for disk_idx in range(100):
if (
f"disk_teff_{disk_idx}" in box.parameters
and f"disk_radius_{disk_idx}" in box.parameters
):
n_disk += 1
else:
break
if n_disk == 1:
teff_index = np.argwhere(np.array(box.parameters) == "disk_teff")[0]
radius_index = np.argwhere(np.array(box.parameters) == "disk_radius")[0]
lum_disk = (
4.0
* np.pi
* (samples[..., radius_index] * constants.R_JUP) ** 2
* constants.SIGMA_SB
* samples[..., teff_index] ** 4.0
/ constants.L_SUN
)
samples = np.append(samples, np.log10(lum_atm), axis=-1)
box.parameters.append("log_lum_atm")
ndim += 1
samples = np.append(samples, np.log10(lum_disk), axis=-1)
box.parameters.append("log_lum_disk")
ndim += 1
radius_bb = np.sqrt(
lum_atm
* constants.L_SUN
/ (
16.0
* np.pi
* constants.SIGMA_SB
* samples[..., teff_index] ** 4
)
)
samples = np.append(samples, radius_bb / constants.R_JUP, axis=-1)
box.parameters.append("radius_bb")
ndim += 1
elif n_disk > 1:
lum_disk = 0.0
for disk_idx in range(n_disk):
teff_index = np.argwhere(
np.array(box.parameters) == f"disk_teff_{disk_idx}"
)[0]
radius_index = np.argwhere(
np.array(box.parameters) == f"disk_radius_{disk_idx}"
)[0]
lum_disk += (
4.0
* np.pi
* (samples[..., radius_index] * constants.R_JUP) ** 2
* constants.SIGMA_SB
* samples[..., teff_index] ** 4.0
/ constants.L_SUN
)
radius_bb = np.sqrt(
lum_atm
* constants.L_SUN
/ (
16.0
* np.pi
* constants.SIGMA_SB
* samples[..., teff_index] ** 4
)
)
samples = np.append(samples, radius_bb / constants.R_JUP, axis=-1)
box.parameters.append(f"radius_bb_{disk_idx}")
ndim += 1
samples = np.append(samples, np.log10(lum_atm), axis=-1)
box.parameters.append("log_lum_atm")
ndim += 1
samples = np.append(samples, np.log10(lum_disk), axis=-1)
box.parameters.append("log_lum_disk")
ndim += 1
else:
samples = np.append(samples, np.log10(lum_atm), axis=-1)
box.parameters.append("log_lum_atm")
ndim += 1
for i in range(100):
if f"teff_{i}" in box.parameters and f"radius_{i}" in box.parameters:
teff_index = np.argwhere(np.array(box.parameters) == f"teff_{i}")
radius_index = np.argwhere(np.array(box.parameters) == f"radius_{i}")
luminosity = (
4.0
* np.pi
* (samples[..., radius_index[0]] * constants.R_JUP) ** 2
* constants.SIGMA_SB
* samples[..., teff_index[0]] ** 4.0
/ constants.L_SUN
)
samples = np.append(samples, np.log10(luminosity), axis=-1)
box.parameters.append(f"log_lum_{i}")
ndim += 1
else:
break
if "teff_0" in box.parameters and "radius_0" in box.parameters:
luminosity = 0.0
for i in range(100):
teff_index = np.argwhere(np.array(box.parameters) == f"teff_{i}")
radius_index = np.argwhere(np.array(box.parameters) == f"radius_{i}")
if len(teff_index) > 0 and len(radius_index) > 0:
luminosity += (
4.0
* np.pi
* (samples[..., radius_index[0]] * constants.R_JUP) ** 2
* constants.SIGMA_SB
* samples[..., teff_index[0]] ** 4.0
/ constants.L_SUN
)
else:
break
samples = np.append(samples, np.log10(luminosity), axis=-1)
box.parameters.append("log_lum")
ndim += 1
# teff_index = np.argwhere(np.array(box.parameters) == 'teff_0')
# radius_index = np.argwhere(np.array(box.parameters) == 'radius_0')
#
# luminosity_0 = 4. * np.pi * (samples[..., radius_index[0]]*constants.R_JUP)**2 \
# * constants.SIGMA_SB * samples[..., teff_index[0]]**4. / constants.L_SUN
#
# samples = np.append(samples, np.log10(luminosity_0), axis=-1)
# box.parameters.append('luminosity_0')
# ndim += 1
#
# teff_index = np.argwhere(np.array(box.parameters) == 'teff_1')
# radius_index = np.argwhere(np.array(box.parameters) == 'radius_1')
#
# luminosity_1 = 4. * np.pi * (samples[..., radius_index[0]]*constants.R_JUP)**2 \
# * constants.SIGMA_SB * samples[..., teff_index[0]]**4. / constants.L_SUN
#
# samples = np.append(samples, np.log10(luminosity_1), axis=-1)
# box.parameters.append('luminosity_1')
# ndim += 1
#
# teff_index_0 = np.argwhere(np.array(box.parameters) == 'teff_0')
# radius_index_0 = np.argwhere(np.array(box.parameters) == 'radius_0')
#
# teff_index_1 = np.argwhere(np.array(box.parameters) == 'teff_1')
# radius_index_1 = np.argwhere(np.array(box.parameters) == 'radius_1')
#
# luminosity_0 = 4. * np.pi * (samples[..., radius_index_0[0]]*constants.R_JUP)**2 \
# * constants.SIGMA_SB * samples[..., teff_index_0[0]]**4. / constants.L_SUN
#
# luminosity_1 = 4. * np.pi * (samples[..., radius_index_1[0]]*constants.R_JUP)**2 \
# * constants.SIGMA_SB * samples[..., teff_index_1[0]]**4. / constants.L_SUN
#
# samples = np.append(samples, np.log10(luminosity_0/luminosity_1), axis=-1)
# box.parameters.append('luminosity_ratio')
# ndim += 1
# r_tmp = samples[..., radius_index_0[0]]*constants.R_JUP
# lum_diff = (luminosity_1*constants.L_SUN-luminosity_0*constants.L_SUN)
#
# m_mdot = (3600.*24.*365.25)*lum_diff*r_tmp/constants.GRAVITY/constants.M_JUP**2
#
# samples = np.append(samples, m_mdot, axis=-1)
# box.parameters.append('m_mdot')
# ndim += 1
# Include the derived mass
if inc_mass:
check_param = False
if "logg" in box.parameters and "radius" in box.parameters:
logg_index = np.argwhere(np.array(box.parameters) == "logg")[0]
radius_index = np.argwhere(np.array(box.parameters) == "radius")[0]
mass_samples = logg_to_mass(
samples[..., logg_index], samples[..., radius_index]
)
samples = np.append(samples, mass_samples, axis=-1)
box.parameters.append("mass")
ndim += 1
check_param = True
if "logg_0" in box.parameters and "radius_0" in box.parameters:
logg_index = np.argwhere(np.array(box.parameters) == "logg_0")[0]
radius_index = np.argwhere(np.array(box.parameters) == "radius_0")[0]
mass_samples = logg_to_mass(
samples[..., logg_index], samples[..., radius_index]
)
samples = np.append(samples, mass_samples, axis=-1)
box.parameters.append("mass_0")
ndim += 1
check_param = True
if "logg_1" in box.parameters and "radius_1" in box.parameters:
logg_index = np.argwhere(np.array(box.parameters) == "logg_1")[0]
radius_index = np.argwhere(np.array(box.parameters) == "radius_1")[0]
mass_samples = logg_to_mass(
samples[..., logg_index], samples[..., radius_index]
)
samples = np.append(samples, mass_samples, axis=-1)
box.parameters.append("mass_1")
ndim += 1
check_param = True
if not check_param:
warnings.warn(
"Samples with the log(g) and radius are required for 'inc_mass=True'."
)
if inc_log_mass:
check_param = False
if "logg" in box.parameters and "radius" in box.parameters:
logg_index = np.argwhere(np.array(box.parameters) == "logg")[0]
radius_index = np.argwhere(np.array(box.parameters) == "radius")[0]
mass_samples = logg_to_mass(
samples[..., logg_index], samples[..., radius_index]
)
mass_samples = np.log10(mass_samples)
samples = np.append(samples, mass_samples, axis=-1)
box.parameters.append("log_mass")
ndim += 1
check_param = True
if "logg_0" in box.parameters and "radius_0" in box.parameters:
logg_index = np.argwhere(np.array(box.parameters) == "logg_0")[0]
radius_index = np.argwhere(np.array(box.parameters) == "radius_0")[0]
mass_samples = logg_to_mass(
samples[..., logg_index], samples[..., radius_index]
)
mass_samples = np.log10(mass_samples)
samples = np.append(samples, mass_samples, axis=-1)
box.parameters.append("log_mass_0")
ndim += 1
check_param = True
if "logg_1" in box.parameters and "radius_1" in box.parameters:
logg_index = np.argwhere(np.array(box.parameters) == "logg_1")[0]
radius_index = np.argwhere(np.array(box.parameters) == "radius_1")[0]
mass_samples = logg_to_mass(
samples[..., logg_index], samples[..., radius_index]
)
mass_samples = np.log10(mass_samples)
samples = np.append(samples, mass_samples, axis=-1)
box.parameters.append("log_mass_1")
ndim += 1
check_param = True
if not check_param:
warnings.warn(
"Samples with the log(g) and radius are required for 'inc_log_mass=True'."
)
# Change from Jupiter to solar units if star
if "radius" in box.parameters:
radius_index = np.argwhere(np.array(box.parameters) == "radius")[0]
if object_type == "star":
samples[:, radius_index] *= constants.R_JUP / constants.R_SUN
for radius_idx in range(100):
if f"radius_{radius_idx}" in box.parameters:
radius_index = np.argwhere(
np.array(box.parameters) == f"radius_{radius_idx}"
)[0]
if object_type == "star":
samples[:, radius_index] *= constants.R_JUP / constants.R_SUN
else:
break
if "mass" in box.parameters:
mass_index = np.argwhere(np.array(box.parameters) == "mass")[0]
if object_type == "star":
samples[:, mass_index] *= constants.M_JUP / constants.M_SUN
if "mass_0" in box.parameters:
mass_index = np.argwhere(np.array(box.parameters) == "mass_0")[0]
if object_type == "star":
samples[:, mass_index] *= constants.M_JUP / constants.M_SUN
if "mass_1" in box.parameters:
mass_index = np.argwhere(np.array(box.parameters) == "mass_1")[0]
if object_type == "star":
samples[:, mass_index] *= constants.M_JUP / constants.M_SUN
if "log_mass" in box.parameters:
mass_index = np.argwhere(np.array(box.parameters) == "log_mass")[0]
if object_type == "star":
samples[:, mass_index] = np.log10(
10.0 ** samples[:, mass_index] * constants.M_JUP / constants.M_SUN
)
if "log_mass_0" in box.parameters:
mass_index = np.argwhere(np.array(box.parameters) == "log_mass_0")[0]
if object_type == "star":
samples[:, mass_index] = np.log10(
10.0 ** samples[:, mass_index] * constants.M_JUP / constants.M_SUN
)
if "log_mass_1" in box.parameters:
mass_index = np.argwhere(np.array(box.parameters) == "log_mass_1")[0]
if object_type == "star":
samples[:, mass_index] = np.log10(
10.0 ** samples[:, mass_index] * constants.M_JUP / constants.M_SUN
)
if "disk_radius" in box.parameters:
radius_index = np.argwhere(np.array(box.parameters) == "disk_radius")[0]
if object_type == "star":
samples[:, radius_index] *= constants.R_JUP / constants.AU
for disk_idx in range(100):
if f"disk_radius_{disk_idx}" in box.parameters:
radius_index = np.argwhere(
np.array(box.parameters) == f"disk_radius_{disk_idx}"
)[0]
if object_type == "star":
samples[:, radius_index] *= constants.R_JUP / constants.AU
else:
break
if "radius_bb" in box.parameters:
radius_index = np.argwhere(np.array(box.parameters) == "radius_bb")[0]
if object_type == "star":
samples[:, radius_index] *= constants.R_JUP / constants.AU
for disk_idx in range(100):
if f"radius_bb_{disk_idx}" in box.parameters:
radius_index = np.argwhere(
np.array(box.parameters) == f"radius_bb_{disk_idx}"
)[0]
if object_type == "star":
samples[:, radius_index] *= constants.R_JUP / constants.AU
else:
break
# Include the log-likelihood, ln(L)
if inc_loglike:
# Get ln(L) of the samples
ln_prob = box.ln_prob[..., np.newaxis]
# Normalized by the maximum ln(L)
# ln_prob -= np.amax(ln_prob)
# Convert ln(L) to log10(L)
# log_prob = ln_prob * np.exp(1.0)
# Convert log10(L) to L
# prob = 10.0**log_prob
# Normalize to an integrated probability of 1
# prob /= np.sum(prob)
# samples = np.append(samples, np.log10(prob), axis=-1)
# box.parameters.append("log_prob")
# ndim += 1
samples = np.append(samples, ln_prob, axis=-1)
box.parameters.append("ln_prob")
ndim += 1
# Remove abundances
if (
not inc_abund
and "chemistry" in box.attributes
and box.attributes["chemistry"] == "free"
):
index_del = []
item_del = []
if box.attributes["abund_nodes"] == "None":
for line_item in line_species:
param_index = np.argwhere(np.array(box.parameters) == line_item)[0]
index_del.append(param_index)
item_del.append(line_item)
else:
for line_item in line_species:
for node_idx in range(box.attributes["abund_nodes"]):
param_index = np.argwhere(
np.array(box.parameters) == f"{line_item}_{node_idx}"
)[0]
index_del.append(param_index)
item_del.append(f"{line_item}_{node_idx}")
samples = np.delete(samples, index_del, axis=1)
ndim -= len(index_del)
for item in item_del:
box.parameters.remove(item)
# Include a subset of parameters
if param_inc is not None:
param_new = np.zeros((samples.shape[0], len(param_inc)))
param_inc_new = []
for param_idx, param_item in enumerate(param_inc):
if param_item in box.parameters:
param_index = box.parameters.index(param_item)
param_new[:, param_idx] = samples[:, param_index]
param_inc_new.append(param_item)
box.parameters = param_inc_new
ndim = len(param_inc_new)
samples = param_new
# Only for fitting evolutionary models
# Remove index from parameter names when fitting 1 planet
if "model_type" in box.attributes and box.attributes["model_type"] == "evolution":
if box.attributes["n_planets"] == 1:
param_copy = box.parameters.copy()
box.parameters = []
for param_item in param_copy:
if param_item[-2:] == "_0":
box.parameters.append(param_item[:-2])
else:
box.parameters.append(param_item)
# Parameters to be included in the corner plot
print("\nParameters included in corner plot:")
for param_item in box.parameters:
print(f" - {param_item}")
# Update axes labels
box_param = box.parameters.copy()
labels = update_labels(box.parameters, object_type=object_type)
# Check if parameter values were fixed
index_sel = []
index_del = []
for i in range(ndim):
if np.amin(samples[:, i]) == np.amax(samples[:, i]):
index_del.append(i)
else:
index_sel.append(i)
samples = samples[:, index_sel]
for i in range(len(index_del) - 1, -1, -1):
del labels[index_del[i]]
ndim -= len(index_del)
samples = samples.reshape((-1, ndim))
# Get parameter values of maximum likelihood
if max_prob:
max_idx = np.argmax(box.ln_prob)
max_sample = samples[max_idx,]
if isinstance(title_fmt, list) and len(title_fmt) != ndim:
raise ValueError(
f"The number of items in the list of 'title_fmt' ({len(title_fmt)}) is "
f"not equal to the number of dimensions of the samples ({ndim})."
)
hist_titles = []
for i, item in enumerate(labels):
unit_start = item.find("(")
if unit_start == -1:
param_label = item
unit_label = None
else:
param_label = item[:unit_start]
# Remove parenthesis from the units
unit_label = item[unit_start + 1 : -1]
q_16, q_50, q_84 = corner.quantile(samples[:, i], [0.16, 0.5, 0.84])
q_minus, q_plus = q_50 - q_16, q_84 - q_50
if isinstance(title_fmt, str):
fmt = "{{0:{0}}}".format(title_fmt).format
elif isinstance(title_fmt, list):
fmt = "{{0:{0}}}".format(title_fmt[i]).format
best_fit = r"${{{0}}}_{{-{1}}}^{{+{2}}}$"
best_fit = best_fit.format(fmt(q_50), fmt(q_minus), fmt(q_plus))
if unit_label is None:
hist_title = f"{param_label} = {best_fit}"
else:
hist_title = f"{param_label} = {best_fit} {unit_label}"
hist_titles.append(hist_title)
# Create corner plot
fig = corner.corner(
samples,
quantiles=[0.16, 0.5, 0.84],
labels=labels,
label_kwargs={"fontsize": 13},
titles=hist_titles,
show_titles=True,
title_fmt=None,
title_kwargs={"fontsize": 12},
**kwargs_corner,
)
axes = np.array(fig.axes).reshape((ndim, ndim))
for i in range(ndim):
for j in range(ndim):
ax = axes[i, j]
if show_priors and i == j and box_param[i] in box.normal_priors:
norm_param = box.normal_priors[box_param[i]]
x_norm = np.linspace(
norm_param[0] - 5.0 * norm_param[1],
norm_param[0] + 5.0 * norm_param[1],
200,
)
y_norm = norm.pdf(x_norm, norm_param[0], norm_param[1])
ax.plot(
x_norm,
0.9 * ax.get_ylim()[1] * y_norm / np.amax(y_norm),
ls=":",
lw=2.0,
color="dodgerblue",
)
if i >= j:
ax.xaxis.set_major_formatter(ScalarFormatter(useOffset=False))
ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))
labelleft = j == 0 and i != 0
labelbottom = i == ndim - 1
ax.tick_params(
axis="both",
which="major",
colors="black",
labelcolor="black",
direction="in",
width=1,
length=5,
labelsize=12,
top=True,
bottom=True,
left=True,
right=True,
labelleft=labelleft,
labelbottom=labelbottom,
labelright=False,
labeltop=False,
)
ax.tick_params(
axis="both",
which="minor",
colors="black",
labelcolor="black",
direction="in",
width=1,
length=3,
labelsize=12,
top=True,
bottom=True,
left=True,
right=True,
labelleft=labelleft,
labelbottom=labelbottom,
labelright=False,
labeltop=False,
)
if limits is not None:
ax.set_xlim(limits[j])
if max_prob:
ax.axvline(max_sample[j], color="tomato")
if i > j:
if max_prob:
ax.axhline(max_sample[i], color="tomato")
ax.plot(max_sample[j], max_sample[i], "s", color="tomato")
if limits is not None:
ax.set_ylim(limits[i])
if offset is not None:
ax.get_xaxis().set_label_coords(0.5, offset[0])
ax.get_yaxis().set_label_coords(offset[1], 0.5)
else:
ax.get_xaxis().set_label_coords(0.5, -0.26)
ax.get_yaxis().set_label_coords(-0.27, 0.5)
if show_grid:
if (
box.attributes["model_type"] == "atmosphere"
and box.attributes["model_name"] != "petitradtrans"
):
read_model = ReadModel(box.attributes["model_name"])
grid_points = read_model.get_points()
elif (
box.attributes["model_type"] == "evolution"
and box.attributes["regular_grid"] is True
):
read_model = ReadIsochrone(box.attributes["model_name"])
grid_points = read_model.get_points()
else:
grid_points = None
if grid_points is not None:
for i in range(ndim):
for j in range(ndim):
ax = axes[i, j]
if (i == j or i > j) and box_param[j] in grid_points:
ax_ymin, ax_ymax = ax.get_ylim()
y_span = ax_ymax - ax_ymin
ax.vlines(
grid_points[box_param[j]],
ymin=ax_ymin,
ymax=ax_ymin + 0.08 * y_span,
colors="cadetblue",
linestyles="solid",
linewidth=0.8,
zorder=1,
)
ax.vlines(
grid_points[box_param[j]],
ymin=ax_ymax - 0.08 * y_span,
ymax=ax_ymax,
colors="cadetblue",
linestyles="solid",
linewidth=0.8,
zorder=1,
)
if i > j and box_param[i] in grid_points:
ax_xmin, ax_xmax = ax.get_xlim()
x_span = ax_xmax - ax_xmin
ax.hlines(
grid_points[box_param[i]],
xmin=ax_xmin,
xmax=ax_xmin + 0.08 * x_span,
colors="cadetblue",
linestyles="solid",
linewidth=0.8,
zorder=1,
)
ax.hlines(
grid_points[box_param[i]],
xmin=ax_xmax - 0.08 * x_span,
xmax=ax_xmax,
colors="cadetblue",
linestyles="solid",
linewidth=0.8,
zorder=1,
)
if title:
fig.suptitle(title, y=1.02, fontsize=16)
if output is None:
plt.show()
else:
print(f"\nOutput: {output}")
plt.savefig(output, bbox_inches="tight")
return fig
[docs]
@beartype
def plot_mag_posterior(
tag: str,
filter_name: str,
n_samples: Optional[int] = None,
xlim: Optional[Tuple[Real, Real]] = None,
output: Optional[str] = None,
) -> Tuple[np.ndarray, mpl.figure.Figure]:
"""
Function to plot the posterior distribution of the synthetic
magnitudes. The posterior samples are also returned.
Parameters
----------
tag : str
Database tag with the posterior samples.
filter_name : str
Filter name.
n_samples : int, None
Number of randomly drawn samples. All samples of the posterior
are selected if the arguments is set to ``None``.
xlim : tuple(float, float), None
Axis limits. Automatically set if the argument is
set to ``None``.
output : str, None
Output filename for the plot. The plot is shown in an
interface window if the argument is set to ``None``.
Returns
-------
np.ndarray
Array with the posterior samples of the magnitude.
matplotlib.figure.Figure
The ``Figure`` object that can be used for further
customization of the plot.
"""
plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"
if output is None:
print("Plotting photometry samples...", end="", flush=True)
else:
print(f"Plotting photometry samples: {output}...", end="", flush=True)
from species.data.database import Database
species_db = Database()
samples = species_db.get_mcmc_photometry(tag, filter_name, random=n_samples)
fig = corner.corner(
samples,
labels=["Magnitude"],
quantiles=[0.16, 0.5, 0.84],
label_kwargs={"fontsize": 13.0},
show_titles=True,
title_kwargs={"fontsize": 12.0},
title_fmt=".2f",
)
axes = np.array(fig.axes).reshape((1, 1))
ax = axes[0, 0]
ax.tick_params(
axis="both",
which="major",
colors="black",
labelcolor="black",
direction="in",
width=1,
length=5,
labelsize=12,
top=True,
bottom=True,
left=True,
right=True,
)
ax.tick_params(
axis="both",
which="minor",
colors="black",
labelcolor="black",
direction="in",
width=1,
length=3,
labelsize=12,
top=True,
bottom=True,
left=True,
right=True,
)
if xlim is not None:
ax.set_xlim(xlim)
ax.get_xaxis().set_label_coords(0.5, -0.26)
if output is None:
plt.show()
else:
plt.savefig(output, bbox_inches="tight")
print(" [DONE]")
return samples, fig
[docs]
@beartype
def plot_size_distributions(
tag: str,
random: Optional[int] = None,
offset: Optional[Tuple[Real, Real]] = None,
output: Optional[str] = None,
) -> mpl.figure.Figure:
"""
Function to plot random samples of the log-normal
or power-law size distributions.
Parameters
----------
tag : str
Database tag with the samples.
random : int, None
Number of randomly selected samples. All samples are used
if the argument set to ``None``.
offset : tuple(float, float), None
Offset of the x- and y-axis label. Default values are used
if the argument set to ``None``.
output : str, None
Output filename for the plot. The plot is shown in an
interface window if the argument is set to ``None``.
Returns
-------
matplotlib.figure.Figure
The ``Figure`` object that can be used for further
customization of the plot.
"""
from species.data.database import Database
species_db = Database()
box = species_db.get_samples(tag)
print_section("Plot size distributions")
plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"
if "lognorm_ext" not in box.parameters and "powerlaw_ext" not in box.parameters:
raise ValueError(
"The SamplesBox does not contain extinction parameter "
"for a log-normal or power-law size distribution."
)
samples = box.samples
if random is not None:
ran_index = np.random.randint(samples.shape[0], size=random)
samples = samples[ran_index,]
print(f"Database tag: {tag}")
print(f"Number of samples: {samples.shape[0]}")
print(f"\nLabel offset: {offset}")
if "lognorm_ext" in box.parameters:
log_r_index = box.parameters.index("lognorm_radius")
sigma_index = box.parameters.index("lognorm_sigma")
log_r_g = samples[:, log_r_index]
sigma_g = samples[:, sigma_index]
if "powerlaw_ext" in box.parameters:
r_max_index = box.parameters.index("powerlaw_max")
exponent_index = box.parameters.index("powerlaw_exp")
r_max = samples[:, r_max_index]
exponent = samples[:, exponent_index]
fig = plt.figure(figsize=(6, 3))
gridsp = mpl.gridspec.GridSpec(1, 1)
gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
ax = plt.subplot(gridsp[0, 0])
ax.tick_params(
axis="both",
which="major",
colors="black",
labelcolor="black",
direction="in",
width=1,
length=5,
labelsize=12,
top=True,
bottom=True,
left=True,
right=True,
labelbottom=True,
)
ax.tick_params(
axis="both",
which="minor",
colors="black",
labelcolor="black",
direction="in",
width=1,
length=3,
labelsize=12,
top=True,
bottom=True,
left=True,
right=True,
labelbottom=True,
)
ax.set_xlabel("Grain size (µm)", fontsize=12)
ax.set_ylabel("dn/dr", fontsize=12)
ax.set_xscale("log")
if "powerlaw_ext" in box.parameters:
ax.set_yscale("log")
if offset is not None:
ax.get_xaxis().set_label_coords(0.5, offset[0])
ax.get_yaxis().set_label_coords(offset[1], 0.5)
for i in range(samples.shape[0]):
if "lognorm_ext" in box.parameters:
dn_grains, r_width, radii = log_normal_distribution(
10.0 ** log_r_g[i], sigma_g[i], 1000
)
# Exclude radii smaller than 1 nm
indices = np.argwhere(radii >= 1e-3)
dn_grains = dn_grains[indices]
r_width = r_width[indices]
radii = radii[indices]
elif "powerlaw_ext" in box.parameters:
dn_grains, r_width, radii = power_law_distribution(
exponent[i], 1e-3, 10.0 ** r_max[i], 1000
)
ax.plot(radii, dn_grains / r_width, ls="-", lw=0.5, color="tab:gray", alpha=0.5)
if output is None:
plt.show()
else:
print(f"Output: {output}")
plt.savefig(output, bbox_inches="tight")
return fig
[docs]
@beartype
def plot_extinction(
tag: str,
random: Optional[int] = None,
wavel_range: Optional[Tuple[Real, Real]] = None,
xlim: Optional[Tuple[Real, Real]] = None,
ylim: Optional[Tuple[Real, Real]] = None,
offset: Optional[Tuple[Real, Real]] = None,
output: Optional[str] = None,
) -> mpl.figure.Figure:
"""
Function to plot random samples of the extinction.
Parameters
----------
tag : str
Database tag with the samples.
random : int, None
Number of randomly selected samples. All samples are used if
the argument is set to ``None``.
wavel_range : tuple(float, float), None
Wavelength range (um) for the extinction. The default
wavelength range (0.4, 10.) is used if the argument is
set to ``None``.
xlim : tuple(float, float), None
Limits of the wavelength axis. The range is set automatically
if the argument is set to ``None``.
ylim : tuple(float, float)
Limits of the extinction axis. The range is set automatically
if the argument is set to ``None``.
offset : tuple(float, float), None
Offset of the x- and y-axis label. Default values are used
if the argument is set to ``None``.
output : str, None
Output filename for the plot. The plot is shown in an
interface window if the argument is set to ``None``.
Returns
-------
matplotlib.figure.Figure
The ``Figure`` object that can be used for further
customization of the plot.
"""
from species.data.database import Database
species_db = Database()
samples_box = species_db.get_samples(tag)
samples = samples_box.samples
print_section("Plot extinction")
plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"
if random is not None:
ran_index = np.random.randint(samples.shape[0], size=random)
samples = samples[ran_index,]
if wavel_range is None:
wavel_range = (0.4, 10.0)
print(f"Database tag: {tag}")
print(f"Number of samples: {samples.shape[0]}")
print(f"\nWavelength range: {wavel_range}")
print(f"Label offset: {offset}")
fig = plt.figure(figsize=(6, 3))
gridsp = mpl.gridspec.GridSpec(1, 1)
gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
ax = plt.subplot(gridsp[0, 0])
ax.tick_params(
axis="both",
which="major",
colors="black",
labelcolor="black",
direction="in",
width=1,
length=5,
labelsize=12,
top=True,
bottom=True,
left=True,
right=True,
labelbottom=True,
)
ax.tick_params(
axis="both",
which="minor",
colors="black",
labelcolor="black",
direction="in",
width=1,
length=3,
labelsize=12,
top=True,
bottom=True,
left=True,
right=True,
labelbottom=True,
)
ax.set_xlabel("Wavelength (µm)", fontsize=12)
ax.set_ylabel("Extinction (mag)", fontsize=12)
if xlim is not None:
ax.set_xlim(xlim[0], xlim[1])
if ylim is not None:
ax.set_ylim(ylim[0], ylim[1])
if offset is not None:
ax.get_xaxis().set_label_coords(0.5, offset[0])
ax.get_yaxis().set_label_coords(offset[1], 0.5)
sample_wavel = np.linspace(wavel_range[0], wavel_range[1], 100)
if "lognorm_ext" in samples_box.parameters:
dust_interp, _, _ = interp_lognorm(verbose=False)
log_r_index = samples_box.parameters.index("lognorm_radius")
sigma_index = samples_box.parameters.index("lognorm_sigma")
ext_index = samples_box.parameters.index("lognorm_ext")
for i in range(samples.shape[0]):
cross_sections = dust_interp(
(sample_wavel, 10.0 ** samples[i, log_r_index], samples[i, sigma_index])
)
sample_ext = -2.5 * np.log10(
np.exp(-samples[i, ext_index] * cross_sections)
)
ax.plot(sample_wavel, sample_ext, ls="-", lw=0.5, color="black", alpha=0.5)
elif "powerlaw_ext" in samples_box.parameters:
dust_interp, _, _ = interp_powerlaw(verbose=False)
r_max_index = samples_box.parameters.index("powerlaw_max")
exp_index = samples_box.parameters.index("powerlaw_exp")
ext_index = samples_box.parameters.index("powerlaw_ext")
for i in range(samples.shape[0]):
cross_sections = dust_interp(
(sample_wavel, 10.0 ** samples[i, r_max_index], samples[i, exp_index])
)
sample_ext = -2.5 * np.log10(
np.exp(-samples[i, ext_index] * cross_sections)
)
ax.plot(sample_wavel, sample_ext, ls="-", lw=0.5, color="black", alpha=0.5)
elif "ext_av" in samples_box.parameters:
ext_index = samples_box.parameters.index("ext_av")
ext_av = samples[:, ext_index]
if "ext_rv" in samples_box.parameters:
rv_index = samples_box.parameters.index("ext_rv")
ext_rv = samples[:, rv_index]
else:
# Use default ISM redenning (R_V = 3.1) if ext_rv was not fitted
ext_rv = np.full(samples.shape[0], 3.1)
for i in range(samples.shape[0]):
sample_ext = ism_extinction(ext_av[i], ext_rv[i], sample_wavel)
ax.plot(sample_wavel, sample_ext, ls="-", lw=0.5, color="black", alpha=0.5)
elif "ism_ext" in samples_box.parameters:
ext_index = samples_box.parameters.index("ism_ext")
ism_ext = samples[:, ext_index]
if "ism_red" in samples_box.parameters:
red_index = samples_box.parameters.index("ism_red")
ism_red = samples[:, red_index]
else:
# Use default ISM redenning (R_V = 3.1) if ism_red was not fitted
ism_red = np.full(samples.shape[0], 3.1)
for i in range(samples.shape[0]):
sample_ext = ism_extinction(ism_ext[i], ism_red[i], sample_wavel)
ax.plot(sample_wavel, sample_ext, ls="-", lw=0.5, color="black", alpha=0.5)
else:
raise ValueError("The SamplesBox does not contain extinction parameters.")
if output is None:
plt.show()
else:
print(f"Output: {output}")
plt.savefig(output, bbox_inches="tight")
return fig