API Reference
AccumulatedModesMetricsModule
Bases: BaseMetricsModule
Metric module for tracking mode discovery in GFlowNet training.
This module monitors how well a GFlowNet policy discovers different modes by tracking which known modes are visited during training or evaluation. It maintains a set of reference modes and records which ones have been encountered based on a distance threshold.
Attributes:
| Name | Type | Description |
|---|---|---|
env |
Environment instance for mode-related operations |
|
distance_fn |
Function to compute distance between two states |
|
distance_threshold |
Maximum distance for considering a state as visiting a mode |
Source code in gfnx/metrics/modes.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | |
InitArgs
Bases: BaseInitArgs
Arguments for initializing the AccumulatedModesMetricsModule.
Attributes:
| Name | Type | Description |
|---|---|---|
modes |
TEnvState
|
Reference set of mode states to track during training/evaluation. These represent the target modes that the policy should discover. |
Source code in gfnx/metrics/modes.py
89 90 91 92 93 94 95 96 97 98 | |
ProcessArgs
Bases: BaseProcessArgs
Arguments for processing the AccumulatedModesMetricsModule.
Attributes:
| Name | Type | Description |
|---|---|---|
env_params |
Any
|
Environment parameters (currently unused but maintained for interface consistency with other metric modules). |
Source code in gfnx/metrics/modes.py
156 157 158 159 160 161 162 163 164 165 | |
UpdateArgs
Bases: BaseUpdateArgs
Arguments for updating the AccumulatedModesMetricsModule.
Attributes:
| Name | Type | Description |
|---|---|---|
states |
TEnvState
|
Environment states visited during training/evaluation to check against the reference modes for mode discovery tracking. |
Source code in gfnx/metrics/modes.py
119 120 121 122 123 124 125 126 127 128 | |
__init__(env, distance_fn, distance_threshold=0.1)
Initialize the accumulated modes metric module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
TEnvironment
|
Environment instance for mode-related operations |
required |
distance_fn
|
Callable[[TEnvState, TEnvState], float]
|
Function that computes distance between two environment states. Must return a scalar distance value. |
required |
distance_threshold
|
float
|
Maximum distance threshold for considering a visited state as discovering a mode. States within this distance of a mode are considered to have "visited" that mode. |
0.1
|
Source code in gfnx/metrics/modes.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | |
get(metrics_state)
Get the computed mode discovery metrics from the current state.
Computes and returns metrics that quantify how well the policy has discovered the reference modes. Provides both absolute and relative measures of mode coverage.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
AccumulatedModesMetricsState
|
Current metric state containing mode visit tracking information |
required |
Returns:
| Type | Description |
|---|---|
dict
|
Dict[str, Any]: Dictionary containing computed mode discovery metrics: - 'num_modes': Total number of modes discovered (integer count) - 'percent_modes': Fraction of modes discovered (float between 0 and 1) |
Source code in gfnx/metrics/modes.py
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | |
init(rng_key, args)
Initialize the accumulated modes metric state.
Creates initial state with the provided reference modes and initializes the visited modes tracking array to all False (no modes visited yet).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
PRNGKey
|
JAX PRNG key for any random initialization (currently unused) |
required |
args
|
InitArgs
|
InitArgs object containing the reference modes to track |
required |
Returns:
| Name | Type | Description |
|---|---|---|
AccumulatedModesMetricsState |
AccumulatedModesMetricsState
|
Initialized state with reference modes and empty visited modes tracking array |
Source code in gfnx/metrics/modes.py
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | |
process(metrics_state, rng_key, args=None)
Process the metric state for final computation (no-op for modes metrics).
This method performs any final processing needed before metric computation. For accumulated modes metrics, no additional processing is required as the state is maintained incrementally during updates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
AccumulatedModesMetricsState
|
Current metric state with accumulated mode visit information |
required |
rng_key
|
PRNGKey
|
JAX PRNG key for any random operations (currently unused) |
required |
args
|
ProcessArgs | None
|
ProcessArgs object containing environment parameters (currently unused) |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
AccumulatedModesMetricsState |
AccumulatedModesMetricsState
|
Unchanged metric state ready for get() call |
Source code in gfnx/metrics/modes.py
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | |
update(metrics_state, rng_key, args)
Update the metric state with newly visited states.
Checks which reference modes have been visited by computing distances between the provided states and all reference modes. A mode is considered "visited" if any of the provided states is within the distance threshold of that mode.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
AccumulatedModesMetricsState
|
Current metric state containing modes and visited tracking |
required |
rng_key
|
PRNGKey
|
JAX PRNG key for any random operations (currently unused) |
required |
args
|
UpdateArgs
|
UpdateArgs object containing the newly visited states |
required |
Returns:
| Name | Type | Description |
|---|---|---|
AccumulatedModesMetricsState |
AccumulatedModesMetricsState
|
Updated state with updated visited modes tracking. The visited_modes_idx array is updated to mark newly discovered modes. |
Source code in gfnx/metrics/modes.py
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | |
AccumulatedModesMetricsState
Bases: MetricsState
State for accumulating mode discovery metrics.
This state container tracks which modes have been visited during training or evaluation by maintaining a record of known modes and a boolean array indicating which modes have been encountered.
Attributes:
| Name | Type | Description |
|---|---|---|
modes |
TEnvState
|
Reference set of mode states to track. These represent the target modes that the policy should discover. |
visited_modes_idx |
Array
|
Boolean array indicating which modes have been visited at least once. Shape: (num_modes,) |
Source code in gfnx/metrics/modes.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | |
ApproxDistributionMetricsModule
Bases: BaseMetricsModule
Distribution-based metrics module for enumerable environments.
This metric module computes distribution-based metrics by comparing the true distribution of an enumerable environment with an empirical distribution derived from collected samples. It supports various distance metrics between distributions such as KL divergence and total variation distance.
The module maintains a replay buffer to accumulate environment states and computes empirical distributions from these samples. It can only be used with enumerable environments that provide access to their true distribution.
Supported metrics
- "tv": Total variation distance between distributions
- "kl": KL divergence from empirical to true distribution
- "jsd": Jensen-Shannon divergence from empirical to true distribution
- "2d_marginal_distribution": Marginal distribution computation
Attributes:
| Name | Type | Description |
|---|---|---|
metrics |
List of metric names to compute, choose from {"tv", "kl", "jsd", "2d_marginal_distribution"}. |
|
env |
Enumerable environment for which to compute metrics. |
|
buffer_size |
Maximum number of states to store in the replay buffer. |
Source code in gfnx/metrics/approx_distribution.py
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 | |
InitArgs
Bases: BaseInitArgs
Arguments for initializing the ApproxDistributionMetricsModule.
Attributes:
| Name | Type | Description |
|---|---|---|
env_params |
TEnvParams
|
Environment parameters needed to obtain the true distribution and initialize environment states for the replay buffer. |
Source code in gfnx/metrics/approx_distribution.py
138 139 140 141 142 143 144 145 146 147 | |
ProcessArgs
Bases: BaseProcessArgs
Arguments for processing the ApproxDistributionMetricsModule.
Attributes:
| Name | Type | Description |
|---|---|---|
env_params |
TEnvParams
|
Environment parameters needed for empirical distribution computation from the collected states in the replay buffer. |
Source code in gfnx/metrics/approx_distribution.py
213 214 215 216 217 218 219 220 221 222 | |
UpdateArgs
Bases: BaseUpdateArgs
Arguments for updating the ApproxDistributionMetricsModule.
Attributes:
| Name | Type | Description |
|---|---|---|
states |
TEnvState
|
Environment states to add to the replay buffer for empirical distribution computation. These represent states visited during training/evaluation episodes. |
Source code in gfnx/metrics/approx_distribution.py
177 178 179 180 181 182 183 184 185 186 187 | |
__init__(metrics, env, buffer_size=1000)
Initialize the distribution metrics module.
Sets up the metric module with specified metrics, environment, and buffer size. Validates that the environment is enumerable and that all requested metrics are supported.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics
|
list[str]
|
List of metric names to compute. Must be subset of supported metrics. Supported options: ["tv", "kl", "jsd", "marginal_distribution"] |
required |
env
|
TEnvironment
|
Environment instance for which to compute metrics. Must be enumerable (i.e., env.is_enumerable must be True) |
required |
buffer_size
|
int
|
Maximum number of states to store in the replay buffer for empirical distribution computation. Must be positive integer. |
1000
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If environment is not enumerable, buffer_size is not positive, metrics is not a list of strings, or contains unsupported metrics |
Source code in gfnx/metrics/approx_distribution.py
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | |
get(metrics_state)
Get the computed distribution metrics from the processed state.
Computes and returns the requested distribution metrics by comparing the true distribution with the empirical distribution. Each metric is computed using the corresponding function from the _supported_metrics dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
ApproxDistributionMetricsState
|
Processed metric state containing both true and empirical distributions |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dict[str, Any]: Dictionary containing the computed metrics. Keys are the metric names specified during initialization, and values are the computed metric values (typically float scalars). |
Example
If initialized with metrics=["tv", "kl", "jsd"], might return: {"tv": 0.123, "kl": 0.456}
Source code in gfnx/metrics/approx_distribution.py
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 | |
init(rng_key, args)
Initialize the metric state for distribution metrics.
Creates and initializes all components needed for distribution metric computation: the true distribution from the environment, an initial uniform empirical distribution, and a replay buffer for storing environment states.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
PRNGKey
|
JAX PRNG key for any random initialization (currently unused) |
required |
args
|
InitArgs
|
InitArgs object containing environment parameters |
required |
Returns:
| Name | Type | Description |
|---|---|---|
ApproxDistributionMetricsState |
ApproxDistributionMetricsState
|
Initialized state containing: - true_distribution: Ground truth distribution from the environment - empirical_distribution: Initial uniform distribution - replay_buffer: Empty buffer ready to collect states |
Source code in gfnx/metrics/approx_distribution.py
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | |
process(metrics_state, rng_key, args)
Process the metric state to compute the empirical distribution.
This method computes the empirical distribution from the states stored in the
replay buffer. It calls the environment's get_empirical_distribution method
to convert the collected states into a probability distribution that can be
compared with the true distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
ApproxDistributionMetricsState
|
Current metric state containing the replay buffer with collected states |
required |
rng_key
|
PRNGKey
|
JAX PRNG key for any random operations (currently unused) |
required |
args
|
ProcessArgs
|
ProcessArgs object containing environment parameters |
required |
Returns:
| Name | Type | Description |
|---|---|---|
ApproxDistributionMetricsState |
ApproxDistributionMetricsState
|
Updated state with the computed empirical distribution. This state is now ready for metric computation via get(). |
Source code in gfnx/metrics/approx_distribution.py
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 | |
update(metrics_state, rng_key, args)
Update the metric state with new environment states.
Adds new environment states to the replay buffer for empirical distribution computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
ApproxDistributionMetricsState
|
Current metric state containing the replay buffer |
required |
rng_key
|
PRNGKey
|
JAX PRNG key for any random operations (currently unused) |
required |
args
|
UpdateArgs
|
UpdateArgs object containing environment states to add |
required |
Returns:
| Name | Type | Description |
|---|---|---|
ApproxDistributionMetricsState |
ApproxDistributionMetricsState
|
Updated state with new states added to the buffer. The true_distribution and empirical_distribution remain unchanged until process() is called. |
Source code in gfnx/metrics/approx_distribution.py
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | |
ApproxDistributionMetricsState
Bases: MetricsState
State for approximate distribution-based metrics.
This state container holds the data necessary for computing distribution-based metrics such as KL divergence and total variation distance. It stores both the true distribution from the environment and the empirical distribution computed from collected samples, along with a replay buffer for state storage.
Attributes:
| Name | Type | Description |
|---|---|---|
true_distribution |
Array
|
The true distribution of the environment (ground truth) |
empirical_distribution |
Array
|
The empirical distribution computed from collected samples |
replay_buffer |
ArrayTree
|
Buffer storing environment states for distribution computation |
Source code in gfnx/metrics/approx_distribution.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | |
BaseCorrelationMetricsModule
Bases: BaseMetricsModule
Abstract base class for correlation-based GFlowNet evaluation metrics.
This class provides common functionality for computing correlation metrics between transformed model-predicted marginal probabilities and transformed log true rewards. It implements the core logic for backward trajectory sampling and log-ratio computation that is shared across different correlation metric variants.
The correlation metrics evaluate how well the learned GFlowNet policy captures the true reward distribution by computing correlations between: - Transformed log-ratios of backward trajectories (model predictions) - Transformed log-rewards of terminal states (ground truth)
Attributes:
| Name | Type | Description |
|---|---|---|
env |
Environment instance for trajectory generation and evaluation |
|
bwd_policy_fn |
Backward policy function for computing trajectory ratios |
|
n_rounds |
Number of sampling rounds for statistical stability |
|
batch_size |
Batch size for processing terminal states during evaluation |
|
transform_fn |
Function to marginalize distributions from terminal states to an arbitrary domain size for correlation computation |
Source code in gfnx/metrics/correlation.py
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 | |
ProcessArgs
Bases: BaseProcessArgs
Arguments for processing the correlation metric module.
Attributes:
| Name | Type | Description |
|---|---|---|
policy_params |
TPolicyParams
|
Current policy parameters used for trajectory generation, backward trajectory sampling, and log-ratio computation. |
env_params |
TEnvParams
|
Environment parameters required for trajectory generation, reward computation, and rollout operations. |
Source code in gfnx/metrics/correlation.py
105 106 107 108 109 110 111 112 113 114 115 116 117 | |
__init__(env, bwd_policy_fn, n_rounds, batch_size=1, transform_fn=None)
Initialize the correlation metric module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
TEnvironment
|
Environment for trajectory generation and reward computation |
required |
bwd_policy_fn
|
TPolicyFn
|
Backward policy function for trajectory probability estimation |
required |
n_rounds
|
int
|
Number of sampling rounds for averaging log-ratios |
required |
batch_size
|
int
|
Batch size for processing terminal states |
1
|
transform_fn
|
Callable[[TEnvState, ndarray], ndarray] | None
|
Function to marginalize distributions from terminal states to an arbitrary domain size for correlation computation |
None
|
Source code in gfnx/metrics/correlation.py
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | |
get(metrics_state)
Compute and return correlation metrics from the current state.
Calculates two correlation measures between transformed model predictions and transformed true rewards: 1. Pearson correlation in log-probability space 2. Spearman rank correlation
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
CorrelationMetricsState
|
Current state containing transformed log-ratios and transformed log-rewards |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dict[str, Any]: Dictionary containing computed correlation metrics: - 'pearson': Pearson correlation of log-probabilities - 'spearman': Spearman rank correlation of log-probabilities |
Source code in gfnx/metrics/correlation.py
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | |
update(metrics_state, rng_key, args=None)
Update metric state with new data (no-op for correlation metrics).
Source code in gfnx/metrics/correlation.py
96 97 98 99 100 101 102 103 | |
BaseInitArgs
Bases: ABC
Base class for argument containers passed as args to the init method
of metric modules.
Source code in gfnx/metrics/base.py
26 27 28 29 30 | |
BaseMetricsModule
Bases: ABC, Generic[TInitArgs, TUpdateArgs, TProcessArgs, TMetricsState]
Environment-agnostic base metric module.
This abstract base class defines the interface for all metric modules in the system. Metric modules are responsible for tracking, computing, and reporting metrics during training and evaluation phases. They maintain their own state and can be composed together using the MultiMetricsModule.
The lifecycle of a metric module follows this pattern: 1. init() - Initialize the metric state with required parameters 2. update() - Update state with new data points during training/evaluation 3. process() - Apply any final transformations before metric computation 4. get() - Retrieve the computed metrics as a dictionary
In addition to implementing the abstract methods, each subclass must define the following inner classes to specify argument types: - InitArgs, inheriting from BaseInitArgs - UpdateArgs, inheriting from BaseUpdateArgs - ProcessArgs, inheriting from BaseProcessArgs
Subclasses that do not provide these classes or methods will raise errors via init_subclass validation.
Source code in gfnx/metrics/base.py
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | |
get(metrics_state)
abstractmethod
Get computed metrics from the current state.
Computes and returns the final metrics based on the current state. This method should extract meaningful metrics from the accumulated data and return them in a standardized dictionary format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
TMetricsState
|
Current processed state of the metric |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dict[str, Any]: Dictionary containing computed metrics with descriptive keys and their corresponding values |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
Must be implemented by subclasses |
Source code in gfnx/metrics/base.py
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | |
init(rng_key, args=None)
abstractmethod
Initialize metric state.
Creates and returns the initial state required for metric computation. This method is called once at the beginning of training or evaluation to set up any necessary data structures, counters, or buffers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
PRNGKey
|
JAX PRNG key for any random initialization required |
required |
args
|
TInitArgs | None
|
Optional InitArgs object containing metric-specific init parameters |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
MetricsState |
TMetricsState
|
Initialized state object for this metric |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
Must be implemented by subclasses |
Source code in gfnx/metrics/base.py
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | |
process(metrics_state, rng_key, args=None)
abstractmethod
Process metric state to compute metrics and perform final transformations.
This method is called exactly to compute metrics before getting their results and to perform final transformations during the evaluation period. It prepares the accumulated data for final metric computation by applying any necessary calculations, normalizations, or statistical operations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
TMetricsState
|
Current state of the metric after all updates |
required |
rng_key
|
PRNGKey
|
JAX PRNG key for any random operations during processing |
required |
args
|
TProcessArgs | None
|
Optional ProcessArgs object containing metric-specific processing parameters |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
MetricsState |
TMetricsState
|
Processed state ready for metric retrieval via get() |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
Must be implemented by subclasses |
Source code in gfnx/metrics/base.py
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | |
update(metrics_state, rng_key, args=None)
abstractmethod
Update metric state with new data.
Updates the metric state with new data points collected during training or evaluation. This method is called repeatedly as new data becomes available and should accumulate or process the information needed for final metric computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
TMetricsState
|
Current state of the metric |
required |
rng_key
|
PRNGKey
|
JAX PRNG key for any random operations during update |
required |
args
|
TUpdateArgs | None
|
Optional UpdateArgs object containing metric-specific update data |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
MetricsState |
TMetricsState
|
Updated state object with new data incorporated |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
Must be implemented by subclasses |
Source code in gfnx/metrics/base.py
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | |
BaseProcessArgs
Bases: ABC
Base class for argument containers passed as args to the process method
of metric modules.
Source code in gfnx/metrics/base.py
42 43 44 45 46 | |
BaseUpdateArgs
Bases: ABC
Base class for argument containers passed as args to the update method
of metric modules.
Source code in gfnx/metrics/base.py
34 35 36 37 38 | |
CorrelationMetricsState
Bases: MetricsState
State container for correlation-based metric computation.
This state class stores the data required for computing correlation metrics between transformed distributions of model predictions and true rewards. It maintains terminal states, their corresponding transformed log-rewards, and computed transformed log-ratios of backward trajectories.
Attributes:
| Name | Type | Description |
|---|---|---|
test_terminal_states |
TEnvState
|
Terminal states used for correlation evaluation. These can be either sampled on-policy or from a fixed test set. |
test_log_rewards_transformed |
ndarray
|
Transformed log-rewards corresponding to the terminal states. Shape: (domain_size,) - obtained by marginalizing the distribution over terminal states to a domain of arbitrary size. |
log_ratio_traj_transformed |
ndarray
|
Transformed log-ratios of backward trajectories from terminal states. These ratios represent the model's estimated probability of reaching each terminal state, marginalized to domain_size. Shape: (domain_size,) |
Source code in gfnx/metrics/correlation.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | |
ELBOMetricState
Bases: MetricsState
State container for the Evidence Lower Bound (ELBO) metric.
This state container stores the computed ELBO metric.
Attributes:
| Name | Type | Description |
|---|---|---|
elbo |
ndarray
|
metric value. |
Source code in gfnx/metrics/elbo.py
17 18 19 20 21 22 23 24 25 26 27 | |
ELBOMetricsModule
Bases: BaseMetricsModule
Computes the Evidence Lower Bound (ELBO) for a GFlowNet model.
This metric evaluates the GFlowNet model by estimating the ELBO. The ELBO is computed by sampling trajectories from the forward policy and evaluating the log-ratios of the forward and backward probabilities plus the log reward.
The ELBO is defined as: ELBO = { if logZ is tractable: E_{traj ~ Pf} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)] - logZ else: E_{traj ~ Pf} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)] }, where traj is sampled from the trained forward policy.
Attributes:
| Name | Type | Description |
|---|---|---|
env |
Environment instance for trajectory generation and evaluation. |
|
env_params |
Environment parameters. |
|
fwd_policy_fn |
Forward policy function producing action logits. |
|
n_rounds |
Number of sampling rounds for statistical stability. |
|
batch_size |
Batch size used when evaluating policy over states. |
Source code in gfnx/metrics/elbo.py
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | |
ProcessArgs
Bases: BaseProcessArgs
Arguments for processing the ELBO metric module.
Attributes:
| Name | Type | Description |
|---|---|---|
policy_params |
TPolicyParams
|
Current policy parameters used for forward and backward rollouts to generate terminal states and compute log-ratios. |
env_params |
TEnvParams
|
Environment parameters required for trajectory generation and reward computation. |
Source code in gfnx/metrics/elbo.py
112 113 114 115 116 117 118 119 120 121 122 123 124 | |
__init__(env, env_params, fwd_policy_fn, n_rounds, batch_size)
Initializes the ELBO metric module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
TEnvironment
|
Environment for trajectory generation and reward computation. |
required |
env_params
|
TEnvParams
|
Environment parameters used for trajectory generation. |
required |
fwd_policy_fn
|
TPolicyFn
|
Forward policy function for generating trajectories. |
required |
n_rounds
|
int
|
The number of sampling rounds to perform for estimation. |
required |
batch_size
|
int
|
The number of environments to run in parallel for sampling. |
required |
Source code in gfnx/metrics/elbo.py
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | |
get(metrics_state)
Returns the computed ELBO metric from the current state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
ELBOMetricState
|
The current state containing the computed ELBO. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the ELBO value. |
Source code in gfnx/metrics/elbo.py
101 102 103 104 105 106 107 108 109 110 | |
init(rng_key, args)
Initialize the metric state for ELBO metric.
Source code in gfnx/metrics/elbo.py
83 84 85 | |
process(metrics_state, rng_key, args)
Computes the ELBO by sampling trajectories from the forward policy.
This method performs multiple rounds of forward rollouts to sample trajectories, and then computes the ELBO for each trajectory. The final ELBO is the average over all sampled trajectories across all rounds.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
PRNGKey
|
Random number generator key for sampling. |
required |
args
|
ProcessArgs
|
Arguments for processing, containing policy and environment parameters. |
required |
Returns:
| Type | Description |
|---|---|
ELBOMetricState
|
An updated metrics state containing the ELBO value, averaged over all |
ELBOMetricState
|
trajectories and rounds. |
Source code in gfnx/metrics/elbo.py
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | |
update(metrics_state, rng_key, args=None)
Update metric state with new data. This is a no-op as the metric is computed on demand.
Source code in gfnx/metrics/elbo.py
89 90 91 92 93 94 95 96 97 98 99 | |
EUBOMetricState
Bases: MetricsState
State container for the Evidence Upper Bound (EUBO) metric.
This state container stores the computed EUBO metric.
Attributes:
| Name | Type | Description |
|---|---|---|
eubo |
ndarray
|
metric value. |
Source code in gfnx/metrics/eubo.py
17 18 19 20 21 22 23 24 25 26 27 | |
EUBOMetricsModule
Bases: BaseMetricsModule
Computes the Evidence Upper Bound (EUBO) for a GFlowNet model.
This metric evaluates the GFlowNet model by estimating the EUBO. The EUBO is computed by sampling trajectories from the backward policy and evaluating the log-ratios of the forward and backward probabilities plus the log reward.
The EUBO is defined as: EUBO = { if logZ is tractable: E_{traj ~ R * Pb} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)] else: E_{traj ~ R * Pb} [log Pb(traj | traj_n) + log R(traj_n) - log Pf(traj)] - logZ }, where traj is sampled from the trained backward policy.
Attributes:
| Name | Type | Description |
|---|---|---|
env |
Environment instance for trajectory generation and evaluation. |
|
env_params |
Environment parameters. |
|
bwd_policy_fn |
Backward policy function for sampling trajectories starting from terminal states and computing log-ratios. |
|
n_rounds |
Number of sampling rounds for statistical stability. |
|
batch_size |
Batch size used when evaluating policy over states. |
|
rng_key |
Key used for pseudo random generation. |
Source code in gfnx/metrics/eubo.py
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | |
ProcessArgs
Bases: BaseProcessArgs
Arguments for processing the EUBO metric module.
Attributes:
| Name | Type | Description |
|---|---|---|
policy_params |
TPolicyParams
|
Current policy parameters used for forward and backward rollouts to generate terminal states and compute log-ratios. |
env_params |
TEnvParams
|
Environment parameters required for trajectory generation and reward computation. |
Source code in gfnx/metrics/eubo.py
118 119 120 121 122 123 124 125 126 127 128 129 130 | |
__init__(env, env_params, bwd_policy_fn, n_rounds, batch_size, rng_key)
Initializes the EUBO metric module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
TEnvironment
|
Environment for trajectory generation and reward computation. |
required |
env_params
|
TEnvParams
|
Environment parameters. |
required |
bwd_policy_fn
|
TPolicyFn
|
Backward policy function for generating trajectories starting from terminal states. |
required |
n_rounds
|
int
|
The number of sampling rounds to perform for estimation. |
required |
batch_size
|
int
|
The number of environments to run in parallel for sampling. |
required |
Source code in gfnx/metrics/eubo.py
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | |
get(metrics_state)
Returns the computed EUBO metric from the current state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
EUBOMetricState
|
The current state containing the computed EUBO. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing the EUBO value. |
Source code in gfnx/metrics/eubo.py
107 108 109 110 111 112 113 114 115 116 | |
init(rng_key, args)
Initialize the metric state for EUBO metric.
Source code in gfnx/metrics/eubo.py
89 90 91 | |
process(metrics_state, rng_key, args)
Computes the EUBO by sampling trajectories from the backward policy.
This method performs multiple rounds of backward rollouts to sample trajectories, and then computes the EUBO for each trajectory. The final EUBO is the average over all sampled trajectories across all rounds.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
PRNGKey
|
Random number generator key for sampling. |
required |
args
|
ProcessArgs
|
Arguments for processing, containing policy and environment parameters. |
required |
Returns:
| Type | Description |
|---|---|
EUBOMetricState
|
An updated metrics state containing the EUBO value, averaged over all |
EUBOMetricState
|
trajectories and rounds. |
Source code in gfnx/metrics/eubo.py
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | |
update(metrics_state, rng_key, args=None)
Update metric state with new data. This is a no-op as the metric is computed on demand.
Source code in gfnx/metrics/eubo.py
95 96 97 98 99 100 101 102 103 104 105 | |
ExactDistributionMetricsModule
Bases: BaseMetricsModule
Exact distribution metrics for enumerable environments.
For enumerable environments, this module computes the exact terminal distribution induced by simple policy evaluation method.
Supported metrics
- "tv": Total variation distance between distributions
- "kl": KL divergence between true and exact terminal distributions
- "jsd": Jensen-Shannon divergence between true and exact terminal distributions
- "2d_marginal_distribution": Marginal distribution computation
Attributes:
| Name | Type | Description |
|---|---|---|
metrics |
List of required metrics, choose from {"tv", "kl", "jsd", "2d_marginal_distribution"} |
|
env |
Enumerable environment for which to compute metrics |
|
fwd_policy_fn |
Forward policy function producing action logits |
|
batch_size |
Batch size used when evaluating policy over states |
|
tol_epsilon |
Tolerance for convergence in distribution computation |
|
_supported_metrics |
dict[str, Callable[..., Array]]
|
Dictionary mapping metric names to computation functions |
Source code in gfnx/metrics/exact_distribution.py
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 | |
InitArgs
Bases: BaseInitArgs
Arguments for initializing the ExactDistributionMetricsModule.
Attributes:
| Name | Type | Description |
|---|---|---|
env_params |
TEnvParams
|
Environment parameters used to obtain the true distribution and query environment functions. |
Source code in gfnx/metrics/exact_distribution.py
132 133 134 135 136 137 138 139 140 141 | |
ProcessArgs
Bases: BaseProcessArgs
Arguments for processing the ExactDistributionMetricsModule.
Attributes:
| Name | Type | Description |
|---|---|---|
policy_params |
TPolicyParams
|
Parameters for the forward policy used during propagation. |
env_params |
TEnvParams
|
Environment parameters. |
Source code in gfnx/metrics/exact_distribution.py
173 174 175 176 177 178 179 180 181 182 183 | |
__init__(metrics, env, fwd_policy_fn, batch_size, tol_epsilon=1e-07)
Initialize the exact-distribution metrics module.
Validates inputs and records configuration. batch_size controls how
states are grouped when evaluating the policy over the topological order
for memory efficiency; it is unrelated to any replay buffer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics
|
list[str]
|
List of metric names to compute, choose from {"tv", "kl", "jsd", "2d_marginal_distribution"}. |
required |
env
|
TEnvironment
|
Enumerable environment for which to compute metrics. |
required |
fwd_policy_fn
|
TPolicyFn
|
Forward policy function for generating trajectories. |
required |
batch_size
|
int
|
Batch size used when evaluating policy over states. |
required |
tol_epsilon
|
float
|
Tolerance for convergence in distribution computation. |
1e-07
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If the environment is not enumerable or not topologically sortable,
or if |
Source code in gfnx/metrics/exact_distribution.py
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | |
init(rng_key, args)
Initialize the metric state.
Obtains the ground-truth terminal distribution from the environment and
initializes the exact_distribution with a uniform placeholder of the
same shape. The exact distribution will be computed in process.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
PRNGKey
|
JAX PRNG key (currently unused). |
required |
args
|
InitArgs
|
Initialization arguments containing environment parameters. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
ExactDistributionMetricsState |
ExactDistributionMetricsState
|
Initialized state containing: - true_distribution: Ground truth distribution from the environment - exact_distribution: Initial uniform distribution |
Source code in gfnx/metrics/exact_distribution.py
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | |
process(metrics_state, rng_key, args)
Compute the exact terminal distribution induced by the forward policy.
Uses a simple power iteration method to propagate the initial state distribution:
- Initialize the state distribution with all mass on the initial state.
- Repeatedly apply the transition matrix induced by the forward policy
until convergence (L1 norm change below tol_epsilon).
- Extract the terminal distribution from the converged state distribution.
Overall, we have
final_distribution = sum_{t=0}^inf (transition_matrix^t) @ initial_distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
ExactDistributionMetricsState
|
Current metrics state containing the true distribution. |
required |
rng_key
|
PRNGKey
|
JAX PRNG key for any stochasticity (currently unused). |
required |
args
|
ProcessArgs
|
Processing arguments containing policy and environment parameters. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
ExactDistributionMetricsState |
ExactDistributionMetricsState
|
Updated metrics state with the computed |
ExactDistributionMetricsState
|
exact distribution. |
Source code in gfnx/metrics/exact_distribution.py
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 | |
ExactDistributionMetricsState
Bases: MetricsState
State for exact distribution metrics.
Holds the ground-truth terminal distribution from the environment and the exact terminal distribution computed under a given forward policy.
Attributes:
| Name | Type | Description |
|---|---|---|
true_distribution |
Array
|
Ground-truth terminal distribution provided by the environment |
exact_distribution |
Array
|
Exact terminal distribution computed by |
Source code in gfnx/metrics/exact_distribution.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 | |
MeanRewardMetricsModule
Bases: BaseMetricsModule
Metric module for computing mean reward and its deviation from ground truth.
This module tracks the empirical mean reward from collected samples and compares it against the known ground truth mean reward of the environment. It computes both absolute and relative deviations to assess how well the policy's reward distribution matches the true environment distribution.
Attributes:
| Name | Type | Description |
|---|---|---|
env |
Environment instance that must support tractable mean reward computation |
|
gt_mean_reward |
Ground truth mean reward from the environment |
Source code in gfnx/metrics/reward_delta.py
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | |
UpdateArgs
Bases: BaseUpdateArgs
Arguments for updating the MeanRewardMetricsModule.
Attributes:
| Name | Type | Description |
|---|---|---|
log_rewards |
Array
|
Array of log-reward values to add to the running statistics. These are expected to be in log space and will be converted to linear space for mean computation. |
Source code in gfnx/metrics/reward_delta.py
82 83 84 85 86 87 88 89 90 91 92 | |
__init__(env, env_params)
Initialize the mean reward metric module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
TEnvironment
|
Environment instance that must have tractable mean reward computation (env.is_mean_reward_tractable must be True) |
required |
env_params
|
TEnvParams
|
Environment parameters needed to compute the ground truth mean reward |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the environment does not support tractable mean reward computation |
Source code in gfnx/metrics/reward_delta.py
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | |
get(metrics_state)
Get the computed mean reward metrics from the current state.
Computes the empirical mean reward from accumulated samples and calculates both absolute and relative deviations from the ground truth mean reward.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
MeanRewardMetricsState
|
Current metric state containing accumulated reward statistics |
required |
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Dict[str, float]: Dictionary containing computed reward metrics: - 'mean_reward': Empirical mean reward from collected samples - 'reward_delta': Absolute difference between empirical and ground truth mean - 'rel_reward_delta': Relative difference (normalized by ground truth mean) |
Source code in gfnx/metrics/reward_delta.py
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | |
init(rng_key, args=None)
Initialize the mean reward metric state.
Creates initial state with zero cumulative reward and zero sample count. The actual mean reward computation will be performed incrementally as new reward samples are added via update().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
PRNGKey
|
JAX PRNG key for any random initialization (currently unused) |
required |
args
|
InitArgs | None
|
EmptyInitArgs (no additional initialization parameters needed) |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
MeanRewardMetricsState |
MeanRewardMetricsState
|
Initialized state with zero accumulated reward and count |
Source code in gfnx/metrics/reward_delta.py
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | |
process(metrics_state, rng_key, args=None)
Process the metric state for final computation (no-op for mean reward metrics).
This method performs any final processing needed before metric computation. For mean reward metrics, no additional processing is required as the statistics are maintained incrementally during updates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
MeanRewardMetricsState
|
Current metric state with accumulated reward statistics |
required |
rng_key
|
PRNGKey
|
JAX PRNG key for any random operations (currently unused) |
required |
args
|
ProcessArgs | None
|
EmptyProcessArgs (no additional processing parameters needed) |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
MeanRewardMetricsState |
MeanRewardMetricsState
|
Unchanged metric state ready for get() call |
Source code in gfnx/metrics/reward_delta.py
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | |
update(metrics_state, rng_key, args)
Update the metric state with new reward samples.
Adds new reward samples to the running statistics by converting log-rewards to linear space and accumulating them in the sum. Also updates the sample count.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
MeanRewardMetricsState
|
Current metric state with accumulated reward sum and count |
required |
rng_key
|
PRNGKey
|
JAX PRNG key for any random operations (currently unused) |
required |
args
|
UpdateArgs
|
UpdateArgs object containing the new reward samples |
required |
Returns:
| Name | Type | Description |
|---|---|---|
MeanRewardMetricsState |
MeanRewardMetricsState
|
Updated state with new rewards incorporated into the running statistics |
Source code in gfnx/metrics/reward_delta.py
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | |
MeanRewardMetricsState
Bases: MetricsState
State for accumulating mean reward computation.
This state container tracks the cumulative sum of rewards and the count of samples to compute running mean reward statistics. It enables incremental computation of mean reward without storing all individual reward values.
Attributes:
| Name | Type | Description |
|---|---|---|
sum_reward |
float
|
Cumulative sum of all rewards (in linear space, not log space) |
num |
int
|
Total number of reward samples processed |
Source code in gfnx/metrics/reward_delta.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | |
MetricsState
Base state for metric computation.
This is an abstract base class that serves as a pure data container for metric states. Metric states are designed to only store the necessary data and intermediate values required for computing metrics during training and evaluation. They should not contain any methods or processing logic.
All processing, computation, and transformation logic should be implemented in the corresponding BaseMetricsModule subclasses, not in the state objects themselves.
Subclasses should define specific data fields (typically using @chex.dataclass) needed for their metric computation requirements, but no methods.
Source code in gfnx/metrics/base.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 | |
MultiMetricsModule
Bases: BaseMetricsModule
Module for handling multiple metrics in a unified way.
This class implements the BaseMetricsModule interface to manage multiple individual metric modules simultaneously. It provides a convenient way to compute multiple metrics together while maintaining the same interface as single metric modules.
The MultiMetricsModule coordinates the lifecycle of all contained metrics, calling their respective init, update, process, and get methods in sequence. Final metrics are returned with prefixed names to avoid conflicts between different metric modules.
Attributes:
| Name | Type | Description |
|---|---|---|
metrics |
Dictionary of metric modules indexed by their names |
|
_supported_metrics |
Internal mapping of metric names to their get methods |
Source code in gfnx/metrics/base.py
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 | |
InitArgs
Bases: BaseInitArgs
Arguments for initializing the MultiMetricsModule.
Source code in gfnx/metrics/base.py
259 260 261 262 263 | |
ProcessArgs
Bases: BaseProcessArgs
Arguments for processing the MultiMetricsModule.
Source code in gfnx/metrics/base.py
328 329 330 331 332 | |
UpdateArgs
Bases: BaseUpdateArgs
Arguments for updating the MultiMetricsModule.
Source code in gfnx/metrics/base.py
288 289 290 291 292 | |
__init__(metrics)
Initialize the MultiMetricsModule with a collection of metrics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics
|
dict[str, BaseMetricsModule]
|
Dictionary mapping metric names to their corresponding BaseMetricsModule instances. Names should be unique and descriptive as they will be used as prefixes in the final metric output. |
required |
Source code in gfnx/metrics/base.py
247 248 249 250 251 252 253 254 255 256 257 | |
get(metrics_state)
Get computed metrics from all contained modules.
Collects metrics from all metric modules and returns them with prefixed names to avoid conflicts. Each metric's results are prefixed with its module name followed by a forward slash.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
MultiMetricsState
|
Current MultiMetricsState containing all processed states |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dict[str, Any]: Dictionary of all computed metrics with prefixed names. Keys are in the format "{metric_name}/{original_key}". |
Example
If a metric module named "accuracy" returns {"score": 0.95}, the result will include {"accuracy/score": 0.95}.
Source code in gfnx/metrics/base.py
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 | |
init(rng_key, args=None)
Initialize all contained metrics.
Calls the init method on each metric module and collects their individual states into a MultiMetricsState container.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
PRNGKey
|
JAX PRNG key passed to each metric's init method |
required |
args
|
InitArgs | None
|
Optional InitArgs object mapping metric names to init args |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
MultiMetricsState |
MultiMetricsState
|
Container holding all initialized metric states |
Source code in gfnx/metrics/base.py
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 | |
process(metrics_state, rng_key, args=None)
Process all metric states to compute metrics and perform final transformations.
Calls the process method on each metric module with their corresponding state to compute metrics and apply final transformations during the evaluation period, preparing them for result retrieval.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
MultiMetricsState
|
Current MultiMetricsState containing all metric states |
required |
rng_key
|
PRNGKey
|
JAX PRNG key passed to each metric's process method |
required |
args
|
ProcessArgs | None
|
Optional ProcessArgs object mapping metric names to process params |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
MultiMetricsState |
MultiMetricsState
|
Container with all processed metric states ready for get() |
Source code in gfnx/metrics/base.py
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 | |
update(metrics_state, rng_key, args=None)
Update all contained metrics with new data.
Calls the update method on each metric module with their corresponding state and the provided data, then returns a new MultiMetricsState with all updated states.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
MultiMetricsState
|
Current MultiMetricsState containing all metric states |
required |
rng_key
|
PRNGKey
|
JAX PRNG key passed to each metric's update method |
required |
args
|
UpdateArgs | None
|
Optional UpdateArgs object mapping metric names to update data |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
MultiMetricsState |
MultiMetricsState
|
Container with all updated metric states |
Source code in gfnx/metrics/base.py
294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 | |
MultiMetricsState
Bases: MetricsState
State container for multiple metrics.
This class extends MetricsState to hold the states of multiple individual metric modules. It provides a unified interface for managing the states of different metrics that need to be computed together.
Attributes:
| Name | Type | Description |
|---|---|---|
states |
dict[str, MetricsState]
|
Dictionary mapping metric names to their individual MetricsState objects. Each key represents a unique metric identifier, and each value is the corresponding metric's state object. |
Source code in gfnx/metrics/base.py
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | |
OnPolicyCorrelationMetricsModule
Bases: BaseCorrelationMetricsModule
On-policy correlation metric module for GFlowNet evaluation.
This metric module computes correlation metrics between transformed model-predicted marginal probabilities and transformed log true rewards. During evaluation, it generates fresh terminal states by performing forward rollouts with the current policy, then computes backward trajectory log-ratios.
Attributes:
| Name | Type | Description |
|---|---|---|
n_terminal_states |
Number of terminal states to sample for evaluation |
|
domain_size |
Number of points to compute correlations on after marginalization |
|
fwd_policy_fn |
Forward policy function for generating trajectories |
Source code in gfnx/metrics/correlation.py
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 | |
InitArgs
Bases: BaseInitArgs
Arguments for initializing the on-policy correlation metric module.
Attributes:
| Name | Type | Description |
|---|---|---|
env_params |
TEnvParams
|
Environment parameters needed to create dummy terminal states during initialization and for environment operations. |
Source code in gfnx/metrics/correlation.py
337 338 339 340 341 342 343 344 345 346 | |
__init__(n_rounds, n_terminal_states, batch_size, fwd_policy_fn, bwd_policy_fn, env, domain_size=None, transform_fn=lambda x, y: y)
Initialize the on-policy correlation metric module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_rounds
|
int
|
Number of sampling rounds for statistical stability |
required |
n_terminal_states
|
int
|
Number of terminal states to generate and evaluate |
required |
batch_size
|
int
|
Batch size for efficient processing of terminal states |
required |
fwd_policy_fn
|
TPolicyFn
|
Forward policy function for generating trajectories |
required |
bwd_policy_fn
|
TPolicyFn
|
Backward policy function for computing log-ratios |
required |
env
|
TEnvironment
|
Environment for trajectory generation and reward computation |
required |
domain_size
|
int | None
|
Number of points to compute correlations on, corresponds to the output size of the transform_fn. Defaults to the number of terminal states and corresponds to the identity transform. |
None
|
transform_fn
|
Callable[[TEnvState, ndarray], ndarray]
|
Function to marginalize distributions from terminal states to an arbitrary domain size for correlation computation |
lambda x, y: y
|
Source code in gfnx/metrics/correlation.py
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 | |
init(rng_key, args)
Initialize the on-policy correlation metric state.
Creates an initial state with placeholder data structures. The actual terminal states and rewards will be generated during the process phase using on-policy sampling.
The arrays are initialized with domain_size, which represents the size of the transformed probability space after marginalization. This allows the correlation metrics to be computed on an arbitrary domain rather than just the terminal states.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
PRNGKey
|
Random number generator key (unused in initialization) |
required |
args
|
InitArgs
|
InitArgs object containing environment parameters |
required |
Returns:
| Name | Type | Description |
|---|---|---|
CorrelationMetricsState |
CorrelationMetricsState
|
Initialized state with dummy terminal states and zero-initialized arrays for rewards and log-ratios with shape (domain_size,) representing the transformed probability space |
Source code in gfnx/metrics/correlation.py
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 | |
SWMeanRewardMetricsState
Bases: MetricsState
State for mean reward metrics with sliding window buffer.
This state container maintains a sliding window buffer of recent rewards to compute mean reward statistics over a fixed-size window of the most recent samples.
Attributes:
| Name | Type | Description |
|---|---|---|
reward_buffer |
Any
|
Flashbax buffer storing the most recent reward samples in a circular buffer with configurable maximum size |
Source code in gfnx/metrics/reward_delta.py
165 166 167 168 169 170 171 172 173 174 175 176 177 178 | |
SWMeanRewardSWMetricsModule
Bases: BaseMetricsModule
Sliding window mean reward metric module for recent performance tracking.
This module computes mean reward statistics using a sliding window approach, maintaining only the most recent reward samples in a circular buffer.
Attributes:
| Name | Type | Description |
|---|---|---|
env |
Environment instance that must support tractable mean reward computation |
|
gt_mean_reward |
Ground truth mean reward from the environment |
|
buffer_size |
Maximum number of rewards to keep in the sliding window |
|
buffer_module |
Flashbax buffer module for managing the sliding window |
Source code in gfnx/metrics/reward_delta.py
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 | |
UpdateArgs
Bases: BaseUpdateArgs
Arguments for updating the SWMeanRewardSWMetricsModule.
Attributes:
| Name | Type | Description |
|---|---|---|
rewards |
Array
|
Array of reward values to add to the sliding window buffer. These rewards will replace the oldest entries when the buffer is full. |
Source code in gfnx/metrics/reward_delta.py
242 243 244 245 246 247 248 249 250 251 | |
__init__(env, env_params, buffer_size)
Initialize the sliding window mean reward metric module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
TEnvironment
|
Environment instance that must have tractable mean reward computation (env.is_mean_reward_tractable must be True) |
required |
env_params
|
TEnvParams
|
Environment parameters needed to compute the ground truth mean reward |
required |
buffer_size
|
int
|
Maximum number of reward samples to keep in the sliding window. Must be a positive integer. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the environment does not support tractable mean reward computation |
Source code in gfnx/metrics/reward_delta.py
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | |
get(metrics_state)
Get the computed sliding window mean reward metrics from the current state.
Computes the mean reward from samples in the sliding window buffer and calculates both absolute and relative deviations from the ground truth mean reward. Only valid (non-empty) buffer entries are included in the computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
SWMeanRewardMetricsState
|
Current metric state containing the sliding window buffer |
required |
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Dict[str, float]: Dictionary containing computed reward metrics: - 'mean_reward': Mean reward from samples in the sliding window - 'reward_delta': Absolute difference between window mean and ground truth - 'rel_reward_delta': Relative difference (normalized by ground truth mean) |
Source code in gfnx/metrics/reward_delta.py
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 | |
init(rng_key, args=None)
Initialize the sliding window mean reward metric state.
Creates initial state with an empty sliding window buffer ready to accept reward samples. The buffer will automatically manage the sliding window behavior as new samples are added.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
PRNGKey
|
JAX PRNG key for any random initialization (currently unused) |
required |
args
|
InitArgs | None
|
EmptyInitArgs (no additional initialization parameters needed) |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
SWMeanRewardMetricsState |
SWMeanRewardMetricsState
|
Initialized state with empty sliding window buffer |
Source code in gfnx/metrics/reward_delta.py
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 | |
process(metrics_state, rng_key, args=None)
Process the metric state for final computation (no-op for sliding window metrics).
This method performs any final processing needed before metric computation. For sliding window mean reward metrics, no additional processing is required as the statistics are computed directly from the buffer contents.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
MeanRewardMetricsState
|
Current metric state with sliding window buffer |
required |
rng_key
|
PRNGKey
|
JAX PRNG key for any random operations (currently unused) |
required |
args
|
ProcessArgs | None
|
EmptyProcessArgs (no additional processing parameters needed) |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
MeanRewardMetricsState |
MeanRewardMetricsState
|
Unchanged metric state ready for get() call |
Source code in gfnx/metrics/reward_delta.py
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 | |
update(metrics_state, rng_key, args)
Update the metric state with new reward samples in the sliding window.
Adds new reward samples to the sliding window buffer. When the buffer is full, the oldest samples are automatically replaced with the new ones, maintaining the fixed window size.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
SWMeanRewardMetricsState
|
Current metric state containing the sliding window buffer |
required |
rng_key
|
PRNGKey
|
JAX PRNG key for any random operations (currently unused) |
required |
args
|
UpdateArgs
|
UpdateArgs object containing the new reward samples |
required |
Returns:
| Name | Type | Description |
|---|---|---|
SWMeanRewardMetricsState |
SWMeanRewardMetricsState
|
Updated state with new rewards added to the buffer |
Source code in gfnx/metrics/reward_delta.py
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 | |
TestCorrelationMetricsModule
Bases: BaseCorrelationMetricsModule
Fixed test set correlation metric module for GFlowNet evaluation.
This metric module computes correlation metrics between transformed model predictions and transformed true rewards using a fixed set of test terminal states. Unlike the on-policy variant, this module evaluates the model's performance on the same set of terminal states across different training iterations, providing consistent evaluation points for tracking training progress.
Source code in gfnx/metrics/correlation.py
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 | |
InitArgs
Bases: BaseInitArgs
Arguments for initializing the test correlation metric module.
Attributes:
| Name | Type | Description |
|---|---|---|
env_params |
TEnvParams
|
Environment parameters needed for computing log-rewards from the test set and for environment operations. |
test_set |
TEnvState
|
Fixed set of terminal states that will be used consistently for correlation evaluation across training iterations. |
Source code in gfnx/metrics/correlation.py
426 427 428 429 430 431 432 433 434 435 436 437 438 | |
init(rng_key, args)
Initialize metric state with a fixed test set.
Sets up the metric state using a predefined set of terminal states. The log-rewards are computed immediately from the test set, while log-ratios are initialized to zero and will be computed during processing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
PRNGKey
|
Random number generator key (unused in initialization) |
required |
args
|
InitArgs
|
InitArgs object containing environment parameters and test set |
required |
Returns:
| Name | Type | Description |
|---|---|---|
CorrelationMetricsState |
CorrelationMetricsState
|
Initialized state with fixed test terminal states, their computed transformed log-rewards, and zero-initialized log-ratios with shape matching the transformed log-rewards |
Source code in gfnx/metrics/correlation.py
440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 | |
TopKMetricsModule
Bases: BaseMetricsModule
Metric module for computing top-K reward and diversity statistics.
This module evaluates policy performance by sampling trajectories and computing statistics for the top-K highest-reward samples. It measures both the quality (reward) and diversity of the best samples, providing insights into the policy's ability to find high-reward solutions while maintaining diversity.
Attributes:
| Name | Type | Description |
|---|---|---|
env |
Environment instance for trajectory generation and reward computation |
|
fwd_policy_fn |
Forward policy function for generating trajectories |
|
num_traj |
Total number of trajectories to sample for evaluation |
|
batch_size |
Batch size for processing trajectories (for memory management) |
|
top_k |
List of K values for which to compute top-K statistics |
|
distance_fn |
Function to compute distance between states for diversity measurement |
Source code in gfnx/metrics/top_k.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 | |
ProcessArgs
Bases: BaseProcessArgs
Arguments for processing the TopKMetricsModule.
Attributes:
| Name | Type | Description |
|---|---|---|
policy_params |
TPolicyParams
|
Current policy parameters used for forward rollouts to generate trajectory samples for top-K evaluation. |
env_params |
Any
|
Environment parameters required for trajectory generation and reward computation. |
Source code in gfnx/metrics/top_k.py
152 153 154 155 156 157 158 159 160 161 162 163 164 | |
__init__(env, fwd_policy_fn, num_traj, batch_size, top_k=None, distance_fn=None)
Initialize the top-K metrics module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
TEnvironment
|
Environment instance for trajectory generation and reward computation |
required |
fwd_policy_fn
|
TPolicyFn
|
Forward policy function for generating trajectory samples |
required |
num_traj
|
int
|
Total number of trajectories to sample during evaluation. Must be >= max(top_k) for meaningful statistics. |
required |
batch_size
|
int
|
Batch size for processing trajectories (used for memory management) |
required |
top_k
|
list[int] | None
|
List of K values for which to compute top-K statistics. Default is [10, 50, 100]. |
None
|
distance_fn
|
Callable[[TEnvState, TEnvState], float] | None
|
Function that computes distance between two environment states for diversity measurement. Must return a scalar distance value. |
None
|
Source code in gfnx/metrics/top_k.py
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | |
get(metrics_state)
Get the computed top-K metrics from the current state.
Extracts the computed top-K reward and diversity statistics and formats them into a dictionary with descriptive keys for each K value specified during module initialization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
TopKMetricsState
|
Current metric state containing computed top-K statistics |
required |
Returns:
| Type | Description |
|---|---|
dict
|
Dict[str, float]: Dictionary containing computed top-K metrics with keys: - 'top_{k}reward': Mean reward of the top-K samples for each K - 'top{k}_diversity': Mean pairwise distance among top-K samples for each K |
Example
If initialized with top_k=[10, 50], might return: { "top_10_reward": 0.85, "top_10_diversity": 0.42, "top_50_reward": 0.78, "top_50_diversity": 0.51 }
Source code in gfnx/metrics/top_k.py
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 | |
init(rng_key, args=None)
Initialize the top-K metrics state.
Creates initial state with zero-initialized arrays for top-K rewards and diversity statistics. The actual values will be computed during the process phase.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
PRNGKey
|
JAX PRNG key for any random initialization (currently unused) |
required |
args
|
InitArgs | None
|
EmptyInitArgs (no additional initialization parameters needed) |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
TopKMetricsState |
TopKMetricsState
|
Initialized state with zero arrays for rewards and diversity |
Source code in gfnx/metrics/top_k.py
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | |
process(metrics_state, rng_key, args)
Process the metric state to compute top-K reward and diversity statistics.
This method performs the core computation for top-K metrics: 1. Samples trajectories using the current policy 2. Computes rewards for all terminal states 3. Identifies the top-K highest-reward samples for each K value 4. Computes mean rewards and diversity measures for each top-K set
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
TopKMetricsState
|
Current metric state (largely ignored, will be replaced) |
required |
rng_key
|
PRNGKey
|
JAX PRNG key for trajectory sampling |
required |
args
|
ProcessArgs
|
ProcessArgs object containing policy parameters and environment parameters |
required |
Returns:
| Name | Type | Description |
|---|---|---|
TopKMetricsState |
TopKMetricsState
|
Updated state with computed top-K rewards and diversity statistics |
Source code in gfnx/metrics/top_k.py
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | |
update(metrics_state, rng_key, args=None)
Update the metric state with new data (no-op for top-K metrics).
Top-K metrics are computed entirely during the process phase by sampling fresh trajectories, so no incremental updates are needed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics_state
|
TopKMetricsState
|
Current metric state (unchanged) |
required |
rng_key
|
PRNGKey
|
JAX PRNG key for any random operations (currently unused) |
required |
args
|
UpdateArgs | None
|
EmptyUpdateArgs (no update parameters needed) |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
TopKMetricsState |
TopKMetricsState
|
Unchanged metric state |
Source code in gfnx/metrics/top_k.py
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | |
TopKMetricsState
Bases: MetricsState
State for Top-K reward and diversity metrics.
This state container stores the computed top-K reward and diversity statistics for different values of K. It maintains arrays where each element corresponds to a different K value specified during module initialization.
Attributes:
| Name | Type | Description |
|---|---|---|
top_k_rewards |
Array
|
Array of mean rewards for the top-K samples for each K value. Shape: (len(top_k),) where each entry is the mean reward of the top-K samples. |
top_k_diversity |
Array
|
Array of diversity measures for the top-K samples for each K value. Shape: (len(top_k),) where each entry is the average pairwise distance among top-K samples. |
Source code in gfnx/metrics/top_k.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | |