Approximate Distribution Metrics
ApproxDistributionMetricsModule compares the terminal states distribution
generated by recent policies with the ground truth distribution. The terminal states are stored in a buffer.
Intuition
- Use this metric when the environment can enumerate its terminal space and provide a true distribution;
- Prefer the exact-distribution metric if produced time overhead is acceptable;
- For even larger or not enumerable environments, use ELBO, since distribution-based metrics will not be reasonable.
- Accurately choose the size of the replay buffer to cover the policy’s support: too small yields a noisy metric, too large reacts slowly to policy updates;
- Add "2d_marginal_distribution" to view a coarse heatmap.
Key parameters
metrics: List of metric names to compute, choose from{"tv", "kl", "jsd", "2d_marginal_distribution"}.env: Enumerable environment for which to compute metrics.buffer_size: Maximum number of states to store in the replay buffer for empirical distribution computation.
Quick start
Environment requirement: you must be able to enumerate or sample the terminal distribution and capture terminal states during rollouts (for example via
gfnx.utils.forward_rollout) before updating the metric.
import jax
import gfnx
env = gfnx.HypergridEnvironment(reward_module=gfnx.EasyHypergridRewardModule())
env_params = env.init(jax.random.PRNGKey(0))
metrics = gfnx.metrics.ApproxDistributionMetricsModule(
metrics=["tv", "kl", "jsd"],
env=env,
buffer_size=10_000,
)
state = metrics.init(jax.random.PRNGKey(1), metrics.InitArgs(env_params=env_params))
# During training: add every batch of terminal states you collect.
state = metrics.update(
state,
jax.random.PRNGKey(2),
metrics.UpdateArgs(states=trajectory.final_env_state), # terminal states from your rollout
)
# When you want a report: rebuild the empirical distribution and read the metrics.
state = metrics.process(
state,
jax.random.PRNGKey(3),
metrics.ProcessArgs(env_params=env_params),
)
scores = metrics.get(state)
print(scores["tv"], scores["kl"], scores["jsd"])