Skip to content

API Reference

AccumulatedModesMetricsModule

Bases: BaseMetricsModule

Metric module for tracking mode discovery in GFlowNet training.

This module monitors how well a GFlowNet policy discovers different modes by tracking which known modes are visited during training or evaluation. It maintains a set of reference modes and records which ones have been encountered based on a distance threshold.

Attributes:

Name Type Description
env

Environment instance for mode-related operations

distance_fn

Function to compute distance between two states

distance_threshold

Maximum distance for considering a state as visiting a mode

Source code in gfnx/metrics/modes.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class AccumulatedModesMetricsModule(BaseMetricsModule):
    """Metric module for tracking mode discovery in GFlowNet training.

    This module monitors how well a GFlowNet policy discovers different modes
    by tracking which known modes are visited during training or evaluation.
    It maintains a set of reference modes and records which ones have been
    encountered based on a distance threshold.

    Attributes:
        env: Environment instance for mode-related operations
        distance_fn: Function to compute distance between two states
        distance_threshold: Maximum distance for considering a state as visiting a mode
    """

    def __init__(
        self,
        env: TEnvironment,
        distance_fn: Callable[[TEnvState, TEnvState], float],
        distance_threshold: float = 0.1,
    ):
        """Initialize the accumulated modes metric module.

        Args:
            env: Environment instance for mode-related operations
            distance_fn: Function that computes distance between two environment states.
                Must return a scalar distance value.
            distance_threshold: Maximum distance threshold for considering a visited
                state as discovering a mode. States within this distance of a mode
                are considered to have "visited" that mode.
        """
        self.env = env
        self.distance_fn = distance_fn
        self.distance_threshold = distance_threshold

    def _get_distance_matrix(self, lhs_states: TEnvState, rhs_states: TEnvState) -> jnp.ndarray:
        """Compute distance matrix between two sets of states.

        Computes pairwise distances between all states in lhs_states and all states
        in rhs_states using the configured distance function. This is used to determine
        which visited states are close enough to known modes.

        Args:
            lhs_states: First set of states (typically visited states), N states
            rhs_states: Second set of states (typically reference modes), M states

        Returns:
            jnp.ndarray: Distance matrix of shape (N, M) where entry (i,j) contains
                the distance between lhs_states[i] and rhs_states[j]
        """
        result = jax.vmap(
            lambda lhs_state, rhs_states: jax.vmap(
                lambda rhs_state: self.distance_fn(lhs_state, rhs_state)
            )(rhs_states),
            in_axes=(0, None),
        )(lhs_states, rhs_states)
        chex.assert_shape(result, (lhs_states.is_pad.shape[0], rhs_states.is_pad.shape[0]))
        return result

    @chex.dataclass
    class InitArgs(BaseInitArgs):
        """Arguments for initializing the AccumulatedModesMetricsModule.

        Attributes:
            modes: Reference set of mode states to track during training/evaluation.
                These represent the target modes that the policy should discover.
        """

        modes: TEnvState

    def init(self, rng_key: chex.PRNGKey, args: InitArgs) -> AccumulatedModesMetricsState:
        """Initialize the accumulated modes metric state.

        Creates initial state with the provided reference modes and initializes
        the visited modes tracking array to all False (no modes visited yet).

        Args:
            rng_key: JAX PRNG key for any random initialization (currently unused)
            args: InitArgs object containing the reference modes to track

        Returns:
            AccumulatedModesMetricsState: Initialized state with reference modes
                and empty visited modes tracking array
        """
        return AccumulatedModesMetricsState(
            modes=args.modes,
            visited_modes_idx=jnp.zeros((args.modes.is_pad.shape[0],), dtype=jnp.bool),
        )

    @chex.dataclass
    class UpdateArgs(BaseUpdateArgs):
        """Arguments for updating the AccumulatedModesMetricsModule.

        Attributes:
            states: Environment states visited during training/evaluation to check
                against the reference modes for mode discovery tracking.
        """

        states: TEnvState

    def update(
        self, metrics_state: AccumulatedModesMetricsState, rng_key: chex.PRNGKey, args: UpdateArgs
    ) -> AccumulatedModesMetricsState:
        """Update the metric state with newly visited states.

        Checks which reference modes have been visited by computing distances between
        the provided states and all reference modes. A mode is considered "visited"
        if any of the provided states is within the distance threshold of that mode.

        Args:
            metrics_state: Current metric state containing modes and visited tracking
            rng_key: JAX PRNG key for any random operations (currently unused)
            args: UpdateArgs object containing the newly visited states

        Returns:
            AccumulatedModesMetricsState: Updated state with updated visited modes tracking.
                The visited_modes_idx array is updated to mark newly discovered modes.
        """
        # Compute distance matrix between current states and modes
        num_modes = metrics_state.visited_modes_idx.shape[0]
        d_matrix = self._get_distance_matrix(args.states, metrics_state.modes)
        mode_passed = jnp.any(d_matrix < self.distance_threshold, axis=0)
        chex.assert_shape(mode_passed, (num_modes,))
        visited_modes_idx = jnp.logical_or(metrics_state.visited_modes_idx, mode_passed)
        return metrics_state.replace(visited_modes_idx=visited_modes_idx)

    @chex.dataclass
    class ProcessArgs(BaseProcessArgs):
        """Arguments for processing the AccumulatedModesMetricsModule.

        Attributes:
            env_params: Environment parameters (currently unused but maintained
                for interface consistency with other metric modules).
        """

        env_params: Any

    def process(
        self,
        metrics_state: AccumulatedModesMetricsState,
        rng_key: chex.PRNGKey,
        args: ProcessArgs | None = None,
    ) -> AccumulatedModesMetricsState:
        """Process the metric state for final computation (no-op for modes metrics).

        This method performs any final processing needed before metric computation.
        For accumulated modes metrics, no additional processing is required as the
        state is maintained incrementally during updates.

        Args:
            metrics_state: Current metric state with accumulated mode visit information
            rng_key: JAX PRNG key for any random operations (currently unused)
            args: ProcessArgs object containing environment parameters (currently unused)

        Returns:
            AccumulatedModesMetricsState: Unchanged metric state ready for get() call
        """
        return metrics_state

    def get(self, metrics_state: AccumulatedModesMetricsState) -> dict:
        """Get the computed mode discovery metrics from the current state.

        Computes and returns metrics that quantify how well the policy has discovered
        the reference modes. Provides both absolute and relative measures of mode coverage.

        Args:
            metrics_state: Current metric state containing mode visit tracking information

        Returns:
            Dict[str, Any]: Dictionary containing computed mode discovery metrics:
                - 'num_modes': Total number of modes discovered (integer count)
                - 'percent_modes': Fraction of modes discovered (float between 0 and 1)
        """
        return {
            "num_modes": metrics_state.visited_modes_idx.sum(),
            "percent_modes": metrics_state.visited_modes_idx.mean(),
        }

InitArgs

Bases: BaseInitArgs

Arguments for initializing the AccumulatedModesMetricsModule.

Attributes:

Name Type Description
modes TEnvState

Reference set of mode states to track during training/evaluation. These represent the target modes that the policy should discover.

Source code in gfnx/metrics/modes.py
89
90
91
92
93
94
95
96
97
98
@chex.dataclass
class InitArgs(BaseInitArgs):
    """Arguments for initializing the AccumulatedModesMetricsModule.

    Attributes:
        modes: Reference set of mode states to track during training/evaluation.
            These represent the target modes that the policy should discover.
    """

    modes: TEnvState

ProcessArgs

Bases: BaseProcessArgs

Arguments for processing the AccumulatedModesMetricsModule.

Attributes:

Name Type Description
env_params Any

Environment parameters (currently unused but maintained for interface consistency with other metric modules).

Source code in gfnx/metrics/modes.py
156
157
158
159
160
161
162
163
164
165
@chex.dataclass
class ProcessArgs(BaseProcessArgs):
    """Arguments for processing the AccumulatedModesMetricsModule.

    Attributes:
        env_params: Environment parameters (currently unused but maintained
            for interface consistency with other metric modules).
    """

    env_params: Any

UpdateArgs

Bases: BaseUpdateArgs

Arguments for updating the AccumulatedModesMetricsModule.

Attributes:

Name Type Description
states TEnvState

Environment states visited during training/evaluation to check against the reference modes for mode discovery tracking.

Source code in gfnx/metrics/modes.py
119
120
121
122
123
124
125
126
127
128
@chex.dataclass
class UpdateArgs(BaseUpdateArgs):
    """Arguments for updating the AccumulatedModesMetricsModule.

    Attributes:
        states: Environment states visited during training/evaluation to check
            against the reference modes for mode discovery tracking.
    """

    states: TEnvState

__init__(env, distance_fn, distance_threshold=0.1)

Initialize the accumulated modes metric module.

Parameters:

Name Type Description Default
env TEnvironment

Environment instance for mode-related operations

required
distance_fn Callable[[TEnvState, TEnvState], float]

Function that computes distance between two environment states. Must return a scalar distance value.

required
distance_threshold float

Maximum distance threshold for considering a visited state as discovering a mode. States within this distance of a mode are considered to have "visited" that mode.

0.1
Source code in gfnx/metrics/modes.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(
    self,
    env: TEnvironment,
    distance_fn: Callable[[TEnvState, TEnvState], float],
    distance_threshold: float = 0.1,
):
    """Initialize the accumulated modes metric module.

    Args:
        env: Environment instance for mode-related operations
        distance_fn: Function that computes distance between two environment states.
            Must return a scalar distance value.
        distance_threshold: Maximum distance threshold for considering a visited
            state as discovering a mode. States within this distance of a mode
            are considered to have "visited" that mode.
    """
    self.env = env
    self.distance_fn = distance_fn
    self.distance_threshold = distance_threshold

get(metrics_state)

Get the computed mode discovery metrics from the current state.

Computes and returns metrics that quantify how well the policy has discovered the reference modes. Provides both absolute and relative measures of mode coverage.

Parameters:

Name Type Description Default
metrics_state AccumulatedModesMetricsState

Current metric state containing mode visit tracking information

required

Returns:

Type Description
dict

Dict[str, Any]: Dictionary containing computed mode discovery metrics: - 'num_modes': Total number of modes discovered (integer count) - 'percent_modes': Fraction of modes discovered (float between 0 and 1)

Source code in gfnx/metrics/modes.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def get(self, metrics_state: AccumulatedModesMetricsState) -> dict:
    """Get the computed mode discovery metrics from the current state.

    Computes and returns metrics that quantify how well the policy has discovered
    the reference modes. Provides both absolute and relative measures of mode coverage.

    Args:
        metrics_state: Current metric state containing mode visit tracking information

    Returns:
        Dict[str, Any]: Dictionary containing computed mode discovery metrics:
            - 'num_modes': Total number of modes discovered (integer count)
            - 'percent_modes': Fraction of modes discovered (float between 0 and 1)
    """
    return {
        "num_modes": metrics_state.visited_modes_idx.sum(),
        "percent_modes": metrics_state.visited_modes_idx.mean(),
    }

init(rng_key, args)

Initialize the accumulated modes metric state.

Creates initial state with the provided reference modes and initializes the visited modes tracking array to all False (no modes visited yet).

Parameters:

Name Type Description Default
rng_key PRNGKey

JAX PRNG key for any random initialization (currently unused)

required
args InitArgs

InitArgs object containing the reference modes to track

required

Returns:

Name Type Description
AccumulatedModesMetricsState AccumulatedModesMetricsState

Initialized state with reference modes and empty visited modes tracking array

Source code in gfnx/metrics/modes.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def init(self, rng_key: chex.PRNGKey, args: InitArgs) -> AccumulatedModesMetricsState:
    """Initialize the accumulated modes metric state.

    Creates initial state with the provided reference modes and initializes
    the visited modes tracking array to all False (no modes visited yet).

    Args:
        rng_key: JAX PRNG key for any random initialization (currently unused)
        args: InitArgs object containing the reference modes to track

    Returns:
        AccumulatedModesMetricsState: Initialized state with reference modes
            and empty visited modes tracking array
    """
    return AccumulatedModesMetricsState(
        modes=args.modes,
        visited_modes_idx=jnp.zeros((args.modes.is_pad.shape[0],), dtype=jnp.bool),
    )

process(metrics_state, rng_key, args=None)

Process the metric state for final computation (no-op for modes metrics).

This method performs any final processing needed before metric computation. For accumulated modes metrics, no additional processing is required as the state is maintained incrementally during updates.

Parameters:

Name Type Description Default
metrics_state AccumulatedModesMetricsState

Current metric state with accumulated mode visit information

required
rng_key PRNGKey

JAX PRNG key for any random operations (currently unused)

required
args ProcessArgs | None

ProcessArgs object containing environment parameters (currently unused)

None

Returns:

Name Type Description
AccumulatedModesMetricsState AccumulatedModesMetricsState

Unchanged metric state ready for get() call

Source code in gfnx/metrics/modes.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def process(
    self,
    metrics_state: AccumulatedModesMetricsState,
    rng_key: chex.PRNGKey,
    args: ProcessArgs | None = None,
) -> AccumulatedModesMetricsState:
    """Process the metric state for final computation (no-op for modes metrics).

    This method performs any final processing needed before metric computation.
    For accumulated modes metrics, no additional processing is required as the
    state is maintained incrementally during updates.

    Args:
        metrics_state: Current metric state with accumulated mode visit information
        rng_key: JAX PRNG key for any random operations (currently unused)
        args: ProcessArgs object containing environment parameters (currently unused)

    Returns:
        AccumulatedModesMetricsState: Unchanged metric state ready for get() call
    """
    return metrics_state

update(metrics_state, rng_key, args)

Update the metric state with newly visited states.

Checks which reference modes have been visited by computing distances between the provided states and all reference modes. A mode is considered "visited" if any of the provided states is within the distance threshold of that mode.

Parameters:

Name Type Description Default
metrics_state AccumulatedModesMetricsState

Current metric state containing modes and visited tracking

required
rng_key PRNGKey

JAX PRNG key for any random operations (currently unused)

required
args UpdateArgs

UpdateArgs object containing the newly visited states

required

Returns:

Name Type Description
AccumulatedModesMetricsState AccumulatedModesMetricsState

Updated state with updated visited modes tracking. The visited_modes_idx array is updated to mark newly discovered modes.

Source code in gfnx/metrics/modes.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def update(
    self, metrics_state: AccumulatedModesMetricsState, rng_key: chex.PRNGKey, args: UpdateArgs
) -> AccumulatedModesMetricsState:
    """Update the metric state with newly visited states.

    Checks which reference modes have been visited by computing distances between
    the provided states and all reference modes. A mode is considered "visited"
    if any of the provided states is within the distance threshold of that mode.

    Args:
        metrics_state: Current metric state containing modes and visited tracking
        rng_key: JAX PRNG key for any random operations (currently unused)
        args: UpdateArgs object containing the newly visited states

    Returns:
        AccumulatedModesMetricsState: Updated state with updated visited modes tracking.
            The visited_modes_idx array is updated to mark newly discovered modes.
    """
    # Compute distance matrix between current states and modes
    num_modes = metrics_state.visited_modes_idx.shape[0]
    d_matrix = self._get_distance_matrix(args.states, metrics_state.modes)
    mode_passed = jnp.any(d_matrix < self.distance_threshold, axis=0)
    chex.assert_shape(mode_passed, (num_modes,))
    visited_modes_idx = jnp.logical_or(metrics_state.visited_modes_idx, mode_passed)
    return metrics_state.replace(visited_modes_idx=visited_modes_idx)

AccumulatedModesMetricsState

Bases: MetricsState

State for accumulating mode discovery metrics.

This state container tracks which modes have been visited during training or evaluation by maintaining a record of known modes and a boolean array indicating which modes have been encountered.

Attributes:

Name Type Description
modes TEnvState

Reference set of mode states to track. These represent the target modes that the policy should discover.

visited_modes_idx Array

Boolean array indicating which modes have been visited at least once. Shape: (num_modes,)

Source code in gfnx/metrics/modes.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
@chex.dataclass
class AccumulatedModesMetricsState(MetricsState):
    """State for accumulating mode discovery metrics.

    This state container tracks which modes have been visited during training
    or evaluation by maintaining a record of known modes and a boolean array
    indicating which modes have been encountered.

    Attributes:
        modes: Reference set of mode states to track. These represent the
            target modes that the policy should discover.
        visited_modes_idx: Boolean array indicating which modes have been
            visited at least once. Shape: (num_modes,)
    """

    modes: TEnvState
    visited_modes_idx: chex.Array

ApproxDistributionMetricsModule

Bases: BaseMetricsModule

Distribution-based metrics module for enumerable environments.

This metric module computes distribution-based metrics by comparing the true distribution of an enumerable environment with an empirical distribution derived from collected samples. It supports various distance metrics between distributions such as KL divergence and total variation distance.

The module maintains a replay buffer to accumulate environment states and computes empirical distributions from these samples. It can only be used with enumerable environments that provide access to their true distribution.

Supported metrics
  • "tv": Total variation distance between distributions
  • "kl": KL divergence from empirical to true distribution
  • "jsd": Jensen-Shannon divergence from empirical to true distribution
  • "2d_marginal_distribution": Marginal distribution computation

Attributes:

Name Type Description
metrics

List of metric names to compute, choose from {"tv", "kl", "jsd", "2d_marginal_distribution"}.

env

Enumerable environment for which to compute metrics.

buffer_size

Maximum number of states to store in the replay buffer.

Source code in gfnx/metrics/approx_distribution.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
class ApproxDistributionMetricsModule(BaseMetricsModule):
    """Distribution-based metrics module for enumerable environments.

    This metric module computes distribution-based metrics by comparing the true
    distribution of an enumerable environment with an empirical distribution
    derived from collected samples. It supports various distance metrics between
    distributions such as KL divergence and total variation distance.

    The module maintains a replay buffer to accumulate environment states and
    computes empirical distributions from these samples. It can only be used
    with enumerable environments that provide access to their true distribution.

    Supported metrics:
        - "tv": Total variation distance between distributions
        - "kl": KL divergence from empirical to true distribution
        - "jsd": Jensen-Shannon divergence from empirical to true distribution
        - "2d_marginal_distribution": Marginal distribution computation

    Attributes:
        metrics: List of metric names to compute, 
            choose from {"tv", "kl", "jsd", "2d_marginal_distribution"}.
        env: Enumerable environment for which to compute metrics.
        buffer_size: Maximum number of states to store in the replay buffer.
    """

    _supported_metrics: ClassVar[dict[str, Callable[..., chex.Array]]] = {
        "tv": total_variation_distance,
        "kl": kl_divergence,
        "jsd": jensen_shannon_divergence,
        "2d_marginal_distribution": marginal_distribution,
    }

    def __init__(
        self,
        metrics: list[str],
        env: TEnvironment,
        buffer_size: int = 1000,
    ):
        """Initialize the distribution metrics module.

        Sets up the metric module with specified metrics, environment, and buffer size.
        Validates that the environment is enumerable and that all requested metrics
        are supported.

        Args:
            metrics: List of metric names to compute. Must be subset of supported metrics.
                    Supported options: ["tv", "kl", "jsd", "marginal_distribution"]
            env: Environment instance for which to compute metrics. Must be enumerable
                (i.e., env.is_enumerable must be True)
            buffer_size: Maximum number of states to store in the replay buffer for
                       empirical distribution computation. Must be positive integer.

        Raises:
            ValueError: If environment is not enumerable, buffer_size is not positive,
                       metrics is not a list of strings, or contains unsupported metrics
        """
        if not env.is_enumerable:
            raise ValueError(f"Environment {env.name} is not enumerable")
        if not isinstance(buffer_size, int) or buffer_size <= 0:
            raise ValueError("buffer_size must be a positive integer")
        if not isinstance(metrics, list) or not all(isinstance(m, str) for m in metrics):
            raise ValueError("metrics must be a list of strings")
        if not all(m in self._supported_metrics for m in metrics):
            raise ValueError(
                f"Unsupported metrics. Supported metrics are: \
                    {self._supported_metrics}"
            )

        self.metrics = metrics
        self.env = env

        self.buffer_module = fbx.make_item_buffer(
            max_length=buffer_size,
            min_length=1,
            sample_batch_size=1,
            add_batches=True,
        )

    @chex.dataclass
    class InitArgs(BaseInitArgs):
        """Arguments for initializing the ApproxDistributionMetricsModule.

        Attributes:
            env_params: Environment parameters needed to obtain the true distribution
                and initialize environment states for the replay buffer.
        """

        env_params: TEnvParams

    def init(self, rng_key: chex.PRNGKey, args: InitArgs) -> ApproxDistributionMetricsState:
        """Initialize the metric state for distribution metrics.

        Creates and initializes all components needed for distribution metric computation:
        the true distribution from the environment, an initial uniform empirical distribution,
        and a replay buffer for storing environment states.

        Args:
            rng_key: JAX PRNG key for any random initialization (currently unused)
            args: InitArgs object containing environment parameters

        Returns:
            ApproxDistributionMetricsState: Initialized state containing:
                - true_distribution: Ground truth distribution from the environment
                - empirical_distribution: Initial uniform distribution
                - replay_buffer: Empty buffer ready to collect states
        """
        _, fake_state = self.env.reset(1, args.env_params)
        fake_single_state = jax.tree.map(lambda x: x[0], fake_state)
        # Replay buffer takes only shapes, but not values
        replay_buffer_state = self.buffer_module.init(fake_single_state)
        true_distribution = self.env.get_true_distribution(args.env_params)
        return ApproxDistributionMetricsState(
            true_distribution=true_distribution,
            empirical_distribution=jnp.ones_like(true_distribution) / true_distribution.size,
            replay_buffer=replay_buffer_state,
        )

    @chex.dataclass
    class UpdateArgs(BaseUpdateArgs):
        """Arguments for updating the ApproxDistributionMetricsModule.

        Attributes:
            states: Environment states to add to the replay buffer for empirical
                distribution computation. These represent states visited during
                training/evaluation episodes.
        """

        states: TEnvState

    def update(
        self,
        metrics_state: ApproxDistributionMetricsState,
        rng_key: chex.PRNGKey,
        args: UpdateArgs,
    ) -> ApproxDistributionMetricsState:
        """Update the metric state with new environment states.

        Adds new environment states to the replay buffer for empirical distribution
        computation.

        Args:
            metrics_state: Current metric state containing the replay buffer
            rng_key: JAX PRNG key for any random operations (currently unused)
            args: UpdateArgs object containing environment states to add

        Returns:
            ApproxDistributionMetricsState: Updated state with new states added to the buffer.
                                         The true_distribution and empirical_distribution
                                         remain unchanged until process() is called.
        """
        updated_buffer = self.buffer_module.add(metrics_state.replay_buffer, args.states)
        return metrics_state.replace(replay_buffer=updated_buffer)

    @chex.dataclass
    class ProcessArgs(BaseProcessArgs):
        """Arguments for processing the ApproxDistributionMetricsModule.

        Attributes:
            env_params: Environment parameters needed for empirical distribution
                computation from the collected states in the replay buffer.
        """

        env_params: TEnvParams

    def process(
        self,
        metrics_state: ApproxDistributionMetricsState,
        rng_key: chex.PRNGKey,
        args: ProcessArgs,
    ) -> ApproxDistributionMetricsState:
        """Process the metric state to compute the empirical distribution.

        This method computes the empirical distribution from the states stored in the
        replay buffer. It calls the environment's `get_empirical_distribution` method
        to convert the collected states into a probability distribution that can be
        compared with the true distribution.

        Args:
            metrics_state: Current metric state containing the replay buffer with collected states
            rng_key: JAX PRNG key for any random operations (currently unused)
            args: ProcessArgs object containing environment parameters

        Returns:
            ApproxDistributionMetricsState: Updated state with the computed empirical distribution.
                                         This state is now ready for metric computation via get().
        """
        # Here you would compute the actual metrics based on the true distribution
        # and the empirical distribution from the replay buffer.
        empirical_distribution = self.env.get_empirical_distribution(
            jax.tree.map(lambda x: x[0], metrics_state.replay_buffer.experience), args.env_params
        )
        return metrics_state.replace(empirical_distribution=empirical_distribution)

    def get(self, metrics_state: ApproxDistributionMetricsState) -> dict[str, Any]:
        """Get the computed distribution metrics from the processed state.

        Computes and returns the requested distribution metrics by comparing the
        true distribution with the empirical distribution. Each metric is computed
        using the corresponding function from the _supported_metrics dictionary.

        Args:
            metrics_state: Processed metric state containing both true and empirical distributions

        Returns:
            Dict[str, Any]: Dictionary containing the computed metrics. Keys are the metric
                          names specified during initialization, and values are the computed
                          metric values (typically float scalars).

        Example:
            If initialized with metrics=["tv", "kl", "jsd"], might return:
            {"tv": 0.123, "kl": 0.456}
        """
        # Here you would return the actual computed metrics
        results = {}
        for metric in self.metrics:
            results[metric] = self._supported_metrics[metric](
                metrics_state.true_distribution, metrics_state.empirical_distribution
            )
        return results

