Utilities
Functions
Computes all pairwise distances between atoms in a molecule. |
|
Updates the weights of the first Dense layer in a Flax PyTree. |
|
Updates the length scale parameters for an Anisotropic PIP model. |
|
Computes the Mean Absolute Error (MAE) loss. |
|
Computes Morse-like variables using a single length scale parameter. |
|
Computes the Mean Squared Error (MSE) loss. |
|
Computes the inverse of the softplus function. |
- molpipx.utils.all_distances(xi: jaxtyping.Float.(jaxtyping.Array, '...')) -> jaxtyping.Float.(jaxtyping.Array, '...')[source]
Computes all pairwise distances between atoms in a molecule.
Calculates the Euclidean distance (L2 norm) between every pair of atoms in the input geometry. It returns the upper triangular part of the distance matrix in lexicographical order.
- Parameters:
xi (Array) – Cartesian coordinates of the atoms, shape (N_atoms, 3).
- Returns:
A flattened array containing all unique pairwise distances.
- Return type:
Array
- molpipx.utils.flax_params(w: jaxtyping.Float.(jaxtyping.Array, '...'), params: jaxtyping.PyTree) jaxtyping.PyTree[source]
Updates the weights of the first Dense layer in a Flax PyTree.
- Parameters:
w (Array) – Array containing the new linear weights (e.g., from a least-squares solution).
params (PyTree) – The existing Flax parameter PyTree.
- Returns:
The updated parameter PyTree.
- Return type:
PyTree
- molpipx.utils.flax_params_aniso(l: jaxtyping.Float.(jaxtyping.Array, '...'), params: jaxtyping.PyTree) jaxtyping.PyTree[source]
Updates the length scale parameters for an Anisotropic PIP model.
Warning
This function assumes a specific Flax model structure (Anisotropic PIP).
- Parameters:
l (Array) – Array containing the new length scale parameters.
params (PyTree) – The existing Flax parameter PyTree.
- Returns:
The updated parameter PyTree.
- Return type:
PyTree
- molpipx.utils.mae_loss(predictions: jaxtyping.Float.(jaxtyping.Array, 'dim1'), targets: jaxtyping.Float.(jaxtyping.Array, 'dim1')) jaxtyping.Float[source]
Computes the Mean Absolute Error (MAE) loss.
- Parameters:
predictions (Array) – The predicted values.
targets (Array) – The ground truth values.
- Returns:
The mean absolute error.
- Return type:
Float
- molpipx.utils.morse_variables(x: jaxtyping.Float.(jaxtyping.Array, 'dim1'), l: jaxtyping.Float.(jaxtyping.Array, '')) -> jaxtyping.Float.(jaxtyping.Array, 'dim1')[source]
Computes Morse-like variables using a single length scale parameter.
- Parameters:
x (Array) – Cartesian coordinates of the atoms (N_atoms, 3).
l (float) – Length scale parameter (decay rate).
- Returns:
The computed Morse variables for all pairwise distances.
- Return type:
Array
- molpipx.utils.mse_loss(predictions: jaxtyping.Float.(jaxtyping.Array, 'dim1'), targets: jaxtyping.Float.(jaxtyping.Array, 'dim1')) jaxtyping.Float[source]
Computes the Mean Squared Error (MSE) loss.
- Parameters:
predictions (Array) – The predicted values.
targets (Array) – The ground truth values.
- Returns:
The mean squared error.
- Return type:
Float