Source code for species.plot.plot_spectrum

Module with a function for plotting spectra.

import os
import math
import itertools

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from species.core import box, constants
from import read_filter
from species.util import plot_util

mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams[''] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)
plt.rcParams['axes.axisbelow'] = False

[docs]def plot_spectrum(boxes, filters=None, residuals=None, colors=None, xlim=None, ylim=None, ylim_res=None, scale=('linear', 'linear'), title=None, offset=None, legend=None, figsize=(7., 5.), object_type='planet', quantity='flux', output='spectrum.pdf'): """ Parameters ---------- boxes : list(, ) Boxes with data. filters : list(str, ), None Filter IDs for which the transmission profile is plotted. Not plotted if set to None. residuals :, None Box with residuals of a fit. Not plotted if set to None. colors : list(str, ), None Colors to be used for the different boxes. Note that a box with residuals requires a tuple with two colors (i.e., for the photometry and spectrum). Automatic colors are used if set to None. xlim : tuple(float, float) Limits of the wavelength axis. ylim : tuple(float, float) Limits of the flux axis. ylim_res : tuple(float, float), None Limits of the residuals axis. Automatically chosen (based on the minimum and maximum residual value) if set to None. scale : tuple(str, str) Scale of the axes ('linear' or 'log'). title : str Title. offset : tuple(float, float) Offset for the label of the x- and y-axis. legend : str, tuple, dict, None Location of the legend (str, tuple) or a dictionary with the ``**kwargs`` of ``matplotlib.pyplot.legend``, e.g. ``{'loc': 'upper left', 'fontsize: 12.}``. figsize : tuple(float, float) Figure size. object_type : str Object type ('planet' or 'star'). With 'planet', the radius and mass are expressed in Jupiter units. With 'star', the radius and mass are expressed in solar units. quantity: str The quantity of the y-axis ('flux' or 'magnitude'). output : str Output filename. Returns ------- NoneType None """ marker = itertools.cycle(('o', 's', '*', 'p', '<', '>', 'P', 'v', '^')) if colors is not None and len(boxes) != len(colors): raise ValueError(f'The number of \'boxes\' ({len(boxes)}) should be equal to the ' f'number of \'colors\' ({len(colors)}).') if residuals is not None and filters is not None: plt.figure(1, figsize=figsize) gridsp = mpl.gridspec.GridSpec(3, 1, height_ratios=[1, 3, 1]) gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1) ax1 = plt.subplot(gridsp[1, 0]) ax2 = plt.subplot(gridsp[0, 0]) ax3 = plt.subplot(gridsp[2, 0]) elif residuals is not None: plt.figure(1, figsize=figsize) gridsp = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1]) gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1) ax1 = plt.subplot(gridsp[0, 0]) ax3 = plt.subplot(gridsp[1, 0]) elif filters is not None: plt.figure(1, figsize=figsize) gridsp = mpl.gridspec.GridSpec(2, 1, height_ratios=[1, 4]) gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1) ax1 = plt.subplot(gridsp[1, 0]) ax2 = plt.subplot(gridsp[0, 0]) else: plt.figure(1, figsize=figsize) gridsp = mpl.gridspec.GridSpec(1, 1) gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1) ax1 = plt.subplot(gridsp[0, 0]) if residuals is not None: labelbottom = False else: labelbottom = True ax1.tick_params(axis='both', which='major', colors='black', labelcolor='black', direction='in', width=1, length=5, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=labelbottom) ax1.tick_params(axis='both', which='minor', colors='black', labelcolor='black', direction='in', width=1, length=3, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=labelbottom) if filters is not None: ax2.tick_params(axis='both', which='major', colors='black', labelcolor='black', direction='in', width=1, length=5, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=False) ax2.tick_params(axis='both', which='minor', colors='black', labelcolor='black', direction='in', width=1, length=3, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=False) if residuals is not None: ax3.tick_params(axis='both', which='major', colors='black', labelcolor='black', direction='in', width=1, length=5, labelsize=12, top=True, bottom=True, left=True, right=True) ax3.tick_params(axis='both', which='minor', colors='black', labelcolor='black', direction='in', width=1, length=3, labelsize=12, top=True, bottom=True, left=True, right=True) if residuals is not None and filters is not None: ax1.set_xlabel('', fontsize=13) ax2.set_xlabel('', fontsize=13) ax3.set_xlabel(r'Wavelength ($\mu$m)', fontsize=13) elif residuals is not None: ax1.set_xlabel('', fontsize=13) ax3.set_xlabel(r'Wavelength ($\mu$m)', fontsize=13) elif filters is not None: ax1.set_xlabel(r'Wavelength ($\mu$m)', fontsize=13) ax2.set_xlabel('', fontsize=13) else: ax1.set_xlabel(r'Wavelength ($\mu$m)', fontsize=13) if filters is not None: ax2.set_ylabel('Transmission', fontsize=13) if residuals is not None: ax3.set_ylabel(r'Residual ($\sigma$)', fontsize=13) if xlim is not None: ax1.set_xlim(xlim[0], xlim[1]) else: ax1.set_xlim(0.6, 6.) if quantity == 'magnitude': scaling = 1. ax1.set_ylabel('Flux contrast (mag)', fontsize=13) if ylim: ax1.set_ylim(ylim[0], ylim[1]) elif quantity == 'flux': if ylim: ax1.set_ylim(ylim[0], ylim[1]) ylim = ax1.get_ylim() exponent = math.floor(math.log10(ylim[1])) scaling = 10.**exponent ylabel = r'Flux (10$^{'+str(exponent)+r'}$ W m$^{-2}$ $\mu$m$^{-1}$)' ax1.set_ylabel(ylabel, fontsize=13) ax1.set_ylim(ylim[0]/scaling, ylim[1]/scaling) if ylim[0] < 0.: ax1.axhline(0.0, linestyle='--', color='gray', dashes=(2, 4), zorder=0.5) else: ax1.set_ylabel(r'Flux (W m$^{-2}$ $\mu$m$^{-1}$)', fontsize=13) scaling = 1. if filters is not None: ax2.set_ylim(0., 1.) xlim = ax1.get_xlim() if filters is not None: ax2.set_xlim(xlim[0], xlim[1]) if residuals is not None: ax3.set_xlim(xlim[0], xlim[1]) if offset is not None and residuals is not None and filters is not None: ax3.get_xaxis().set_label_coords(0.5, offset[0]) ax1.get_yaxis().set_label_coords(offset[1], 0.5) ax2.get_yaxis().set_label_coords(offset[1], 0.5) ax3.get_yaxis().set_label_coords(offset[1], 0.5) elif offset is not None and filters is not None: ax1.get_xaxis().set_label_coords(0.5, offset[0]) ax1.get_yaxis().set_label_coords(offset[1], 0.5) ax2.get_yaxis().set_label_coords(offset[1], 0.5) elif offset is not None and residuals is not None: ax3.get_xaxis().set_label_coords(0.5, offset[0]) ax1.get_yaxis().set_label_coords(offset[1], 0.5) ax3.get_yaxis().set_label_coords(offset[1], 0.5) elif offset is not None: ax1.get_xaxis().set_label_coords(0.5, offset[0]) ax1.get_yaxis().set_label_coords(offset[1], 0.5) else: ax1.get_xaxis().set_label_coords(0.5, -0.12) ax1.get_yaxis().set_label_coords(-0.1, 0.5) ax1.set_xscale(scale[0]) ax1.set_yscale(scale[1]) if filters is not None: ax2.set_xscale(scale[0]) if residuals is not None: ax3.set_xscale(scale[0]) color_obj_phot = None color_obj_spec = None for j, boxitem in enumerate(boxes): if isinstance(boxitem, (box.SpectrumBox, box.ModelBox)): wavelength = boxitem.wavelength flux = boxitem.flux if isinstance(wavelength[0], (np.float32, np.float64)): data = np.array(flux, dtype=np.float64) masked =, mask=np.isnan(data)) if isinstance(boxitem, box.ModelBox): param = boxitem.parameters par_key, par_unit, par_label = plot_util.quantity_unit( param=list(param.keys()), object_type=object_type) label = '' newline = False for i, item in enumerate(par_key): if item == 'teff': value = f'{param[item]:.1f}' elif item in ['logg', 'feh', 'co', 'fsed']: value = f'{param[item]:.2f}' elif item == 'radius': if object_type == 'planet': value = f'{param[item]:.2f}' elif object_type == 'star': value = f'{param[item]*constants.R_JUP/constants.R_SUN:.2f}' elif item == 'mass': if object_type == 'planet': value = f'{param[item]:.2f}' elif object_type == 'star': value = f'{param[item]*constants.M_JUP/constants.M_SUN:.2f}' elif item == 'luminosity': value = f'{param[item]:.1e}' else: continue # if len(label) > 110 and newline == False: # label += '\n' # newline = True if par_unit[i] is None: label += f'{par_label[i]} = {value}' else: label += f'{par_label[i]} = {value} {par_unit[i]}' if i < len(par_key)-1: label += ', ' else: label = None if colors: ax1.plot(wavelength, masked/scaling, color=colors[j], lw=0.5, label=label, zorder=2) else: ax1.plot(wavelength, masked/scaling, lw=0.5, label=label, zorder=2) elif isinstance(wavelength[0], (np.ndarray)): for i, item in enumerate(wavelength): data = np.array(flux[i], dtype=np.float64) masked =, mask=np.isnan(data)) if isinstance([i], bytes): label =[i].decode('utf-8') else: label =[i] ax1.plot(item, masked/scaling, lw=0.5, label=label) elif isinstance(boxitem, list): for i, item in enumerate(boxitem): wavelength = item.wavelength flux = item.flux data = np.array(flux, dtype=np.float64) masked =, mask=np.isnan(data)) if colors: ax1.plot(wavelength, masked/scaling, lw=0.2, color=colors[j], alpha=0.5, zorder=1) else: ax1.plot(wavelength, masked/scaling, lw=0.2, alpha=0.5, zorder=1) elif isinstance(boxitem, box.PhotometryBox): marker = next(marker) for i, item in enumerate(boxitem.wavelength): transmission = read_filter.ReadFilter(boxitem.filter_name[i]) fwhm = transmission.filter_fwhm() if colors: ax1.errorbar(item, boxitem.flux[i][0]/scaling, xerr=fwhm/2., yerr=boxitem.flux[i][1], marker=marker, ms=6, color=colors[j], zorder=3) else: ax1.errorbar(item, boxitem.flux[i][0]/scaling, xerr=fwhm/2., yerr=boxitem.flux[i][1], marker=marker, ms=6, color='black', zorder=3) elif isinstance(boxitem, box.ObjectBox): if boxitem.flux is not None: for item in boxitem.flux: transmission = read_filter.ReadFilter(item) wavelength = transmission.mean_wavelength() fwhm = transmission.filter_fwhm() if colors is None: ax1.errorbar(wavelength, boxitem.flux[item][0]/scaling, xerr=fwhm/2., yerr=boxitem.flux[item][1]/scaling, marker='s', ms=5, zorder=3, markerfacecolor=color_obj_phot) else: color_obj_phot = colors[j][0] ax1.errorbar(wavelength, boxitem.flux[item][0]/scaling, xerr=fwhm/2., yerr=boxitem.flux[item][1]/scaling, marker='s', ms=5, zorder=3, color=color_obj_phot, markerfacecolor=color_obj_phot) if boxitem.spectrum is not None: for key, value in boxitem.spectrum.items(): masked =[key][0], mask=np.isnan(boxitem.spectrum[key][0])) if colors is None: ax1.errorbar(masked[:, 0], masked[:, 1]/scaling, yerr=masked[:, 2]/scaling, ms=2, marker='s', zorder=2.5, ls='none') else: color_obj_spec = colors[j][1] ax1.errorbar(masked[:, 0], masked[:, 1]/scaling, yerr=masked[:, 2]/scaling, marker='o', ms=2, zorder=2.5, color=color_obj_spec, markerfacecolor=color_obj_spec, ls='none') elif isinstance(boxitem, box.SynphotBox): for item in boxitem.flux: transmission = read_filter.ReadFilter(item) wavelength = transmission.mean_wavelength() fwhm = transmission.filter_fwhm() ax1.errorbar(wavelength, boxitem.flux[item]/scaling, xerr=fwhm/2., yerr=None, alpha=0.7, marker='s', ms=5, zorder=4, color=colors[j], markerfacecolor='white') if filters is not None: for i, item in enumerate(filters): transmission = read_filter.ReadFilter(item) data = transmission.get_filter() ax2.plot(data[0, ], data[1, ], '-', lw=0.7, color='black', zorder=1) if residuals is not None: res_max = 0. if residuals.photometry is not None: ax3.plot(residuals.photometry[0, ], residuals.photometry[1, ], marker='s', ms=5, linestyle='none', color=color_obj_phot, zorder=2) res_max = np.nanmax(np.abs(residuals.photometry[1, ])) if residuals.spectrum is not None: for key, value in residuals.spectrum.items(): if colors is None: ax3.plot(value[:, 0], value[:, 1], marker='o', ms=2, linestyle='none', zorder=1) else: ax3.plot(value[:, 0], value[:, 1], marker='o', ms=2, linestyle='none', color=color_obj_spec, zorder=1) max_tmp = np.nanmax(np.abs(value[:, 1])) if max_tmp > res_max: res_max = max_tmp res_lim = math.ceil(1.1*res_max) if res_lim > 10.: res_lim = 5. ax3.axhline(0.0, linestyle='--', color='gray', dashes=(2, 4), zorder=0.5) if ylim_res is None: ax3.set_ylim(-res_lim, res_lim) else: ax3.set_ylim(ylim_res[0], ylim_res[1]) if filters is not None: ax2.set_ylim(0., 1.1) print(f'Plotting spectrum: {output}...', end='', flush=True) if title is not None: if filters: ax2.set_title(title, y=1.02, fontsize=15) else: ax1.set_title(title, y=1.02, fontsize=15) handles, _ = ax1.get_legend_handles_labels() if handles and legend is not None: if isinstance(legend, (str, tuple)): ax1.legend(loc=legend, fontsize=8, frameon=False) else: ax1.legend(**legend) plt.savefig(os.getcwd()+'/'+output, bbox_inches='tight') plt.clf() plt.close() print(' [DONE]')