Skip to content

Approximate Distribution Metrics

ApproxDistributionMetricsModule compares the terminal states distribution generated by recent policies with the ground truth distribution. The terminal states are stored in a buffer.

Intuition

  • Use this metric when the environment can enumerate its terminal space and provide a true distribution;
  • Prefer the exact-distribution metric if produced time overhead is acceptable;
  • For even larger or not enumerable environments, use ELBO, since distribution-based metrics will not be reasonable.
  • Accurately choose the size of the replay buffer to cover the policy’s support: too small yields a noisy metric, too large reacts slowly to policy updates;
  • Add "2d_marginal_distribution" to view a coarse heatmap.

Key parameters

  • metrics: List of metric names to compute, choose from {"tv", "kl", "jsd", "2d_marginal_distribution"}.
  • env: Enumerable environment for which to compute metrics.
  • buffer_size: Maximum number of states to store in the replay buffer for empirical distribution computation.

Quick start

Environment requirement: you must be able to enumerate or sample the terminal distribution and capture terminal states during rollouts (for example via gfnx.utils.forward_rollout) before updating the metric.

import jax
import gfnx

env = gfnx.HypergridEnvironment(reward_module=gfnx.EasyHypergridRewardModule())
env_params = env.init(jax.random.PRNGKey(0))

metrics = gfnx.metrics.ApproxDistributionMetricsModule(
    metrics=["tv", "kl", "jsd"],
    env=env,
    buffer_size=10_000,
)

state = metrics.init(jax.random.PRNGKey(1), metrics.InitArgs(env_params=env_params))

# During training: add every batch of terminal states you collect.
state = metrics.update(
    state,
    jax.random.PRNGKey(2),
    metrics.UpdateArgs(states=trajectory.final_env_state),  # terminal states from your rollout
)

# When you want a report: rebuild the empirical distribution and read the metrics.
state = metrics.process(
    state,
    jax.random.PRNGKey(3),
    metrics.ProcessArgs(env_params=env_params),
)
scores = metrics.get(state)
print(scores["tv"], scores["kl"], scores["jsd"])