InitArgs

Bases: BaseInitArgs

Arguments for initializing the ApproxDistributionMetricsModule.

Attributes:

Name Type Description
env_params TEnvParams

Environment parameters needed to obtain the true distribution and initialize environment states for the replay buffer.

Source code in gfnx/metrics/approx_distribution.py
138
139
140
141
142
143
144
145
146
147
@chex.dataclass
class InitArgs(BaseInitArgs):
    """Arguments for initializing the ApproxDistributionMetricsModule.

    Attributes:
        env_params: Environment parameters needed to obtain the true distribution
            and initialize environment states for the replay buffer.
    """

    env_params: TEnvParams

ProcessArgs

Bases: BaseProcessArgs

Arguments for processing the ApproxDistributionMetricsModule.

Attributes:

Name Type Description
env_params TEnvParams

Environment parameters needed for empirical distribution computation from the collected states in the replay buffer.

Source code in gfnx/metrics/approx_distribution.py
213
214
215
216
217
218
219
220
221
222
@chex.dataclass
class ProcessArgs(BaseProcessArgs):
    """Arguments for processing the ApproxDistributionMetricsModule.

    Attributes:
        env_params: Environment parameters needed for empirical distribution
            computation from the collected states in the replay buffer.
    """

    env_params: TEnvParams

UpdateArgs

Bases: BaseUpdateArgs

Arguments for updating the ApproxDistributionMetricsModule.

Attributes:

Name Type Description
states TEnvState

Environment states to add to the replay buffer for empirical distribution computation. These represent states visited during training/evaluation episodes.

Source code in gfnx/metrics/approx_distribution.py
177
178
179
180
181
182
183
184
185
186
187
@chex.dataclass
class UpdateArgs(BaseUpdateArgs):
    """Arguments for updating the ApproxDistributionMetricsModule.

    Attributes:
        states: Environment states to add to the replay buffer for empirical
            distribution computation. These represent states visited during
            training/evaluation episodes.
    """

    states: TEnvState

__init__(metrics, env, buffer_size=1000)

Initialize the distribution metrics module.

Sets up the metric module with specified metrics, environment, and buffer size. Validates that the environment is enumerable and that all requested metrics are supported.

Parameters:

Name Type Description Default
metrics list[str]

List of metric names to compute. Must be subset of supported metrics. Supported options: ["tv", "kl", "jsd", "marginal_distribution"]

required
env TEnvironment

Environment instance for which to compute metrics. Must be enumerable (i.e., env.is_enumerable must be True)

required
buffer_size int

Maximum number of states to store in the replay buffer for empirical distribution computation. Must be positive integer.

1000

Raises:

Type Description
ValueError

If environment is not enumerable, buffer_size is not positive, metrics is not a list of strings, or contains unsupported metrics

Source code in gfnx/metrics/approx_distribution.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def __init__(
    self,
    metrics: list[str],
    env: TEnvironment,
    buffer_size: int = 1000,
):
    """Initialize the distribution metrics module.

    Sets up the metric module with specified metrics, environment, and buffer size.
    Validates that the environment is enumerable and that all requested metrics
    are supported.

    Args:
        metrics: List of metric names to compute. Must be subset of supported metrics.
                Supported options: ["tv", "kl", "jsd", "marginal_distribution"]
        env: Environment instance for which to compute metrics. Must be enumerable
            (i.e., env.is_enumerable must be True)
        buffer_size: Maximum number of states to store in the replay buffer for
                   empirical distribution computation. Must be positive integer.

    Raises:
        ValueError: If environment is not enumerable, buffer_size is not positive,
                   metrics is not a list of strings, or contains unsupported metrics
    """
    if not env.is_enumerable:
        raise ValueError(f"Environment {env.name} is not enumerable")
    if not isinstance(buffer_size, int) or buffer_size <= 0:
        raise ValueError("buffer_size must be a positive integer")
    if not isinstance(metrics, list) or not all(isinstance(m, str) for m in metrics):
        raise ValueError("metrics must be a list of strings")
    if not all(m in self._supported_metrics for m in metrics):
        raise ValueError(
            f"Unsupported metrics. Supported metrics are: \
                {self._supported_metrics}"
        )

    self.metrics = metrics
    self.env = env

    self.buffer_module = fbx.make_item_buffer(
        max_length=buffer_size,
        min_length=1,
        sample_batch_size=1,
        add_batches=True,
    )

get(metrics_state)

Get the computed distribution metrics from the processed state.

Computes and returns the requested distribution metrics by comparing the true distribution with the empirical distribution. Each metric is computed using the corresponding function from the _supported_metrics dictionary.

Parameters:

Name Type Description Default
metrics_state ApproxDistributionMetricsState

Processed metric state containing both true and empirical distributions

required

Returns:

Type Description
dict[str, Any]

Dict[str, Any]: Dictionary containing the computed metrics. Keys are the metric names specified during initialization, and values are the computed metric values (typically float scalars).

Example

If initialized with metrics=["tv", "kl", "jsd"], might return: {"tv": 0.123, "kl": 0.456}

Source code in gfnx/metrics/approx_distribution.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def get(self, metrics_state: ApproxDistributionMetricsState) -> dict[str, Any]:
    """Get the computed distribution metrics from the processed state.

    Computes and returns the requested distribution metrics by comparing the
    true distribution with the empirical distribution. Each metric is computed
    using the corresponding function from the _supported_metrics dictionary.

    Args:
        metrics_state: Processed metric state containing both true and empirical distributions

    Returns:
        Dict[str, Any]: Dictionary containing the computed metrics. Keys are the metric
                      names specified during initialization, and values are the computed
                      metric values (typically float scalars).

    Example:
        If initialized with metrics=["tv", "kl", "jsd"], might return:
        {"tv": 0.123, "kl": 0.456}
    """
    # Here you would return the actual computed metrics
    results = {}
    for metric in self.metrics:
        results[metric] = self._supported_metrics[metric](
            metrics_state.true_distribution, metrics_state.empirical_distribution
        )
    return results

init(rng_key, args)

Initialize the metric state for distribution metrics.

Creates and initializes all components needed for distribution metric computation: the true distribution from the environment, an initial uniform empirical distribution, and a replay buffer for storing environment states.

Parameters:

Name Type Description Default
rng_key PRNGKey

JAX PRNG key for any random initialization (currently unused)

required
args InitArgs

InitArgs object containing environment parameters

required

Returns:

Name Type Description
ApproxDistributionMetricsState ApproxDistributionMetricsState

Initialized state containing: - true_distribution: Ground truth distribution from the environment - empirical_distribution: Initial uniform distribution - replay_buffer: Empty buffer ready to collect states

Source code in gfnx/metrics/approx_distribution.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def init(self, rng_key: chex.PRNGKey, args: InitArgs) -> ApproxDistributionMetricsState:
    """Initialize the metric state for distribution metrics.

    Creates and initializes all components needed for distribution metric computation:
    the true distribution from the environment, an initial uniform empirical distribution,
    and a replay buffer for storing environment states.

    Args:
        rng_key: JAX PRNG key for any random initialization (currently unused)
        args: InitArgs object containing environment parameters

    Returns:
        ApproxDistributionMetricsState: Initialized state containing:
            - true_distribution: Ground truth distribution from the environment
            - empirical_distribution: Initial uniform distribution
            - replay_buffer: Empty buffer ready to collect states
    """
    _, fake_state = self.env.reset(1, args.env_params)
    fake_single_state = jax.tree.map(lambda x: x[0], fake_state)
    # Replay buffer takes only shapes, but not values
    replay_buffer_state = self.buffer_module.init(fake_single_state)
    true_distribution = self.env.get_true_distribution(args.env_params)
    return ApproxDistributionMetricsState(
        true_distribution=true_distribution,
        empirical_distribution=jnp.ones_like(true_distribution) / true_distribution.size,
        replay_buffer=replay_buffer_state,
    )

process(metrics_state, rng_key, args)

Process the metric state to compute the empirical distribution.

This method computes the empirical distribution from the states stored in the replay buffer. It calls the environment's get_empirical_distribution method to convert the collected states into a probability distribution that can be compared with the true distribution.

Parameters:

Name Type Description Default
metrics_state ApproxDistributionMetricsState

Current metric state containing the replay buffer with collected states

required
rng_key PRNGKey

JAX PRNG key for any random operations (currently unused)

required
args ProcessArgs

ProcessArgs object containing environment parameters

required

Returns:

Name Type Description
ApproxDistributionMetricsState ApproxDistributionMetricsState

Updated state with the computed empirical distribution. This state is now ready for metric computation via get().

Source code in gfnx/metrics/approx_distribution.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def process(
    self,
    metrics_state: ApproxDistributionMetricsState,
    rng_key: chex.PRNGKey,
    args: ProcessArgs,
) -> ApproxDistributionMetricsState:
    """Process the metric state to compute the empirical distribution.

    This method computes the empirical distribution from the states stored in the
    replay buffer. It calls the environment's `get_empirical_distribution` method
    to convert the collected states into a probability distribution that can be
    compared with the true distribution.

    Args:
        metrics_state: Current metric state containing the replay buffer with collected states
        rng_key: JAX PRNG key for any random operations (currently unused)
        args: ProcessArgs object containing environment parameters

    Returns:
        ApproxDistributionMetricsState: Updated state with the computed empirical distribution.
                                     This state is now ready for metric computation via get().
    """
    # Here you would compute the actual metrics based on the true distribution
    # and the empirical distribution from the replay buffer.
    empirical_distribution = self.env.get_empirical_distribution(
        jax.tree.map(lambda x: x[0], metrics_state.replay_buffer.experience), args.env_params
    )
    return metrics_state.replace(empirical_distribution=empirical_distribution)

update(metrics_state, rng_key, args)

Update the metric state with new environment states.

Adds new environment states to the replay buffer for empirical distribution computation.

Parameters:

Name Type Description Default
metrics_state ApproxDistributionMetricsState

Current metric state containing the replay buffer

required
rng_key PRNGKey

JAX PRNG key for any random operations (currently unused)

required
args UpdateArgs

UpdateArgs object containing environment states to add

required

Returns:

Name Type Description
ApproxDistributionMetricsState ApproxDistributionMetricsState

Updated state with new states added to the buffer. The true_distribution and empirical_distribution remain unchanged until process() is called.

Source code in gfnx/metrics/approx_distribution.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def update(
    self,
    metrics_state: ApproxDistributionMetricsState,
    rng_key: chex.PRNGKey,
    args: UpdateArgs,
) -> ApproxDistributionMetricsState:
    """Update the metric state with new environment states.

    Adds new environment states to the replay buffer for empirical distribution
    computation.

    Args:
        metrics_state: Current metric state containing the replay buffer
        rng_key: JAX PRNG key for any random operations (currently unused)
        args: UpdateArgs object containing environment states to add

    Returns:
        ApproxDistributionMetricsState: Updated state with new states added to the buffer.
                                     The true_distribution and empirical_distribution
                                     remain unchanged until process() is called.
    """
    updated_buffer = self.buffer_module.add(metrics_state.replay_buffer, args.states)
    return metrics_state.replace(replay_buffer=updated_buffer)

ApproxDistributionMetricsState

Bases: MetricsState

State for approximate distribution-based metrics.

This state container holds the data necessary for computing distribution-based metrics such as KL divergence and total variation distance. It stores both the true distribution from the environment and the empirical distribution computed from collected samples, along with a replay buffer for state storage.

Attributes:

Name Type Description
true_distribution Array

The true distribution of the environment (ground truth)

empirical_distribution Array

The empirical distribution computed from collected samples

replay_buffer ArrayTree

Buffer storing environment states for distribution computation

Source code in gfnx/metrics/approx_distribution.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
@chex.dataclass
class ApproxDistributionMetricsState(MetricsState):
    """State for approximate distribution-based metrics.

    This state container holds the data necessary for computing distribution-based
    metrics such as KL divergence and total variation distance. It stores both the
    true distribution from the environment and the empirical distribution computed
    from collected samples, along with a replay buffer for state storage.

    Attributes:
        true_distribution: The true distribution of the environment (ground truth)
        empirical_distribution: The empirical distribution computed from collected samples
        replay_buffer: Buffer storing environment states for distribution computation
    """

    true_distribution: chex.Array
    empirical_distribution: chex.Array
    replay_buffer: chex.ArrayTree

BaseCorrelationMetricsModule

Bases: BaseMetricsModule

Abstract base class for correlation-based GFlowNet evaluation metrics.

This class provides common functionality for computing correlation metrics between transformed model-predicted marginal probabilities and transformed log true rewards. It implements the core logic for backward trajectory sampling and log-ratio computation that is shared across different correlation metric variants.

The correlation metrics evaluate how well the learned GFlowNet policy captures the true reward distribution by computing correlations between: - Transformed log-ratios of backward trajectories (model predictions) - Transformed log-rewards of terminal states (ground truth)

Attributes:

Name Type Description
env

Environment instance for trajectory generation and evaluation

bwd_policy_fn

Backward policy function for computing trajectory ratios

n_rounds

Number of sampling rounds for statistical stability

batch_size

Batch size for processing terminal states during evaluation

transform_fn

Function to marginalize distributions from terminal states to an arbitrary domain size for correlation computation

Source code in gfnx/metrics/correlation.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
class BaseCorrelationMetricsModule(BaseMetricsModule):
    """Abstract base class for correlation-based GFlowNet evaluation metrics.

    This class provides common functionality for computing correlation metrics between
    transformed model-predicted marginal probabilities and transformed log true rewards.
    It implements the core logic for backward trajectory sampling and log-ratio computation
    that is shared across different correlation metric variants.

    The correlation metrics evaluate how well the learned GFlowNet policy captures
    the true reward distribution by computing correlations between:
    - Transformed log-ratios of backward trajectories (model predictions)
    - Transformed log-rewards of terminal states (ground truth)

    Attributes:
        env: Environment instance for trajectory generation and evaluation
        bwd_policy_fn: Backward policy function for computing trajectory ratios
        n_rounds: Number of sampling rounds for statistical stability
        batch_size: Batch size for processing terminal states during evaluation
        transform_fn: Function to marginalize distributions from terminal states
            to an arbitrary domain size for correlation computation
    """

    def __init__(
        self,
        env: TEnvironment,
        bwd_policy_fn: TPolicyFn,
        n_rounds: int,
        batch_size: int = 1,
        transform_fn: Callable[[TEnvState, jnp.ndarray], jnp.ndarray] | None = None,
    ):
        """Initialize the correlation metric module.

        Args:
            env: Environment for trajectory generation and reward computation
            bwd_policy_fn: Backward policy function for trajectory probability estimation
            n_rounds: Number of sampling rounds for averaging log-ratios
            batch_size: Batch size for processing terminal states
            transform_fn: Function to marginalize distributions from terminal states
                to an arbitrary domain size for correlation computation
        """
        self.env = env
        self.bwd_policy_fn = bwd_policy_fn
        self.n_rounds = n_rounds
        self.batch_size = batch_size
        self.transform_fn = transform_fn if transform_fn is not None else lambda x, y: y

    # Ensure the module has a consistent interface
    UpdateArgs = EmptyUpdateArgs

    def update(
        self,
        metrics_state: CorrelationMetricsState,
        rng_key: chex.PRNGKey,
        args: UpdateArgs | None = None,
    ) -> CorrelationMetricsState:
        """Update metric state with new data (no-op for correlation metrics)."""
        return metrics_state

    @chex.dataclass
    class ProcessArgs(BaseProcessArgs):
        """Arguments for processing the correlation metric module.

        Attributes:
            policy_params: Current policy parameters used for trajectory generation,
                backward trajectory sampling, and log-ratio computation.
            env_params: Environment parameters required for trajectory generation,
                reward computation, and rollout operations.
        """

        policy_params: TPolicyParams
        env_params: TEnvParams

    def process(
        self,
        metrics_state: CorrelationMetricsState,
        rng_key: chex.PRNGKey,
        args: ProcessArgs,
    ) -> CorrelationMetricsState:
        rng_key, test_data_key, rollout_key = jax.random.split(rng_key, 3)
        # Compute test terminal states and already transformed log-rewards
        test_terminal_states, test_log_rewards_transformed = self._get_states_and_rewards(
            metrics_state, test_data_key, args
        )
        # Compute log-ratios using backward trajectory sampling
        log_ratio_traj = self._compute_log_ratio(
            rng_key=rollout_key,
            terminal_states=test_terminal_states,
            policy_params=args.policy_params,
            env_params=args.env_params,
        )
        log_ratio_traj_transformed = self.transform_fn(
            test_terminal_states,
            log_ratio_traj,
        )
        chex.assert_equal_shape([
            log_ratio_traj_transformed,
            test_log_rewards_transformed,
        ])
        return metrics_state.replace(
            test_terminal_states=test_terminal_states,
            test_log_rewards_transformed=test_log_rewards_transformed,
            log_ratio_traj_transformed=log_ratio_traj_transformed,
        )

    def get(self, metrics_state: CorrelationMetricsState) -> dict[str, Any]:
        """Compute and return correlation metrics from the current state.

        Calculates two correlation measures between transformed model predictions and
        transformed true rewards:
        1. Pearson correlation in log-probability space
        2. Spearman rank correlation

        Args:
            metrics_state: Current state containing transformed log-ratios
                and transformed log-rewards

        Returns:
            Dict[str, Any]: Dictionary containing computed correlation metrics:
                - 'pearson': Pearson correlation of log-probabilities
                - 'spearman': Spearman rank correlation of log-probabilities
        """
        chex.assert_equal_shape([
            metrics_state.log_ratio_traj_transformed,
            metrics_state.test_log_rewards_transformed,
        ])
        return {
            "pearson": pearson_corr(
                metrics_state.log_ratio_traj_transformed,
                metrics_state.test_log_rewards_transformed,
            ),
            "spearman": spearman_corr(
                metrics_state.log_ratio_traj_transformed,
                metrics_state.test_log_rewards_transformed,
            ),
        }

    @abstractmethod
    def _get_states_and_rewards(
        self,
        metrics_state: CorrelationMetricsState,
        rng_key: chex.PRNGKey,
        args: ProcessArgs,
    ) -> tuple[TEnvState, jnp.ndarray]:
        """Get terminal states and transformed log rewards for correlation computation.

        This method should be implemented by subclasses to provide the terminal states
        and transformed log rewards for correlation metric computation.

        Args:
            metrics_state: Current metric state (largely ignored for on-policy)
            rng_key: Random number generator key for sampling
            args: ProcessArgs object containing policy parameters and environment parameters

        Returns:
            Tuple[TEnvState, jnp.ndarray]: A tuple containing:
                - Terminal states used for correlation evaluation
                - Transformed log rewards corresponding to the terminal states
        """
        raise NotImplementedError

    def _compute_log_ratio(
        self,
        rng_key: chex.PRNGKey,
        terminal_states: TEnvState,  # Shape: [T x ...]
        policy_params: TPolicyParams,
        env_params: TEnvParams,
    ) -> jnp.ndarray:
        """Compute log-ratios for terminal states using backward rollouts.

        This method performs multiple rounds of backward rollouts from given terminal
        states and computes the log-ratio of backward trajectories. The log-ratios
        are averaged in a probability space (instead of log space) using log-sum-exp.

        Args:
            rng_key: Random number generator key for sampling
            terminal_states: Terminal states to start backward rollouts from.
                Shape: [T x ...], where T is the total number of terminal states
            policy_params: Parameters for the backward policy function
            env_params: Environment parameters for rollout execution

        Returns:
            jnp.ndarray: Computed log-ratios with shape (n_terminal_states,)
                representing the average log-ratio for each terminal state
                across all sampling rounds

        Note:
            The method uses jax.lax.scan for efficient computation across batches
            and rounds. Log-ratios are computed using backward_trajectory_log_probs
            and averaged using logsumexp for numerical stability.
        """
        n_terminal_states = terminal_states.is_pad.shape[0]
        # Use additional batches to avoid OOM
        terminal_states = jax.tree.map(
            lambda x: x.reshape(-1, self.batch_size, *x.shape[1:]),
            terminal_states,
        )

        def process_batch(rng_key, terminal_states):
            """Process a single batch of terminal states."""
            rng_key, rollout_key = jax.random.split(rng_key)
            bwd_traj_data, _ = backward_rollout(
                rng_key=rollout_key,
                init_state=terminal_states,
                policy_fn=self.bwd_policy_fn,
                policy_params=policy_params,
                env=self.env,
                env_params=env_params,
            )
            log_pf_traj, log_pb_traj = backward_trajectory_log_probs(
                self.env, bwd_traj_data, env_params
            )
            log_ratio_traj = log_pf_traj - log_pb_traj
            return rng_key, log_ratio_traj

        def process_round(carry: tuple[chex.PRNGKey, TEnvState], xs: None):
            """Process a single round of sampling across all batches."""
            rng_key, terminal_states = carry
            rng_key, log_ratio_traj = jax.lax.scan(process_batch, rng_key, terminal_states)
            chex.assert_shape(log_ratio_traj, terminal_states.is_pad.shape[:2])
            return (rng_key, terminal_states), log_ratio_traj.reshape(-1)

        _, log_ratio_traj = jax.lax.scan(
            process_round,
            (rng_key, terminal_states),
            xs=None,
            length=self.n_rounds,
        )
        chex.assert_shape(log_ratio_traj, (self.n_rounds, n_terminal_states))

        # Average ratios over rounds for each test datum using log-sum-exp
        log_ratio_traj = jax.nn.logsumexp(log_ratio_traj, axis=0)
        return log_ratio_traj - jnp.log(self.n_rounds)

ProcessArgs

Bases: BaseProcessArgs

Arguments for processing the correlation metric module.

Attributes:

Name Type Description
policy_params TPolicyParams

Current policy parameters used for trajectory generation, backward trajectory sampling, and log-ratio computation.

env_params TEnvParams

Environment parameters required for trajectory generation, reward computation, and rollout operations.

Source code in gfnx/metrics/correlation.py
105
106
107
108
109
110
111
112
113
114
115
116
117
@chex.dataclass
class ProcessArgs(BaseProcessArgs):
    """Arguments for processing the correlation metric module.

    Attributes:
        policy_params: Current policy parameters used for trajectory generation,
            backward trajectory sampling, and log-ratio computation.
        env_params: Environment parameters required for trajectory generation,
            reward computation, and rollout operations.
    """

    policy_params: TPolicyParams
    env_params: TEnvParams

