Linear PIP Tutorial
In this tutorial, we will learn how to train a Linear Permutationally Invariant Polynomial (PIP) model from scratch using the Methane dataset.
Configuration
We define the hyperparameters for our training. In a script, these might come from command-line arguments, but here we define them directly.
# Configuration parameters
cfg = config_dict.ConfigDict()
cfg.molecule = 'A4B' # Methane symmetry
cfg.poly_degree = 3 # Polynomial degree
cfg.ntr = 1000 # Number of training points
cfg.nval = 1000 # Number of validation points
cfg.seed = 0 # Random seed
Data Loading
import jax
import jax.numpy as jnp
from ml_collections import config_dict
# MOLPIPx imports
from molpipx import EnergyPIP, PIPlayer as PIP
from molpipx import training, flax_params, mse_loss
from molpipx.pip_generator import get_functions, detect_molecule
from molpipx.utils_training import split_train_and_test_data
# Import the built-in data loader
from molpipx.data import load_methane
First, We need to load the Methane geometry and energy data.
# 1. Load Data
# Returns: (Geometries, Forces, Energies, Atoms)
X_all, _, y_all, atoms = load_methane(energy_normalization=False)
# 2. Split into Training and Validation sets
rng = jax.random.PRNGKey(cfg.seed)
_, key = jax.random.split(rng)
(X_tr, y_tr), (X_val, y_val) = split_train_and_test_data(
X_all, y_all, cfg.ntr, key, cfg.nval
)
Model Initialization
We need to generate the basis functions (monomials and polynomials) specific to the molecule’s symmetry before initializing the model.
# Detect molecule symmetry and generate functions
mol_dict, mol_sym = detect_molecule(cfg.molecule)
f_mono, f_poly = get_functions(cfg.molecule, cfg.poly_degree)
# Initialize the PIP Model
# We use 'PIP' for the model structure and 'EnergyPIP' for inference
model_pip = PIP(f_mono, f_poly)
model_energy = EnergyPIP(f_mono, f_poly)
# Initialize parameters using a single data point
params = model_energy.init(key, X_tr[:1])
Training
Since this is a linear model, we can solve for the optimal weights directly using the training function (Linear Least Squares).
# Solve for weights 'w'
w = training(model_pip, X_tr, y_tr)
# Update the parameters with the optimized weights
params = flax_params(w, params)
Evaluation
Finally, we use the trained parameters to predict energies on the validation set and calculate the Mean Squared Error (MSE).
# Predict energies
y_pred_tr = model_energy.apply(params, X_tr)
y_pred_val = model_energy.apply(params, X_val)
# Calculate Loss
loss_tr = mse_loss(y_pred_tr, y_tr)
loss_val = mse_loss(y_pred_val, y_val)
Training Utilities
For training linear models, MOLPIPx provides specialized functions to handle the linear least squares optimization. We use flax_params to handle the conversion between the raw optimized weights and the Flax Pytree structure.
from molpipx import training, flax_params
# 1. Initialize the models
model_pip = PIP(f_mono, f_poly)
model_energy = EnergyPIP(f_mono, f_poly)
# 2. Initialize parameters with dummy data
params = model_energy.init(key, X_tr[:1])
# 3. Run Training (Linear Solve)
# Returns the optimized weights 'w' as a flat array
w = training(model_pip, X_tr, y_tr)
# 4. Update Parameters
# The function flax_params() copies the parameters to the Pytree object used by Flax
params_opt = flax_params(w, params)
Energy and Forces
Given the flexibility of JAX, we can jointly compute the energy and the forces (gradients) efficiently. We provide the get_energy_and_forces utility for this purpose.
from molpipx.utils_gradients import get_energy_and_forces
# Define the model for inference
e_pip_model = EnergyPIP(f_mono, f_poly)
# Option 1: Predict Energy Only
y_pred = e_pip_model.apply(params_opt, X_val)
# Option 2: Predict Energy and Forces Jointly
# This computes the gradient of energy with respect to positions
y_pred, f_pred = get_energy_and_forces(e_pip_model.apply, X_val, params_opt)