Source code for species.plot.plot_spectrum

"""
Module with a function for plotting spectra.
"""

import os
import math
import itertools

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from species.core import box, constants
from species.read import read_filter
from species.util import plot_util


mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)
plt.rcParams['axes.axisbelow'] = False


[docs]def plot_spectrum(boxes, filters=None, residuals=None, colors=None, xlim=None, ylim=None, ylim_res=None, scale=('linear', 'linear'), title=None, offset=None, legend=None, figsize=(7., 5.), object_type='planet', quantity='flux', output='spectrum.pdf'): """ Parameters ---------- boxes : list(species.core.box, ) Boxes with data. filters : list(str, ), None Filter IDs for which the transmission profile is plotted. Not plotted if set to None. residuals : species.core.box.ResidualsBox, None Box with residuals of a fit. Not plotted if set to None. colors : list(str, ), None Colors to be used for the different boxes. Note that a box with residuals requires a tuple with two colors (i.e., for the photometry and spectrum). Automatic colors are used if set to None. xlim : tuple(float, float) Limits of the wavelength axis. ylim : tuple(float, float) Limits of the flux axis. ylim_res : tuple(float, float), None Limits of the residuals axis. Automatically chosen (based on the minimum and maximum residual value) if set to None. scale : tuple(str, str) Scale of the axes ('linear' or 'log'). title : str Title. offset : tuple(float, float) Offset for the label of the x- and y-axis. legend : str, tuple, dict, None Location of the legend (str, tuple) or a dictionary with the ``**kwargs`` of ``matplotlib.pyplot.legend``, e.g. ``{'loc': 'upper left', 'fontsize: 12.}``. figsize : tuple(float, float) Figure size. object_type : str Object type ('planet' or 'star'). With 'planet', the radius and mass are expressed in Jupiter units. With 'star', the radius and mass are expressed in solar units. quantity: str The quantity of the y-axis ('flux' or 'magnitude'). output : str Output filename. Returns ------- NoneType None """ marker = itertools.cycle(('o', 's', '*', 'p', '<', '>', 'P', 'v', '^')) if colors is not None and len(boxes) != len(colors): raise ValueError(f'The number of \'boxes\' ({len(boxes)}) should be equal to the ' f'number of \'colors\' ({len(colors)}).') if residuals is not None and filters is not None: plt.figure(1, figsize=figsize) gridsp = mpl.gridspec.GridSpec(3, 1, height_ratios=[1, 3, 1]) gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1) ax1 = plt.subplot(gridsp[1, 0]) ax2 = plt.subplot(gridsp[0, 0]) ax3 = plt.subplot(gridsp[2, 0]) elif residuals is not None: plt.figure(1, figsize=figsize) gridsp = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1]) gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1) ax1 = plt.subplot(gridsp[0, 0]) ax3 = plt.subplot(gridsp[1, 0]) elif filters is not None: plt.figure(1, figsize=figsize) gridsp = mpl.gridspec.GridSpec(2, 1, height_ratios=[1, 4]) gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1) ax1 = plt.subplot(gridsp[1, 0]) ax2 = plt.subplot(gridsp[0, 0]) else: plt.figure(1, figsize=figsize) gridsp = mpl.gridspec.GridSpec(1, 1) gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1) ax1 = plt.subplot(gridsp[0, 0]) if residuals is not None: labelbottom = False else: labelbottom = True ax1.tick_params(axis='both', which='major', colors='black', labelcolor='black', direction='in', width=1, length=5, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=labelbottom) ax1.tick_params(axis='both', which='minor', colors='black', labelcolor='black', direction='in', width=1, length=3, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=labelbottom) if filters is not None: ax2.tick_params(axis='both', which='major', colors='black', labelcolor='black', direction='in', width=1, length=5, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=False) ax2.tick_params(axis='both', which='minor', colors='black', labelcolor='black', direction='in', width=1, length=3, labelsize=12, top=True, bottom=True, left=True, right=True, labelbottom=False) if residuals is not None: ax3.tick_params(axis='both', which='major', colors='black', labelcolor='black', direction='in', width=1, length=5, labelsize=12, top=True, bottom=True, left=True, right=True) ax3.tick_params(axis='both', which='minor', colors='black', labelcolor='black', direction='in', width=1, length=3, labelsize=12, top=True, bottom=True, left=True, right=True) if residuals is not None and filters is not None: ax1.set_xlabel('', fontsize=13) ax2.set_xlabel('', fontsize=13) ax3.set_xlabel(r'Wavelength ($\mu$m)', fontsize=13) elif residuals is not None: ax1.set_xlabel('', fontsize=13) ax3.set_xlabel(r'Wavelength ($\mu$m)', fontsize=13) elif filters is not None: ax1.set_xlabel(r'Wavelength ($\mu$m)', fontsize=13) ax2.set_xlabel('', fontsize=13) else: ax1.set_xlabel(r'Wavelength ($\mu$m)', fontsize=13) if filters is not None: ax2.set_ylabel('Transmission', fontsize=13) if residuals is not None: ax3.set_ylabel(r'Residual ($\sigma$)', fontsize=13) if xlim is not None: ax1.set_xlim(xlim[0], xlim[1]) else: ax1.set_xlim(0.6, 6.) if quantity == 'magnitude': scaling = 1. ax1.set_ylabel('Flux contrast (mag)', fontsize=13) if ylim: ax1.set_ylim(ylim[0], ylim[1]) elif quantity == 'flux': if ylim: ax1.set_ylim(ylim[0], ylim[1]) ylim = ax1.get_ylim() exponent = math.floor(math.log10(ylim[1])) scaling = 10.**exponent ylabel = r'Flux (10$^{'+str(exponent)+r'}$ W m$^{-2}$ $\mu$m$^{-1}$)' ax1.set_ylabel(ylabel, fontsize=13) ax1.set_ylim(ylim[0]/scaling, ylim[1]/scaling) if ylim[0] < 0.: ax1.axhline(0.0, linestyle='--', color='gray', dashes=(2, 4), zorder=0.5) else: ax1.set_ylabel(r'Flux (W m$^{-2}$ $\mu$m$^{-1}$)', fontsize=13) scaling = 1. if filters is not None: ax2.set_ylim(0., 1.) xlim = ax1.get_xlim() if filters is not None: ax2.set_xlim(xlim[0], xlim[1]) if residuals is not None: ax3.set_xlim(xlim[0], xlim[1]) if offset is not None and residuals is not None and filters is not None: ax3.get_xaxis().set_label_coords(0.5, offset[0]) ax1.get_yaxis().set_label_coords(offset[1], 0.5) ax2.get_yaxis().set_label_coords(offset[1], 0.5) ax3.get_yaxis().set_label_coords(offset[1], 0.5) elif offset is not None and filters is not None: ax1.get_xaxis().set_label_coords(0.5, offset[0]) ax1.get_yaxis().set_label_coords(offset[1], 0.5) ax2.get_yaxis().set_label_coords(offset[1], 0.5) elif offset is not None and residuals is not None: ax3.get_xaxis().set_label_coords(0.5, offset[0]) ax1.get_yaxis().set_label_coords(offset[1], 0.5) ax3.get_yaxis().set_label_coords(offset[1], 0.5) elif offset is not None: ax1.get_xaxis().set_label_coords(0.5, offset[0]) ax1.get_yaxis().set_label_coords(offset[1], 0.5) else: ax1.get_xaxis().set_label_coords(0.5, -0.12) ax1.get_yaxis().set_label_coords(-0.1, 0.5) ax1.set_xscale(scale[0]) ax1.set_yscale(scale[1]) if filters is not None: ax2.set_xscale(scale[0]) if residuals is not None: ax3.set_xscale(scale[0]) color_obj_phot = None color_obj_spec = None for j, boxitem in enumerate(boxes): if isinstance(boxitem, (box.SpectrumBox, box.ModelBox)): wavelength = boxitem.wavelength flux = boxitem.flux if isinstance(wavelength[0], (np.float32, np.float64)): data = np.array(flux, dtype=np.float64) masked = np.ma.array(data, mask=np.isnan(data)) if isinstance(boxitem, box.ModelBox): param = boxitem.parameters par_key, par_unit, par_label = plot_util.quantity_unit( param=list(param.keys()), object_type=object_type) label = '' newline = False for i, item in enumerate(par_key): if item == 'teff': value = f'{param[item]:.1f}' elif item in ['logg', 'feh', 'co', 'fsed']: value = f'{param[item]:.2f}' elif item == 'radius': if object_type == 'planet': value = f'{param[item]:.2f}' elif object_type == 'star': value = f'{param[item]*constants.R_JUP/constants.R_SUN:.2f}' elif item == 'mass': if object_type == 'planet': value = f'{param[item]:.2f}' elif object_type == 'star': value = f'{param[item]*constants.M_JUP/constants.M_SUN:.2f}' elif item == 'luminosity': value = f'{param[item]:.1e}' else: continue # if len(label) > 110 and newline == False: # label += '\n' # newline = True if par_unit[i] is None: label += f'{par_label[i]} = {value}' else: label += f'{par_label[i]} = {value} {par_unit[i]}' if i < len(par_key)-1: label += ', ' else: label = None if colors: ax1.plot(wavelength, masked/scaling, color=colors[j], lw=0.5, label=label, zorder=2) else: ax1.plot(wavelength, masked/scaling, lw=0.5, label=label, zorder=2) elif isinstance(wavelength[0], (np.ndarray)): for i, item in enumerate(wavelength): data = np.array(flux[i], dtype=np.float64) masked = np.ma.array(data, mask=np.isnan(data)) if isinstance(boxitem.name[i], bytes): label = boxitem.name[i].decode('utf-8') else: label = boxitem.name[i] ax1.plot(item, masked/scaling, lw=0.5, label=label) elif isinstance(boxitem, list): for i, item in enumerate(boxitem): wavelength = item.wavelength flux = item.flux data = np.array(flux, dtype=np.float64) masked = np.ma.array(data, mask=np.isnan(data)) if colors: ax1.plot(wavelength, masked/scaling, lw=0.2, color=colors[j], alpha=0.5, zorder=1) else: ax1.plot(wavelength, masked/scaling, lw=0.2, alpha=0.5, zorder=1) elif isinstance(boxitem, box.PhotometryBox): marker = next(marker) for i, item in enumerate(boxitem.wavelength): transmission = read_filter.ReadFilter(boxitem.filter_name[i]) fwhm = transmission.filter_fwhm() if colors: ax1.errorbar(item, boxitem.flux[i][0]/scaling, xerr=fwhm/2., yerr=boxitem.flux[i][1], marker=marker, ms=6, color=colors[j], zorder=3) else: ax1.errorbar(item, boxitem.flux[i][0]/scaling, xerr=fwhm/2., yerr=boxitem.flux[i][1], marker=marker, ms=6, color='black', zorder=3) elif isinstance(boxitem, box.ObjectBox): if boxitem.flux is not None: for item in boxitem.flux: transmission = read_filter.ReadFilter(item) wavelength = transmission.mean_wavelength() fwhm = transmission.filter_fwhm() if colors is None: ax1.errorbar(wavelength, boxitem.flux[item][0]/scaling, xerr=fwhm/2., yerr=boxitem.flux[item][1]/scaling, marker='s', ms=5, zorder=3, markerfacecolor=color_obj_phot) else: color_obj_phot = colors[j][0] ax1.errorbar(wavelength, boxitem.flux[item][0]/scaling, xerr=fwhm/2., yerr=boxitem.flux[item][1]/scaling, marker='s', ms=5, zorder=3, color=color_obj_phot, markerfacecolor=color_obj_phot) if boxitem.spectrum is not None: for key, value in boxitem.spectrum.items(): masked = np.ma.array(boxitem.spectrum[key][0], mask=np.isnan(boxitem.spectrum[key][0])) if colors is None: ax1.errorbar(masked[:, 0], masked[:, 1]/scaling, yerr=masked[:, 2]/scaling, ms=2, marker='s', zorder=2.5, ls='none') else: color_obj_spec = colors[j][1] ax1.errorbar(masked[:, 0], masked[:, 1]/scaling, yerr=masked[:, 2]/scaling, marker='o', ms=2, zorder=2.5, color=color_obj_spec, markerfacecolor=color_obj_spec, ls='none') elif isinstance(boxitem, box.SynphotBox): for item in boxitem.flux: transmission = read_filter.ReadFilter(item) wavelength = transmission.mean_wavelength() fwhm = transmission.filter_fwhm() ax1.errorbar(wavelength, boxitem.flux[item]/scaling, xerr=fwhm/2., yerr=None, alpha=0.7, marker='s', ms=5, zorder=4, color=colors[j], markerfacecolor='white') if filters is not None: for i, item in enumerate(filters): transmission = read_filter.ReadFilter(item) data = transmission.get_filter() ax2.plot(data[0, ], data[1, ], '-', lw=0.7, color='black', zorder=1) if residuals is not None: res_max = 0. if residuals.photometry is not None: ax3.plot(residuals.photometry[0, ], residuals.photometry[1, ], marker='s', ms=5, linestyle='none', color=color_obj_phot, zorder=2) res_max = np.nanmax(np.abs(residuals.photometry[1, ])) if residuals.spectrum is not None: for key, value in residuals.spectrum.items(): if colors is None: ax3.plot(value[:, 0], value[:, 1], marker='o', ms=2, linestyle='none', zorder=1) else: ax3.plot(value[:, 0], value[:, 1], marker='o', ms=2, linestyle='none', color=color_obj_spec, zorder=1) max_tmp = np.nanmax(np.abs(value[:, 1])) if max_tmp > res_max: res_max = max_tmp res_lim = math.ceil(1.1*res_max) if res_lim > 10.: res_lim = 5. ax3.axhline(0.0, linestyle='--', color='gray', dashes=(2, 4), zorder=0.5) if ylim_res is None: ax3.set_ylim(-res_lim, res_lim) else: ax3.set_ylim(ylim_res[0], ylim_res[1]) if filters is not None: ax2.set_ylim(0., 1.1) print(f'Plotting spectrum: {output}...', end='', flush=True) if title is not None: if filters: ax2.set_title(title, y=1.02, fontsize=15) else: ax1.set_title(title, y=1.02, fontsize=15) handles, _ = ax1.get_legend_handles_labels() if handles and legend is not None: if isinstance(legend, (str, tuple)): ax1.legend(loc=legend, fontsize=8, frameon=False) else: ax1.legend(**legend) plt.savefig(os.getcwd()+'/'+output, bbox_inches='tight') plt.clf() plt.close() print(' [DONE]')