generax provides implementations of flow based generative models. The library is built on top of Equinox which removes the need to worry about keeping track of model parameters.
key = random.PRNGKey(0) # JAX random key
x = ... # some data
# Create a flow model
model = NeuralSpline(input_shape=x.shape[1:],
n_flow_layers=3,
n_blocks=4,
hidden_size=32,
working_size=16,
n_spline_knots=8,
key=key)
# Data dependent initialization
model = model.data_dependent_init(x, key=key)
# Take multiple samples using vmap
keys = random.split(key, 1000)
samples = eqx.filter_vmap(model.sample)(keys)
# Compute the log probability of data
log_prob = eqx.filter_vmap(model.log_prob)(x)
There is also support for probability paths (time-dependent probability distributions) which can be used to train continuous normalizing flows with flow matching. See the examples on flow matching and multi-sample flow matching for more details.
generax is available on pip:
pip install generax
Generax provides an easy interface to train these models:
trainer = Trainer(checkpoint_path='tmp/model_path')
model = trainer.train(model=model, # Generax model
objective=my_objective, # Objective function
evaluate_model=tester, # Testing function
optimizer=optimizer, # Optax optimizer
num_steps=10000, # Number of training steps
data_iterator=train_ds, # Training data iterator
double_batch=1000, # Train these many batches in a scan loop
checkpoint_every=1000, # Checkpoint interval
test_every=1000, # Test interval
retrain=True) # Retrain from checkpoint
See the examples folder for more details.