__init__(env, bwd_policy_fn, n_rounds, batch_size=1, transform_fn=None)

Initialize the correlation metric module.

Parameters:

Name Type Description Default
env TEnvironment

Environment for trajectory generation and reward computation

required
bwd_policy_fn TPolicyFn

Backward policy function for trajectory probability estimation

required
n_rounds int

Number of sampling rounds for averaging log-ratios

required
batch_size int

Batch size for processing terminal states

1
transform_fn Callable[[TEnvState, ndarray], ndarray] | None

Function to marginalize distributions from terminal states to an arbitrary domain size for correlation computation

None
Source code in gfnx/metrics/correlation.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def __init__(
    self,
    env: TEnvironment,
    bwd_policy_fn: TPolicyFn,
    n_rounds: int,
    batch_size: int = 1,
    transform_fn: Callable[[TEnvState, jnp.ndarray], jnp.ndarray] | None = None,
):
    """Initialize the correlation metric module.

    Args:
        env: Environment for trajectory generation and reward computation
        bwd_policy_fn: Backward policy function for trajectory probability estimation
        n_rounds: Number of sampling rounds for averaging log-ratios
        batch_size: Batch size for processing terminal states
        transform_fn: Function to marginalize distributions from terminal states
            to an arbitrary domain size for correlation computation
    """
    self.env = env
    self.bwd_policy_fn = bwd_policy_fn
    self.n_rounds = n_rounds
    self.batch_size = batch_size
    self.transform_fn = transform_fn if transform_fn is not None else lambda x, y: y

get(metrics_state)

Compute and return correlation metrics from the current state.

Calculates two correlation measures between transformed model predictions and transformed true rewards: 1. Pearson correlation in log-probability space 2. Spearman rank correlation

Parameters:

Name Type Description Default
metrics_state CorrelationMetricsState

Current state containing transformed log-ratios and transformed log-rewards

required

Returns:

Type Description
dict[str, Any]

Dict[str, Any]: Dictionary containing computed correlation metrics: - 'pearson': Pearson correlation of log-probabilities - 'spearman': Spearman rank correlation of log-probabilities

Source code in gfnx/metrics/correlation.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def get(self, metrics_state: CorrelationMetricsState) -> dict[str, Any]:
    """Compute and return correlation metrics from the current state.

    Calculates two correlation measures between transformed model predictions and
    transformed true rewards:
    1. Pearson correlation in log-probability space
    2. Spearman rank correlation

    Args:
        metrics_state: Current state containing transformed log-ratios
            and transformed log-rewards

    Returns:
        Dict[str, Any]: Dictionary containing computed correlation metrics:
            - 'pearson': Pearson correlation of log-probabilities
            - 'spearman': Spearman rank correlation of log-probabilities
    """
    chex.assert_equal_shape([
        metrics_state.log_ratio_traj_transformed,
        metrics_state.test_log_rewards_transformed,
    ])
    return {
        "pearson": pearson_corr(
            metrics_state.log_ratio_traj_transformed,
            metrics_state.test_log_rewards_transformed,
        ),
        "spearman": spearman_corr(
            metrics_state.log_ratio_traj_transformed,
            metrics_state.test_log_rewards_transformed,
        ),
    }

update(metrics_state, rng_key, args=None)

Update metric state with new data (no-op for correlation metrics).

Source code in gfnx/metrics/correlation.py
 96
 97
 98
 99
100
101
102
103
def update(
    self,
    metrics_state: CorrelationMetricsState,
    rng_key: chex.PRNGKey,
    args: UpdateArgs | None = None,
) -> CorrelationMetricsState:
    """Update metric state with new data (no-op for correlation metrics)."""
    return metrics_state

BaseInitArgs

Bases: ABC

Base class for argument containers passed as args to the init method of metric modules.

Source code in gfnx/metrics/base.py
26
27
28
29
30
class BaseInitArgs(ABC):
    """
    Base class for argument containers passed as `args` to the `init` method
    of metric modules.
    """

BaseMetricsModule

Bases: ABC, Generic[TInitArgs, TUpdateArgs, TProcessArgs, TMetricsState]

Environment-agnostic base metric module.

This abstract base class defines the interface for all metric modules in the system. Metric modules are responsible for tracking, computing, and reporting metrics during training and evaluation phases. They maintain their own state and can be composed together using the MultiMetricsModule.

The lifecycle of a metric module follows this pattern: 1. init() - Initialize the metric state with required parameters 2. update() - Update state with new data points during training/evaluation 3. process() - Apply any final transformations before metric computation 4. get() - Retrieve the computed metrics as a dictionary

In addition to implementing the abstract methods, each subclass must define the following inner classes to specify argument types: - InitArgs, inheriting from BaseInitArgs - UpdateArgs, inheriting from BaseUpdateArgs - ProcessArgs, inheriting from BaseProcessArgs

Subclasses that do not provide these classes or methods will raise errors via init_subclass validation.

Source code in gfnx/metrics/base.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
class BaseMetricsModule(ABC, Generic[TInitArgs, TUpdateArgs, TProcessArgs, TMetricsState]):
    """Environment-agnostic base metric module.

    This abstract base class defines the interface for all metric modules in the system.
    Metric modules are responsible for tracking, computing, and reporting metrics during
    training and evaluation phases. They maintain their own state and can be composed
    together using the MultiMetricsModule.

    The lifecycle of a metric module follows this pattern:
    1. init() - Initialize the metric state with required parameters
    2. update() - Update state with new data points during training/evaluation
    3. process() - Apply any final transformations before metric computation
    4. get() - Retrieve the computed metrics as a dictionary

    In addition to implementing the abstract methods, each subclass must define the
    following inner classes to specify argument types:
      - InitArgs, inheriting from BaseInitArgs
      - UpdateArgs, inheriting from BaseUpdateArgs
      - ProcessArgs, inheriting from BaseProcessArgs

    Subclasses that do not provide these classes or methods will raise errors via
    __init_subclass__ validation.
    """

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        # skip verification if the class is still abstract
        if inspect.isabstract(cls):
            return

        # 1. Must define InitArgs, UpdateArgs, ProcessArgs
        if not hasattr(cls, "InitArgs"):
            raise TypeError(f"{cls.__name__} must define an InitArgs class")
        if not hasattr(cls, "UpdateArgs"):
            raise TypeError(f"{cls.__name__} must define an UpdateArgs class")
        if not hasattr(cls, "ProcessArgs"):
            raise TypeError(f"{cls.__name__} must define a ProcessArgs class")

        # 2. Must inherit from BaseInitArgs
        if not issubclass(cls.InitArgs, BaseInitArgs):
            raise TypeError(f"{cls.__name__}.InitArgs must inherit from BaseInitArgs")
        if not issubclass(cls.UpdateArgs, BaseUpdateArgs):
            raise TypeError(f"{cls.__name__}.UpdateArgs must inherit from BaseUpdateArgs")
        if not issubclass(cls.ProcessArgs, BaseProcessArgs):
            raise TypeError(f"{cls.__name__}.ProcessArgs must inherit from BaseProcessArgs")

    @abstractmethod
    def init(self, rng_key: chex.PRNGKey, args: TInitArgs | None = None) -> TMetricsState:
        """Initialize metric state.

        Creates and returns the initial state required for metric computation.
        This method is called once at the beginning of training or evaluation
        to set up any necessary data structures, counters, or buffers.

        Args:
            rng_key: JAX PRNG key for any random initialization required
            args: Optional InitArgs object containing metric-specific init parameters

        Returns:
            MetricsState: Initialized state object for this metric

        Raises:
            NotImplementedError: Must be implemented by subclasses
        """
        raise NotImplementedError

    @abstractmethod
    def update(
        self, metrics_state: TMetricsState, rng_key: chex.PRNGKey, args: TUpdateArgs | None = None
    ) -> TMetricsState:
        """Update metric state with new data.

        Updates the metric state with new data points collected during training
        or evaluation. This method is called repeatedly as new data becomes available
        and should accumulate or process the information needed for final metric
        computation.

        Args:
            metrics_state: Current state of the metric
            rng_key: JAX PRNG key for any random operations during update
            args: Optional UpdateArgs object containing metric-specific update data

        Returns:
            MetricsState: Updated state object with new data incorporated

        Raises:
            NotImplementedError: Must be implemented by subclasses
        """
        raise NotImplementedError

    @abstractmethod
    def process(
        self, metrics_state: TMetricsState, rng_key: chex.PRNGKey, args: TProcessArgs | None = None
    ) -> TMetricsState:
        """Process metric state to compute metrics and perform final transformations.

        This method is called exactly to compute metrics before getting their results
        and to perform final transformations during the evaluation period. It prepares
        the accumulated data for final metric computation by applying any necessary
        calculations, normalizations, or statistical operations.

        Args:
            metrics_state: Current state of the metric after all updates
            rng_key: JAX PRNG key for any random operations during processing
            args: Optional ProcessArgs object containing metric-specific processing parameters

        Returns:
            MetricsState: Processed state ready for metric retrieval via get()

        Raises:
            NotImplementedError: Must be implemented by subclasses
        """
        raise NotImplementedError

    @abstractmethod
    def get(self, metrics_state: TMetricsState) -> dict[str, Any]:
        """Get computed metrics from the current state.

        Computes and returns the final metrics based on the current state.
        This method should extract meaningful metrics from the accumulated
        data and return them in a standardized dictionary format.

        Args:
            metrics_state: Current processed state of the metric

        Returns:
            Dict[str, Any]: Dictionary containing computed metrics with
                          descriptive keys and their corresponding values

        Raises:
            NotImplementedError: Must be implemented by subclasses
        """
        raise NotImplementedError

get(metrics_state) abstractmethod

Get computed metrics from the current state.

Computes and returns the final metrics based on the current state. This method should extract meaningful metrics from the accumulated data and return them in a standardized dictionary format.

Parameters:

Name Type Description Default
metrics_state TMetricsState

Current processed state of the metric

required

Returns:

Type Description
dict[str, Any]

Dict[str, Any]: Dictionary containing computed metrics with descriptive keys and their corresponding values

Raises:

Type Description
NotImplementedError

Must be implemented by subclasses

Source code in gfnx/metrics/base.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
@abstractmethod
def get(self, metrics_state: TMetricsState) -> dict[str, Any]:
    """Get computed metrics from the current state.

    Computes and returns the final metrics based on the current state.
    This method should extract meaningful metrics from the accumulated
    data and return them in a standardized dictionary format.

    Args:
        metrics_state: Current processed state of the metric

    Returns:
        Dict[str, Any]: Dictionary containing computed metrics with
                      descriptive keys and their corresponding values

    Raises:
        NotImplementedError: Must be implemented by subclasses
    """
    raise NotImplementedError

init(rng_key, args=None) abstractmethod

Initialize metric state.

Creates and returns the initial state required for metric computation. This method is called once at the beginning of training or evaluation to set up any necessary data structures, counters, or buffers.

Parameters:

Name Type Description Default
rng_key PRNGKey

JAX PRNG key for any random initialization required

required
args TInitArgs | None

Optional InitArgs object containing metric-specific init parameters

None

Returns:

Name Type Description
MetricsState TMetricsState

Initialized state object for this metric

Raises:

Type Description
NotImplementedError

Must be implemented by subclasses

Source code in gfnx/metrics/base.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
@abstractmethod
def init(self, rng_key: chex.PRNGKey, args: TInitArgs | None = None) -> TMetricsState:
    """Initialize metric state.

    Creates and returns the initial state required for metric computation.
    This method is called once at the beginning of training or evaluation
    to set up any necessary data structures, counters, or buffers.

    Args:
        rng_key: JAX PRNG key for any random initialization required
        args: Optional InitArgs object containing metric-specific init parameters

    Returns:
        MetricsState: Initialized state object for this metric

    Raises:
        NotImplementedError: Must be implemented by subclasses
    """
    raise NotImplementedError

process(metrics_state, rng_key, args=None) abstractmethod

Process metric state to compute metrics and perform final transformations.

This method is called exactly to compute metrics before getting their results and to perform final transformations during the evaluation period. It prepares the accumulated data for final metric computation by applying any necessary calculations, normalizations, or statistical operations.

Parameters:

Name Type Description Default
metrics_state TMetricsState

Current state of the metric after all updates

required
rng_key PRNGKey

JAX PRNG key for any random operations during processing

required
args TProcessArgs | None

Optional ProcessArgs object containing metric-specific processing parameters

None

Returns:

Name Type Description
MetricsState TMetricsState

Processed state ready for metric retrieval via get()

Raises:

Type Description
NotImplementedError

Must be implemented by subclasses

Source code in gfnx/metrics/base.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
@abstractmethod
def process(
    self, metrics_state: TMetricsState, rng_key: chex.PRNGKey, args: TProcessArgs | None = None
) -> TMetricsState:
    """Process metric state to compute metrics and perform final transformations.

    This method is called exactly to compute metrics before getting their results
    and to perform final transformations during the evaluation period. It prepares
    the accumulated data for final metric computation by applying any necessary
    calculations, normalizations, or statistical operations.

    Args:
        metrics_state: Current state of the metric after all updates
        rng_key: JAX PRNG key for any random operations during processing
        args: Optional ProcessArgs object containing metric-specific processing parameters

    Returns:
        MetricsState: Processed state ready for metric retrieval via get()

    Raises:
        NotImplementedError: Must be implemented by subclasses
    """
    raise NotImplementedError

update(metrics_state, rng_key, args=None) abstractmethod

Update metric state with new data.

Updates the metric state with new data points collected during training or evaluation. This method is called repeatedly as new data becomes available and should accumulate or process the information needed for final metric computation.

Parameters:

Name Type Description Default
metrics_state TMetricsState

Current state of the metric

required
rng_key PRNGKey

JAX PRNG key for any random operations during update

required
args TUpdateArgs | None

Optional UpdateArgs object containing metric-specific update data

None

Returns:

Name Type Description
MetricsState TMetricsState

Updated state object with new data incorporated

Raises:

Type Description
NotImplementedError

Must be implemented by subclasses

Source code in gfnx/metrics/base.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
@abstractmethod
def update(
    self, metrics_state: TMetricsState, rng_key: chex.PRNGKey, args: TUpdateArgs | None = None
) -> TMetricsState:
    """Update metric state with new data.

    Updates the metric state with new data points collected during training
    or evaluation. This method is called repeatedly as new data becomes available
    and should accumulate or process the information needed for final metric
    computation.

    Args:
        metrics_state: Current state of the metric
        rng_key: JAX PRNG key for any random operations during update
        args: Optional UpdateArgs object containing metric-specific update data

    Returns:
        MetricsState: Updated state object with new data incorporated

    Raises:
        NotImplementedError: Must be implemented by subclasses
    """
    raise NotImplementedError

BaseProcessArgs

Bases: ABC

Base class for argument containers passed as args to the process method of metric modules.

Source code in gfnx/metrics/base.py
42
43
44
45
46
class BaseProcessArgs(ABC):
    """
    Base class for argument containers passed as `args` to the `process` method
    of metric modules.
    """

BaseUpdateArgs

Bases: ABC

Base class for argument containers passed as args to the update method of metric modules.

Source code in gfnx/metrics/base.py
34
35
36
37
38
class BaseUpdateArgs(ABC):
    """
    Base class for argument containers passed as `args` to the `update` method
    of metric modules.
    """

CorrelationMetricsState

Bases: MetricsState

State container for correlation-based metric computation.

This state class stores the data required for computing correlation metrics between transformed distributions of model predictions and true rewards. It maintains terminal states, their corresponding transformed log-rewards, and computed transformed log-ratios of backward trajectories.

Attributes:

Name Type Description
test_terminal_states TEnvState

Terminal states used for correlation evaluation. These can be either sampled on-policy or from a fixed test set.

test_log_rewards_transformed ndarray

Transformed log-rewards corresponding to the terminal states. Shape: (domain_size,) - obtained by marginalizing the distribution over terminal states to a domain of arbitrary size.

log_ratio_traj_transformed ndarray

Transformed log-ratios of backward trajectories from terminal states. These ratios represent the model's estimated probability of reaching each terminal state, marginalized to domain_size. Shape: (domain_size,)

Source code in gfnx/metrics/correlation.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@chex.dataclass
class CorrelationMetricsState(MetricsState):
    """State container for correlation-based metric computation.

    This state class stores the data required for computing correlation metrics
    between transformed distributions of model predictions and true rewards.
    It maintains terminal states, their corresponding transformed log-rewards,
    and computed transformed log-ratios of backward trajectories.

    Attributes:
        test_terminal_states: Terminal states used for correlation evaluation.
            These can be either sampled on-policy or from a fixed test set.
        test_log_rewards_transformed: Transformed log-rewards corresponding to the
            terminal states. Shape: (domain_size,) - obtained by marginalizing
            the distribution over terminal states to a domain of arbitrary size.
        log_ratio_traj_transformed: Transformed log-ratios of backward trajectories from
            terminal states. These ratios represent the model's estimated probability
            of reaching each terminal state, marginalized to domain_size.
            Shape: (domain_size,)
    """

    test_terminal_states: TEnvState
    test_log_rewards_transformed: jnp.ndarray
    log_ratio_traj_transformed: jnp.ndarray

ELBOMetricState

Bases: MetricsState

State container for the Evidence Lower Bound (ELBO) metric.

This state container stores the computed ELBO metric.

Attributes:

Name Type Description
elbo ndarray

metric value.

Source code in gfnx/metrics/elbo.py
17
18
19
20
21
22
23
24
25
26
27
@chex.dataclass
class ELBOMetricState(MetricsState):
    """State container for the Evidence Lower Bound (ELBO) metric.

    This state container stores the computed ELBO metric.

    Attributes:
        elbo: metric value.
    """

    elbo: jnp.ndarray

ELBOMetricsModule

Bases: BaseMetricsModule

Computes the Evidence Lower Bound (ELBO) for a GFlowNet model.

This metric evaluates the GFlowNet model by estimating the ELBO. The ELBO is computed by sampling trajectories from the forward policy and evaluating the log-ratios of the forward and backward probabilities plus the log reward.

The ELBO is defined as: ELBO = { if logZ is tractable: E_{traj ~ Pf} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)] - logZ else: E_{traj ~ Pf} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)] }, where traj is sampled from the trained forward policy.

Attributes:

Name Type Description
env

Environment instance for trajectory generation and evaluation.

env_params

Environment parameters.

fwd_policy_fn

Forward policy function producing action logits.

n_rounds

Number of sampling rounds for statistical stability.

batch_size

Batch size used when evaluating policy over states.

Source code in gfnx/metrics/elbo.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class ELBOMetricsModule(BaseMetricsModule):
    """Computes the Evidence Lower Bound (ELBO) for a GFlowNet model.

    This metric evaluates the GFlowNet model by estimating the ELBO.
    The ELBO is computed by sampling trajectories from the forward policy and evaluating
    the log-ratios of the forward and backward probabilities plus the log reward.

    The ELBO is defined as:
    ELBO = {
        if logZ is tractable:
            E_{traj ~ Pf} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)] - logZ
        else:
            E_{traj ~ Pf} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)]
    },
    where traj is sampled from the trained forward policy.

    Attributes:
        env: Environment instance for trajectory generation and evaluation.
        env_params: Environment parameters.
        fwd_policy_fn: Forward policy function producing action logits.
        n_rounds: Number of sampling rounds for statistical stability.
        batch_size: Batch size used when evaluating policy over states.
    """

    def __init__(
        self,
        env: TEnvironment,
        env_params: TEnvParams,
        fwd_policy_fn: TPolicyFn,
        n_rounds: int,
        batch_size: int,
    ):
        """Initializes the ELBO metric module.

        Args:
            env: Environment for trajectory generation and reward computation.
            env_params: Environment parameters used for trajectory generation.
            fwd_policy_fn: Forward policy function for generating trajectories.
            n_rounds: The number of sampling rounds to perform for estimation.
            batch_size: The number of environments to run in parallel for sampling.
        """
        self.env = env
        if env.is_normalizing_constant_tractable:
            self.logZ = jnp.log(env.get_normalizing_constant(env_params))
        else:
            self.logZ = jnp.array(0.0)
        self.fwd_policy_fn = fwd_policy_fn
        self.n_rounds = n_rounds
        self.batch_size = batch_size

    # Ensure the module has a consistent interface
    InitArgs = EmptyInitArgs

    def init(self, rng_key: chex.PRNGKey, args: InitArgs) -> ELBOMetricState:
        """Initialize the metric state for ELBO metric."""
        return ELBOMetricState(elbo=jnp.array(-jnp.inf, dtype=jnp.float32))

    UpdateArgs = EmptyUpdateArgs

    def update(
        self,
        metrics_state: ELBOMetricState,
        rng_key: chex.PRNGKey,
        args: UpdateArgs | None = None,
    ) -> ELBOMetricState:
        """
        Update metric state with new data.
        This is a no-op as the metric is computed on demand.
        """
        return metrics_state

    def get(self, metrics_state: ELBOMetricState) -> dict[str, Any]:
        """Returns the computed ELBO metric from the current state.

        Args:
            metrics_state: The current state containing the computed ELBO.

        Returns:
            A dictionary containing the ELBO value.
        """
        return {"elbo": metrics_state.elbo}

    @chex.dataclass
    class ProcessArgs(BaseProcessArgs):
        """Arguments for processing the ELBO metric module.

        Attributes:
            policy_params: Current policy parameters used for forward and backward rollouts
                to generate terminal states and compute log-ratios.
            env_params: Environment parameters required for trajectory generation
                and reward computation.
        """

        policy_params: TPolicyParams
        env_params: TEnvParams

    def process(
        self,
        metrics_state: ELBOMetricState,
        rng_key: chex.PRNGKey,
        args: ProcessArgs,
    ) -> ELBOMetricState:
        """Computes the ELBO by sampling trajectories from the forward policy.

        This method performs multiple rounds of forward rollouts to sample
        trajectories, and then computes the ELBO for each trajectory. The final
        ELBO is the average over all sampled trajectories across all rounds.

        Args:
            rng_key: Random number generator key for sampling.
            args: Arguments for processing, containing policy and environment parameters.

        Returns:
            An updated metrics state containing the ELBO value, averaged over all
            trajectories and rounds.
        """

        def process_round(carry_rng_key, _):
            """Process a single round of sampling across all batches."""
            rng_key, rollout_key = jax.random.split(carry_rng_key)
            fwd_traj_data, aux_info = forward_rollout(
                rng_key=rollout_key,
                num_envs=self.batch_size,
                policy_fn=self.fwd_policy_fn,
                policy_params=args.policy_params,
                env=self.env,
                env_params=args.env_params,
            )
            # ELBO = E_{traj ~ Pf} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)]
            # (without normalising constant)
            log_pf_traj, log_pb_traj = forward_trajectory_log_probs(
                self.env, fwd_traj_data, args.env_params
            )
            elbo = log_pb_traj - log_pf_traj + aux_info["log_gfn_reward"]
            chex.assert_shape(elbo, (self.batch_size,))
            return rng_key, elbo

        _, elbo_per_round = jax.lax.scan(
            process_round,
            rng_key,
            None,
            length=self.n_rounds,
        )
        chex.assert_shape(elbo_per_round, (self.n_rounds, self.batch_size))

        # Average over rounds and batch. Normalise using logZ, if it is tractable.
        elbo = jnp.mean(elbo_per_round) - self.logZ
        return metrics_state.replace(elbo=elbo)

ProcessArgs

Bases: BaseProcessArgs

Arguments for processing the ELBO metric module.

Attributes:

Name Type Description
policy_params TPolicyParams

Current policy parameters used for forward and backward rollouts to generate terminal states and compute log-ratios.

env_params TEnvParams

Environment parameters required for trajectory generation and reward computation.

