CrossValidation#

This module contains the CrossValidation object

class eurydice.CV.CrossValidation(training_data, test_data, inst_params, kernel, include_keplerian=False, orbit_params=None)#

A class for running and analyzing Gaussian Process-based cross validation on radial velocity data.

Parameters:
  • training_data (pd.DataFrame) – Contains ‘times’, ‘rv’, ‘err’, ‘inst’ columns for conditioning the GP.

  • test_data (pd.DataFrame) – Contains ‘times’, ‘rv’, ‘err’, ‘inst’ columns for testing in cross-validation.

  • inst_params (dict) – Dictionary of {instrument: [gamma, jitter]}.

  • kernel (eurydice.kernel object) – Kernel object used for computing the GP covariance matrices.

  • include_keplerian (bool) – Whether or not to use the Keplerian signals of the planet(s) as your GP’s mean function. Defaults to False.

  • orbit_params (dict of dicts, optional) – Dictionary of per-planet parameters. If include_keplerian is True, this must be included. Each planet’s parameters must include keys: ‘T0’, ‘P’, ‘e’, ‘omega’, ‘K’.

calc_total_planetary_signal(times)#

Computes the combined radial velocity signal of all exoplanets provided in orbit_params.

Parameters:

times (np.array) – Times to calculate the planetary signal at

Returns:

Total radial velociy signal of the system at given times

Return type:

(np.array)

condition()#

Conditions the Gaussian Process using the set of training data provided.

get_CV_results()#

Get cross-validation results.

Returns:

(results_df, residual_stats) where

results_df (pd.DataFrame): Detailed results DataFrame. residual_stats (dict of floats): Detailed statistics of the means and standard deviations of the training and test set residuals.

Raises:

RuntimeError – If ‘run_CV’ hasn’t been called yet.

plot_CV(inst_to_plot=None, include_Gaussian=False, save_path=None, colors=None, labels=None, prediction_plot_axis_kwargs=None, histogram_axis_kwargs=None, legend_kwargs=None)#

Plot CV results in a multi-panel plot with prediction + residuals on the left and histogram on the right.

Parameters:
  • inst_predict (str or np.array) – Instrument(s) corresponding to the data you’d like to predict for.

  • include_Gaussian (bool) – If True, overlay Gaussian curve fits on histogram. Defaults to False.

  • save_path (str or Path, optional) – If provided, saves the figure to given path.

  • colors (dict, optional) – Custom color mapping for ‘train’ and ‘test’.

  • labels (dict, optional) – Custom label mapping for ‘train’ and ‘test’.

  • prediction_plot_axis_kwargs (dict, optional) – Axis settings for prediction + residual panel.

  • histogram_axis_kwargs (dict, optional) – Axis settings for histogram.

  • legend_kwargs (dict, optional) – Custom legend settings.

Returns:

(matplotlib.figure.Figure)

predict(times_predict, inst_predict)#

Performs Gaussian Process regression conditioned on the training set of data to predict values at new points.

Parameters:
  • times_predict (np.array) – Set of times for the GP to predict values at.

  • inst_predict (str or np.array) – Instrument(s) corresponding to the data you’d like to predict for. If a string is provided, uses the same instrument to predict at all times. If an array is passed, it must be the same length as times_predict.

Returns:

predictive means, predictive variances

Return type:

(np.array, np.array)

run_CV()#

Run cross-validation by determining how well the model conditioned on the training set predicts the values of the test set.