PIPGP Tutorial
In this tutorial, we show how to train a Gaussian Process (GP) model enhanced with Permutationally Invariant Polynomials (PIPGP) using the Methane dataset.
Configuration
We define the hyperparameters for the GP kernel and the training schedule.
# Hyperparameters
MOLECULE = 'A4B' # Methane symmetry
POLY_DEGREE = 3 # PIP degree
KERNEL_TYPE = 'Matern52'
N_TR = 1000 # Training samples
N_TST = 5000 # Test samples
NUM_ITER = 800 # Optimization steps
# Optimizer schedule settings
INIT_LR = 0.0
PEAK_LR = 0.01
END_LR = 0.0
WARMUP_STEPS = 75
DECAY_STEPS = 700
Data Loading
import jax
import jax.numpy as jnp
import gpjax as gpx
import optax as ox
from gpjax import Dataset
# MOLPIPx Imports
from molpipx import PIPlayerGP, PIPLayerKernel, get_forces_gp
from molpipx.pip_generator import get_functions, detect_molecule
from molpipx.utils_training import split_train_and_test_data
from molpipx.data import load_methane
We load the Methane dataset and prepare it for GPJax. Note that GPJax typically expects 2D inputs, so we reshape the geometry array (N, Atoms, 3) into (N, Atoms*3).
# 1. Load Data
X_all, _, y_all, atoms = load_methane(energy_normalization=False)
# 2. Split Data
key = jax.random.PRNGKey(123)
(X_tr, y_tr), (X_tst, y_tst) = split_train_and_test_data(
X_all, y_all, N_TR, key, N_TST
)
# 3. Create GPJax Dataset
# Reshape inputs to (N, Na * 3)
X_tr_flat = X_tr.reshape(X_tr.shape[0], -1).astype(jnp.float64)
y_tr_64 = y_tr.astype(jnp.float64)
train_ds = Dataset(X=X_tr_flat, y=y_tr_64)
# Prepare Test Data (Dictionary format for evaluation)
X_tst_flat = X_tst.reshape(X_tst.shape[0], -1).astype(jnp.float64)
tst_ds = {'x': X_tst_flat, 'e': y_tst.astype(jnp.float64)}
Model Initialization
We initialize the PIP layer (feature extractor) and wrap it inside a PIPLayerKernel. This kernel projects the atomic coordinates into the PIP space before applying the base GP kernel.
# 1. Generate Basis Functions
f_mono, f_poly = get_functions(MOLECULE, POLY_DEGREE)
# 2. Initialize PIP Layer (Feature Extractor)
# We use a dummy input to initialize parameters
na = 5
x0_dummy = jnp.ones((1, na, 3))
pipgp_layer = PIPlayerGP(f_mono, f_poly, trainable_l=False)
params = pipgp_layer.init(key, x0_dummy)
# Determine output dimension (number of polynomials)
output = pipgp_layer.apply(params, x0_dummy)
n_poly = output.shape[1]
# 3. Define Base Kernel (Matern 5/2)
base_kernel = gpx.kernels.Matern52(
active_dims=list(range(n_poly)),
lengthscale=jnp.ones((n_poly,))
)
# 4. Create the Composite PIP Kernel
kernel = PIPLayerKernel(
network=pipgp_layer,
base_kernel=base_kernel,
key=key,
dummy_x=x0_dummy
)
Gaussian Process Setup
We construct the Gaussian Process by defining the prior (Mean + Kernel) and the likelihood.
# Zero Mean Function
meanf = gpx.mean_functions.Zero()
# Prior Distribution
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
# Gaussian Likelihood
likelihood = gpx.likelihoods.Gaussian(num_datapoints=train_ds.n)
# Posterior
posterior = prior * likelihood
Training
We optimize the GP hyperparameters using optax and the negative log-marginal likelihood objective.
# 1. Define Optimizer Schedule
schedule = ox.warmup_cosine_decay_schedule(
init_value=INIT_LR,
peak_value=PEAK_LR,
warmup_steps=WARMUP_STEPS,
decay_steps=DECAY_STEPS,
end_value=END_LR
)
optimiser = ox.chain(
ox.clip(1.0),
ox.adamw(learning_rate=schedule),
)
# 2. Run Optimization
print("Starting training...")
opt_posterior, history = gpx.fit(
model=posterior,
objective=jax.jit(gpx.objectives.conjugate_mll),
train_data=train_ds,
optim=optimiser,
num_iters=NUM_ITER,
key=key,
safe=False
)
Prediction (Energy & Forces)
Finally, we predict the potential energy and forces on the test set.
# 1. Predict Energy (Mean and Std Dev)
latent_dist = opt_posterior(tst_ds['x'], train_data=train_ds)
predictive_dist = opt_posterior.likelihood(latent_dist)
pred_mean = predictive_dist.mean
pred_std = jnp.sqrt(predictive_dist.variance)
# 2. Predict Forces (Gradients)
# We use the helper function 'get_forces_gp' which handles gradients automatically
forces_pred, _ = get_forces_gp(
gp_model=opt_posterior,
train_data=train_ds,
x=tst_ds['x']
)