Training Utilities
Functions
Splits geometry and energy datasets into training and validation sets. |
|
Splits geometry, force, and energy datasets into training and validation sets. |
- molpipx.utils_training.split_train_and_test_data(Geometries: jaxtyping.Float.(jaxtyping.Array, '...'), Energies: jaxtyping.Float.(jaxtyping.Array, '...'), N: int, key: jaxtyping.Key, Nval: int = 0) -> ((jaxtyping.Float, jaxtyping.Float), (jaxtyping.Float, jaxtyping.Float))[source]
Splits geometry and energy datasets into training and validation sets.
- Parameters:
Geometries (Array) – The complete dataset of geometries (Batch, Na, 3).
Energies (Array) – The complete dataset of corresponding energies (Batch, 1).
N (int) – The number of samples to include in the training set.
key (Key) – A JAX random key used to shuffle the data.
Nval (int, optional) – The number of samples to include in the validation set. If 0 or None, the remaining samples after selecting
Nare used. Defaults to 0.
- Returns:
- Two tuples containing the split data:
Train:
(X_tr, y_tr)Validation:
(X_val, y_val)
- Return type:
Tuple
- molpipx.utils_training.split_train_and_test_data_w_forces(Geometries: jaxtyping.Float.(jaxtyping.Array, '...'), Forces: jaxtyping.Float.(jaxtyping.Array, '...'), Energies: jaxtyping.Float.(jaxtyping.Array, '...'), N: int, key: jaxtyping.Key, Nval: int = 0) -> ((jaxtyping.Float, jaxtyping.Float, jaxtyping.Float), (jaxtyping.Float, jaxtyping.Float, jaxtyping.Float))[source]
Splits geometry, force, and energy datasets into training and validation sets.
- Parameters:
Geometries (Array) – The complete dataset of geometries (Batch, Na, 3).
Forces (Array) – The complete dataset of corresponding forces (Batch, Na, 3).
Energies (Array) – The complete dataset of corresponding energies (Batch, 1).
N (int) – The number of samples to include in the training set.
key (Key) – A JAX random key used to shuffle the data.
Nval (int, optional) – The number of samples to include in the validation set. If 0 or None, the remaining samples are used. Defaults to 0.
- Returns:
- Two tuples containing the split data:
Train:
(X_tr, G_tr, y_tr)Validation:
(X_val, G_val, y_val)
- Return type:
Tuple