Gradient Utilities
Functions
Compute the energy and the forces for a Flax based PIP model using |
|
Compute the forces for a Flax based PIP model using reverse mode differentiation. |
|
Computes forces (gradients) for a Gaussian Process model. |
|
Computes the gradients of the PIP vectors with respect to the input coordinates. |
- molpipx.utils_gradients.get_energy_and_forces(model: Callable, x: jaxtyping.Float.(jaxtyping.Array, '...'), params: jaxtyping.PyTree) -> jaxtyping.Float.(jaxtyping.Array, '...')[source]
Compute the energy and the forces for a Flax based PIP model using
value_and_gradfunction.- Parameters:
model (Callable) – The energy model function.
x (Array) – Input geometries with shape (Batch, Na, 3).
params (PyTree) – The parameters of the model.
- Returns:
- A tuple containing:
Energy values (Batch, 1)
Forces/Gradients (Batch, Na, 3)
- Return type:
Tuple[Array, Array]
- molpipx.utils_gradients.get_forces(model: Callable, x: jaxtyping.Float.(jaxtyping.Array, '...'), params: jaxtyping.PyTree) -> jaxtyping.Float.(jaxtyping.Array, '...')[source]
Compute the forces for a Flax based PIP model using reverse mode differentiation.
- Parameters:
model (Callable) – The energy model function (e.g.,
energy_model.apply).x (Array) – Input geometries with shape (Batch, Na, 3).
params (PyTree) – The parameters of the energy model.
- Returns:
Forces (gradients) for each atom, shape (Batch, Na, 3).
- Return type:
Array
- molpipx.utils_gradients.get_forces_gp(train_data, gp_model, x)[source]
Computes forces (gradients) for a Gaussian Process model.
- Parameters:
train_data (Dataset) – The training dataset used by the GP.
gp_model (GP) – The GPJax model instance.
x (Array) – Input geometries.
- Returns:
- A tuple containing:
Forces (gradients of the mean)
A tuple of (predictive_mean, predictive_std)
- Return type:
Tuple
- molpipx.utils_gradients.get_pip_grad(model_pip: Callable, x: jaxtyping.Float.(jaxtyping.Array, '...'), params_pip: jaxtyping.PyTree) -> jaxtyping.Float.(jaxtyping.Array, '...')[source]
Computes the gradients of the PIP vectors with respect to the input coordinates.
- Parameters:
model_pip (Callable) – The PIP model function (e.g.,
model.apply).x (Array) – Input geometries with shape (Batch, Na, 3).
params_pip (PyTree) – The parameters of the PIP model.
- Returns:
Gradients of the PIP model, shape (Batch, N_pips * Na * 3).
- Return type:
Array