Source code for species.plot.plot_mcmc

"""
Module for plotting MCMC results.
"""

import warnings

from typing import List, Optional, Tuple, Union

import h5py
import corner
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from typeguard import typechecked
from matplotlib.ticker import ScalarFormatter
from scipy.interpolate import RegularGridInterpolator
from scipy.stats import norm

from species.core import constants
from species.util.convert_util import logg_to_mass
from species.util.core_util import print_section
from species.util.plot_util import update_labels
from species.util.dust_util import (
    check_dust_database,
    interp_lognorm,
    interp_powerlaw,
    ism_extinction,
    log_normal_distribution,
    power_law_distribution,
)
from species.util.retrieval_util import (
    atomic_masses,
    calc_metal_ratio,
    get_line_species,
    mass_fractions,
    mean_molecular_weight,
)


[docs] @typechecked def plot_walkers( tag: str, nsteps: Optional[int] = None, offset: Optional[Tuple[float, float]] = None, output: Optional[str] = None, ) -> mpl.figure.Figure: """ Function to plot the step history of the walkers. Parameters ---------- tag : str Database tag with the samples. nsteps : int, None Number of steps that are plotted. All steps are plotted if the argument is set to ``None``. offset : tuple(float, float), None Offset of the x- and y-axis label. Default values are used if the arguments is set to ``None``. output : str, None Output filename for the plot. The plot is shown in an interface window if the argument is set to ``None``. Returns ------- matplotlib.figure.Figure The ``Figure`` object that can be used for further customization of the plot. """ from species.data.database import Database species_db = Database() box = species_db.get_samples(tag) if output is None: print("Plotting walkers...", end="", flush=True) else: print(f"Plotting walkers: {output}...", end="", flush=True) plt.rcParams["font.family"] = "serif" plt.rcParams["mathtext.fontset"] = "dejavuserif" samples = box.samples labels = update_labels(box.parameters) if samples.ndim == 2: raise ValueError( f"The samples of '{tag}' have only 2 dimensions " f"whereas 3 are required for plotting the walkers. " f"The plot_walkers function can only be used after " "running the MCMC with run_mcmc and not after " f"running run_ultranest or run_multinest." ) ndim = samples.shape[-1] fig = plt.figure(figsize=(6, ndim * 1.5)) gridsp = mpl.gridspec.GridSpec(ndim, 1) gridsp.update(wspace=0, hspace=0.1, left=0, right=1, bottom=0, top=1) for i in range(ndim): ax = plt.subplot(gridsp[i, 0]) if i == ndim - 1: ax.tick_params( axis="both", which="major", colors="black", labelcolor="black", direction="in", width=1, length=5, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=True, ) ax.tick_params( axis="both", which="minor", colors="black", labelcolor="black", direction="in", width=1, length=3, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=True, ) else: ax.tick_params( axis="both", which="major", colors="black", labelcolor="black", direction="in", width=1, length=5, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=False, ) ax.tick_params( axis="both", which="minor", colors="black", labelcolor="black", direction="in", width=1, length=3, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=False, ) if i == ndim - 1: ax.set_xlabel("Step number", fontsize=10) else: ax.set_xlabel("", fontsize=10) ax.set_ylabel(labels[i], fontsize=10) if offset is not None: ax.get_xaxis().set_label_coords(0.5, offset[0]) ax.get_yaxis().set_label_coords(offset[1], 0.5) else: ax.get_xaxis().set_label_coords(0.5, -0.22) ax.get_yaxis().set_label_coords(-0.09, 0.5) if nsteps is not None: ax.set_xlim(0, nsteps) for j in range(samples.shape[0]): ax.plot(samples[j, :, i], ls="-", lw=0.5, color="black", alpha=0.5) if output is None: plt.show() else: plt.savefig(output, bbox_inches="tight") print(" [DONE]") return fig
[docs] @typechecked def plot_posterior( tag: str, burnin: Optional[int] = None, title: Optional[str] = None, offset: Optional[Tuple[float, float]] = None, title_fmt: Union[str, List[str]] = ".2f", limits: Optional[List[Tuple[float, float]]] = None, max_prob: bool = False, vmr: bool = False, inc_luminosity: bool = False, inc_mass: bool = False, inc_log_mass: bool = False, inc_pt_param: bool = False, inc_loglike: bool = False, inc_abund: bool = True, output: Optional[str] = None, object_type: str = "planet", param_inc: Optional[List[str]] = None, show_priors: bool = False, ) -> mpl.figure.Figure: """ Function to plot the posterior distribution of the fitted parameters. Parameters ---------- tag : str Database tag with the samples. burnin : int, None Number of burnin steps to exclude. All samples are used if the argument is set to ``None``. title : str, None Plot title. No title is shown if the arguments is set to ``None``. offset : tuple(float, float), None Offset of the x- and y-axis label. Default values are used if the arguments is set to ``None``. title_fmt : str, list(str) Format of the titles above the 1D distributions. Either a single string, which will be used for all parameters, or a list with the title format for each parameter separately (in the order as shown in the corner plot). limits : list(tuple(float, float), ), None Axis limits of all parameters. Automatically set if the argument is set to ``None``. max_prob : bool Plot the position of the sample with the maximum posterior probability. vmr : bool Plot the volume mixing ratios (i.e. number fractions) instead of the mass fractions of the retrieved species with :class:`~species.fit.retrieval.AtmosphericRetrieval`. inc_luminosity : bool Include the log10 of the luminosity in the posterior plot as calculated from the effective temperature and radius. inc_mass : bool Include the mass in the posterior plot as calculated from the surface gravity and radius. inc_log_mass : bool Include the logarithm of the mass, :math:`\\log10{M}`, in the posterior plot, as calculated from the surface gravity and radius. inc_pt_param : bool Include the parameters of the pressure-temperature profile. Only used if the ``tag`` contains samples obtained with :class:`~species.fit.retrieval.AtmosphericRetrieval`. inc_loglike : bool Include the log10 of the likelihood as additional parameter in the corner plot. inc_abund : bool Include the abundances when retrieving free abundances with :class:`~species.fit.retrieval.AtmosphericRetrieval`. output : str, None Output filename for the plot. The plot is shown in an interface window if the argument is set to ``None``. object_type : str Object type ('planet' or 'star'). With 'planet', the radius and mass are expressed in Jupiter units. With 'star', the radius and mass are expressed in solar units. param_inc : list(str), None List with subset of parameters that will be included in the posterior plot. This parameter can also be used to change the order of the parameters in the posterior plot. All parameters will be included if the argument is set to ``None``. show_priors : bool Plot the normal priors in the diagonal panels together with the 1D marginalized posterior distributions. This will only show the priors that had a normal distribution, so those that were set with the ``normal_prior`` parameter in :class:`~species.fit.fit_model.FitModel` and :class:`~species.fit.retrieval.AtmosphericRetrieval.setup_retrieval`. Returns ------- matplotlib.figure.Figure The ``Figure`` object that can be used for further customization of the plot. """ from species.data.database import Database species_db = Database() box = species_db.get_samples(tag, burnin=burnin) samples = box.samples print_section("Plot posterior distributions") print(f"Database tag: {tag}") print(f"Object type: {object_type}") print(f"Manual parameters: {param_inc}") plt.rcParams["font.family"] = "serif" plt.rcParams["mathtext.fontset"] = "dejavuserif" if burnin is None: burnin = 0 # index_sel = [0, 1, 8, 9, 14] # samples = samples[:, index_sel] # # for i in range(13, 9, -1): # del box.parameters[i] # # del box.parameters[2] # del box.parameters[2] # del box.parameters[2] # del box.parameters[2] # del box.parameters[2] # del box.parameters[2] ndim = len(box.parameters) if not inc_pt_param and box.spectrum == "petitradtrans": pt_param = [ "tint", "t1", "t2", "t3", "alpha", "log_delta", "T_bottom", "PTslope_1", "PTslope_2", "PTslope_3", "PTslope_4", "PTslope_5", "PTslope_6", ] index_del = [] item_del = [] for i in range(100): pt_item = f"t{i}" if pt_item in box.parameters: param_index = np.argwhere(np.array(box.parameters) == pt_item)[0] index_del.append(param_index) item_del.append(pt_item) else: break for item in pt_param: if item in box.parameters and item not in item_del: param_index = np.argwhere(np.array(box.parameters) == item)[0] index_del.append(param_index) item_del.append(item) samples = np.delete(samples, index_del, axis=1) ndim -= len(index_del) for item in item_del: box.parameters.remove(item) if box.spectrum == "petitradtrans": n_line_species = box.attributes["n_line_species"] line_species = [] for i in range(n_line_species): line_species.append(box.attributes[f"line_species{i}"]) if "abund_nodes" not in box.attributes: box.attributes["abund_nodes"] = "None" if box.spectrum == "petitradtrans" and box.attributes["chemistry"] == "free": if box.attributes["abund_nodes"] == "None": box.parameters.append("c_h_ratio") box.parameters.append("o_h_ratio") box.parameters.append("c_o_ratio") ndim += 3 abund_index = {} for line_item in line_species: abund_index[line_item] = box.parameters.index(line_item) c_h_ratio = np.zeros(samples.shape[0]) o_h_ratio = np.zeros(samples.shape[0]) c_o_ratio = np.zeros(samples.shape[0]) for i, sample_item in enumerate(samples): abund_dict = {} for line_item in line_species: abund_dict[line_item] = sample_item[abund_index[line_item]] ( c_h_ratio[i], o_h_ratio[i], c_o_ratio[i], ) = calc_metal_ratio( abund_dict, line_species, ) if ( vmr and box.spectrum == "petitradtrans" and box.attributes["chemistry"] == "free" ): print("Changing mass fractions to number fractions...", end="", flush=True) # Get all available line species all_line_species = get_line_species() # Get the atomic and molecular masses masses = atomic_masses() # Create array for the updated samples updated_samples = np.zeros(samples.shape) for i, samples_item in enumerate(samples): # Initiate a dictionary for the log10 mass fraction of the metals log_x_abund = {} for param_item in box.parameters: if param_item in all_line_species: # Get the index of the parameter param_index = box.parameters.index(param_item) # Store log10 mass fraction in the dictionary log_x_abund[param_item] = samples_item[param_index] # Create a dictionary with all mass fractions, including H2 and He x_abund = mass_fractions(log_x_abund, line_species) # Calculate the mean molecular weight from the input mass fractions mmw = mean_molecular_weight(x_abund) for param_item in box.parameters: if param_item in all_line_species: # Get the index of the parameter param_index = box.parameters.index(param_item) # Overwrite the sample with the log10 number fraction samples_item[param_index] = np.log10( 10.0 ** samples_item[param_index] * mmw / masses[param_item] ) # Store the updated sample to the array updated_samples[i,] = samples_item # Overwrite the samples in the SamplesBox box.samples = updated_samples print("\nMedian parameters:") for key, value in box.median_sample.items(): print(f" - {key} = {value:.2e}") if "gauss_mean" in box.parameters: param_index = np.argwhere(np.array(box.parameters) == "gauss_mean")[0] samples[:, param_index] *= 1e3 # (um) -> (nm) if "gauss_sigma" in box.parameters: param_index = np.argwhere(np.array(box.parameters) == "gauss_sigma")[0] samples[:, param_index] *= 1e3 # (um) -> (nm) if box.prob_sample is not None: print("\nSample with highest probability:") for key, value in box.prob_sample.items(): print(f" - {key} = {value:.2e}") for item in box.parameters: if item[0:11] == "wavelength_": param_index = box.parameters.index(item) # (um) -> (nm) box.samples[:, param_index] *= 1e3 # Add [C/H], [O/H], and C/O if free abundances were retrieved if box.attributes["abund_nodes"] == "None": for param_item in box.parameters: if param_item.split("_")[0] == "H2O": samples = np.column_stack((samples, c_h_ratio, o_h_ratio, c_o_ratio)) break # Include the derived bolometric luminosity if inc_luminosity: if "teff" in box.parameters and "radius" in box.parameters: teff_index = np.argwhere(np.array(box.parameters) == "teff")[0] radius_index = np.argwhere(np.array(box.parameters) == "radius")[0] lum_planet = ( 4.0 * np.pi * (samples[..., radius_index] * constants.R_JUP) ** 2 * constants.SIGMA_SB * samples[..., teff_index] ** 4.0 / constants.L_SUN ) n_disk = 0 if "disk_teff" in box.parameters and "disk_radius" in box.parameters: n_disk = 1 else: for disk_idx in range(100): if f"disk_teff_{disk_idx}" in box.parameters and f"disk_radius_{disk_idx}" in box.parameters: n_disk += 1 else: break if n_disk == 1: teff_index = np.argwhere(np.array(box.parameters) == "disk_teff")[0] radius_index = np.argwhere(np.array(box.parameters) == "disk_radius")[0] lum_disk = ( 4.0 * np.pi * (samples[..., radius_index] * constants.R_JUP) ** 2 * constants.SIGMA_SB * samples[..., teff_index] ** 4.0 / constants.L_SUN ) samples = np.append(samples, np.log10(lum_planet + lum_disk), axis=-1) box.parameters.append("luminosity") ndim += 1 samples = np.append(samples, lum_disk / lum_planet, axis=-1) box.parameters.append("luminosity_disk_planet") ndim += 1 radius_bb = np.sqrt( lum_planet * constants.L_SUN / ( 16.0 * np.pi * constants.SIGMA_SB * samples[..., teff_index] ** 4 ) ) samples = np.append(samples, radius_bb / constants.R_JUP, axis=-1) box.parameters.append("radius_bb") ndim += 1 elif n_disk > 1: lum_disk = 0.0 for disk_idx in range(n_disk): teff_index = np.argwhere(np.array(box.parameters) == f"disk_teff_{disk_idx}")[0] radius_index = np.argwhere(np.array(box.parameters) == f"disk_radius_{disk_idx}")[0] lum_disk += ( 4.0 * np.pi * (samples[..., radius_index] * constants.R_JUP) ** 2 * constants.SIGMA_SB * samples[..., teff_index] ** 4.0 / constants.L_SUN ) radius_bb = np.sqrt( lum_planet * constants.L_SUN / ( 16.0 * np.pi * constants.SIGMA_SB * samples[..., teff_index] ** 4 ) ) samples = np.append(samples, radius_bb / constants.R_JUP, axis=-1) box.parameters.append(f"radius_bb_{disk_idx}") ndim += 1 samples = np.append(samples, np.log10(lum_planet + lum_disk), axis=-1) box.parameters.append("luminosity") ndim += 1 samples = np.append(samples, lum_disk / lum_planet, axis=-1) box.parameters.append("luminosity_disk_planet") ndim += 1 else: samples = np.append(samples, np.log10(lum_planet), axis=-1) box.parameters.append("luminosity") ndim += 1 elif "teff_0" in box.parameters and "radius_0" in box.parameters: luminosity = 0.0 for i in range(100): teff_index = np.argwhere(np.array(box.parameters) == f"teff_{i}") radius_index = np.argwhere(np.array(box.parameters) == f"radius_{i}") if len(teff_index) > 0 and len(radius_index) > 0: luminosity += ( 4.0 * np.pi * (samples[..., radius_index[0]] * constants.R_JUP) ** 2 * constants.SIGMA_SB * samples[..., teff_index[0]] ** 4.0 / constants.L_SUN ) else: break samples = np.append(samples, np.log10(luminosity), axis=-1) box.parameters.append("luminosity") ndim += 1 # teff_index = np.argwhere(np.array(box.parameters) == 'teff_0') # radius_index = np.argwhere(np.array(box.parameters) == 'radius_0') # # luminosity_0 = 4. * np.pi * (samples[..., radius_index[0]]*constants.R_JUP)**2 \ # * constants.SIGMA_SB * samples[..., teff_index[0]]**4. / constants.L_SUN # # samples = np.append(samples, np.log10(luminosity_0), axis=-1) # box.parameters.append('luminosity_0') # ndim += 1 # # teff_index = np.argwhere(np.array(box.parameters) == 'teff_1') # radius_index = np.argwhere(np.array(box.parameters) == 'radius_1') # # luminosity_1 = 4. * np.pi * (samples[..., radius_index[0]]*constants.R_JUP)**2 \ # * constants.SIGMA_SB * samples[..., teff_index[0]]**4. / constants.L_SUN # # samples = np.append(samples, np.log10(luminosity_1), axis=-1) # box.parameters.append('luminosity_1') # ndim += 1 # # teff_index_0 = np.argwhere(np.array(box.parameters) == 'teff_0') # radius_index_0 = np.argwhere(np.array(box.parameters) == 'radius_0') # # teff_index_1 = np.argwhere(np.array(box.parameters) == 'teff_1') # radius_index_1 = np.argwhere(np.array(box.parameters) == 'radius_1') # # luminosity_0 = 4. * np.pi * (samples[..., radius_index_0[0]]*constants.R_JUP)**2 \ # * constants.SIGMA_SB * samples[..., teff_index_0[0]]**4. / constants.L_SUN # # luminosity_1 = 4. * np.pi * (samples[..., radius_index_1[0]]*constants.R_JUP)**2 \ # * constants.SIGMA_SB * samples[..., teff_index_1[0]]**4. / constants.L_SUN # # samples = np.append(samples, np.log10(luminosity_0/luminosity_1), axis=-1) # box.parameters.append('luminosity_ratio') # ndim += 1 # r_tmp = samples[..., radius_index_0[0]]*constants.R_JUP # lum_diff = (luminosity_1*constants.L_SUN-luminosity_0*constants.L_SUN) # # m_mdot = (3600.*24.*365.25)*lum_diff*r_tmp/constants.GRAVITY/constants.M_JUP**2 # # samples = np.append(samples, m_mdot, axis=-1) # box.parameters.append('m_mdot') # ndim += 1 # Include the derived mass if inc_mass: if "logg" in box.parameters and "radius" in box.parameters: logg_index = np.argwhere(np.array(box.parameters) == "logg")[0] radius_index = np.argwhere(np.array(box.parameters) == "radius")[0] mass_samples = logg_to_mass( samples[..., logg_index], samples[..., radius_index] ) samples = np.append(samples, mass_samples, axis=-1) box.parameters.append("mass") ndim += 1 else: warnings.warn( "Samples with the log(g) and radius are required for 'inc_mass=True'." ) if inc_log_mass: if "logg" in box.parameters and "radius" in box.parameters: logg_index = np.argwhere(np.array(box.parameters) == "logg")[0] radius_index = np.argwhere(np.array(box.parameters) == "radius")[0] mass_samples = logg_to_mass( samples[..., logg_index], samples[..., radius_index] ) mass_samples = np.log10(mass_samples) samples = np.append(samples, mass_samples, axis=-1) box.parameters.append("log_mass") ndim += 1 else: warnings.warn( "Samples with the log(g) and radius are required for 'inc_log_mass=True'." ) # Change from Jupiter to solar units if star if "radius" in box.parameters: radius_index = np.argwhere(np.array(box.parameters) == "radius")[0] if object_type == "star": samples[:, radius_index] *= constants.R_JUP / constants.R_SUN if "mass" in box.parameters: mass_index = np.argwhere(np.array(box.parameters) == "mass")[0] if object_type == "star": samples[:, mass_index] *= constants.M_JUP / constants.M_SUN if "log_mass" in box.parameters: mass_index = np.argwhere(np.array(box.parameters) == "log_mass")[0] if object_type == "star": samples[:, mass_index] = np.log10( 10.0 ** samples[:, mass_index] * constants.M_JUP / constants.M_SUN ) if "disk_radius" in box.parameters: radius_index = np.argwhere(np.array(box.parameters) == "disk_radius")[0] if object_type == "star": samples[:, radius_index] *= constants.R_JUP / constants.AU for disk_idx in range(100): if f"disk_radius_{disk_idx}" in box.parameters: radius_index = np.argwhere(np.array(box.parameters) == f"disk_radius_{disk_idx}")[0] if object_type == "star": samples[:, radius_index] *= constants.R_JUP / constants.AU else: break if "radius_bb" in box.parameters: radius_index = np.argwhere(np.array(box.parameters) == "radius_bb")[0] if object_type == "star": samples[:, radius_index] *= constants.R_JUP / constants.AU for disk_idx in range(100): if f"radius_bb_{disk_idx}" in box.parameters: radius_index = np.argwhere(np.array(box.parameters) == f"radius_bb_{disk_idx}")[0] if object_type == "star": samples[:, radius_index] *= constants.R_JUP / constants.AU else: break # Include the log-likelihood if inc_loglike: # Get ln(L) of the samples ln_prob = box.ln_prob[..., np.newaxis] # Normalized by the maximum ln(L) ln_prob -= np.amax(ln_prob) # Convert ln(L) to log10(L) log_prob = ln_prob * np.exp(1.0) # Convert log10(L) to L prob = 10.0**log_prob # Normalize to an integrated probability of 1 prob /= np.sum(prob) samples = np.append(samples, np.log10(prob), axis=-1) box.parameters.append("log_prob") ndim += 1 # Remove abundances if ( not inc_abund and "chemistry" in box.attributes and box.attributes["chemistry"] == "free" ): index_del = [] item_del = [] if box.attributes["abund_nodes"] == "None": for line_item in line_species: param_index = np.argwhere(np.array(box.parameters) == line_item)[0] index_del.append(param_index) item_del.append(line_item) else: for line_item in line_species: for node_idx in range(box.attributes["abund_nodes"]): param_index = np.argwhere( np.array(box.parameters) == f"{line_item}_{node_idx}" )[0] index_del.append(param_index) item_del.append(f"{line_item}_{node_idx}") samples = np.delete(samples, index_del, axis=1) ndim -= len(index_del) for item in item_del: box.parameters.remove(item) # Include a subset of parameters if param_inc is not None: param_new = np.zeros((samples.shape[0], len(param_inc))) for i, item in enumerate(param_inc): if item in box.parameters: param_index = box.parameters.index(item) param_new[:, i] = samples[:, param_index] box.parameters = param_inc ndim = len(param_inc) samples = param_new # Update axes labels box_param = box.parameters.copy() labels = update_labels(box.parameters, object_type=object_type) # Check if parameter values were fixed index_sel = [] index_del = [] for i in range(ndim): if np.amin(samples[:, i]) == np.amax(samples[:, i]): index_del.append(i) else: index_sel.append(i) samples = samples[:, index_sel] for i in range(len(index_del) - 1, -1, -1): del labels[index_del[i]] ndim -= len(index_del) samples = samples.reshape((-1, ndim)) # Get parameter values of maximum likelihood if max_prob: max_idx = np.argmax(box.ln_prob) max_sample = samples[max_idx, :] if isinstance(title_fmt, list) and len(title_fmt) != ndim: raise ValueError( f"The number of items in the list of 'title_fmt' ({len(title_fmt)}) is " f"not equal to the number of dimensions of the samples ({ndim})." ) hist_titles = [] for i, item in enumerate(labels): unit_start = item.find("(") if unit_start == -1: param_label = item unit_label = None else: param_label = item[:unit_start] # Remove parenthesis from the units unit_label = item[unit_start + 1 : -1] q_16, q_50, q_84 = corner.quantile(samples[:, i], [0.16, 0.5, 0.84]) q_minus, q_plus = q_50 - q_16, q_84 - q_50 if isinstance(title_fmt, str): fmt = "{{0:{0}}}".format(title_fmt).format elif isinstance(title_fmt, list): fmt = "{{0:{0}}}".format(title_fmt[i]).format best_fit = r"${{{0}}}_{{-{1}}}^{{+{2}}}$" best_fit = best_fit.format(fmt(q_50), fmt(q_minus), fmt(q_plus)) if unit_label is None: hist_title = f"{param_label} = {best_fit}" else: hist_title = f"{param_label} = {best_fit} {unit_label}" hist_titles.append(hist_title) fig = corner.corner( samples, quantiles=[0.16, 0.5, 0.84], labels=labels, label_kwargs={"fontsize": 13}, titles=hist_titles, show_titles=True, title_fmt=None, title_kwargs={"fontsize": 12}, ) axes = np.array(fig.axes).reshape((ndim, ndim)) for i in range(ndim): for j in range(ndim): ax = axes[i, j] if show_priors and i == j and box_param[i] in box.normal_priors: norm_param = box.normal_priors[box_param[i]] x_norm = np.linspace( norm_param[0] - 5.0 * norm_param[1], norm_param[0] + 5.0 * norm_param[1], 200, ) y_norm = norm.pdf(x_norm, norm_param[0], norm_param[1]) ax.plot( x_norm, 0.9 * ax.get_ylim()[1] * y_norm / np.amax(y_norm), ls=":", lw=2.0, color="dodgerblue", ) if i >= j: ax.xaxis.set_major_formatter(ScalarFormatter(useOffset=False)) ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=False)) labelleft = j == 0 and i != 0 labelbottom = i == ndim - 1 ax.tick_params( axis="both", which="major", colors="black", labelcolor="black", direction="in", width=1, length=5, labelsize=12, top=True, bottom=True, left=True, right=True, labelleft=labelleft, labelbottom=labelbottom, labelright=False, labeltop=False, ) ax.tick_params( axis="both", which="minor", colors="black", labelcolor="black", direction="in", width=1, length=3, labelsize=12, top=True, bottom=True, left=True, right=True, labelleft=labelleft, labelbottom=labelbottom, labelright=False, labeltop=False, ) if limits is not None: ax.set_xlim(limits[j]) if max_prob: ax.axvline(max_sample[j], color="tomato") if i > j: if max_prob: ax.axhline(max_sample[i], color="tomato") ax.plot(max_sample[j], max_sample[i], "s", color="tomato") if limits is not None: ax.set_ylim(limits[i]) if offset is not None: ax.get_xaxis().set_label_coords(0.5, offset[0]) ax.get_yaxis().set_label_coords(offset[1], 0.5) else: ax.get_xaxis().set_label_coords(0.5, -0.26) ax.get_yaxis().set_label_coords(-0.27, 0.5) if title: fig.suptitle(title, y=1.02, fontsize=16) if output is not None: print(f"\nOutput: {output}") if output is None: plt.show() else: plt.savefig(output, bbox_inches="tight") return fig
[docs] @typechecked def plot_mag_posterior( tag: str, filter_name: str, burnin: Optional[int] = None, xlim: Optional[Tuple[float, float]] = None, output: Optional[str] = None, ) -> Tuple[np.ndarray, mpl.figure.Figure]: """ Function to plot the posterior distribution of the synthetic magnitudes. The posterior samples are also returned. Parameters ---------- tag : str Database tag with the posterior samples. filter_name : str Filter name. burnin : int, None Number of burnin steps to exclude. All samples are used if the argument is set to ``None``. xlim : tuple(float, float), None Axis limits. Automatically set if the argument is set to ``None``. output : str, None Output filename for the plot. The plot is shown in an interface window if the argument is set to ``None``. Returns ------- np.ndarray Array with the posterior samples of the magnitude. matplotlib.figure.Figure The ``Figure`` object that can be used for further customization of the plot. """ plt.rcParams["font.family"] = "serif" plt.rcParams["mathtext.fontset"] = "dejavuserif" from species.data.database import Database species_db = Database() samples = species_db.get_mcmc_photometry(tag, filter_name, burnin) if output is None: print("Plotting photometry samples...", end="", flush=True) else: print(f"Plotting photometry samples: {output}...", end="", flush=True) fig = corner.corner( samples, labels=["Magnitude"], quantiles=[0.16, 0.5, 0.84], label_kwargs={"fontsize": 13.0}, show_titles=True, title_kwargs={"fontsize": 12.0}, title_fmt=".2f", ) axes = np.array(fig.axes).reshape((1, 1)) ax = axes[0, 0] ax.tick_params( axis="both", which="major", colors="black", labelcolor="black", direction="in", width=1, length=5, labelsize=12, top=True, bottom=True, left=True, right=True, ) ax.tick_params( axis="both", which="minor", colors="black", labelcolor="black", direction="in", width=1, length=3, labelsize=12, top=True, bottom=True, left=True, right=True, ) if xlim is not None: ax.set_xlim(xlim) ax.get_xaxis().set_label_coords(0.5, -0.26) if output is None: plt.show() else: plt.savefig(output, bbox_inches="tight") print(" [DONE]") return samples, fig
[docs] @typechecked def plot_size_distributions( tag: str, burnin: Optional[int] = None, random: Optional[int] = None, offset: Optional[Tuple[float, float]] = None, output: Optional[str] = None, ) -> mpl.figure.Figure: """ Function to plot random samples of the log-normal or power-law size distributions. Parameters ---------- tag : str Database tag with the samples. burnin : int, None Number of burnin steps to exclude. All samples are used if the argument is set to ``None``. Only required after running MCMC with :func:`~species.fit.fit_model.FitModel.run_mcmc`. random : int, None Number of randomly selected samples. All samples are used if the argument set to ``None``. offset : tuple(float, float), None Offset of the x- and y-axis label. Default values are used if the argument set to ``None``. output : str, None Output filename for the plot. The plot is shown in an interface window if the argument is set to ``None``. Returns ------- matplotlib.figure.Figure The ``Figure`` object that can be used for further customization of the plot. """ from species.data.database import Database species_db = Database() box = species_db.get_samples(tag) if output is None: print("Plotting size distributions...", end="", flush=True) else: print(f"Plotting size distributions: {output}...", end="", flush=True) if burnin is None: burnin = 0 plt.rcParams["font.family"] = "serif" plt.rcParams["mathtext.fontset"] = "dejavuserif" if "lognorm_radius" not in box.parameters and "powerlaw_max" not in box.parameters: raise ValueError( "The SamplesBox does not contain extinction parameter " "for a log-normal or power-law size distribution." ) samples = box.samples if samples.ndim == 2 and random is not None: ran_index = np.random.randint(samples.shape[0], size=random) samples = samples[ran_index,] elif samples.ndim == 3: if burnin > samples.shape[1]: raise ValueError( f"The 'burnin' value is larger than the number of steps " f"({samples.shape[1]}) that are made by the walkers." ) samples = samples[:, burnin:, :] ran_walker = np.random.randint(samples.shape[0], size=random) ran_step = np.random.randint(samples.shape[1], size=random) samples = samples[ran_walker, ran_step, :] if "lognorm_radius" in box.parameters: log_r_index = box.parameters.index("lognorm_radius") sigma_index = box.parameters.index("lognorm_sigma") log_r_g = samples[:, log_r_index] sigma_g = samples[:, sigma_index] if "powerlaw_max" in box.parameters: r_max_index = box.parameters.index("powerlaw_max") exponent_index = box.parameters.index("powerlaw_exp") r_max = samples[:, r_max_index] exponent = samples[:, exponent_index] fig = plt.figure(figsize=(6, 3)) gridsp = mpl.gridspec.GridSpec(1, 1) gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1) ax = plt.subplot(gridsp[0, 0]) ax.tick_params( axis="both", which="major", colors="black", labelcolor="black", direction="in", width=1, length=5, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=True, ) ax.tick_params( axis="both", which="minor", colors="black", labelcolor="black", direction="in", width=1, length=3, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=True, ) ax.set_xlabel("Grain size (µm)", fontsize=12) ax.set_ylabel("dn/dr", fontsize=12) ax.set_xscale("log") if "powerlaw_max" in box.parameters: ax.set_yscale("log") if offset is not None: ax.get_xaxis().set_label_coords(0.5, offset[0]) ax.get_yaxis().set_label_coords(offset[1], 0.5) else: ax.get_xaxis().set_label_coords(0.5, -0.22) ax.get_yaxis().set_label_coords(-0.09, 0.5) for i in range(samples.shape[0]): if "lognorm_radius" in box.parameters: dn_grains, r_width, radii = log_normal_distribution( 10.0 ** log_r_g[i], sigma_g[i], 1000 ) # Exclude radii smaller than 1 nm indices = np.argwhere(radii >= 1e-3) dn_grains = dn_grains[indices] r_width = r_width[indices] radii = radii[indices] elif "powerlaw_max" in box.parameters: dn_grains, r_width, radii = power_law_distribution( exponent[i], 1e-3, 10.0 ** r_max[i], 1000 ) ax.plot(radii, dn_grains / r_width, ls="-", lw=0.5, color="black", alpha=0.5) if output is None: plt.show() else: plt.savefig(output, bbox_inches="tight") print(" [DONE]") return fig
[docs] @typechecked def plot_extinction( tag: str, burnin: Optional[int] = None, random: Optional[int] = None, wavel_range: Optional[Tuple[float, float]] = None, xlim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None, offset: Optional[Tuple[float, float]] = None, output: Optional[str] = None, ) -> mpl.figure.Figure: """ Function to plot random samples of the extinction, either from fitting a size distribution of enstatite grains (``dust_radius``, ``dust_sigma``, and ``dust_ext``), or from fitting ISM extinction (``ism_ext`` and optionally ``ism_red``). Parameters ---------- tag : str Database tag with the samples. burnin : int, None Number of burnin steps to exclude. All samples are used if the argument is set to ``None``. Only required after running MCMC with :func:`~species.fit.fit_model.FitModel.run_mcmc`. random : int, None Number of randomly selected samples. All samples are used if the argument is set to ``None``. wavel_range : tuple(float, float), None Wavelength range (um) for the extinction. The default wavelength range (0.4, 10.) is used if the argument is set to ``None``. xlim : tuple(float, float), None Limits of the wavelength axis. The range is set automatically if the argument is set to ``None``. ylim : tuple(float, float) Limits of the extinction axis. The range is set automatically if the argument is set to ``None``. offset : tuple(float, float), None Offset of the x- and y-axis label. Default values are used if the argument is set to ``None``. output : str, None Output filename for the plot. The plot is shown in an interface window if the argument is set to ``None``. Returns ------- matplotlib.figure.Figure The ``Figure`` object that can be used for further customization of the plot. """ from species.data.database import Database species_db = Database() box = species_db.get_samples(tag) if burnin is None: burnin = 0 if wavel_range is None: wavel_range = (0.4, 10.0) plt.rcParams["font.family"] = "serif" plt.rcParams["mathtext.fontset"] = "dejavuserif" samples = box.samples if samples.ndim == 2 and random is not None: ran_index = np.random.randint(samples.shape[0], size=random) samples = samples[ran_index,] elif samples.ndim == 3: if burnin > samples.shape[1]: raise ValueError( f"The 'burnin' value is larger than the number of steps " f"({samples.shape[1]}) that are made by the walkers." ) samples = samples[:, burnin:, :] ran_walker = np.random.randint(samples.shape[0], size=random) ran_step = np.random.randint(samples.shape[1], size=random) samples = samples[ran_walker, ran_step, :] fig = plt.figure(figsize=(6, 3)) gridsp = mpl.gridspec.GridSpec(1, 1) gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1) ax = plt.subplot(gridsp[0, 0]) ax.tick_params( axis="both", which="major", colors="black", labelcolor="black", direction="in", width=1, length=5, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=True, ) ax.tick_params( axis="both", which="minor", colors="black", labelcolor="black", direction="in", width=1, length=3, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=True, ) ax.set_xlabel("Wavelength (\N{GREEK SMALL LETTER MU}m)", fontsize=12) ax.set_ylabel("Extinction (mag)", fontsize=12) if xlim is not None: ax.set_xlim(xlim[0], xlim[1]) if ylim is not None: ax.set_ylim(ylim[0], ylim[1]) if offset is not None: ax.get_xaxis().set_label_coords(0.5, offset[0]) ax.get_yaxis().set_label_coords(offset[1], 0.5) else: ax.get_xaxis().set_label_coords(0.5, -0.22) ax.get_yaxis().set_label_coords(-0.09, 0.5) sample_wavel = np.linspace(wavel_range[0], wavel_range[1], 100) if ( "lognorm_radius" in box.parameters and "lognorm_sigma" in box.parameters and "lognorm_ext" in box.parameters ): cross_optical, dust_radius, dust_sigma = interp_lognorm([], []) log_r_index = box.parameters.index("lognorm_radius") sigma_index = box.parameters.index("lognorm_sigma") ext_index = box.parameters.index("lognorm_ext") log_r_g = samples[:, log_r_index] sigma_g = samples[:, sigma_index] dust_ext = samples[:, ext_index] database_path = check_dust_database() with h5py.File(database_path, "r") as h5_file: cross_section = np.asarray( h5_file["dust/lognorm/mgsio3/crystalline/cross_section"] ) wavelength = np.asarray( h5_file["dust/lognorm/mgsio3/crystalline/wavelength"] ) cross_interp = RegularGridInterpolator( (wavelength, dust_radius, dust_sigma), cross_section ) for i in range(samples.shape[0]): cross_tmp = cross_optical["Generic/Bessell.V"]( (10.0 ** log_r_g[i], sigma_g[i]) ) n_grains = dust_ext[i] / cross_tmp / 2.5 / np.log10(np.exp(1.0)) sample_cross = np.zeros(sample_wavel.shape) for j, item in enumerate(sample_wavel): sample_cross[j] = cross_interp((item, 10.0 ** log_r_g[i], sigma_g[i])) sample_ext = 2.5 * np.log10(np.exp(1.0)) * sample_cross * n_grains ax.plot(sample_wavel, sample_ext, ls="-", lw=0.5, color="black", alpha=0.5) elif ( "powerlaw_max" in box.parameters and "powerlaw_exp" in box.parameters and "powerlaw_ext" in box.parameters ): cross_optical, dust_max, dust_exp = interp_powerlaw([], []) r_max_index = box.parameters.index("powerlaw_max") exp_index = box.parameters.index("powerlaw_exp") ext_index = box.parameters.index("powerlaw_ext") r_max = samples[:, r_max_index] exponent = samples[:, exp_index] dust_ext = samples[:, ext_index] database_path = check_dust_database() with h5py.File(database_path, "r") as h5_file: cross_section = np.asarray( h5_file["dust/powerlaw/mgsio3/crystalline/cross_section"] ) wavelength = np.asarray( h5_file["dust/powerlaw/mgsio3/crystalline/wavelength"] ) cross_interp = RegularGridInterpolator( (wavelength, dust_max, dust_exp), cross_section ) for i in range(samples.shape[0]): cross_tmp = cross_optical["Generic/Bessell.V"]( (10.0 ** r_max[i], exponent[i]) ) n_grains = dust_ext[i] / cross_tmp / 2.5 / np.log10(np.exp(1.0)) sample_cross = np.zeros(sample_wavel.shape) for j, item in enumerate(sample_wavel): sample_cross[j] = cross_interp((item, 10.0 ** r_max[i], exponent[i])) sample_ext = 2.5 * np.log10(np.exp(1.0)) * sample_cross * n_grains ax.plot(sample_wavel, sample_ext, ls="-", lw=0.5, color="black", alpha=0.5) elif "ism_ext" in box.parameters: ext_index = box.parameters.index("ism_ext") ism_ext = samples[:, ext_index] if "ism_red" in box.parameters: red_index = box.parameters.index("ism_red") ism_red = samples[:, red_index] else: # Use default ISM redenning (R_V = 3.1) if ism_red was not fitted ism_red = np.full(samples.shape[0], 3.1) for i in range(samples.shape[0]): sample_ext = ism_extinction(ism_ext[i], ism_red[i], sample_wavel) ax.plot(sample_wavel, sample_ext, ls="-", lw=0.5, color="black", alpha=0.5) else: raise ValueError("The SamplesBox does not contain extinction parameters.") if output is None: print("Plotting extinction...", end="", flush=True) else: print(f"Plotting extinction: {output}...", end="", flush=True) if output is None: plt.show() else: plt.savefig(output, bbox_inches="tight") print(" [DONE]") return fig