"""
Module for plotting MCMC results.
"""
import os
import corner
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
from species.data import database
from species.util import plot_util
mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'
plt.rc('axes', edgecolor='black', linewidth=2.2)
[docs]def plot_walkers(tag,
nsteps=None,
offset=None,
output='walkers.pdf'):
"""
Function to plot the step history of the walkers.
Parameters
----------
tag : str
Database tag with the MCMC samples.
nsteps : int, None
Number of steps that are plotted.
offset : tuple(float, float), None
Offset of the x- and y-axis label. Not used if set to None.
output : str
Output filename.
Returns
-------
NoneType
None
"""
print(f'Plotting walkers: {output}...', end='', flush=True)
species_db = database.Database()
box = species_db.get_samples(tag)
samples = box.samples
labels = plot_util.update_labels(box.parameters)
ndim = samples.shape[-1]
plt.figure(1, figsize=(6, ndim*1.5))
gridsp = mpl.gridspec.GridSpec(ndim, 1)
gridsp.update(wspace=0, hspace=0.1, left=0, right=1, bottom=0, top=1)
for i in range(ndim):
ax = plt.subplot(gridsp[i, 0])
if i == ndim-1:
ax.tick_params(axis='both', which='major', colors='black', labelcolor='black',
direction='in', width=1, length=5, labelsize=12, top=True,
bottom=True, left=True, right=True, labelbottom=True)
ax.tick_params(axis='both', which='minor', colors='black', labelcolor='black',
direction='in', width=1, length=3, labelsize=12, top=True,
bottom=True, left=True, right=True, labelbottom=True)
else:
ax.tick_params(axis='both', which='major', colors='black', labelcolor='black',
direction='in', width=1, length=5, labelsize=12, top=True,
bottom=True, left=True, right=True, labelbottom=False)
ax.tick_params(axis='both', which='minor', colors='black', labelcolor='black',
direction='in', width=1, length=3, labelsize=12, top=True,
bottom=True, left=True, right=True, labelbottom=False)
if i == ndim-1:
ax.set_xlabel('Step number', fontsize=10)
else:
ax.set_xlabel('', fontsize=10)
ax.set_ylabel(labels[i], fontsize=10)
if offset is not None:
ax.get_xaxis().set_label_coords(0.5, offset[0])
ax.get_yaxis().set_label_coords(offset[1], 0.5)
else:
ax.get_xaxis().set_label_coords(0.5, -0.22)
ax.get_yaxis().set_label_coords(-0.09, 0.5)
if nsteps is not None:
ax.set_xlim(0, nsteps)
for j in range(samples.shape[0]):
ax.plot(samples[j, :, i], ls='-', lw=0.5, color='black', alpha=0.5)
plt.savefig(os.getcwd()+'/'+output, bbox_inches='tight')
plt.clf()
plt.close()
print(' [DONE]')
[docs]def plot_posterior(tag,
burnin=None,
title=None,
offset=None,
title_fmt='.2f',
limits=None,
max_posterior=False,
output='posterior.pdf'):
"""
Function to plot the posterior distribution.
Parameters
----------
tag : str
Database tag with the MCMC samples.
burnin : int, None
Number of burnin steps to exclude. All samples are used if set to None.
title : str, None
Plot title.
offset : tuple(float, float), None
Offset of the x- and y-axis label.
title_fmt : str
Format of the median and error values.
limits : tuple(tuple(float, float), ), None
Axis limits of all parameters. Automatically set if set to None.
max_posterior : bool
Plot the position of the sample with the maximum posterior probability.
output : str
Output filename.
Returns
-------
NoneType
None
"""
species_db = database.Database()
box = species_db.get_samples(tag, burnin=burnin)
print(f'Median sample:')
for key, value in box.median_sample.items():
print(f' - {key} = {value:.2f}')
samples = box.samples
if box.prob_sample is not None:
par_val = tuple(box.prob_sample.values())
print(f'Maximum posterior sample:')
for key, value in box.prob_sample.items():
print(f' - {key} = {value:.2f}')
print(f'Plotting the posterior: {output}...', end='', flush=True)
labels = plot_util.update_labels(box.parameters)
ndim = samples.shape[-1]
samples = samples.reshape((-1, ndim))
fig = corner.corner(samples, labels=labels, quantiles=[0.16, 0.5, 0.84],
label_kwargs={'fontsize': 13}, show_titles=True,
title_kwargs={'fontsize': 12}, title_fmt=title_fmt)
axes = np.array(fig.axes).reshape((ndim, ndim))
for i in range(ndim):
for j in range(ndim):
if i >= j:
ax = axes[i, j]
ax.xaxis.set_major_formatter(ScalarFormatter(useOffset=False))
ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))
if j == 0 and i != 0:
labelleft = True
else:
labelleft = False
if i == ndim-1:
labelbottom = True
else:
labelbottom = False
ax.tick_params(axis='both', which='major', colors='black', labelcolor='black',
direction='in', width=1, length=5, labelsize=12, top=True,
bottom=True, left=True, right=True, labelleft=labelleft,
labelbottom=labelbottom, labelright=False, labeltop=False)
ax.tick_params(axis='both', which='minor', colors='black', labelcolor='black',
direction='in', width=1, length=3, labelsize=12, top=True,
bottom=True, left=True, right=True, labelleft=labelleft,
labelbottom=labelbottom, labelright=False, labeltop=False)
if limits is not None:
ax.set_xlim(limits[j])
if max_posterior:
ax.axvline(par_val[j], color='tomato')
if i > j:
if max_posterior:
ax.axhline(par_val[i], color='tomato')
ax.plot(par_val[j], par_val[i], 's', color='tomato')
if limits is not None:
ax.set_ylim(limits[i])
if offset is not None:
ax.get_xaxis().set_label_coords(0.5, offset[0])
ax.get_yaxis().set_label_coords(offset[1], 0.5)
else:
ax.get_xaxis().set_label_coords(0.5, -0.26)
ax.get_yaxis().set_label_coords(-0.27, 0.5)
if title:
fig.suptitle(title, y=1.02, fontsize=16)
plt.savefig(os.getcwd()+'/'+output, bbox_inches='tight')
plt.clf()
plt.close()
print(' [DONE]')
[docs]def plot_photometry(tag,
filter_id,
burnin=None,
xlim=None,
output='photometry.pdf'):
"""
Function to plot the posterior distribution of the synthetic photometry.
Parameters
----------
tag : str
Database tag with the MCMC samples.
filter_id : str
Filter ID.
burnin : int, None
Number of burnin steps to exclude. All samples are used if set to None.
xlim : tuple(float, float), None
Axis limits. Automatically set if set to None.
output : strr
Output filename.
Returns
-------
NoneType
None
"""
species_db = database.Database()
samples = species_db.get_mcmc_photometry(tag, burnin, filter_id)
print(f'Plotting photometry samples: {output}...', end='', flush=True)
fig = corner.corner(samples, labels=['Magnitude'], quantiles=[0.16, 0.5, 0.84],
label_kwargs={'fontsize': 13}, show_titles=True,
title_kwargs={'fontsize': 12}, title_fmt='.2f')
axes = np.array(fig.axes).reshape((1, 1))
ax = axes[0, 0]
ax.tick_params(axis='both', which='major', colors='black', labelcolor='black',
direction='in', width=1, length=5, labelsize=12, top=True,
bottom=True, left=True, right=True)
ax.tick_params(axis='both', which='minor', colors='black', labelcolor='black',
direction='in', width=1, length=3, labelsize=12, top=True,
bottom=True, left=True, right=True)
if xlim is not None:
ax.set_xlim(xlim)
ax.get_xaxis().set_label_coords(0.5, -0.26)
plt.savefig(os.getcwd()+'/'+output, bbox_inches='tight')
plt.clf()
plt.close()
print(' [DONE]')