Deep PIP Kernels

Classes

PIPLayerKernel

A GPJax kernel that applies a Neural Network transformation (PIP) before computing the base kernel.

class molpipx.pipgp_flax.PIPLayerKernel(*args: Any, **kwargs: Any)[source]

Bases: AbstractKernel

A GPJax kernel that applies a Neural Network transformation (PIP) before computing the base kernel.

Variables:
  • base_kernel (gpx.kernels.base.AbstractKernel) – The kernel function to apply after the transformation.

  • network (nn.Module) – A Flax Neural Network module (e.g., PIPlayer) that transforms the input.

  • dummy_x (jax.Array) – A sample input array used to initialize the neural network parameters.

  • key (jax.Array) – A JAX random key used for parameter initialization. Defaults to jax.random.PRNGKey(123).

  • compute_engine (DenseKernelComputation) – The computation engine for the kernel matrix.

  • nn_params (Any) – The initialized parameters of the neural network (automatically generated in __post_init__).

__call__(x: jaxtyping.Float.(jaxtyping.Array, 'D'), y: jaxtyping.Float.(jaxtyping.Array, 'D')) -> jaxtyping.Float.(jaxtyping.Array, '1')[source]

Computes the kernel value between two inputs after transforming them via the network.

Parameters:
  • x (Array) – First input vector. If 1D, it is reshaped to (1, N_atoms, 3).

  • y (Array) – Second input vector.

Returns:

The scalar kernel value.

Return type:

Array

__post_init__()[source]

Initializes the network parameters using the provided dummy input.