Source code in gfnx/metrics/elbo.py
112
113
114
115
116
117
118
119
120
121
122
123
124
@chex.dataclass
class ProcessArgs(BaseProcessArgs):
    """Arguments for processing the ELBO metric module.

    Attributes:
        policy_params: Current policy parameters used for forward and backward rollouts
            to generate terminal states and compute log-ratios.
        env_params: Environment parameters required for trajectory generation
            and reward computation.
    """

    policy_params: TPolicyParams
    env_params: TEnvParams

__init__(env, env_params, fwd_policy_fn, n_rounds, batch_size)

Initializes the ELBO metric module.

Parameters:

Name Type Description Default
env TEnvironment

Environment for trajectory generation and reward computation.

required
env_params TEnvParams

Environment parameters used for trajectory generation.

required
fwd_policy_fn TPolicyFn

Forward policy function for generating trajectories.

required
n_rounds int

The number of sampling rounds to perform for estimation.

required
batch_size int

The number of environments to run in parallel for sampling.

required
Source code in gfnx/metrics/elbo.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def __init__(
    self,
    env: TEnvironment,
    env_params: TEnvParams,
    fwd_policy_fn: TPolicyFn,
    n_rounds: int,
    batch_size: int,
):
    """Initializes the ELBO metric module.

    Args:
        env: Environment for trajectory generation and reward computation.
        env_params: Environment parameters used for trajectory generation.
        fwd_policy_fn: Forward policy function for generating trajectories.
        n_rounds: The number of sampling rounds to perform for estimation.
        batch_size: The number of environments to run in parallel for sampling.
    """
    self.env = env
    if env.is_normalizing_constant_tractable:
        self.logZ = jnp.log(env.get_normalizing_constant(env_params))
    else:
        self.logZ = jnp.array(0.0)
    self.fwd_policy_fn = fwd_policy_fn
    self.n_rounds = n_rounds
    self.batch_size = batch_size

get(metrics_state)

Returns the computed ELBO metric from the current state.

Parameters:

Name Type Description Default
metrics_state ELBOMetricState

The current state containing the computed ELBO.

required

Returns:

Type Description
dict[str, Any]

A dictionary containing the ELBO value.

Source code in gfnx/metrics/elbo.py
101
102
103
104
105
106
107
108
109
110
def get(self, metrics_state: ELBOMetricState) -> dict[str, Any]:
    """Returns the computed ELBO metric from the current state.

    Args:
        metrics_state: The current state containing the computed ELBO.

    Returns:
        A dictionary containing the ELBO value.
    """
    return {"elbo": metrics_state.elbo}

init(rng_key, args)

Initialize the metric state for ELBO metric.

Source code in gfnx/metrics/elbo.py
83
84
85
def init(self, rng_key: chex.PRNGKey, args: InitArgs) -> ELBOMetricState:
    """Initialize the metric state for ELBO metric."""
    return ELBOMetricState(elbo=jnp.array(-jnp.inf, dtype=jnp.float32))

process(metrics_state, rng_key, args)

Computes the ELBO by sampling trajectories from the forward policy.

This method performs multiple rounds of forward rollouts to sample trajectories, and then computes the ELBO for each trajectory. The final ELBO is the average over all sampled trajectories across all rounds.

Parameters:

Name Type Description Default
rng_key PRNGKey

Random number generator key for sampling.

required
args ProcessArgs

Arguments for processing, containing policy and environment parameters.

required

Returns:

Type Description
ELBOMetricState

An updated metrics state containing the ELBO value, averaged over all

ELBOMetricState

trajectories and rounds.

Source code in gfnx/metrics/elbo.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def process(
    self,
    metrics_state: ELBOMetricState,
    rng_key: chex.PRNGKey,
    args: ProcessArgs,
) -> ELBOMetricState:
    """Computes the ELBO by sampling trajectories from the forward policy.

    This method performs multiple rounds of forward rollouts to sample
    trajectories, and then computes the ELBO for each trajectory. The final
    ELBO is the average over all sampled trajectories across all rounds.

    Args:
        rng_key: Random number generator key for sampling.
        args: Arguments for processing, containing policy and environment parameters.

    Returns:
        An updated metrics state containing the ELBO value, averaged over all
        trajectories and rounds.
    """

    def process_round(carry_rng_key, _):
        """Process a single round of sampling across all batches."""
        rng_key, rollout_key = jax.random.split(carry_rng_key)
        fwd_traj_data, aux_info = forward_rollout(
            rng_key=rollout_key,
            num_envs=self.batch_size,
            policy_fn=self.fwd_policy_fn,
            policy_params=args.policy_params,
            env=self.env,
            env_params=args.env_params,
        )
        # ELBO = E_{traj ~ Pf} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)]
        # (without normalising constant)
        log_pf_traj, log_pb_traj = forward_trajectory_log_probs(
            self.env, fwd_traj_data, args.env_params
        )
        elbo = log_pb_traj - log_pf_traj + aux_info["log_gfn_reward"]
        chex.assert_shape(elbo, (self.batch_size,))
        return rng_key, elbo

    _, elbo_per_round = jax.lax.scan(
        process_round,
        rng_key,
        None,
        length=self.n_rounds,
    )
    chex.assert_shape(elbo_per_round, (self.n_rounds, self.batch_size))

    # Average over rounds and batch. Normalise using logZ, if it is tractable.
    elbo = jnp.mean(elbo_per_round) - self.logZ
    return metrics_state.replace(elbo=elbo)

update(metrics_state, rng_key, args=None)

Update metric state with new data. This is a no-op as the metric is computed on demand.

Source code in gfnx/metrics/elbo.py
89
90
91
92
93
94
95
96
97
98
99
def update(
    self,
    metrics_state: ELBOMetricState,
    rng_key: chex.PRNGKey,
    args: UpdateArgs | None = None,
) -> ELBOMetricState:
    """
    Update metric state with new data.
    This is a no-op as the metric is computed on demand.
    """
    return metrics_state

EUBOMetricState

Bases: MetricsState

State container for the Evidence Upper Bound (EUBO) metric.

This state container stores the computed EUBO metric.

Attributes:

Name Type Description
eubo ndarray

metric value.

Source code in gfnx/metrics/eubo.py
17
18
19
20
21
22
23
24
25
26
27
@chex.dataclass
class EUBOMetricState(MetricsState):
    """State container for the Evidence Upper Bound (EUBO) metric.

    This state container stores the computed EUBO metric.

    Attributes:
        eubo: metric value.
    """

    eubo: jnp.ndarray

EUBOMetricsModule

Bases: BaseMetricsModule

Computes the Evidence Upper Bound (EUBO) for a GFlowNet model.

This metric evaluates the GFlowNet model by estimating the EUBO. The EUBO is computed by sampling trajectories from the backward policy and evaluating the log-ratios of the forward and backward probabilities plus the log reward.

The EUBO is defined as: EUBO = { if logZ is tractable: E_{traj ~ R * Pb} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)] else: E_{traj ~ R * Pb} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)] - logZ }, where traj is sampled from the trained backward policy.

Attributes:

Name Type Description
env

Environment instance for trajectory generation and evaluation.

env_params

Environment parameters.

bwd_policy_fn

Backward policy function for sampling trajectories starting from terminal states and computing log-ratios.

n_rounds

Number of sampling rounds for statistical stability.

batch_size

Batch size used when evaluating policy over states.

rng_key

Key used for pseudo random generation.

Source code in gfnx/metrics/eubo.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
class EUBOMetricsModule(BaseMetricsModule):
    """Computes the Evidence Upper Bound (EUBO) for a GFlowNet model.

    This metric evaluates the GFlowNet model by estimating the EUBO.
    The EUBO is computed by sampling trajectories from the backward policy and
    evaluating the log-ratios of the forward and backward probabilities plus the log reward.

    The EUBO is defined as:
    EUBO = {
        if logZ is tractable:
            E_{traj ~ R * Pb} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)]
        else:
            E_{traj ~ R * Pb} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)] - logZ
    },
    where traj is sampled from the trained backward policy.

    Attributes:
        env: Environment instance for trajectory generation and evaluation.
        env_params: Environment parameters.
        bwd_policy_fn: Backward policy function for sampling trajectories
            starting from terminal states and computing log-ratios.
        n_rounds: Number of sampling rounds for statistical stability.
        batch_size: Batch size used when evaluating policy over states.
        rng_key: Key used for pseudo random generation.
    """

    def __init__(
        self,
        env: TEnvironment,
        env_params: TEnvParams,
        bwd_policy_fn: TPolicyFn,
        n_rounds: int,
        batch_size: int,
        rng_key: chex.PRNGKey,
    ):
        """Initializes the EUBO metric module.

        Args:
            env: Environment for trajectory generation and reward computation.
            env_params: Environment parameters.
            bwd_policy_fn: Backward policy function for generating trajectories starting
                from terminal states.
            n_rounds: The number of sampling rounds to perform for estimation.
            batch_size: The number of environments to run in parallel for sampling.
        """
        self.env = env
        rng_key, sample_key = jax.random.split(rng_key)
        self.test_set = env.get_ground_truth_sampling(sample_key, batch_size, env_params)
        if env.is_normalizing_constant_tractable:
            self.logZ = jnp.log(env.get_normalizing_constant(env_params))
        else:
            self.logZ = jnp.array(0.0)
        self.bwd_policy_fn = bwd_policy_fn
        self.n_rounds = n_rounds
        self.batch_size = batch_size

    # Ensure the module has a consistent interface
    InitArgs = EmptyInitArgs

    def init(self, rng_key: chex.PRNGKey, args: InitArgs) -> EUBOMetricState:
        """Initialize the metric state for EUBO metric."""
        return EUBOMetricState(eubo=jnp.array(jnp.inf, dtype=jnp.float32))

    UpdateArgs = EmptyUpdateArgs

    def update(
        self,
        metrics_state: EUBOMetricState,
        rng_key: chex.PRNGKey,
        args: UpdateArgs | None = None,
    ) -> EUBOMetricState:
        """
        Update metric state with new data.
        This is a no-op as the metric is computed on demand.
        """
        return metrics_state

    def get(self, metrics_state: EUBOMetricState) -> dict[str, Any]:
        """Returns the computed EUBO metric from the current state.

        Args:
            metrics_state: The current state containing the computed EUBO.

        Returns:
            A dictionary containing the EUBO value.
        """
        return {"eubo": metrics_state.eubo}

    @chex.dataclass
    class ProcessArgs(BaseProcessArgs):
        """Arguments for processing the EUBO metric module.

        Attributes:
            policy_params: Current policy parameters used for forward and backward rollouts
                to generate terminal states and compute log-ratios.
            env_params: Environment parameters required for trajectory generation
                and reward computation.
        """

        policy_params: TPolicyParams
        env_params: TEnvParams

    def process(
        self,
        metrics_state: EUBOMetricState,
        rng_key: chex.PRNGKey,
        args: ProcessArgs,
    ) -> EUBOMetricState:
        """Computes the EUBO by sampling trajectories from the backward policy.

        This method performs multiple rounds of backward rollouts to sample
        trajectories, and then computes the EUBO for each trajectory. The final
        EUBO is the average over all sampled trajectories across all rounds.

        Args:
            rng_key: Random number generator key for sampling.
            args: Arguments for processing, containing policy and environment parameters.

        Returns:
            An updated metrics state containing the EUBO value, averaged over all
            trajectories and rounds.
        """

        def process_round(carry_rng_key, _):
            """Process a single round of sampling across all batches."""
            rng_key, rollout_key = jax.random.split(carry_rng_key)
            bwd_traj_data, _ = backward_rollout(
                rng_key=rollout_key,
                init_state=self.test_set,
                policy_fn=self.bwd_policy_fn,
                policy_params=args.policy_params,
                env=self.env,
                env_params=args.env_params,
            )
            # EUBO = E_{traj ~ R * Pb} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)]
            # (without normalising constant)
            log_rewards = self.env.reward_module.log_reward(self.test_set, args.env_params)
            log_pf_traj, log_pb_traj = backward_trajectory_log_probs(
                self.env, bwd_traj_data, args.env_params
            )
            eubo = log_pb_traj - log_pf_traj + log_rewards
            chex.assert_shape(eubo, (self.batch_size,))
            return rng_key, eubo

        _, eubo_per_round = jax.lax.scan(
            process_round,
            rng_key,
            None,
            length=self.n_rounds,
        )
        chex.assert_shape(eubo_per_round, (self.n_rounds, self.batch_size))

        # Average over rounds and batch. Normalise using logZ, if it is tractable.
        eubo = jnp.mean(eubo_per_round) - self.logZ
        return metrics_state.replace(eubo=eubo)

ProcessArgs

Bases: BaseProcessArgs

Arguments for processing the EUBO metric module.

Attributes:

Name Type Description
policy_params TPolicyParams

Current policy parameters used for forward and backward rollouts to generate terminal states and compute log-ratios.

env_params TEnvParams

Environment parameters required for trajectory generation and reward computation.

Source code in gfnx/metrics/eubo.py
118
119
120
121
122
123
124
125
126
127
128
129
130
@chex.dataclass
class ProcessArgs(BaseProcessArgs):
    """Arguments for processing the EUBO metric module.

    Attributes:
        policy_params: Current policy parameters used for forward and backward rollouts
            to generate terminal states and compute log-ratios.
        env_params: Environment parameters required for trajectory generation
            and reward computation.
    """

    policy_params: TPolicyParams
    env_params: TEnvParams

__init__(env, env_params, bwd_policy_fn, n_rounds, batch_size, rng_key)

Initializes the EUBO metric module.

Parameters:

Name Type Description Default
env TEnvironment

Environment for trajectory generation and reward computation.

required
env_params TEnvParams

Environment parameters.

required
bwd_policy_fn TPolicyFn

Backward policy function for generating trajectories starting from terminal states.

required
n_rounds int

The number of sampling rounds to perform for estimation.

required
batch_size int

The number of environments to run in parallel for sampling.

required
Source code in gfnx/metrics/eubo.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def __init__(
    self,
    env: TEnvironment,
    env_params: TEnvParams,
    bwd_policy_fn: TPolicyFn,
    n_rounds: int,
    batch_size: int,
    rng_key: chex.PRNGKey,
):
    """Initializes the EUBO metric module.

    Args:
        env: Environment for trajectory generation and reward computation.
        env_params: Environment parameters.
        bwd_policy_fn: Backward policy function for generating trajectories starting
            from terminal states.
        n_rounds: The number of sampling rounds to perform for estimation.
        batch_size: The number of environments to run in parallel for sampling.
    """
    self.env = env
    rng_key, sample_key = jax.random.split(rng_key)
    self.test_set = env.get_ground_truth_sampling(sample_key, batch_size, env_params)
    if env.is_normalizing_constant_tractable:
        self.logZ = jnp.log(env.get_normalizing_constant(env_params))
    else:
        self.logZ = jnp.array(0.0)
    self.bwd_policy_fn = bwd_policy_fn
    self.n_rounds = n_rounds
    self.batch_size = batch_size

get(metrics_state)

Returns the computed EUBO metric from the current state.

Parameters:

Name Type Description Default
metrics_state EUBOMetricState

The current state containing the computed EUBO.

required

Returns:

Type Description
dict[str, Any]

A dictionary containing the EUBO value.

Source code in gfnx/metrics/eubo.py
107
108
109
110
111
112
113
114
115
116
def get(self, metrics_state: EUBOMetricState) -> dict[str, Any]:
    """Returns the computed EUBO metric from the current state.

    Args:
        metrics_state: The current state containing the computed EUBO.

    Returns:
        A dictionary containing the EUBO value.
    """
    return {"eubo": metrics_state.eubo}

init(rng_key, args)

Initialize the metric state for EUBO metric.

Source code in gfnx/metrics/eubo.py
89
90
91
def init(self, rng_key: chex.PRNGKey, args: InitArgs) -> EUBOMetricState:
    """Initialize the metric state for EUBO metric."""
    return EUBOMetricState(eubo=jnp.array(jnp.inf, dtype=jnp.float32))

process(metrics_state, rng_key, args)

Computes the EUBO by sampling trajectories from the backward policy.

This method performs multiple rounds of backward rollouts to sample trajectories, and then computes the EUBO for each trajectory. The final EUBO is the average over all sampled trajectories across all rounds.

Parameters:

Name Type Description Default
rng_key PRNGKey

Random number generator key for sampling.

required
args ProcessArgs

Arguments for processing, containing policy and environment parameters.

required

Returns:

Type Description
EUBOMetricState

An updated metrics state containing the EUBO value, averaged over all

EUBOMetricState

trajectories and rounds.

Source code in gfnx/metrics/eubo.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def process(
    self,
    metrics_state: EUBOMetricState,
    rng_key: chex.PRNGKey,
    args: ProcessArgs,
) -> EUBOMetricState:
    """Computes the EUBO by sampling trajectories from the backward policy.

    This method performs multiple rounds of backward rollouts to sample
    trajectories, and then computes the EUBO for each trajectory. The final
    EUBO is the average over all sampled trajectories across all rounds.

    Args:
        rng_key: Random number generator key for sampling.
        args: Arguments for processing, containing policy and environment parameters.

    Returns:
        An updated metrics state containing the EUBO value, averaged over all
        trajectories and rounds.
    """

    def process_round(carry_rng_key, _):
        """Process a single round of sampling across all batches."""
        rng_key, rollout_key = jax.random.split(carry_rng_key)
        bwd_traj_data, _ = backward_rollout(
            rng_key=rollout_key,
            init_state=self.test_set,
            policy_fn=self.bwd_policy_fn,
            policy_params=args.policy_params,
            env=self.env,
            env_params=args.env_params,
        )
        # EUBO = E_{traj ~ R * Pb} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)]
        # (without normalising constant)
        log_rewards = self.env.reward_module.log_reward(self.test_set, args.env_params)
        log_pf_traj, log_pb_traj = backward_trajectory_log_probs(
            self.env, bwd_traj_data, args.env_params
        )
        eubo = log_pb_traj - log_pf_traj + log_rewards
        chex.assert_shape(eubo, (self.batch_size,))
        return rng_key, eubo

    _, eubo_per_round = jax.lax.scan(
        process_round,
        rng_key,
        None,
        length=self.n_rounds,
    )
    chex.assert_shape(eubo_per_round, (self.n_rounds, self.batch_size))

    # Average over rounds and batch. Normalise using logZ, if it is tractable.
    eubo = jnp.mean(eubo_per_round) - self.logZ
    return metrics_state.replace(eubo=eubo)

update(metrics_state, rng_key, args=None)

Update metric state with new data. This is a no-op as the metric is computed on demand.

Source code in gfnx/metrics/eubo.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
def update(
    self,
    metrics_state: EUBOMetricState,
    rng_key: chex.PRNGKey,
    args: UpdateArgs | None = None,
) -> EUBOMetricState:
    """
    Update metric state with new data.
    This is a no-op as the metric is computed on demand.
    """
    return metrics_state

ExactDistributionMetricsModule

Bases: BaseMetricsModule

Exact distribution metrics for enumerable environments.

For enumerable environments, this module computes the exact terminal distribution induced by simple policy evaluation method.

Supported metrics
  • "tv": Total variation distance between distributions
  • "kl": KL divergence between true and exact terminal distributions
  • "jsd": Jensen-Shannon divergence between true and exact terminal distributions
  • "2d_marginal_distribution": Marginal distribution computation

Attributes:

Name Type Description
metrics

List of required metrics, choose from {"tv", "kl", "jsd", "2d_marginal_distribution"}

env

Enumerable environment for which to compute metrics

fwd_policy_fn

Forward policy function producing action logits

batch_size

Batch size used when evaluating policy over states

tol_epsilon

Tolerance for convergence in distribution computation

_supported_metrics dict[str, Callable[..., Array]]

Dictionary mapping metric names to computation functions

