Training Utilities

Functions

split_train_and_test_data

Splits geometry and energy datasets into training and validation sets.

split_train_and_test_data_w_forces

Splits geometry, force, and energy datasets into training and validation sets.

molpipx.utils_training.split_train_and_test_data(Geometries: jaxtyping.Float.(jaxtyping.Array, '...'), Energies: jaxtyping.Float.(jaxtyping.Array, '...'), N: int, key: jaxtyping.Key, Nval: int = 0) -> ((jaxtyping.Float, jaxtyping.Float), (jaxtyping.Float, jaxtyping.Float))[source]

Splits geometry and energy datasets into training and validation sets.

Parameters:
  • Geometries (Array) – The complete dataset of geometries (Batch, Na, 3).

  • Energies (Array) – The complete dataset of corresponding energies (Batch, 1).

  • N (int) – The number of samples to include in the training set.

  • key (Key) – A JAX random key used to shuffle the data.

  • Nval (int, optional) – The number of samples to include in the validation set. If 0 or None, the remaining samples after selecting N are used. Defaults to 0.

Returns:

Two tuples containing the split data:
  • Train: (X_tr, y_tr)

  • Validation: (X_val, y_val)

Return type:

Tuple

molpipx.utils_training.split_train_and_test_data_w_forces(Geometries: jaxtyping.Float.(jaxtyping.Array, '...'), Forces: jaxtyping.Float.(jaxtyping.Array, '...'), Energies: jaxtyping.Float.(jaxtyping.Array, '...'), N: int, key: jaxtyping.Key, Nval: int = 0) -> ((jaxtyping.Float, jaxtyping.Float, jaxtyping.Float), (jaxtyping.Float, jaxtyping.Float, jaxtyping.Float))[source]

Splits geometry, force, and energy datasets into training and validation sets.

Parameters:
  • Geometries (Array) – The complete dataset of geometries (Batch, Na, 3).

  • Forces (Array) – The complete dataset of corresponding forces (Batch, Na, 3).

  • Energies (Array) – The complete dataset of corresponding energies (Batch, 1).

  • N (int) – The number of samples to include in the training set.

  • key (Key) – A JAX random key used to shuffle the data.

  • Nval (int, optional) – The number of samples to include in the validation set. If 0 or None, the remaining samples are used. Defaults to 0.

Returns:

Two tuples containing the split data:
  • Train: (X_tr, G_tr, y_tr)

  • Validation: (X_val, G_val, y_val)

Return type:

Tuple