Deep PIP Kernels
Classes
A GPJax kernel that applies a Neural Network transformation (PIP) before computing the base kernel. |
- class molpipx.pipgp_flax.PIPLayerKernel(*args: Any, **kwargs: Any)[source]
Bases:
AbstractKernelA GPJax kernel that applies a Neural Network transformation (PIP) before computing the base kernel.
- Variables:
base_kernel (gpx.kernels.base.AbstractKernel) – The kernel function to apply after the transformation.
network (nn.Module) – A Flax Neural Network module (e.g.,
PIPlayer) that transforms the input.dummy_x (jax.Array) – A sample input array used to initialize the neural network parameters.
key (jax.Array) – A JAX random key used for parameter initialization. Defaults to
jax.random.PRNGKey(123).compute_engine (DenseKernelComputation) – The computation engine for the kernel matrix.
nn_params (Any) – The initialized parameters of the neural network (automatically generated in
__post_init__).
- __call__(x: jaxtyping.Float.(jaxtyping.Array, 'D'), y: jaxtyping.Float.(jaxtyping.Array, 'D')) -> jaxtyping.Float.(jaxtyping.Array, '1')[source]
Computes the kernel value between two inputs after transforming them via the network.
- Parameters:
x (Array) – First input vector. If 1D, it is reshaped to (1, N_atoms, 3).
y (Array) – Second input vector.
- Returns:
The scalar kernel value.
- Return type:
Array