Source code in gfnx/metrics/exact_distribution.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
class ExactDistributionMetricsModule(BaseMetricsModule):
    """Exact distribution metrics for enumerable environments.

    For enumerable environments, this module computes the exact terminal distribution
    induced by simple policy evaluation method.

    Supported metrics:
        - "tv": Total variation distance between distributions
        - "kl": KL divergence between true and exact terminal distributions
        - "jsd": Jensen-Shannon divergence between true and exact terminal distributions
        - "2d_marginal_distribution": Marginal distribution computation

    Attributes:
        metrics: List of required metrics, choose from {"tv", "kl", "jsd",
            "2d_marginal_distribution"}
        env: Enumerable environment for which to compute metrics
        fwd_policy_fn: Forward policy function producing action logits
        batch_size: Batch size used when evaluating policy over states
        tol_epsilon: Tolerance for convergence in distribution computation
        _supported_metrics: Dictionary mapping metric names to computation functions
    """

    _supported_metrics: ClassVar[dict[str, Callable[..., chex.Array]]] = {
        "tv": total_variation_distance,
        "kl": kl_divergence,
        "jsd": jensen_shannon_divergence,
        "2d_marginal_distribution": marginal_distribution,
    }

    def __init__(
        self,
        metrics: list[str],
        env: TEnvironment,
        fwd_policy_fn: TPolicyFn,
        batch_size: int,
        tol_epsilon: float = 1e-7,
    ):
        """Initialize the exact-distribution metrics module.

        Validates inputs and records configuration. ``batch_size`` controls how
        states are grouped when evaluating the policy over the topological order
        for memory efficiency; it is unrelated to any replay buffer.

        Args:
            metrics: List of metric names to compute,
                choose from {"tv", "kl", "jsd", "2d_marginal_distribution"}.
            env: Enumerable environment for which to compute metrics.
            fwd_policy_fn: Forward policy function for generating trajectories.
            batch_size: Batch size used when evaluating policy over states.
            tol_epsilon: Tolerance for convergence in distribution computation.

        Raises:
            ValueError: If the environment is not enumerable or not topologically sortable,
                or if ``metrics`` is invalid or contains unsupported entries.
        """
        if not env.is_enumerable:
            raise ValueError(f"Environment {env.name} is not enumerable")
        if not isinstance(metrics, list) or not all(isinstance(m, str) for m in metrics):
            raise ValueError("metrics must be a list of strings")
        if not all(m in self._supported_metrics for m in metrics):
            raise ValueError(
                f"Unsupported metrics. Supported metrics are: \
                    {self._supported_metrics}"
            )

        self.metrics = metrics
        self.env = env
        self.fwd_policy_fn = fwd_policy_fn
        self.batch_size = batch_size
        self.tol_epsilon = tol_epsilon

    @chex.dataclass
    class InitArgs(BaseInitArgs):
        """Arguments for initializing the ExactDistributionMetricsModule.

        Attributes:
            env_params: Environment parameters used to obtain the true distribution
                and query environment functions.
        """

        env_params: TEnvParams

    def init(self, rng_key: chex.PRNGKey, args: InitArgs) -> ExactDistributionMetricsState:
        """Initialize the metric state.

        Obtains the ground-truth terminal distribution from the environment and
        initializes the ``exact_distribution`` with a uniform placeholder of the
        same shape. The exact distribution will be computed in ``process``.

        Args:
            rng_key: JAX PRNG key (currently unused).
            args: Initialization arguments containing environment parameters.

        Returns:
            ExactDistributionMetricsState: Initialized state containing:
                - true_distribution: Ground truth distribution from the environment
                - exact_distribution: Initial uniform distribution
        """
        true_distribution = self.env.get_true_distribution(args.env_params)
        exact_distribution = jnp.ones_like(true_distribution) / true_distribution.size
        return ExactDistributionMetricsState(
            true_distribution=true_distribution,
            exact_distribution=exact_distribution,
        )

    UpdateArgs = EmptyUpdateArgs

    def update(
        self, metrics_state: ExactDistributionMetricsState, rng_key: chex.PRNGKey, args: UpdateArgs
    ) -> ExactDistributionMetricsState:
        return metrics_state

    @chex.dataclass
    class ProcessArgs(BaseProcessArgs):
        """Arguments for processing the ExactDistributionMetricsModule.

        Attributes:
            policy_params: Parameters for the forward policy used during propagation.
            env_params: Environment parameters.
        """

        policy_params: TPolicyParams
        env_params: TEnvParams

    def process(
        self,
        metrics_state: ExactDistributionMetricsState,
        rng_key: chex.PRNGKey,
        args: ProcessArgs,
    ) -> ExactDistributionMetricsState:
        """
        Compute the exact terminal distribution induced by the forward policy.
        Uses a simple power iteration method to propagate the initial state distribution:
            - Initialize the state distribution with all mass on the initial state.
            - Repeatedly apply the transition matrix induced by the forward policy
              until convergence (L1 norm change below ``tol_epsilon``).
            - Extract the terminal distribution from the converged state distribution.

        Overall, we have
            `final_distribution = sum_{t=0}^inf (transition_matrix^t) @ initial_distribution`.

        Args:
            metrics_state: Current metrics state containing the true distribution.
            rng_key: JAX PRNG key for any stochasticity (currently unused).
            args: Processing arguments containing policy and environment parameters.

        Returns:
            ExactDistributionMetricsState: Updated metrics state with the computed
            exact distribution.
        """
        transition = self._preprare_transition_matrix(
            rng_key,
            args.env_params,
            args.policy_params,
        )
        num_states = transition.shape[0] // 2

        initial_vector = jnp.zeros((2 * num_states,))
        initial_state = jax.tree_map(lambda x: x[0], self.env.get_init_state(1))
        initial_idx = self.env.state_to_index(initial_state, args.env_params)
        initial_vector = initial_vector.at[initial_idx].set(1.0)

        def cond_function(carry: tuple):
            last_vec, _, _ = carry
            return jnp.linalg.norm(last_vec, ord=1) > self.tol_epsilon

        def one_step(carry: tuple):
            last_vec, total, transition = carry
            last_vec = transition @ last_vec
            return last_vec, total + last_vec, transition

        _, result, _ = jax.lax.while_loop(
            cond_function, one_step, (initial_vector, initial_vector, transition)
        )

        exact_distribution = result[num_states:].reshape(*metrics_state.true_distribution.shape)
        chex.assert_shape(exact_distribution, metrics_state.true_distribution.shape)
        return metrics_state.replace(exact_distribution=exact_distribution)

    def _preprare_transition_matrix(
        self,
        rng_key: chex.PRNGKey,
        env_params: TEnvParams,
        policy_params: TPolicyParams,
    ) -> chex.Array:
        """
        Returns a transposed sparse transition matrix of size (2 * num_states, 2 * num_states)
        in the BCSR format.
        """
        all_states = self.env.get_all_states(env_params)
        num_states = jax.tree_util.tree_leaves(all_states)[0].shape[0]

        remainder = num_states % self.batch_size
        if remainder != 0:
            pad_width = self.batch_size - remainder
            padded_sorted_states = jax.tree_map(
                lambda x: jnp.pad(
                    x, ((0, pad_width),) + ((0, 0),) * (x.ndim - 1), mode="constant"
                ),
                all_states,
            )
        else:
            padded_sorted_states = all_states

        num_batches = (
            jax.tree_util.tree_leaves(padded_sorted_states)[0].shape[0] // self.batch_size
        )
        batched_sorted_states = jax.tree_util.tree_map(
            lambda x: x.reshape((num_batches, self.batch_size, *x.shape[1:])),
            padded_sorted_states,
        )

        def scan_body(carry, states_batch: TEnvState):
            rng_key = carry
            rng_key, subkey = jax.random.split(rng_key)

            obs_batch = self.env.get_obs(states_batch, env_params)
            invalid_mask_batch = self.env.get_invalid_mask(states_batch, env_params)

            fwd_policy_logits, _ = self.fwd_policy_fn(subkey, obs_batch, policy_params)
            fwd_policy_probs = jax.nn.softmax(
                fwd_policy_logits, axis=-1, where=jnp.logical_not(invalid_mask_batch)
            )
            return rng_key, fwd_policy_probs

        _, batched_fwd_policy_probs = jax.lax.scan(
            scan_body,
            rng_key,
            batched_sorted_states,
        )

        fwd_policy_probs = batched_fwd_policy_probs.reshape(-1, batched_fwd_policy_probs.shape[-1])
        fwd_policy_probs = fwd_policy_probs[:num_states]
        chex.assert_shape(fwd_policy_probs, (num_states, self.env.action_space.n))

        state_idx = jax.vmap(self.env.state_to_index, in_axes=(0, None))(
            all_states, env_params
        )  # [num_states]
        actions = jnp.arange(self.env.action_space.n)  # [num_actions]

        next_state, is_terminal, _ = jax.vmap(
            jax.vmap(self.env._single_transition, in_axes=(None, 0, None)), in_axes=(0, None, None)
        )(all_states, actions, env_params)
        next_state_idx = jax.vmap(
            jax.vmap(self.env.state_to_index, in_axes=(0, None)), in_axes=(0, None)
        )(next_state, env_params)
        next_state_idx = next_state_idx + is_terminal * num_states

        rows = jnp.repeat(state_idx[:, None], self.env.action_space.n, axis=1).reshape(-1)
        cols = next_state_idx.reshape(-1)
        data = fwd_policy_probs.reshape(-1)

        transition_matrix = jsp.BCOO(
            (data, jnp.stack([rows, cols], axis=-1)), shape=(2 * num_states, 2 * num_states)
        )
        return jsp.BCSR.from_bcoo(transition_matrix.T)

    def get(self, metrics_state: ExactDistributionMetricsState) -> dict[str, Any]:
        return {
            metric: self._supported_metrics[metric](
                metrics_state.true_distribution, metrics_state.exact_distribution
            )
            for metric in self.metrics
        }

InitArgs

Bases: BaseInitArgs

Arguments for initializing the ExactDistributionMetricsModule.

Attributes:

Name Type Description
env_params TEnvParams

Environment parameters used to obtain the true distribution and query environment functions.

Source code in gfnx/metrics/exact_distribution.py
132
133
134
135
136
137
138
139
140
141
@chex.dataclass
class InitArgs(BaseInitArgs):
    """Arguments for initializing the ExactDistributionMetricsModule.

    Attributes:
        env_params: Environment parameters used to obtain the true distribution
            and query environment functions.
    """

    env_params: TEnvParams

ProcessArgs

Bases: BaseProcessArgs

Arguments for processing the ExactDistributionMetricsModule.

Attributes:

Name Type Description
policy_params TPolicyParams

Parameters for the forward policy used during propagation.

env_params TEnvParams

Environment parameters.

Source code in gfnx/metrics/exact_distribution.py
173
174
175
176
177
178
179
180
181
182
183
@chex.dataclass
class ProcessArgs(BaseProcessArgs):
    """Arguments for processing the ExactDistributionMetricsModule.

    Attributes:
        policy_params: Parameters for the forward policy used during propagation.
        env_params: Environment parameters.
    """

    policy_params: TPolicyParams
    env_params: TEnvParams

__init__(metrics, env, fwd_policy_fn, batch_size, tol_epsilon=1e-07)

Initialize the exact-distribution metrics module.

Validates inputs and records configuration. batch_size controls how states are grouped when evaluating the policy over the topological order for memory efficiency; it is unrelated to any replay buffer.

Parameters:

Name Type Description Default
metrics list[str]

List of metric names to compute, choose from {"tv", "kl", "jsd", "2d_marginal_distribution"}.

required
env TEnvironment

Enumerable environment for which to compute metrics.

required
fwd_policy_fn TPolicyFn

Forward policy function for generating trajectories.

required
batch_size int

Batch size used when evaluating policy over states.

required
tol_epsilon float

Tolerance for convergence in distribution computation.

1e-07

Raises:

Type Description
ValueError

If the environment is not enumerable or not topologically sortable, or if metrics is invalid or contains unsupported entries.

Source code in gfnx/metrics/exact_distribution.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def __init__(
    self,
    metrics: list[str],
    env: TEnvironment,
    fwd_policy_fn: TPolicyFn,
    batch_size: int,
    tol_epsilon: float = 1e-7,
):
    """Initialize the exact-distribution metrics module.

    Validates inputs and records configuration. ``batch_size`` controls how
    states are grouped when evaluating the policy over the topological order
    for memory efficiency; it is unrelated to any replay buffer.

    Args:
        metrics: List of metric names to compute,
            choose from {"tv", "kl", "jsd", "2d_marginal_distribution"}.
        env: Enumerable environment for which to compute metrics.
        fwd_policy_fn: Forward policy function for generating trajectories.
        batch_size: Batch size used when evaluating policy over states.
        tol_epsilon: Tolerance for convergence in distribution computation.

    Raises:
        ValueError: If the environment is not enumerable or not topologically sortable,
            or if ``metrics`` is invalid or contains unsupported entries.
    """
    if not env.is_enumerable:
        raise ValueError(f"Environment {env.name} is not enumerable")
    if not isinstance(metrics, list) or not all(isinstance(m, str) for m in metrics):
        raise ValueError("metrics must be a list of strings")
    if not all(m in self._supported_metrics for m in metrics):
        raise ValueError(
            f"Unsupported metrics. Supported metrics are: \
                {self._supported_metrics}"
        )

    self.metrics = metrics
    self.env = env
    self.fwd_policy_fn = fwd_policy_fn
    self.batch_size = batch_size
    self.tol_epsilon = tol_epsilon

init(rng_key, args)

Initialize the metric state.

Obtains the ground-truth terminal distribution from the environment and initializes the exact_distribution with a uniform placeholder of the same shape. The exact distribution will be computed in process.

Parameters:

Name Type Description Default
rng_key PRNGKey

JAX PRNG key (currently unused).

required
args InitArgs

Initialization arguments containing environment parameters.

required

Returns:

Name Type Description
ExactDistributionMetricsState ExactDistributionMetricsState

Initialized state containing: - true_distribution: Ground truth distribution from the environment - exact_distribution: Initial uniform distribution

Source code in gfnx/metrics/exact_distribution.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def init(self, rng_key: chex.PRNGKey, args: InitArgs) -> ExactDistributionMetricsState:
    """Initialize the metric state.

    Obtains the ground-truth terminal distribution from the environment and
    initializes the ``exact_distribution`` with a uniform placeholder of the
    same shape. The exact distribution will be computed in ``process``.

    Args:
        rng_key: JAX PRNG key (currently unused).
        args: Initialization arguments containing environment parameters.

    Returns:
        ExactDistributionMetricsState: Initialized state containing:
            - true_distribution: Ground truth distribution from the environment
            - exact_distribution: Initial uniform distribution
    """
    true_distribution = self.env.get_true_distribution(args.env_params)
    exact_distribution = jnp.ones_like(true_distribution) / true_distribution.size
    return ExactDistributionMetricsState(
        true_distribution=true_distribution,
        exact_distribution=exact_distribution,
    )

process(metrics_state, rng_key, args)

Compute the exact terminal distribution induced by the forward policy. Uses a simple power iteration method to propagate the initial state distribution: - Initialize the state distribution with all mass on the initial state. - Repeatedly apply the transition matrix induced by the forward policy until convergence (L1 norm change below tol_epsilon). - Extract the terminal distribution from the converged state distribution.

Overall, we have final_distribution = sum_{t=0}^inf (transition_matrix^t) @ initial_distribution.

Parameters:

Name Type Description Default
metrics_state ExactDistributionMetricsState

Current metrics state containing the true distribution.

required
rng_key PRNGKey

JAX PRNG key for any stochasticity (currently unused).

required
args ProcessArgs

Processing arguments containing policy and environment parameters.

required

Returns:

Name Type Description
ExactDistributionMetricsState ExactDistributionMetricsState

Updated metrics state with the computed

ExactDistributionMetricsState

exact distribution.

Source code in gfnx/metrics/exact_distribution.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def process(
    self,
    metrics_state: ExactDistributionMetricsState,
    rng_key: chex.PRNGKey,
    args: ProcessArgs,
) -> ExactDistributionMetricsState:
    """
    Compute the exact terminal distribution induced by the forward policy.
    Uses a simple power iteration method to propagate the initial state distribution:
        - Initialize the state distribution with all mass on the initial state.
        - Repeatedly apply the transition matrix induced by the forward policy
          until convergence (L1 norm change below ``tol_epsilon``).
        - Extract the terminal distribution from the converged state distribution.

    Overall, we have
        `final_distribution = sum_{t=0}^inf (transition_matrix^t) @ initial_distribution`.

    Args:
        metrics_state: Current metrics state containing the true distribution.
        rng_key: JAX PRNG key for any stochasticity (currently unused).
        args: Processing arguments containing policy and environment parameters.

    Returns:
        ExactDistributionMetricsState: Updated metrics state with the computed
        exact distribution.
    """
    transition = self._preprare_transition_matrix(
        rng_key,
        args.env_params,
        args.policy_params,
    )
    num_states = transition.shape[0] // 2

    initial_vector = jnp.zeros((2 * num_states,))
    initial_state = jax.tree_map(lambda x: x[0], self.env.get_init_state(1))
    initial_idx = self.env.state_to_index(initial_state, args.env_params)
    initial_vector = initial_vector.at[initial_idx].set(1.0)

    def cond_function(carry: tuple):
        last_vec, _, _ = carry
        return jnp.linalg.norm(last_vec, ord=1) > self.tol_epsilon

    def one_step(carry: tuple):
        last_vec, total, transition = carry
        last_vec = transition @ last_vec
        return last_vec, total + last_vec, transition

    _, result, _ = jax.lax.while_loop(
        cond_function, one_step, (initial_vector, initial_vector, transition)
    )

    exact_distribution = result[num_states:].reshape(*metrics_state.true_distribution.shape)
    chex.assert_shape(exact_distribution, metrics_state.true_distribution.shape)
    return metrics_state.replace(exact_distribution=exact_distribution)

ExactDistributionMetricsState

Bases: MetricsState

State for exact distribution metrics.

Holds the ground-truth terminal distribution from the environment and the exact terminal distribution computed under a given forward policy.

Attributes:

Name Type Description
true_distribution Array

Ground-truth terminal distribution provided by the environment

exact_distribution Array

Exact terminal distribution computed by process

Source code in gfnx/metrics/exact_distribution.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@chex.dataclass
class ExactDistributionMetricsState(MetricsState):
    """State for exact distribution metrics.

    Holds the ground-truth terminal distribution from the environment and the
    exact terminal distribution computed under a given forward policy.

    Attributes:
        true_distribution: Ground-truth terminal distribution provided by the environment
        exact_distribution: Exact terminal distribution computed by ``process``
    """

    true_distribution: chex.Array
    exact_distribution: chex.Array

MeanRewardMetricsModule

Bases: BaseMetricsModule

Metric module for computing mean reward and its deviation from ground truth.

This module tracks the empirical mean reward from collected samples and compares it against the known ground truth mean reward of the environment. It computes both absolute and relative deviations to assess how well the policy's reward distribution matches the true environment distribution.

Attributes:

Name Type Description
env

Environment instance that must support tractable mean reward computation

gt_mean_reward

Ground truth mean reward from the environment

Source code in gfnx/metrics/reward_delta.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
class MeanRewardMetricsModule(BaseMetricsModule):
    """Metric module for computing mean reward and its deviation from ground truth.

    This module tracks the empirical mean reward from collected samples and compares
    it against the known ground truth mean reward of the environment. It computes
    both absolute and relative deviations to assess how well the policy's reward
    distribution matches the true environment distribution.

    Attributes:
        env: Environment instance that must support tractable mean reward computation
        gt_mean_reward: Ground truth mean reward from the environment
    """

    def __init__(self, env: TEnvironment, env_params: TEnvParams):
        """Initialize the mean reward metric module.

        Args:
            env: Environment instance that must have tractable mean reward computation
                (env.is_mean_reward_tractable must be True)
            env_params: Environment parameters needed to compute the ground truth mean reward

        Raises:
            ValueError: If the environment does not support tractable mean reward computation
        """
        self.env = env
        if self.env.is_mean_reward_tractable:
            self.gt_mean_reward = self.env.get_mean_reward(env_params)
        else:
            raise ValueError("Ground truth mean reward is not tractable for this environment.")

    InitArgs = EmptyInitArgs

    def init(self, rng_key: chex.PRNGKey, args: InitArgs | None = None) -> MeanRewardMetricsState:
        """Initialize the mean reward metric state.

        Creates initial state with zero cumulative reward and zero sample count.
        The actual mean reward computation will be performed incrementally as
        new reward samples are added via update().

        Args:
            rng_key: JAX PRNG key for any random initialization (currently unused)
            args: EmptyInitArgs (no additional initialization parameters needed)

        Returns:
            MeanRewardMetricsState: Initialized state with zero accumulated reward and count
        """
        return MeanRewardMetricsState(sum_reward=0.0, num=0)

    @chex.dataclass
    class UpdateArgs(BaseUpdateArgs):
        """Arguments for updating the MeanRewardMetricsModule.

        Attributes:
            log_rewards: Array of log-reward values to add to the running statistics.
                These are expected to be in log space and will be converted to
                linear space for mean computation.
        """

        log_rewards: chex.Array

    def update(
        self, metrics_state: MeanRewardMetricsState, rng_key: chex.PRNGKey, args: UpdateArgs
    ) -> MeanRewardMetricsState:
        """Update the metric state with new reward samples.

        Adds new reward samples to the running statistics by converting log-rewards
        to linear space and accumulating them in the sum. Also updates the sample count.

        Args:
            metrics_state: Current metric state with accumulated reward sum and count
            rng_key: JAX PRNG key for any random operations (currently unused)
            args: UpdateArgs object containing the new reward samples

        Returns:
            MeanRewardMetricsState: Updated state with new rewards incorporated
                into the running statistics
        """
        return MeanRewardMetricsState(
            sum_reward=metrics_state.sum_reward + jnp.sum(args.log_rewards),
            num=metrics_state.num + args.log_rewards.shape[0],
        )

    ProcessArgs = EmptyProcessArgs

    def process(
        self,
        metrics_state: MeanRewardMetricsState,
        rng_key: chex.PRNGKey,
        args: ProcessArgs | None = None,
    ) -> MeanRewardMetricsState:
        """Process the metric state for final computation (no-op for mean reward metrics).

        This method performs any final processing needed before metric computation.
        For mean reward metrics, no additional processing is required as the
        statistics are maintained incrementally during updates.

        Args:
            metrics_state: Current metric state with accumulated reward statistics
            rng_key: JAX PRNG key for any random operations (currently unused)
            args: EmptyProcessArgs (no additional processing parameters needed)

        Returns:
            MeanRewardMetricsState: Unchanged metric state ready for get() call
        """
        return metrics_state

    def get(self, metrics_state: MeanRewardMetricsState) -> dict[str, float]:
        """Get the computed mean reward metrics from the current state.

        Computes the empirical mean reward from accumulated samples and calculates
        both absolute and relative deviations from the ground truth mean reward.

        Args:
            metrics_state: Current metric state containing accumulated reward statistics

        Returns:
            Dict[str, float]: Dictionary containing computed reward metrics:
                - 'mean_reward': Empirical mean reward from collected samples
                - 'reward_delta': Absolute difference between empirical and ground truth mean
                - 'rel_reward_delta': Relative difference (normalized by ground truth mean)
        """
        mean_reward = metrics_state.sum_reward / jnp.maximum(metrics_state.num, 1)
        reward_delta = abs(mean_reward - self.gt_mean_reward)
        rel_reward_delta = reward_delta / self.gt_mean_reward
        return {
            "mean_reward": mean_reward,
            "reward_delta": reward_delta,
            "rel_reward_delta": rel_reward_delta,
        }

UpdateArgs

Bases: BaseUpdateArgs

Arguments for updating the MeanRewardMetricsModule.

Attributes:

Name Type Description
log_rewards Array

Array of log-reward values to add to the running statistics. These are expected to be in log space and will be converted to linear space for mean computation.

Source code in gfnx/metrics/reward_delta.py
82
83
84
85
86
87
88
89
90
91
92
@chex.dataclass
class UpdateArgs(BaseUpdateArgs):
    """Arguments for updating the MeanRewardMetricsModule.

    Attributes:
        log_rewards: Array of log-reward values to add to the running statistics.
            These are expected to be in log space and will be converted to
            linear space for mean computation.
    """

    log_rewards: chex.Array

__init__(env, env_params)

Initialize the mean reward metric module.

Parameters:

Name Type Description Default
env TEnvironment

Environment instance that must have tractable mean reward computation (env.is_mean_reward_tractable must be True)

required
env_params TEnvParams

Environment parameters needed to compute the ground truth mean reward

required

Raises:

Type Description
ValueError

If the environment does not support tractable mean reward computation

Source code in gfnx/metrics/reward_delta.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(self, env: TEnvironment, env_params: TEnvParams):
    """Initialize the mean reward metric module.

    Args:
        env: Environment instance that must have tractable mean reward computation
            (env.is_mean_reward_tractable must be True)
        env_params: Environment parameters needed to compute the ground truth mean reward

    Raises:
        ValueError: If the environment does not support tractable mean reward computation
    """
    self.env = env
    if self.env.is_mean_reward_tractable:
        self.gt_mean_reward = self.env.get_mean_reward(env_params)
    else:
        raise ValueError("Ground truth mean reward is not tractable for this environment.")

get(metrics_state)

Get the computed mean reward metrics from the current state.

Computes the empirical mean reward from accumulated samples and calculates both absolute and relative deviations from the ground truth mean reward.

Parameters:

Name Type Description Default
metrics_state MeanRewardMetricsState

Current metric state containing accumulated reward statistics

required

Returns:

Type Description
dict[str, float]

Dict[str, float]: Dictionary containing computed reward metrics: - 'mean_reward': Empirical mean reward from collected samples - 'reward_delta': Absolute difference between empirical and ground truth mean - 'rel_reward_delta': Relative difference (normalized by ground truth mean)

Source code in gfnx/metrics/reward_delta.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def get(self, metrics_state: MeanRewardMetricsState) -> dict[str, float]:
    """Get the computed mean reward metrics from the current state.

    Computes the empirical mean reward from accumulated samples and calculates
    both absolute and relative deviations from the ground truth mean reward.

    Args:
        metrics_state: Current metric state containing accumulated reward statistics

    Returns:
        Dict[str, float]: Dictionary containing computed reward metrics:
            - 'mean_reward': Empirical mean reward from collected samples
            - 'reward_delta': Absolute difference between empirical and ground truth mean
            - 'rel_reward_delta': Relative difference (normalized by ground truth mean)
    """
    mean_reward = metrics_state.sum_reward / jnp.maximum(metrics_state.num, 1)
    reward_delta = abs(mean_reward - self.gt_mean_reward)
    rel_reward_delta = reward_delta / self.gt_mean_reward
    return {
        "mean_reward": mean_reward,
        "reward_delta": reward_delta,
        "rel_reward_delta": rel_reward_delta,
    }

init(rng_key, args=None)

Initialize the mean reward metric state.

Creates initial state with zero cumulative reward and zero sample count. The actual mean reward computation will be performed incrementally as new reward samples are added via update().

Parameters:

Name Type Description Default
rng_key PRNGKey

JAX PRNG key for any random initialization (currently unused)

required
args InitArgs | None

EmptyInitArgs (no additional initialization parameters needed)

None

Returns:

Name Type Description
MeanRewardMetricsState MeanRewardMetricsState

Initialized state with zero accumulated reward and count

