Utilities

Functions

all_distances

Computes all pairwise distances between atoms in a molecule.

flax_params

Updates the weights of the first Dense layer in a Flax PyTree.

flax_params_aniso

Updates the length scale parameters for an Anisotropic PIP model.

mae_loss

Computes the Mean Absolute Error (MAE) loss.

morse_variables

Computes Morse-like variables using a single length scale parameter.

mse_loss

Computes the Mean Squared Error (MSE) loss.

softplus_inverse

Computes the inverse of the softplus function.

molpipx.utils.all_distances(xi: jaxtyping.Float.(jaxtyping.Array, '...')) -> jaxtyping.Float.(jaxtyping.Array, '...')[source]

Computes all pairwise distances between atoms in a molecule.

Calculates the Euclidean distance (L2 norm) between every pair of atoms in the input geometry. It returns the upper triangular part of the distance matrix in lexicographical order.

Parameters:

xi (Array) – Cartesian coordinates of the atoms, shape (N_atoms, 3).

Returns:

A flattened array containing all unique pairwise distances.

Return type:

Array

molpipx.utils.flax_params(w: jaxtyping.Float.(jaxtyping.Array, '...'), params: jaxtyping.PyTree) jaxtyping.PyTree[source]

Updates the weights of the first Dense layer in a Flax PyTree.

Parameters:
  • w (Array) – Array containing the new linear weights (e.g., from a least-squares solution).

  • params (PyTree) – The existing Flax parameter PyTree.

Returns:

The updated parameter PyTree.

Return type:

PyTree

molpipx.utils.flax_params_aniso(l: jaxtyping.Float.(jaxtyping.Array, '...'), params: jaxtyping.PyTree) jaxtyping.PyTree[source]

Updates the length scale parameters for an Anisotropic PIP model.

Warning

This function assumes a specific Flax model structure (Anisotropic PIP).

Parameters:
  • l (Array) – Array containing the new length scale parameters.

  • params (PyTree) – The existing Flax parameter PyTree.

Returns:

The updated parameter PyTree.

Return type:

PyTree

molpipx.utils.mae_loss(predictions: jaxtyping.Float.(jaxtyping.Array, 'dim1'), targets: jaxtyping.Float.(jaxtyping.Array, 'dim1')) jaxtyping.Float[source]

Computes the Mean Absolute Error (MAE) loss.

Parameters:
  • predictions (Array) – The predicted values.

  • targets (Array) – The ground truth values.

Returns:

The mean absolute error.

Return type:

Float

molpipx.utils.morse_variables(x: jaxtyping.Float.(jaxtyping.Array, 'dim1'), l: jaxtyping.Float.(jaxtyping.Array, '')) -> jaxtyping.Float.(jaxtyping.Array, 'dim1')[source]

Computes Morse-like variables using a single length scale parameter.

Parameters:
  • x (Array) – Cartesian coordinates of the atoms (N_atoms, 3).

  • l (float) – Length scale parameter (decay rate).

Returns:

The computed Morse variables for all pairwise distances.

Return type:

Array

molpipx.utils.mse_loss(predictions: jaxtyping.Float.(jaxtyping.Array, 'dim1'), targets: jaxtyping.Float.(jaxtyping.Array, 'dim1')) jaxtyping.Float[source]

Computes the Mean Squared Error (MSE) loss.

Parameters:
  • predictions (Array) – The predicted values.

  • targets (Array) – The ground truth values.

Returns:

The mean squared error.

Return type:

Float

molpipx.utils.softplus_inverse(x: jaxtyping.Float.(jaxtyping.Array, 'dim1')) -> jaxtyping.Float.(jaxtyping.Array, 'dim1')[source]

Computes the inverse of the softplus function.

Parameters:

x (Array) – Input value (must be positive).

Returns:

The inverse softplus of the input.

Return type:

Array