Gradient Utilities

Functions

get_energy_and_forces

Compute the energy and the forces for a Flax based PIP model using value_and_grad function.

get_forces

Compute the forces for a Flax based PIP model using reverse mode differentiation.

get_forces_gp

Computes forces (gradients) for a Gaussian Process model.

get_pip_grad

Computes the gradients of the PIP vectors with respect to the input coordinates.

molpipx.utils_gradients.get_energy_and_forces(model: Callable, x: jaxtyping.Float.(jaxtyping.Array, '...'), params: jaxtyping.PyTree) -> jaxtyping.Float.(jaxtyping.Array, '...')[source]

Compute the energy and the forces for a Flax based PIP model using value_and_grad function.

Parameters:
  • model (Callable) – The energy model function.

  • x (Array) – Input geometries with shape (Batch, Na, 3).

  • params (PyTree) – The parameters of the model.

Returns:

A tuple containing:
  • Energy values (Batch, 1)

  • Forces/Gradients (Batch, Na, 3)

Return type:

Tuple[Array, Array]

molpipx.utils_gradients.get_forces(model: Callable, x: jaxtyping.Float.(jaxtyping.Array, '...'), params: jaxtyping.PyTree) -> jaxtyping.Float.(jaxtyping.Array, '...')[source]

Compute the forces for a Flax based PIP model using reverse mode differentiation.

Parameters:
  • model (Callable) – The energy model function (e.g., energy_model.apply).

  • x (Array) – Input geometries with shape (Batch, Na, 3).

  • params (PyTree) – The parameters of the energy model.

Returns:

Forces (gradients) for each atom, shape (Batch, Na, 3).

Return type:

Array

molpipx.utils_gradients.get_forces_gp(train_data, gp_model, x)[source]

Computes forces (gradients) for a Gaussian Process model.

Parameters:
  • train_data (Dataset) – The training dataset used by the GP.

  • gp_model (GP) – The GPJax model instance.

  • x (Array) – Input geometries.

Returns:

A tuple containing:
  • Forces (gradients of the mean)

  • A tuple of (predictive_mean, predictive_std)

Return type:

Tuple

molpipx.utils_gradients.get_pip_grad(model_pip: Callable, x: jaxtyping.Float.(jaxtyping.Array, '...'), params_pip: jaxtyping.PyTree) -> jaxtyping.Float.(jaxtyping.Array, '...')[source]

Computes the gradients of the PIP vectors with respect to the input coordinates.

Parameters:
  • model_pip (Callable) – The PIP model function (e.g., model.apply).

  • x (Array) – Input geometries with shape (Batch, Na, 3).

  • params_pip (PyTree) – The parameters of the PIP model.

Returns:

Gradients of the PIP model, shape (Batch, N_pips * Na * 3).

Return type:

Array