Source code in gfnx/metrics/reward_delta.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def init(self, rng_key: chex.PRNGKey, args: InitArgs | None = None) -> MeanRewardMetricsState:
    """Initialize the mean reward metric state.

    Creates initial state with zero cumulative reward and zero sample count.
    The actual mean reward computation will be performed incrementally as
    new reward samples are added via update().

    Args:
        rng_key: JAX PRNG key for any random initialization (currently unused)
        args: EmptyInitArgs (no additional initialization parameters needed)

    Returns:
        MeanRewardMetricsState: Initialized state with zero accumulated reward and count
    """
    return MeanRewardMetricsState(sum_reward=0.0, num=0)

process(metrics_state, rng_key, args=None)

Process the metric state for final computation (no-op for mean reward metrics).

This method performs any final processing needed before metric computation. For mean reward metrics, no additional processing is required as the statistics are maintained incrementally during updates.

Parameters:

Name Type Description Default
metrics_state MeanRewardMetricsState

Current metric state with accumulated reward statistics

required
rng_key PRNGKey

JAX PRNG key for any random operations (currently unused)

required
args ProcessArgs | None

EmptyProcessArgs (no additional processing parameters needed)

None

Returns:

Name Type Description
MeanRewardMetricsState MeanRewardMetricsState

Unchanged metric state ready for get() call

Source code in gfnx/metrics/reward_delta.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def process(
    self,
    metrics_state: MeanRewardMetricsState,
    rng_key: chex.PRNGKey,
    args: ProcessArgs | None = None,
) -> MeanRewardMetricsState:
    """Process the metric state for final computation (no-op for mean reward metrics).

    This method performs any final processing needed before metric computation.
    For mean reward metrics, no additional processing is required as the
    statistics are maintained incrementally during updates.

    Args:
        metrics_state: Current metric state with accumulated reward statistics
        rng_key: JAX PRNG key for any random operations (currently unused)
        args: EmptyProcessArgs (no additional processing parameters needed)

    Returns:
        MeanRewardMetricsState: Unchanged metric state ready for get() call
    """
    return metrics_state

update(metrics_state, rng_key, args)

Update the metric state with new reward samples.

Adds new reward samples to the running statistics by converting log-rewards to linear space and accumulating them in the sum. Also updates the sample count.

Parameters:

Name Type Description Default
metrics_state MeanRewardMetricsState

Current metric state with accumulated reward sum and count

required
rng_key PRNGKey

JAX PRNG key for any random operations (currently unused)

required
args UpdateArgs

UpdateArgs object containing the new reward samples

required

Returns:

Name Type Description
MeanRewardMetricsState MeanRewardMetricsState

Updated state with new rewards incorporated into the running statistics

Source code in gfnx/metrics/reward_delta.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def update(
    self, metrics_state: MeanRewardMetricsState, rng_key: chex.PRNGKey, args: UpdateArgs
) -> MeanRewardMetricsState:
    """Update the metric state with new reward samples.

    Adds new reward samples to the running statistics by converting log-rewards
    to linear space and accumulating them in the sum. Also updates the sample count.

    Args:
        metrics_state: Current metric state with accumulated reward sum and count
        rng_key: JAX PRNG key for any random operations (currently unused)
        args: UpdateArgs object containing the new reward samples

    Returns:
        MeanRewardMetricsState: Updated state with new rewards incorporated
            into the running statistics
    """
    return MeanRewardMetricsState(
        sum_reward=metrics_state.sum_reward + jnp.sum(args.log_rewards),
        num=metrics_state.num + args.log_rewards.shape[0],
    )

MeanRewardMetricsState

Bases: MetricsState

State for accumulating mean reward computation.

This state container tracks the cumulative sum of rewards and the count of samples to compute running mean reward statistics. It enables incremental computation of mean reward without storing all individual reward values.

Attributes:

Name Type Description
sum_reward float

Cumulative sum of all rewards (in linear space, not log space)

num int

Total number of reward samples processed

Source code in gfnx/metrics/reward_delta.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
@chex.dataclass
class MeanRewardMetricsState(MetricsState):
    """State for accumulating mean reward computation.

    This state container tracks the cumulative sum of rewards and the count
    of samples to compute running mean reward statistics. It enables incremental
    computation of mean reward without storing all individual reward values.

    Attributes:
        sum_reward: Cumulative sum of all rewards (in linear space, not log space)
        num: Total number of reward samples processed
    """

    sum_reward: float
    num: int

MetricsState

Base state for metric computation.

This is an abstract base class that serves as a pure data container for metric states. Metric states are designed to only store the necessary data and intermediate values required for computing metrics during training and evaluation. They should not contain any methods or processing logic.

All processing, computation, and transformation logic should be implemented in the corresponding BaseMetricsModule subclasses, not in the state objects themselves.

Subclasses should define specific data fields (typically using @chex.dataclass) needed for their metric computation requirements, but no methods.

Source code in gfnx/metrics/base.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class MetricsState:
    """Base state for metric computation.

    This is an abstract base class that serves as a pure data container for metric states.
    Metric states are designed to only store the necessary data and intermediate values
    required for computing metrics during training and evaluation. They should not
    contain any methods or processing logic.

    All processing, computation, and transformation logic should be implemented in the
    corresponding BaseMetricsModule subclasses, not in the state objects themselves.

    Subclasses should define specific data fields (typically using @chex.dataclass)
    needed for their metric computation requirements, but no methods.
    """

MultiMetricsModule

Bases: BaseMetricsModule

Module for handling multiple metrics in a unified way.

This class implements the BaseMetricsModule interface to manage multiple individual metric modules simultaneously. It provides a convenient way to compute multiple metrics together while maintaining the same interface as single metric modules.

The MultiMetricsModule coordinates the lifecycle of all contained metrics, calling their respective init, update, process, and get methods in sequence. Final metrics are returned with prefixed names to avoid conflicts between different metric modules.

Attributes:

Name Type Description
metrics

Dictionary of metric modules indexed by their names

_supported_metrics

Internal mapping of metric names to their get methods

Source code in gfnx/metrics/base.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
class MultiMetricsModule(BaseMetricsModule):
    """Module for handling multiple metrics in a unified way.

    This class implements the BaseMetricsModule interface to manage multiple
    individual metric modules simultaneously. It provides a convenient way to
    compute multiple metrics together while maintaining the same interface
    as single metric modules.

    The MultiMetricsModule coordinates the lifecycle of all contained metrics,
    calling their respective init, update, process, and get methods in sequence.
    Final metrics are returned with prefixed names to avoid conflicts between
    different metric modules.

    Attributes:
        metrics: Dictionary of metric modules indexed by their names
        _supported_metrics: Internal mapping of metric names to their get methods
    """

    def __init__(self, metrics: dict[str, BaseMetricsModule]):
        """Initialize the MultiMetricsModule with a collection of metrics.

        Args:
            metrics: Dictionary mapping metric names to their corresponding
                    BaseMetricsModule instances. Names should be unique and
                    descriptive as they will be used as prefixes in the final
                    metric output.
        """
        self.metrics = metrics
        self._supported_metrics = {name: metric.get for name, metric in metrics.items()}

    @chex.dataclass
    class InitArgs(BaseInitArgs):
        """Arguments for initializing the MultiMetricsModule."""

        metrics_args: dict[str, BaseInitArgs]

    def init(self, rng_key: chex.PRNGKey, args: InitArgs | None = None) -> MultiMetricsState:
        """Initialize all contained metrics.

        Calls the init method on each metric module and collects their
        individual states into a MultiMetricsState container.

        Args:
            rng_key: JAX PRNG key passed to each metric's init method
            args: Optional InitArgs object mapping metric names to init args

        Returns:
            MultiMetricsState: Container holding all initialized metric states
        """
        if args is None:
            args = self.ProcessArgs(metrics_args={})
        metrics_keys = jax.random.split(rng_key, len(self.metrics))
        dict_metrics_keys = dict(zip(self.metrics.keys(), metrics_keys))
        states = {
            name: metric.init(rng_key=dict_metrics_keys[name], args=args.metrics_args.get(name))
            for name, metric in self.metrics.items()
        }
        return MultiMetricsState(states=states)

    @chex.dataclass
    class UpdateArgs(BaseUpdateArgs):
        """Arguments for updating the MultiMetricsModule."""

        metrics_args: dict[str, Any]

    def update(
        self,
        metrics_state: MultiMetricsState,
        rng_key: chex.PRNGKey,
        args: UpdateArgs | None = None,
    ) -> MultiMetricsState:
        """Update all contained metrics with new data.

        Calls the update method on each metric module with their corresponding
        state and the provided data, then returns a new MultiMetricsState with
        all updated states.

        Args:
            metrics_state: Current MultiMetricsState containing all metric states
            rng_key: JAX PRNG key passed to each metric's update method
            args: Optional UpdateArgs object mapping metric names to update data

        Returns:
            MultiMetricsState: Container with all updated metric states
        """
        if args is None:
            args = self.ProcessArgs(metrics_args={})
        metrics_keys = jax.random.split(rng_key, len(self.metrics))
        dict_metrics_keys = dict(zip(self.metrics.keys(), metrics_keys))
        updated_states = {
            name: metric.update(
                metrics_state=metrics_state.states[name],
                rng_key=dict_metrics_keys[name],
                args=args.metrics_args.get(name),
            )
            for name, metric in self.metrics.items()
        }
        return metrics_state.replace(states=updated_states)

    @chex.dataclass
    class ProcessArgs(BaseProcessArgs):
        """Arguments for processing the MultiMetricsModule."""

        metrics_args: dict[str, Any]

    def process(
        self,
        metrics_state: MultiMetricsState,
        rng_key: chex.PRNGKey,
        args: ProcessArgs | None = None,
    ) -> MultiMetricsState:
        """Process all metric states to compute metrics and perform final transformations.

        Calls the process method on each metric module with their corresponding
        state to compute metrics and apply final transformations during the
        evaluation period, preparing them for result retrieval.

        Args:
            metrics_state: Current MultiMetricsState containing all metric states
            rng_key: JAX PRNG key passed to each metric's process method
            args: Optional ProcessArgs object mapping metric names to process params

        Returns:
            MultiMetricsState: Container with all processed metric states ready for get()
        """
        if args is None:
            args = self.ProcessArgs(metrics_args={})
        metrics_keys = jax.random.split(rng_key, len(self.metrics))
        dict_metrics_keys = dict(zip(self.metrics.keys(), metrics_keys))
        processed_states = {
            name: metric.process(
                metrics_state=metrics_state.states[name],
                rng_key=dict_metrics_keys[name],
                args=args.metrics_args.get(name),
            )
            for name, metric in self.metrics.items()
        }
        return metrics_state.replace(states=processed_states)

    def get(self, metrics_state: MultiMetricsState) -> dict[str, Any]:
        """Get computed metrics from all contained modules.

        Collects metrics from all metric modules and returns them with
        prefixed names to avoid conflicts. Each metric's results are
        prefixed with its module name followed by a forward slash.

        Args:
            metrics_state: Current MultiMetricsState containing all processed states

        Returns:
            Dict[str, Any]: Dictionary of all computed metrics with prefixed names.
                          Keys are in the format "{metric_name}/{original_key}".

        Example:
            If a metric module named "accuracy" returns {"score": 0.95},
            the result will include {"accuracy/score": 0.95}.
        """
        results = {}
        for name, state in metrics_state.states.items():
            metrics = self.metrics[name].get(state)
            for key, value in metrics.items():
                results[f"{name}/{key}"] = value
        return results

InitArgs

Bases: BaseInitArgs

Arguments for initializing the MultiMetricsModule.

Source code in gfnx/metrics/base.py
259
260
261
262
263
@chex.dataclass
class InitArgs(BaseInitArgs):
    """Arguments for initializing the MultiMetricsModule."""

    metrics_args: dict[str, BaseInitArgs]

ProcessArgs

Bases: BaseProcessArgs

Arguments for processing the MultiMetricsModule.

Source code in gfnx/metrics/base.py
328
329
330
331
332
@chex.dataclass
class ProcessArgs(BaseProcessArgs):
    """Arguments for processing the MultiMetricsModule."""

    metrics_args: dict[str, Any]

UpdateArgs

Bases: BaseUpdateArgs

Arguments for updating the MultiMetricsModule.

Source code in gfnx/metrics/base.py
288
289
290
291
292
@chex.dataclass
class UpdateArgs(BaseUpdateArgs):
    """Arguments for updating the MultiMetricsModule."""

    metrics_args: dict[str, Any]

__init__(metrics)

Initialize the MultiMetricsModule with a collection of metrics.

Parameters:

Name Type Description Default
metrics dict[str, BaseMetricsModule]

Dictionary mapping metric names to their corresponding BaseMetricsModule instances. Names should be unique and descriptive as they will be used as prefixes in the final metric output.

required
Source code in gfnx/metrics/base.py
247
248
249
250
251
252
253
254
255
256
257
def __init__(self, metrics: dict[str, BaseMetricsModule]):
    """Initialize the MultiMetricsModule with a collection of metrics.

    Args:
        metrics: Dictionary mapping metric names to their corresponding
                BaseMetricsModule instances. Names should be unique and
                descriptive as they will be used as prefixes in the final
                metric output.
    """
    self.metrics = metrics
    self._supported_metrics = {name: metric.get for name, metric in metrics.items()}

get(metrics_state)

Get computed metrics from all contained modules.

Collects metrics from all metric modules and returns them with prefixed names to avoid conflicts. Each metric's results are prefixed with its module name followed by a forward slash.

Parameters:

Name Type Description Default
metrics_state MultiMetricsState

Current MultiMetricsState containing all processed states

required

Returns:

Type Description
dict[str, Any]

Dict[str, Any]: Dictionary of all computed metrics with prefixed names. Keys are in the format "{metric_name}/{original_key}".

Example

If a metric module named "accuracy" returns {"score": 0.95}, the result will include {"accuracy/score": 0.95}.

Source code in gfnx/metrics/base.py
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
def get(self, metrics_state: MultiMetricsState) -> dict[str, Any]:
    """Get computed metrics from all contained modules.

    Collects metrics from all metric modules and returns them with
    prefixed names to avoid conflicts. Each metric's results are
    prefixed with its module name followed by a forward slash.

    Args:
        metrics_state: Current MultiMetricsState containing all processed states

    Returns:
        Dict[str, Any]: Dictionary of all computed metrics with prefixed names.
                      Keys are in the format "{metric_name}/{original_key}".

    Example:
        If a metric module named "accuracy" returns {"score": 0.95},
        the result will include {"accuracy/score": 0.95}.
    """
    results = {}
    for name, state in metrics_state.states.items():
        metrics = self.metrics[name].get(state)
        for key, value in metrics.items():
            results[f"{name}/{key}"] = value
    return results

init(rng_key, args=None)

Initialize all contained metrics.

Calls the init method on each metric module and collects their individual states into a MultiMetricsState container.

Parameters:

Name Type Description Default
rng_key PRNGKey

JAX PRNG key passed to each metric's init method

required
args InitArgs | None

Optional InitArgs object mapping metric names to init args

None

Returns:

Name Type Description
MultiMetricsState MultiMetricsState

Container holding all initialized metric states

Source code in gfnx/metrics/base.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def init(self, rng_key: chex.PRNGKey, args: InitArgs | None = None) -> MultiMetricsState:
    """Initialize all contained metrics.

    Calls the init method on each metric module and collects their
    individual states into a MultiMetricsState container.

    Args:
        rng_key: JAX PRNG key passed to each metric's init method
        args: Optional InitArgs object mapping metric names to init args

    Returns:
        MultiMetricsState: Container holding all initialized metric states
    """
    if args is None:
        args = self.ProcessArgs(metrics_args={})
    metrics_keys = jax.random.split(rng_key, len(self.metrics))
    dict_metrics_keys = dict(zip(self.metrics.keys(), metrics_keys))
    states = {
        name: metric.init(rng_key=dict_metrics_keys[name], args=args.metrics_args.get(name))
        for name, metric in self.metrics.items()
    }
    return MultiMetricsState(states=states)

process(metrics_state, rng_key, args=None)

Process all metric states to compute metrics and perform final transformations.

Calls the process method on each metric module with their corresponding state to compute metrics and apply final transformations during the evaluation period, preparing them for result retrieval.

Parameters:

Name Type Description Default
metrics_state MultiMetricsState

Current MultiMetricsState containing all metric states

required
rng_key PRNGKey

JAX PRNG key passed to each metric's process method

required
args ProcessArgs | None

Optional ProcessArgs object mapping metric names to process params

None

Returns:

Name Type Description
MultiMetricsState MultiMetricsState

Container with all processed metric states ready for get()

Source code in gfnx/metrics/base.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def process(
    self,
    metrics_state: MultiMetricsState,
    rng_key: chex.PRNGKey,
    args: ProcessArgs | None = None,
) -> MultiMetricsState:
    """Process all metric states to compute metrics and perform final transformations.

    Calls the process method on each metric module with their corresponding
    state to compute metrics and apply final transformations during the
    evaluation period, preparing them for result retrieval.

    Args:
        metrics_state: Current MultiMetricsState containing all metric states
        rng_key: JAX PRNG key passed to each metric's process method
        args: Optional ProcessArgs object mapping metric names to process params

    Returns:
        MultiMetricsState: Container with all processed metric states ready for get()
    """
    if args is None:
        args = self.ProcessArgs(metrics_args={})
    metrics_keys = jax.random.split(rng_key, len(self.metrics))
    dict_metrics_keys = dict(zip(self.metrics.keys(), metrics_keys))
    processed_states = {
        name: metric.process(
            metrics_state=metrics_state.states[name],
            rng_key=dict_metrics_keys[name],
            args=args.metrics_args.get(name),
        )
        for name, metric in self.metrics.items()
    }
    return metrics_state.replace(states=processed_states)

update(metrics_state, rng_key, args=None)

Update all contained metrics with new data.

Calls the update method on each metric module with their corresponding state and the provided data, then returns a new MultiMetricsState with all updated states.

Parameters:

Name Type Description Default
metrics_state MultiMetricsState

Current MultiMetricsState containing all metric states

required
rng_key PRNGKey

JAX PRNG key passed to each metric's update method

required
args UpdateArgs | None

Optional UpdateArgs object mapping metric names to update data

None

Returns:

Name Type Description
MultiMetricsState MultiMetricsState

Container with all updated metric states

Source code in gfnx/metrics/base.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def update(
    self,
    metrics_state: MultiMetricsState,
    rng_key: chex.PRNGKey,
    args: UpdateArgs | None = None,
) -> MultiMetricsState:
    """Update all contained metrics with new data.

    Calls the update method on each metric module with their corresponding
    state and the provided data, then returns a new MultiMetricsState with
    all updated states.

    Args:
        metrics_state: Current MultiMetricsState containing all metric states
        rng_key: JAX PRNG key passed to each metric's update method
        args: Optional UpdateArgs object mapping metric names to update data

    Returns:
        MultiMetricsState: Container with all updated metric states
    """
    if args is None:
        args = self.ProcessArgs(metrics_args={})
    metrics_keys = jax.random.split(rng_key, len(self.metrics))
    dict_metrics_keys = dict(zip(self.metrics.keys(), metrics_keys))
    updated_states = {
        name: metric.update(
            metrics_state=metrics_state.states[name],
            rng_key=dict_metrics_keys[name],
            args=args.metrics_args.get(name),
        )
        for name, metric in self.metrics.items()
    }
    return metrics_state.replace(states=updated_states)

MultiMetricsState

Bases: MetricsState

State container for multiple metrics.

This class extends MetricsState to hold the states of multiple individual metric modules. It provides a unified interface for managing the states of different metrics that need to be computed together.

Attributes:

Name Type Description
states dict[str, MetricsState]

Dictionary mapping metric names to their individual MetricsState objects. Each key represents a unique metric identifier, and each value is the corresponding metric's state object.

Source code in gfnx/metrics/base.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
@chex.dataclass
class MultiMetricsState(MetricsState):
    """State container for multiple metrics.

    This class extends MetricsState to hold the states of multiple individual
    metric modules. It provides a unified interface for managing the states
    of different metrics that need to be computed together.

    Attributes:
        states: Dictionary mapping metric names to their individual MetricsState objects.
               Each key represents a unique metric identifier, and each value is the
               corresponding metric's state object.
    """

    states: dict[str, MetricsState]

OnPolicyCorrelationMetricsModule

Bases: BaseCorrelationMetricsModule

On-policy correlation metric module for GFlowNet evaluation.

This metric module computes correlation metrics between transformed model-predicted marginal probabilities and transformed log true rewards. During evaluation, it generates fresh terminal states by performing forward rollouts with the current policy, then computes backward trajectory log-ratios.

Attributes:

Name Type Description
n_terminal_states

Number of terminal states to sample for evaluation

domain_size

Number of points to compute correlations on after marginalization

fwd_policy_fn

Forward policy function for generating trajectories

Source code in gfnx/metrics/correlation.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
class OnPolicyCorrelationMetricsModule(BaseCorrelationMetricsModule):
    """On-policy correlation metric module for GFlowNet evaluation.

    This metric module computes correlation metrics between transformed model-predicted
    marginal probabilities and transformed log true rewards. During evaluation,
    it generates fresh terminal states by performing forward rollouts with the current policy,
    then computes backward trajectory log-ratios.

    Attributes:
        n_terminal_states: Number of terminal states to sample for evaluation
        domain_size: Number of points to compute correlations on after marginalization
        fwd_policy_fn: Forward policy function for generating trajectories
    """

    def __init__(
        self,
        n_rounds: int,
        n_terminal_states: int,
        batch_size: int,
        fwd_policy_fn: TPolicyFn,
        bwd_policy_fn: TPolicyFn,
        env: TEnvironment,
        domain_size: int | None = None,
        transform_fn: Callable[[TEnvState, jnp.ndarray], jnp.ndarray] = lambda x, y: y,
    ):
        """Initialize the on-policy correlation metric module.

        Args:
            n_rounds: Number of sampling rounds for statistical stability
            n_terminal_states: Number of terminal states to generate and evaluate
            batch_size: Batch size for efficient processing of terminal states
            fwd_policy_fn: Forward policy function for generating trajectories
            bwd_policy_fn: Backward policy function for computing log-ratios
            env: Environment for trajectory generation and reward computation
            domain_size: Number of points to compute correlations on,
                corresponds to the output size of the transform_fn.
                Defaults to the number of terminal states and corresponds to
                the identity transform.
            transform_fn: Function to marginalize distributions from terminal states
                to an arbitrary domain size for correlation computation
        """
        super().__init__(
            env=env,
            bwd_policy_fn=bwd_policy_fn,
            n_rounds=n_rounds,
            batch_size=batch_size,
            transform_fn=transform_fn,
        )
        self.fwd_policy_fn = fwd_policy_fn
        self.n_terminal_states = n_terminal_states
        self.domain_size = domain_size or n_terminal_states
        assert n_terminal_states % batch_size == 0, (
            f"n_terminal_states ({n_terminal_states}) must be divisible"
            f" by batch_size ({batch_size})"
        )

    @chex.dataclass
    class InitArgs(BaseInitArgs):
        """Arguments for initializing the on-policy correlation metric module.

        Attributes:
            env_params: Environment parameters needed to create dummy terminal states
                during initialization and for environment operations.
        """

        env_params: TEnvParams

    def init(self, rng_key: chex.PRNGKey, args: InitArgs) -> CorrelationMetricsState:
        """Initialize the on-policy correlation metric state.

        Creates an initial state with placeholder data structures. The actual
        terminal states and rewards will be generated during the process phase
        using on-policy sampling.

        The arrays are initialized with domain_size, which represents the size of the
        transformed probability space after marginalization. This allows the correlation
        metrics to be computed on an arbitrary domain rather than just the terminal states.

        Args:
            rng_key: Random number generator key (unused in initialization)
            args: InitArgs object containing environment parameters

        Returns:
            CorrelationMetricsState: Initialized state with dummy terminal states
                and zero-initialized arrays for rewards and log-ratios with shape
                (domain_size,) representing the transformed probability space
        """
        dummy_terminal_states = self.env.reset(self.n_terminal_states, args.env_params)[1]
        return CorrelationMetricsState(
            test_terminal_states=dummy_terminal_states,
            test_log_rewards_transformed=jnp.zeros(self.domain_size),
            log_ratio_traj_transformed=jnp.zeros(self.domain_size),
        )

    def _get_states_and_rewards(
        self,
        metrics_state: CorrelationMetricsState,
        rng_key: chex.PRNGKey,
        args: BaseCorrelationMetricsModule.ProcessArgs,
    ) -> tuple[TEnvState, jnp.ndarray]:
        """Generate fresh data on-policy.

        This method performs the core computation for on-policy correlation evaluation:
        1. Generates fresh terminal states using forward rollouts with current policy
        2. Extracts log-rewards for these terminal states
        3. Transforms the log-rewards using the transform function
        4. Returns the terminal states and transformed log-rewards

        Args:
            metrics_state: Current metric state (largely ignored for on-policy)
            rng_key: Random number generator key for sampling
            args: ProcessArgs object containing policy parameters and environment parameters

        Returns:
            Tuple[TEnvState, jnp.ndarray]: A tuple containing:
                - Terminal states generated via forward rollouts
                - Transformed log-rewards corresponding to the terminal states
        """
        # First, generate fresh terminal states and rewards via forward rollouts
        _, info = forward_rollout(
            rng_key=rng_key,
            num_envs=self.n_terminal_states,
            policy_fn=self.fwd_policy_fn,
            policy_params=args.policy_params,
            env=self.env,
            env_params=args.env_params,
        )
        # Second, extract the terminal states and log-rewards
        terminal_states = info["final_env_state"]
        log_rewards = info["log_gfn_reward"]
        log_rewards_transformed = self.transform_fn(terminal_states, log_rewards)
        # Third, return the terminal states and transformed log-rewards
        return terminal_states, log_rewards_transformed

