Lifecycle & Multimetrics
The metrics framework in gfnx.metrics gives every metric the same lifecycle and
makes it easy to compose several metrics in parallel. All metric modules inherit
from BaseMetricsModule and operate on immutable MetricsState dataclasses that
are safe to jit, pmap, or store on device.
Lifecycle
init(rng_key, args): allocate a fresh state object (buffers, accumulators, cached ground-truth quantities);update(state, rng_key, args): fold new observations into the state during training or evaluation loops;process(state, rng_key, args): run any expensive post-processing before evaluation (often optional for streaming metrics);get(state): return a flatdict[str, Any]of metric values ready for logging.
Each module defines InitArgs, UpdateArgs, and ProcessArgs classes so the
expected payload is explicit and type-checkable.
Working with multiple metrics
MultiMetricsModule wraps a dictionary of metric modules and fans out every
lifecycle call. Results are namespaced as {metric_name}/{key} to avoid clashes.
Use it when you want a single object that drives several metrics on the same set
of trajectories.
metrics = gfnx.metrics.MultiMetricsModule(
{
"elbo": gfnx.metrics.ELBOMetricsModule(...),
"tv": gfnx.metrics.ApproxDistributionMetricsModule(...),
}
)
state = metrics.init(rng_key, metrics.InitArgs(metrics_args={}))
state = metrics.update(state, rng_key, metrics.UpdateArgs(metrics_args={}))
state = metrics.process(state, rng_key, metrics.ProcessArgs(metrics_args={}))
values = metrics.get(state)
# values -> {"elbo/elbo": ..., "tv/tv": ...}
Authoring a custom metric
- Subclass
BaseMetricsModuleand implement the four lifecycle methods; - Define inner
InitArgs,UpdateArgs, andProcessArgsdataclasses (inherit the empty shells fromgfnx.metricsif you have nothing to pass); - Use pure functions and JAX arrays inside your state so it is compatible with JIT and parallel transforms;
- Return plain Python scalars or JAX arrays from
get; higher-level tooling such as loggers or WandB adapters can handle the resulting dictionary directly.