CrossValidation#
This module contains the CrossValidation object
- class eurydice.CV.CrossValidation(training_data, test_data, inst_params, kernel, include_keplerian=False, orbit_params=None)#
A class for running and analyzing Gaussian Process-based cross validation on radial velocity data.
- Parameters:
training_data (pd.DataFrame) – Contains ‘times’, ‘rv’, ‘err’, ‘inst’ columns for conditioning the GP.
test_data (pd.DataFrame) – Contains ‘times’, ‘rv’, ‘err’, ‘inst’ columns for testing in cross-validation.
inst_params (dict) – Dictionary of {instrument: [gamma, jitter]}.
kernel (eurydice.kernel object) – Kernel object used for computing the GP covariance matrices.
include_keplerian (bool) – Whether or not to use the Keplerian signals of the planet(s) as your GP’s mean function. Defaults to False.
orbit_params (dict of dicts, optional) – Dictionary of per-planet parameters. If include_keplerian is True, this must be included. Each planet’s parameters must include keys: ‘T0’, ‘P’, ‘e’, ‘omega’, ‘K’.
- calc_total_planetary_signal(times)#
Computes the combined radial velocity signal of all exoplanets provided in orbit_params.
- Parameters:
times (np.array) – Times to calculate the planetary signal at
- Returns:
Total radial velociy signal of the system at given times
- Return type:
(np.array)
- condition()#
Conditions the Gaussian Process using the set of training data provided.
- get_CV_results()#
Get cross-validation results.
- Returns:
- (results_df, residual_stats) where
results_df (pd.DataFrame): Detailed results DataFrame. residual_stats (dict of floats): Detailed statistics of the means and standard deviations of the training and test set residuals.
- Raises:
RuntimeError – If ‘run_CV’ hasn’t been called yet.
- plot_CV(inst_to_plot=None, include_Gaussian=False, save_path=None, colors=None, labels=None, prediction_plot_axis_kwargs=None, histogram_axis_kwargs=None, legend_kwargs=None)#
Plot CV results in a multi-panel plot with prediction + residuals on the left and histogram on the right.
- Parameters:
inst_predict (str or np.array) – Instrument(s) corresponding to the data you’d like to predict for.
include_Gaussian (bool) – If True, overlay Gaussian curve fits on histogram. Defaults to False.
save_path (str or Path, optional) – If provided, saves the figure to given path.
colors (dict, optional) – Custom color mapping for ‘train’ and ‘test’.
labels (dict, optional) – Custom label mapping for ‘train’ and ‘test’.
prediction_plot_axis_kwargs (dict, optional) – Axis settings for prediction + residual panel.
histogram_axis_kwargs (dict, optional) – Axis settings for histogram.
legend_kwargs (dict, optional) – Custom legend settings.
- Returns:
(matplotlib.figure.Figure)
- predict(times_predict, inst_predict)#
Performs Gaussian Process regression conditioned on the training set of data to predict values at new points.
- Parameters:
times_predict (np.array) – Set of times for the GP to predict values at.
inst_predict (str or np.array) – Instrument(s) corresponding to the data you’d like to predict for. If a string is provided, uses the same instrument to predict at all times. If an array is passed, it must be the same length as times_predict.
- Returns:
predictive means, predictive variances
- Return type:
(np.array, np.array)
- run_CV()#
Run cross-validation by determining how well the model conditioned on the training set predicts the values of the test set.