InitArgs

Bases: BaseInitArgs

Arguments for initializing the on-policy correlation metric module.

Attributes:

Name Type Description
env_params TEnvParams

Environment parameters needed to create dummy terminal states during initialization and for environment operations.

Source code in gfnx/metrics/correlation.py
337
338
339
340
341
342
343
344
345
346
@chex.dataclass
class InitArgs(BaseInitArgs):
    """Arguments for initializing the on-policy correlation metric module.

    Attributes:
        env_params: Environment parameters needed to create dummy terminal states
            during initialization and for environment operations.
    """

    env_params: TEnvParams

__init__(n_rounds, n_terminal_states, batch_size, fwd_policy_fn, bwd_policy_fn, env, domain_size=None, transform_fn=lambda x, y: y)

Initialize the on-policy correlation metric module.

Parameters:

Name Type Description Default
n_rounds int

Number of sampling rounds for statistical stability

required
n_terminal_states int

Number of terminal states to generate and evaluate

required
batch_size int

Batch size for efficient processing of terminal states

required
fwd_policy_fn TPolicyFn

Forward policy function for generating trajectories

required
bwd_policy_fn TPolicyFn

Backward policy function for computing log-ratios

required
env TEnvironment

Environment for trajectory generation and reward computation

required
domain_size int | None

Number of points to compute correlations on, corresponds to the output size of the transform_fn. Defaults to the number of terminal states and corresponds to the identity transform.

None
transform_fn Callable[[TEnvState, ndarray], ndarray]

Function to marginalize distributions from terminal states to an arbitrary domain size for correlation computation

lambda x, y: y
Source code in gfnx/metrics/correlation.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def __init__(
    self,
    n_rounds: int,
    n_terminal_states: int,
    batch_size: int,
    fwd_policy_fn: TPolicyFn,
    bwd_policy_fn: TPolicyFn,
    env: TEnvironment,
    domain_size: int | None = None,
    transform_fn: Callable[[TEnvState, jnp.ndarray], jnp.ndarray] = lambda x, y: y,
):
    """Initialize the on-policy correlation metric module.

    Args:
        n_rounds: Number of sampling rounds for statistical stability
        n_terminal_states: Number of terminal states to generate and evaluate
        batch_size: Batch size for efficient processing of terminal states
        fwd_policy_fn: Forward policy function for generating trajectories
        bwd_policy_fn: Backward policy function for computing log-ratios
        env: Environment for trajectory generation and reward computation
        domain_size: Number of points to compute correlations on,
            corresponds to the output size of the transform_fn.
            Defaults to the number of terminal states and corresponds to
            the identity transform.
        transform_fn: Function to marginalize distributions from terminal states
            to an arbitrary domain size for correlation computation
    """
    super().__init__(
        env=env,
        bwd_policy_fn=bwd_policy_fn,
        n_rounds=n_rounds,
        batch_size=batch_size,
        transform_fn=transform_fn,
    )
    self.fwd_policy_fn = fwd_policy_fn
    self.n_terminal_states = n_terminal_states
    self.domain_size = domain_size or n_terminal_states
    assert n_terminal_states % batch_size == 0, (
        f"n_terminal_states ({n_terminal_states}) must be divisible"
        f" by batch_size ({batch_size})"
    )

init(rng_key, args)

Initialize the on-policy correlation metric state.

Creates an initial state with placeholder data structures. The actual terminal states and rewards will be generated during the process phase using on-policy sampling.

The arrays are initialized with domain_size, which represents the size of the transformed probability space after marginalization. This allows the correlation metrics to be computed on an arbitrary domain rather than just the terminal states.

Parameters:

Name Type Description Default
rng_key PRNGKey

Random number generator key (unused in initialization)

required
args InitArgs

InitArgs object containing environment parameters

required

Returns:

Name Type Description
CorrelationMetricsState CorrelationMetricsState

Initialized state with dummy terminal states and zero-initialized arrays for rewards and log-ratios with shape (domain_size,) representing the transformed probability space

Source code in gfnx/metrics/correlation.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
def init(self, rng_key: chex.PRNGKey, args: InitArgs) -> CorrelationMetricsState:
    """Initialize the on-policy correlation metric state.

    Creates an initial state with placeholder data structures. The actual
    terminal states and rewards will be generated during the process phase
    using on-policy sampling.

    The arrays are initialized with domain_size, which represents the size of the
    transformed probability space after marginalization. This allows the correlation
    metrics to be computed on an arbitrary domain rather than just the terminal states.

    Args:
        rng_key: Random number generator key (unused in initialization)
        args: InitArgs object containing environment parameters

    Returns:
        CorrelationMetricsState: Initialized state with dummy terminal states
            and zero-initialized arrays for rewards and log-ratios with shape
            (domain_size,) representing the transformed probability space
    """
    dummy_terminal_states = self.env.reset(self.n_terminal_states, args.env_params)[1]
    return CorrelationMetricsState(
        test_terminal_states=dummy_terminal_states,
        test_log_rewards_transformed=jnp.zeros(self.domain_size),
        log_ratio_traj_transformed=jnp.zeros(self.domain_size),
    )

SWMeanRewardMetricsState

Bases: MetricsState

State for mean reward metrics with sliding window buffer.

This state container maintains a sliding window buffer of recent rewards to compute mean reward statistics over a fixed-size window of the most recent samples.

Attributes:

Name Type Description
reward_buffer Any

Flashbax buffer storing the most recent reward samples in a circular buffer with configurable maximum size

Source code in gfnx/metrics/reward_delta.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
@chex.dataclass
class SWMeanRewardMetricsState(MetricsState):
    """State for mean reward metrics with sliding window buffer.

    This state container maintains a sliding window buffer of recent rewards
    to compute mean reward statistics over a fixed-size window of the most
    recent samples.

    Attributes:
        reward_buffer: Flashbax buffer storing the most recent reward samples
            in a circular buffer with configurable maximum size
    """

    reward_buffer: Any

SWMeanRewardSWMetricsModule

Bases: BaseMetricsModule

Sliding window mean reward metric module for recent performance tracking.

This module computes mean reward statistics using a sliding window approach, maintaining only the most recent reward samples in a circular buffer.

Attributes:

Name Type Description
env

Environment instance that must support tractable mean reward computation

gt_mean_reward

Ground truth mean reward from the environment

buffer_size

Maximum number of rewards to keep in the sliding window

buffer_module

Flashbax buffer module for managing the sliding window

Source code in gfnx/metrics/reward_delta.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
class SWMeanRewardSWMetricsModule(BaseMetricsModule):
    """Sliding window mean reward metric module for recent performance tracking.

    This module computes mean reward statistics using a sliding window approach,
    maintaining only the most recent reward samples in a circular buffer.

    Attributes:
        env: Environment instance that must support tractable mean reward computation
        gt_mean_reward: Ground truth mean reward from the environment
        buffer_size: Maximum number of rewards to keep in the sliding window
        buffer_module: Flashbax buffer module for managing the sliding window
    """

    def __init__(self, env: TEnvironment, env_params: TEnvParams, buffer_size: int):
        """Initialize the sliding window mean reward metric module.

        Args:
            env: Environment instance that must have tractable mean reward computation
                (env.is_mean_reward_tractable must be True)
            env_params: Environment parameters needed to compute the ground truth mean reward
            buffer_size: Maximum number of reward samples to keep in the sliding window.
                Must be a positive integer.

        Raises:
            ValueError: If the environment does not support tractable mean reward computation
        """
        self.env = env
        if self.env.is_mean_reward_tractable:
            self.gt_mean_reward = self.env.get_mean_reward(env_params)
        else:
            raise ValueError("Ground truth mean reward is not tractable for this environment.")

        self.buffer_size = buffer_size
        self.buffer_module = fbx.make_item_buffer(
            max_length=buffer_size,
            min_length=1,
            sample_batch_size=1,
            add_batches=True,
        )

    InitArgs = EmptyInitArgs

    def init(
        self, rng_key: chex.PRNGKey, args: InitArgs | None = None
    ) -> SWMeanRewardMetricsState:
        """Initialize the sliding window mean reward metric state.

        Creates initial state with an empty sliding window buffer ready to
        accept reward samples. The buffer will automatically manage the
        sliding window behavior as new samples are added.

        Args:
            rng_key: JAX PRNG key for any random initialization (currently unused)
            args: EmptyInitArgs (no additional initialization parameters needed)

        Returns:
            SWMeanRewardMetricsState: Initialized state with empty sliding window buffer
        """
        buffer_state = self.buffer_module.init(jnp.array(0.0))  # Initialize with a dummy value
        return SWMeanRewardMetricsState(reward_buffer=buffer_state)

    @chex.dataclass
    class UpdateArgs(BaseUpdateArgs):
        """Arguments for updating the SWMeanRewardSWMetricsModule.

        Attributes:
            rewards: Array of reward values to add to the sliding window buffer.
                These rewards will replace the oldest entries when the buffer is full.
        """

        rewards: chex.Array

    def update(
        self, metrics_state: SWMeanRewardMetricsState, rng_key: chex.PRNGKey, args: UpdateArgs
    ) -> SWMeanRewardMetricsState:
        """Update the metric state with new reward samples in the sliding window.

        Adds new reward samples to the sliding window buffer. When the buffer is full,
        the oldest samples are automatically replaced with the new ones, maintaining
        the fixed window size.

        Args:
            metrics_state: Current metric state containing the sliding window buffer
            rng_key: JAX PRNG key for any random operations (currently unused)
            args: UpdateArgs object containing the new reward samples

        Returns:
            SWMeanRewardMetricsState: Updated state with new rewards added to the buffer
        """
        updated_data_buffer = self.buffer_module.add(metrics_state.reward_buffer, args.rewards)
        return metrics_state.replace(reward_buffer=updated_data_buffer)

    ProcessArgs = EmptyProcessArgs

    def process(
        self,
        metrics_state: MeanRewardMetricsState,
        rng_key: chex.PRNGKey,
        args: ProcessArgs | None = None,
    ) -> MeanRewardMetricsState:
        """Process the metric state for final computation (no-op for sliding window metrics).

        This method performs any final processing needed before metric computation.
        For sliding window mean reward metrics, no additional processing is required
        as the statistics are computed directly from the buffer contents.

        Args:
            metrics_state: Current metric state with sliding window buffer
            rng_key: JAX PRNG key for any random operations (currently unused)
            args: EmptyProcessArgs (no additional processing parameters needed)

        Returns:
            MeanRewardMetricsState: Unchanged metric state ready for get() call
        """
        return metrics_state

    def get(self, metrics_state: SWMeanRewardMetricsState) -> dict[str, float]:
        """Get the computed sliding window mean reward metrics from the current state.

        Computes the mean reward from samples in the sliding window buffer and calculates
        both absolute and relative deviations from the ground truth mean reward. Only
        valid (non-empty) buffer entries are included in the computation.

        Args:
            metrics_state: Current metric state containing the sliding window buffer

        Returns:
            Dict[str, float]: Dictionary containing computed reward metrics:
                - 'mean_reward': Mean reward from samples in the sliding window
                - 'reward_delta': Absolute difference between window mean and ground truth
                - 'rel_reward_delta': Relative difference (normalized by ground truth mean)
        """
        buffer_state = metrics_state.reward_buffer
        all_rewards = metrics_state.reward_buffer.experience[0]
        indices = jnp.arange(all_rewards.shape[0])
        valid_mask = jnp.array(
            jnp.logical_or(buffer_state.is_full, indices < buffer_state.current_index),
            dtype=jnp.float32,
        )
        num_valid = jnp.sum(valid_mask)
        mean_reward = jnp.sum(buffer_state.experience * valid_mask) / jnp.maximum(num_valid, 1)
        reward_delta = abs(mean_reward - self.gt_mean_reward)
        rel_reward_delta = reward_delta / self.gt_mean_reward
        return {
            "mean_reward": mean_reward,
            "reward_delta": reward_delta,
            "rel_reward_delta": rel_reward_delta,
        }

UpdateArgs

Bases: BaseUpdateArgs

Arguments for updating the SWMeanRewardSWMetricsModule.

Attributes:

Name Type Description
rewards Array

Array of reward values to add to the sliding window buffer. These rewards will replace the oldest entries when the buffer is full.

Source code in gfnx/metrics/reward_delta.py
242
243
244
245
246
247
248
249
250
251
@chex.dataclass
class UpdateArgs(BaseUpdateArgs):
    """Arguments for updating the SWMeanRewardSWMetricsModule.

    Attributes:
        rewards: Array of reward values to add to the sliding window buffer.
            These rewards will replace the oldest entries when the buffer is full.
    """

    rewards: chex.Array

__init__(env, env_params, buffer_size)

Initialize the sliding window mean reward metric module.

Parameters:

Name Type Description Default
env TEnvironment

Environment instance that must have tractable mean reward computation (env.is_mean_reward_tractable must be True)

required
env_params TEnvParams

Environment parameters needed to compute the ground truth mean reward

required
buffer_size int

Maximum number of reward samples to keep in the sliding window. Must be a positive integer.

required

Raises:

Type Description
ValueError

If the environment does not support tractable mean reward computation

Source code in gfnx/metrics/reward_delta.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def __init__(self, env: TEnvironment, env_params: TEnvParams, buffer_size: int):
    """Initialize the sliding window mean reward metric module.

    Args:
        env: Environment instance that must have tractable mean reward computation
            (env.is_mean_reward_tractable must be True)
        env_params: Environment parameters needed to compute the ground truth mean reward
        buffer_size: Maximum number of reward samples to keep in the sliding window.
            Must be a positive integer.

    Raises:
        ValueError: If the environment does not support tractable mean reward computation
    """
    self.env = env
    if self.env.is_mean_reward_tractable:
        self.gt_mean_reward = self.env.get_mean_reward(env_params)
    else:
        raise ValueError("Ground truth mean reward is not tractable for this environment.")

    self.buffer_size = buffer_size
    self.buffer_module = fbx.make_item_buffer(
        max_length=buffer_size,
        min_length=1,
        sample_batch_size=1,
        add_batches=True,
    )

get(metrics_state)

Get the computed sliding window mean reward metrics from the current state.

Computes the mean reward from samples in the sliding window buffer and calculates both absolute and relative deviations from the ground truth mean reward. Only valid (non-empty) buffer entries are included in the computation.

Parameters:

Name Type Description Default
metrics_state SWMeanRewardMetricsState

Current metric state containing the sliding window buffer

required

Returns:

Type Description
dict[str, float]

Dict[str, float]: Dictionary containing computed reward metrics: - 'mean_reward': Mean reward from samples in the sliding window - 'reward_delta': Absolute difference between window mean and ground truth - 'rel_reward_delta': Relative difference (normalized by ground truth mean)

Source code in gfnx/metrics/reward_delta.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
def get(self, metrics_state: SWMeanRewardMetricsState) -> dict[str, float]:
    """Get the computed sliding window mean reward metrics from the current state.

    Computes the mean reward from samples in the sliding window buffer and calculates
    both absolute and relative deviations from the ground truth mean reward. Only
    valid (non-empty) buffer entries are included in the computation.

    Args:
        metrics_state: Current metric state containing the sliding window buffer

    Returns:
        Dict[str, float]: Dictionary containing computed reward metrics:
            - 'mean_reward': Mean reward from samples in the sliding window
            - 'reward_delta': Absolute difference between window mean and ground truth
            - 'rel_reward_delta': Relative difference (normalized by ground truth mean)
    """
    buffer_state = metrics_state.reward_buffer
    all_rewards = metrics_state.reward_buffer.experience[0]
    indices = jnp.arange(all_rewards.shape[0])
    valid_mask = jnp.array(
        jnp.logical_or(buffer_state.is_full, indices < buffer_state.current_index),
        dtype=jnp.float32,
    )
    num_valid = jnp.sum(valid_mask)
    mean_reward = jnp.sum(buffer_state.experience * valid_mask) / jnp.maximum(num_valid, 1)
    reward_delta = abs(mean_reward - self.gt_mean_reward)
    rel_reward_delta = reward_delta / self.gt_mean_reward
    return {
        "mean_reward": mean_reward,
        "reward_delta": reward_delta,
        "rel_reward_delta": rel_reward_delta,
    }

init(rng_key, args=None)

Initialize the sliding window mean reward metric state.

Creates initial state with an empty sliding window buffer ready to accept reward samples. The buffer will automatically manage the sliding window behavior as new samples are added.

Parameters:

Name Type Description Default
rng_key PRNGKey

JAX PRNG key for any random initialization (currently unused)

required
args InitArgs | None

EmptyInitArgs (no additional initialization parameters needed)

None

Returns:

Name Type Description
SWMeanRewardMetricsState SWMeanRewardMetricsState

Initialized state with empty sliding window buffer

Source code in gfnx/metrics/reward_delta.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def init(
    self, rng_key: chex.PRNGKey, args: InitArgs | None = None
) -> SWMeanRewardMetricsState:
    """Initialize the sliding window mean reward metric state.

    Creates initial state with an empty sliding window buffer ready to
    accept reward samples. The buffer will automatically manage the
    sliding window behavior as new samples are added.

    Args:
        rng_key: JAX PRNG key for any random initialization (currently unused)
        args: EmptyInitArgs (no additional initialization parameters needed)

    Returns:
        SWMeanRewardMetricsState: Initialized state with empty sliding window buffer
    """
    buffer_state = self.buffer_module.init(jnp.array(0.0))  # Initialize with a dummy value
    return SWMeanRewardMetricsState(reward_buffer=buffer_state)

process(metrics_state, rng_key, args=None)

Process the metric state for final computation (no-op for sliding window metrics).

This method performs any final processing needed before metric computation. For sliding window mean reward metrics, no additional processing is required as the statistics are computed directly from the buffer contents.

Parameters:

Name Type Description Default
metrics_state MeanRewardMetricsState

Current metric state with sliding window buffer

required
rng_key PRNGKey

JAX PRNG key for any random operations (currently unused)

required
args ProcessArgs | None

EmptyProcessArgs (no additional processing parameters needed)

None

Returns:

Name Type Description
MeanRewardMetricsState MeanRewardMetricsState

Unchanged metric state ready for get() call

Source code in gfnx/metrics/reward_delta.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def process(
    self,
    metrics_state: MeanRewardMetricsState,
    rng_key: chex.PRNGKey,
    args: ProcessArgs | None = None,
) -> MeanRewardMetricsState:
    """Process the metric state for final computation (no-op for sliding window metrics).

    This method performs any final processing needed before metric computation.
    For sliding window mean reward metrics, no additional processing is required
    as the statistics are computed directly from the buffer contents.

    Args:
        metrics_state: Current metric state with sliding window buffer
        rng_key: JAX PRNG key for any random operations (currently unused)
        args: EmptyProcessArgs (no additional processing parameters needed)

    Returns:
        MeanRewardMetricsState: Unchanged metric state ready for get() call
    """
    return metrics_state

update(metrics_state, rng_key, args)

Update the metric state with new reward samples in the sliding window.

Adds new reward samples to the sliding window buffer. When the buffer is full, the oldest samples are automatically replaced with the new ones, maintaining the fixed window size.

Parameters:

Name Type Description Default
metrics_state SWMeanRewardMetricsState

Current metric state containing the sliding window buffer

required
rng_key PRNGKey

JAX PRNG key for any random operations (currently unused)

required
args UpdateArgs

UpdateArgs object containing the new reward samples

required

Returns:

Name Type Description
SWMeanRewardMetricsState SWMeanRewardMetricsState

Updated state with new rewards added to the buffer

Source code in gfnx/metrics/reward_delta.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def update(
    self, metrics_state: SWMeanRewardMetricsState, rng_key: chex.PRNGKey, args: UpdateArgs
) -> SWMeanRewardMetricsState:
    """Update the metric state with new reward samples in the sliding window.

    Adds new reward samples to the sliding window buffer. When the buffer is full,
    the oldest samples are automatically replaced with the new ones, maintaining
    the fixed window size.

    Args:
        metrics_state: Current metric state containing the sliding window buffer
        rng_key: JAX PRNG key for any random operations (currently unused)
        args: UpdateArgs object containing the new reward samples

    Returns:
        SWMeanRewardMetricsState: Updated state with new rewards added to the buffer
    """
    updated_data_buffer = self.buffer_module.add(metrics_state.reward_buffer, args.rewards)
    return metrics_state.replace(reward_buffer=updated_data_buffer)

TestCorrelationMetricsModule

Bases: BaseCorrelationMetricsModule

Fixed test set correlation metric module for GFlowNet evaluation.

This metric module computes correlation metrics between transformed model predictions and transformed true rewards using a fixed set of test terminal states. Unlike the on-policy variant, this module evaluates the model's performance on the same set of terminal states across different training iterations, providing consistent evaluation points for tracking training progress.

Source code in gfnx/metrics/correlation.py
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
class TestCorrelationMetricsModule(BaseCorrelationMetricsModule):
    """Fixed test set correlation metric module for GFlowNet evaluation.

    This metric module computes correlation metrics between transformed model predictions
    and transformed true rewards using a fixed set of test terminal states. Unlike the
    on-policy variant, this module evaluates the model's performance on the same set of
    terminal states across different training iterations, providing consistent evaluation
    points for tracking training progress.
    """

    @chex.dataclass
    class InitArgs(BaseInitArgs):
        """Arguments for initializing the test correlation metric module.

        Attributes:
            env_params: Environment parameters needed for computing log-rewards
                from the test set and for environment operations.
            test_set: Fixed set of terminal states that will be used consistently
                for correlation evaluation across training iterations.
        """

        env_params: TEnvParams
        test_set: TEnvState

    def init(self, rng_key: chex.PRNGKey, args: InitArgs) -> CorrelationMetricsState:
        """Initialize metric state with a fixed test set.

        Sets up the metric state using a predefined set of terminal states.
        The log-rewards are computed immediately from the test set, while
        log-ratios are initialized to zero and will be computed during processing.

        Args:
            rng_key: Random number generator key (unused in initialization)
            args: InitArgs object containing environment parameters and test set

        Returns:
            CorrelationMetricsState: Initialized state with fixed test terminal
                states, their computed transformed log-rewards, and zero-initialized
                log-ratios with shape matching the transformed log-rewards
        """
        test_log_rewards = self.env.reward_module.log_reward(args.test_set, args.env_params)
        test_log_rewards_transformed = self.transform_fn(args.test_set, test_log_rewards)
        return CorrelationMetricsState(
            test_terminal_states=args.test_set,
            test_log_rewards_transformed=test_log_rewards_transformed,
            log_ratio_traj_transformed=jnp.zeros_like(test_log_rewards_transformed),
        )

    def _get_states_and_rewards(
        self,
        metrics_state: CorrelationMetricsState,
        rng_key: chex.PRNGKey,
        args: BaseCorrelationMetricsModule.ProcessArgs,
    ) -> tuple[TEnvState, jnp.ndarray]:
        """Return existing test data from the metrics state.

        This method simply returns the terminal states and transformed log-rewards
        that were computed during initialization. No new data generation is needed
        since this module uses a fixed test set.

        Args:
            metrics_state: Current metric state containing the fixed test set data
            rng_key: Random number generator key (unused for fixed test set)
            args: ProcessArgs object (unused for fixed test set)

        Returns:
            Tuple[TEnvState, jnp.ndarray]: A tuple containing:
                - Terminal states from the fixed test set
                - Transformed log-rewards corresponding to the terminal states
        """
        return metrics_state.test_terminal_states, metrics_state.test_log_rewards_transformed

InitArgs

Bases: BaseInitArgs

Arguments for initializing the test correlation metric module.

Attributes:

Name Type Description
env_params TEnvParams

Environment parameters needed for computing log-rewards from the test set and for environment operations.

test_set TEnvState

Fixed set of terminal states that will be used consistently for correlation evaluation across training iterations.

Source code in gfnx/metrics/correlation.py
426
427
428
429
430
431
432
433
434
435
436
437
438
@chex.dataclass
class InitArgs(BaseInitArgs):
    """Arguments for initializing the test correlation metric module.

    Attributes:
        env_params: Environment parameters needed for computing log-rewards
            from the test set and for environment operations.
        test_set: Fixed set of terminal states that will be used consistently
            for correlation evaluation across training iterations.
    """

    env_params: TEnvParams
    test_set: TEnvState

init(rng_key, args)

Initialize metric state with a fixed test set.

Sets up the metric state using a predefined set of terminal states. The log-rewards are computed immediately from the test set, while log-ratios are initialized to zero and will be computed during processing.

Parameters:

Name Type Description Default
rng_key PRNGKey

Random number generator key (unused in initialization)

required
args InitArgs

InitArgs object containing environment parameters and test set

required

Returns:

Name Type Description
CorrelationMetricsState CorrelationMetricsState

Initialized state with fixed test terminal states, their computed transformed log-rewards, and zero-initialized log-ratios with shape matching the transformed log-rewards

Source code in gfnx/metrics/correlation.py
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
def init(self, rng_key: chex.PRNGKey, args: InitArgs) -> CorrelationMetricsState:
    """Initialize metric state with a fixed test set.

    Sets up the metric state using a predefined set of terminal states.
    The log-rewards are computed immediately from the test set, while
    log-ratios are initialized to zero and will be computed during processing.

    Args:
        rng_key: Random number generator key (unused in initialization)
        args: InitArgs object containing environment parameters and test set

    Returns:
        CorrelationMetricsState: Initialized state with fixed test terminal
            states, their computed transformed log-rewards, and zero-initialized
            log-ratios with shape matching the transformed log-rewards
    """
    test_log_rewards = self.env.reward_module.log_reward(args.test_set, args.env_params)
    test_log_rewards_transformed = self.transform_fn(args.test_set, test_log_rewards)
    return CorrelationMetricsState(
        test_terminal_states=args.test_set,
        test_log_rewards_transformed=test_log_rewards_transformed,
        log_ratio_traj_transformed=jnp.zeros_like(test_log_rewards_transformed),
    )

TopKMetricsModule

Bases: BaseMetricsModule

Metric module for computing top-K reward and diversity statistics.

This module evaluates policy performance by sampling trajectories and computing statistics for the top-K highest-reward samples. It measures both the quality (reward) and diversity of the best samples, providing insights into the policy's ability to find high-reward solutions while maintaining diversity.

Attributes:

Name Type Description
env

Environment instance for trajectory generation and reward computation

fwd_policy_fn

Forward policy function for generating trajectories

num_traj

Total number of trajectories to sample for evaluation

batch_size

Batch size for processing trajectories (for memory management)

top_k

List of K values for which to compute top-K statistics

distance_fn

Function to compute distance between states for diversity measurement

Source code in gfnx/metrics/top_k.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
class TopKMetricsModule(BaseMetricsModule):
    """Metric module for computing top-K reward and diversity statistics.

    This module evaluates policy performance by sampling trajectories and computing
    statistics for the top-K highest-reward samples. It measures both the quality
    (reward) and diversity of the best samples, providing insights into the policy's
    ability to find high-reward solutions while maintaining diversity.

    Attributes:
        env: Environment instance for trajectory generation and reward computation
        fwd_policy_fn: Forward policy function for generating trajectories
        num_traj: Total number of trajectories to sample for evaluation
        batch_size: Batch size for processing trajectories (for memory management)
        top_k: List of K values for which to compute top-K statistics
        distance_fn: Function to compute distance between states for diversity measurement
    """

    def __init__(
        self,
        env: TEnvironment,
        fwd_policy_fn: TPolicyFn,
        num_traj: int,
        batch_size: int,
        top_k: list[int] | None = None,
        distance_fn: Callable[[TEnvState, TEnvState], float] | None = None,
    ):
        """Initialize the top-K metrics module.

        Args:
            env: Environment instance for trajectory generation and reward computation
            fwd_policy_fn: Forward policy function for generating trajectory samples
            num_traj: Total number of trajectories to sample during evaluation.
                Must be >= max(top_k) for meaningful statistics.
            batch_size: Batch size for processing trajectories (used for memory management)
            top_k: List of K values for which to compute top-K statistics.
                Default is [10, 50, 100].
            distance_fn: Function that computes distance between two environment states
                for diversity measurement. Must return a scalar distance value.
        """
        if top_k is None:
            top_k = [10, 50, 100]
        self.num_traj = num_traj
        self.batch_size = batch_size
        self.env = env
        self.fwd_policy_fn = fwd_policy_fn
        self.top_k = top_k
        self.distance_fn = distance_fn

    def _get_distance_matrix(self, lhs_states: TEnvState, rhs_states: TEnvState) -> jnp.ndarray:
        """Compute pairwise distance matrix between two sets of states.

        Computes all pairwise distances between states in lhs_states and rhs_states
        using the configured distance function. This is used to measure diversity
        among the top-K samples by computing distances between all pairs.

        Args:
            lhs_states: First set of states, N states
            rhs_states: Second set of states, M states

        Returns:
            jnp.ndarray: Distance matrix of shape (N, M) where entry (i,j) contains
                the distance between lhs_states[i] and rhs_states[j]
        """
        result = jax.vmap(
            lambda lhs_state, rhs_states: jax.vmap(
                lambda rhs_state: self.distance_fn(lhs_state, rhs_state)
            )(rhs_states),
            in_axes=(0, None),
        )(lhs_states, rhs_states)
        chex.assert_shape(result, (lhs_states.is_pad.shape[0], rhs_states.is_pad.shape[0]))
        return result

    InitArgs = EmptyInitArgs

    def init(self, rng_key: chex.PRNGKey, args: InitArgs | None = None) -> TopKMetricsState:
        """Initialize the top-K metrics state.

        Creates initial state with zero-initialized arrays for top-K rewards and
        diversity statistics. The actual values will be computed during the process phase.

        Args:
            rng_key: JAX PRNG key for any random initialization (currently unused)
            args: EmptyInitArgs (no additional initialization parameters needed)

        Returns:
            TopKMetricsState: Initialized state with zero arrays for rewards and diversity
        """
        return TopKMetricsState(
            top_k_rewards=jnp.zeros((len(self.top_k),), dtype=jnp.float32),
            top_k_diversity=jnp.zeros((len(self.top_k),), dtype=jnp.float32),
        )

    UpdateArgs = EmptyUpdateArgs

    def update(
        self,
        metrics_state: TopKMetricsState,
        rng_key: chex.PRNGKey,
        args: UpdateArgs | None = None,
    ) -> TopKMetricsState:
        """Update the metric state with new data (no-op for top-K metrics).

        Top-K metrics are computed entirely during the process phase by sampling
        fresh trajectories, so no incremental updates are needed.

        Args:
            metrics_state: Current metric state (unchanged)
            rng_key: JAX PRNG key for any random operations (currently unused)
            args: EmptyUpdateArgs (no update parameters needed)

        Returns:
            TopKMetricsState: Unchanged metric state
        """
        return metrics_state

    @chex.dataclass
    class ProcessArgs(BaseProcessArgs):
        """Arguments for processing the TopKMetricsModule.

        Attributes:
            policy_params: Current policy parameters used for forward rollouts
                to generate trajectory samples for top-K evaluation.
            env_params: Environment parameters required for trajectory generation
                and reward computation.
        """

        policy_params: TPolicyParams
        env_params: Any

    def process(
        self,
        metrics_state: TopKMetricsState,
        rng_key: chex.PRNGKey,
        args: ProcessArgs,
    ) -> TopKMetricsState:
        """Process the metric state to compute top-K reward and diversity statistics.

        This method performs the core computation for top-K metrics:
        1. Samples trajectories using the current policy
        2. Computes rewards for all terminal states
        3. Identifies the top-K highest-reward samples for each K value
        4. Computes mean rewards and diversity measures for each top-K set

        Args:
            metrics_state: Current metric state (largely ignored, will be replaced)
            rng_key: JAX PRNG key for trajectory sampling
            args: ProcessArgs object containing policy parameters and environment parameters

        Returns:
            TopKMetricsState: Updated state with computed top-K rewards and diversity statistics
        """
        # Sample a batch of trajectory
        _, info = forward_rollout(
            rng_key,
            num_envs=self.num_traj,
            policy_fn=self.fwd_policy_fn,
            policy_params=args.policy_params,
            env=self.env,
            env_params=args.env_params,
        )
        final_env_state = info["final_env_state"]
        rewards = jnp.exp(info["log_gfn_reward"])
        chex.assert_shape(rewards, (self.num_traj,))
        arg_sort_idx = jnp.argsort(rewards)

        # Compute top-k rewards and top-k diversity
        topk_rew = jnp.zeros((len(self.top_k),), dtype=jnp.float32)
        topk_div = jnp.zeros((len(self.top_k),), dtype=jnp.float32)

        for idx, k in enumerate(self.top_k):
            top_idx = arg_sort_idx[-k:]
            topk_rew = topk_rew.at[idx].set(jnp.mean(rewards[top_idx]))
            top_samples = jax.tree.map(lambda x, top_idx=top_idx: x[top_idx], final_env_state)
            num_nonzero_dist = (k - 1) * k
            distance_matrix = self._get_distance_matrix(top_samples, top_samples)
            topk_div = topk_div.at[idx].set(distance_matrix.sum() / num_nonzero_dist)

        return metrics_state.replace(
            top_k_rewards=topk_rew,
            top_k_diversity=topk_div,
        )

    def get(self, metrics_state: TopKMetricsState) -> dict:
        """Get the computed top-K metrics from the current state.

        Extracts the computed top-K reward and diversity statistics and formats them
        into a dictionary with descriptive keys for each K value specified during
        module initialization.

        Args:
            metrics_state: Current metric state containing computed top-K statistics

        Returns:
            Dict[str, float]: Dictionary containing computed top-K metrics with keys:
                - 'top_{k}_reward': Mean reward of the top-K samples for each K
                - 'top_{k}_diversity': Mean pairwise distance among top-K samples for each K

        Example:
            If initialized with top_k=[10, 50], might return:
            {
                "top_10_reward": 0.85,
                "top_10_diversity": 0.42,
                "top_50_reward": 0.78,
                "top_50_diversity": 0.51
            }
        """
        reward_dict = {
            f"top_{k}_reward": metrics_state.top_k_rewards[idx] for idx, k in enumerate(self.top_k)
        }
        diversity_dict = {
            f"top_{k}_diversity": metrics_state.top_k_diversity[idx]
            for idx, k in enumerate(self.top_k)
        }
        return {**reward_dict, **diversity_dict}

ProcessArgs

Bases: BaseProcessArgs

Arguments for processing the TopKMetricsModule.

Attributes:

Name Type Description
policy_params TPolicyParams

Current policy parameters used for forward rollouts to generate trajectory samples for top-K evaluation.

env_params Any

Environment parameters required for trajectory generation and reward computation.

Source code in gfnx/metrics/top_k.py
152
153
154
155
156
157
158
159
160
161
162
163
164
@chex.dataclass
class ProcessArgs(BaseProcessArgs):
    """Arguments for processing the TopKMetricsModule.

    Attributes:
        policy_params: Current policy parameters used for forward rollouts
            to generate trajectory samples for top-K evaluation.
        env_params: Environment parameters required for trajectory generation
            and reward computation.
    """

    policy_params: TPolicyParams
    env_params: Any

__init__(env, fwd_policy_fn, num_traj, batch_size, top_k=None, distance_fn=None)

Initialize the top-K metrics module.

Parameters:

Name Type Description Default
env TEnvironment

Environment instance for trajectory generation and reward computation

required
fwd_policy_fn TPolicyFn

Forward policy function for generating trajectory samples

required
num_traj int

Total number of trajectories to sample during evaluation. Must be >= max(top_k) for meaningful statistics.

required
batch_size int

Batch size for processing trajectories (used for memory management)

required
top_k list[int] | None

List of K values for which to compute top-K statistics. Default is [10, 50, 100].

None
distance_fn Callable[[TEnvState, TEnvState], float] | None

Function that computes distance between two environment states for diversity measurement. Must return a scalar distance value.

None
Source code in gfnx/metrics/top_k.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init__(
    self,
    env: TEnvironment,
    fwd_policy_fn: TPolicyFn,
    num_traj: int,
    batch_size: int,
    top_k: list[int] | None = None,
    distance_fn: Callable[[TEnvState, TEnvState], float] | None = None,
):
    """Initialize the top-K metrics module.

    Args:
        env: Environment instance for trajectory generation and reward computation
        fwd_policy_fn: Forward policy function for generating trajectory samples
        num_traj: Total number of trajectories to sample during evaluation.
            Must be >= max(top_k) for meaningful statistics.
        batch_size: Batch size for processing trajectories (used for memory management)
        top_k: List of K values for which to compute top-K statistics.
            Default is [10, 50, 100].
        distance_fn: Function that computes distance between two environment states
            for diversity measurement. Must return a scalar distance value.
    """
    if top_k is None:
        top_k = [10, 50, 100]
    self.num_traj = num_traj
    self.batch_size = batch_size
    self.env = env
    self.fwd_policy_fn = fwd_policy_fn
    self.top_k = top_k
    self.distance_fn = distance_fn

get(metrics_state)

Get the computed top-K metrics from the current state.

Extracts the computed top-K reward and diversity statistics and formats them into a dictionary with descriptive keys for each K value specified during module initialization.

Parameters:

Name Type Description Default
metrics_state TopKMetricsState

Current metric state containing computed top-K statistics

required

Returns:

Type Description
dict

Dict[str, float]: Dictionary containing computed top-K metrics with keys: - 'top_{k}reward': Mean reward of the top-K samples for each K - 'top{k}_diversity': Mean pairwise distance among top-K samples for each K

Example

If initialized with top_k=[10, 50], might return: { "top_10_reward": 0.85, "top_10_diversity": 0.42, "top_50_reward": 0.78, "top_50_diversity": 0.51 }

Source code in gfnx/metrics/top_k.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def get(self, metrics_state: TopKMetricsState) -> dict:
    """Get the computed top-K metrics from the current state.

    Extracts the computed top-K reward and diversity statistics and formats them
    into a dictionary with descriptive keys for each K value specified during
    module initialization.

    Args:
        metrics_state: Current metric state containing computed top-K statistics

    Returns:
        Dict[str, float]: Dictionary containing computed top-K metrics with keys:
            - 'top_{k}_reward': Mean reward of the top-K samples for each K
            - 'top_{k}_diversity': Mean pairwise distance among top-K samples for each K

    Example:
        If initialized with top_k=[10, 50], might return:
        {
            "top_10_reward": 0.85,
            "top_10_diversity": 0.42,
            "top_50_reward": 0.78,
            "top_50_diversity": 0.51
        }
    """
    reward_dict = {
        f"top_{k}_reward": metrics_state.top_k_rewards[idx] for idx, k in enumerate(self.top_k)
    }
    diversity_dict = {
        f"top_{k}_diversity": metrics_state.top_k_diversity[idx]
        for idx, k in enumerate(self.top_k)
    }
    return {**reward_dict, **diversity_dict}

init(rng_key, args=None)

Initialize the top-K metrics state.

Creates initial state with zero-initialized arrays for top-K rewards and diversity statistics. The actual values will be computed during the process phase.

Parameters:

Name Type Description Default
rng_key PRNGKey

JAX PRNG key for any random initialization (currently unused)

required
args InitArgs | None

EmptyInitArgs (no additional initialization parameters needed)

None

Returns:

Name Type Description
TopKMetricsState TopKMetricsState

Initialized state with zero arrays for rewards and diversity

Source code in gfnx/metrics/top_k.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def init(self, rng_key: chex.PRNGKey, args: InitArgs | None = None) -> TopKMetricsState:
    """Initialize the top-K metrics state.

    Creates initial state with zero-initialized arrays for top-K rewards and
    diversity statistics. The actual values will be computed during the process phase.

    Args:
        rng_key: JAX PRNG key for any random initialization (currently unused)
        args: EmptyInitArgs (no additional initialization parameters needed)

    Returns:
        TopKMetricsState: Initialized state with zero arrays for rewards and diversity
    """
    return TopKMetricsState(
        top_k_rewards=jnp.zeros((len(self.top_k),), dtype=jnp.float32),
        top_k_diversity=jnp.zeros((len(self.top_k),), dtype=jnp.float32),
    )

process(metrics_state, rng_key, args)

Process the metric state to compute top-K reward and diversity statistics.

This method performs the core computation for top-K metrics: 1. Samples trajectories using the current policy 2. Computes rewards for all terminal states 3. Identifies the top-K highest-reward samples for each K value 4. Computes mean rewards and diversity measures for each top-K set

Parameters:

Name Type Description Default
metrics_state TopKMetricsState

Current metric state (largely ignored, will be replaced)

required
rng_key PRNGKey

JAX PRNG key for trajectory sampling

required
args ProcessArgs

ProcessArgs object containing policy parameters and environment parameters

required

Returns:

Name Type Description
TopKMetricsState TopKMetricsState

Updated state with computed top-K rewards and diversity statistics

Source code in gfnx/metrics/top_k.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def process(
    self,
    metrics_state: TopKMetricsState,
    rng_key: chex.PRNGKey,
    args: ProcessArgs,
) -> TopKMetricsState:
    """Process the metric state to compute top-K reward and diversity statistics.

    This method performs the core computation for top-K metrics:
    1. Samples trajectories using the current policy
    2. Computes rewards for all terminal states
    3. Identifies the top-K highest-reward samples for each K value
    4. Computes mean rewards and diversity measures for each top-K set

    Args:
        metrics_state: Current metric state (largely ignored, will be replaced)
        rng_key: JAX PRNG key for trajectory sampling
        args: ProcessArgs object containing policy parameters and environment parameters

    Returns:
        TopKMetricsState: Updated state with computed top-K rewards and diversity statistics
    """
    # Sample a batch of trajectory
    _, info = forward_rollout(
        rng_key,
        num_envs=self.num_traj,
        policy_fn=self.fwd_policy_fn,
        policy_params=args.policy_params,
        env=self.env,
        env_params=args.env_params,
    )
    final_env_state = info["final_env_state"]
    rewards = jnp.exp(info["log_gfn_reward"])
    chex.assert_shape(rewards, (self.num_traj,))
    arg_sort_idx = jnp.argsort(rewards)

    # Compute top-k rewards and top-k diversity
    topk_rew = jnp.zeros((len(self.top_k),), dtype=jnp.float32)
    topk_div = jnp.zeros((len(self.top_k),), dtype=jnp.float32)

    for idx, k in enumerate(self.top_k):
        top_idx = arg_sort_idx[-k:]
        topk_rew = topk_rew.at[idx].set(jnp.mean(rewards[top_idx]))
        top_samples = jax.tree.map(lambda x, top_idx=top_idx: x[top_idx], final_env_state)
        num_nonzero_dist = (k - 1) * k
        distance_matrix = self._get_distance_matrix(top_samples, top_samples)
        topk_div = topk_div.at[idx].set(distance_matrix.sum() / num_nonzero_dist)

    return metrics_state.replace(
        top_k_rewards=topk_rew,
        top_k_diversity=topk_div,
    )

update(metrics_state, rng_key, args=None)

Update the metric state with new data (no-op for top-K metrics).

Top-K metrics are computed entirely during the process phase by sampling fresh trajectories, so no incremental updates are needed.

Parameters:

Name Type Description Default
metrics_state TopKMetricsState

Current metric state (unchanged)

required
rng_key PRNGKey

JAX PRNG key for any random operations (currently unused)

required
args UpdateArgs | None

EmptyUpdateArgs (no update parameters needed)

None

Returns:

Name Type Description
TopKMetricsState TopKMetricsState

Unchanged metric state

Source code in gfnx/metrics/top_k.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def update(
    self,
    metrics_state: TopKMetricsState,
    rng_key: chex.PRNGKey,
    args: UpdateArgs | None = None,
) -> TopKMetricsState:
    """Update the metric state with new data (no-op for top-K metrics).

    Top-K metrics are computed entirely during the process phase by sampling
    fresh trajectories, so no incremental updates are needed.

    Args:
        metrics_state: Current metric state (unchanged)
        rng_key: JAX PRNG key for any random operations (currently unused)
        args: EmptyUpdateArgs (no update parameters needed)

    Returns:
        TopKMetricsState: Unchanged metric state
    """
    return metrics_state

TopKMetricsState

Bases: MetricsState

State for Top-K reward and diversity metrics.

This state container stores the computed top-K reward and diversity statistics for different values of K. It maintains arrays where each element corresponds to a different K value specified during module initialization.

Attributes:

Name Type Description
top_k_rewards Array

Array of mean rewards for the top-K samples for each K value. Shape: (len(top_k),) where each entry is the mean reward of the top-K samples.

top_k_diversity Array

Array of diversity measures for the top-K samples for each K value. Shape: (len(top_k),) where each entry is the average pairwise distance among top-K samples.

Source code in gfnx/metrics/top_k.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
@chex.dataclass
class TopKMetricsState(MetricsState):
    """State for Top-K reward and diversity metrics.

    This state container stores the computed top-K reward and diversity statistics
    for different values of K. It maintains arrays where each element corresponds
    to a different K value specified during module initialization.

    Attributes:
        top_k_rewards: Array of mean rewards for the top-K samples for each K value.
            Shape: (len(top_k),) where each entry is the mean reward of the top-K samples.
        top_k_diversity: Array of diversity measures for the top-K samples for each K value.
            Shape: (len(top_k),) where each entry is the average pairwise distance
            among top-K samples.
    """

    top_k_rewards: chex.Array
    top_k_diversity: chex.Array