Skip to content

Reward API Reference

Priors

BaseDAGPrior

Source code in gfnx/reward/dag_prior.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class BaseDAGPrior:
    def init(self, rng_key: chex.PRNGKey, dummy_state: DAGEnvState) -> TRewardParams:
        """Initialize the prior. Default implementation returns None.

        Args:
        - rng_key: chex.PRNGKey, random key
        - dummy_state: DAGEnvState, shape [1, ...], a dummy state
        """
        return None

    def log_prob(self, state: DAGEnvState, env_params: DAGEnvParams) -> TLogReward:
        """Computes log P(G).

        Args:
        - state: DAGEnvState, shape [B, ...], batch of states
        - env_params: DAGEnvParams, params of environment,
          always includes reward params

        Returns:
        - TLogReward, shape [B], batch of log P(G)
        """
        raise NotImplementedError

    def delta_score(
        self,
        state: DAGEnvState,
        action: TAction,
        next_state: DAGEnvState,
        env_params: DAGEnvParams,
    ) -> TLogReward:
        """Computes log P(G') - log P(G), where G' is the result of adding
        the edge X_i -> X_j to G.

        Args:
        - state: DAGEnvState, shape [B, ...], batch of states
        - action: DAGEnvAction, shape [B], batch of actions
        - next_state: DAGEnvState, shape [B, ...], batch of next states
        - env_params: DAGEnvParams, params of environment,
          always includes reward params

        Returns:
        - TLogReward, shape [B], batch of log P(G') - log P(G)
        """
        return self.log_prob(next_state, env_params) - self.log_prob(state, env_params)

    @staticmethod
    def num_parents(state: DAGEnvState) -> chex.Array:
        return jnp.count_nonzero(state.adjacency_matrix, axis=1)

delta_score(state, action, next_state, env_params)

Computes log P(G') - log P(G), where G' is the result of adding the edge X_i -> X_j to G.

Args: - state: DAGEnvState, shape [B, ...], batch of states - action: DAGEnvAction, shape [B], batch of actions - next_state: DAGEnvState, shape [B, ...], batch of next states - env_params: DAGEnvParams, params of environment, always includes reward params

Returns: - TLogReward, shape [B], batch of log P(G') - log P(G)

Source code in gfnx/reward/dag_prior.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def delta_score(
    self,
    state: DAGEnvState,
    action: TAction,
    next_state: DAGEnvState,
    env_params: DAGEnvParams,
) -> TLogReward:
    """Computes log P(G') - log P(G), where G' is the result of adding
    the edge X_i -> X_j to G.

    Args:
    - state: DAGEnvState, shape [B, ...], batch of states
    - action: DAGEnvAction, shape [B], batch of actions
    - next_state: DAGEnvState, shape [B, ...], batch of next states
    - env_params: DAGEnvParams, params of environment,
      always includes reward params

    Returns:
    - TLogReward, shape [B], batch of log P(G') - log P(G)
    """
    return self.log_prob(next_state, env_params) - self.log_prob(state, env_params)

init(rng_key, dummy_state)

Initialize the prior. Default implementation returns None.

Args: - rng_key: chex.PRNGKey, random key - dummy_state: DAGEnvState, shape [1, ...], a dummy state

Source code in gfnx/reward/dag_prior.py
 9
10
11
12
13
14
15
16
def init(self, rng_key: chex.PRNGKey, dummy_state: DAGEnvState) -> TRewardParams:
    """Initialize the prior. Default implementation returns None.

    Args:
    - rng_key: chex.PRNGKey, random key
    - dummy_state: DAGEnvState, shape [1, ...], a dummy state
    """
    return None

log_prob(state, env_params)

Computes log P(G).

Args: - state: DAGEnvState, shape [B, ...], batch of states - env_params: DAGEnvParams, params of environment, always includes reward params

Returns: - TLogReward, shape [B], batch of log P(G)

Source code in gfnx/reward/dag_prior.py
18
19
20
21
22
23
24
25
26
27
28
29
def log_prob(self, state: DAGEnvState, env_params: DAGEnvParams) -> TLogReward:
    """Computes log P(G).

    Args:
    - state: DAGEnvState, shape [B, ...], batch of states
    - env_params: DAGEnvParams, params of environment,
      always includes reward params

    Returns:
    - TLogReward, shape [B], batch of log P(G)
    """
    raise NotImplementedError

Likelihoods

BGeScore

Bases: BaseDAGLikelihood

Source code in gfnx/reward/dag_likelihood.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
class BGeScore(BaseDAGLikelihood):
    @chex.dataclass(frozen=True)
    class BGeScoreParams:
        r_matrix: chex.Array
        log_gamma_term: chex.Array

    def __init__(
        self,
        data: chex.Array,
        mean_obs: chex.Array | None = None,
        alpha_mu: float = 1.0,
        alpha_w: float | None = None,
    ):
        self.num_samples, self.num_variables = data.shape
        if mean_obs is None:
            mean_obs = jnp.zeros((self.num_variables,))
        if alpha_w is None:
            alpha_w = self.num_variables + 2.0

        self.mean_obs = mean_obs
        self.alpha_mu = alpha_mu
        self.alpha_w = alpha_w

        self.t = (self.alpha_mu * (self.alpha_w - self.num_variables - 1)) / (self.alpha_mu + 1)

        t_matrix = self.t * jnp.eye(self.num_variables)
        data_mean = jnp.mean(data, axis=0, keepdims=True)
        data_centered = data - data_mean

        self.r_matrix = (
            t_matrix
            + jnp.matmul(data_centered.T, data_centered)
            + ((self.num_samples * self.alpha_mu) / (self.num_samples + self.alpha_mu))
            * jnp.dot((data_mean - self.mean_obs).T, data_mean - self.mean_obs)
        )
        all_parents = jnp.arange(self.num_variables)
        self.log_gamma_term = (
            0.5 * (jnp.log(self.alpha_mu) - jnp.log(self.num_samples + self.alpha_mu))
            + gammaln(
                0.5 * (self.num_samples + self.alpha_w - self.num_variables + all_parents + 1)
            )
            - gammaln(0.5 * (self.alpha_w - self.num_variables + all_parents + 1))
            - 0.5 * self.num_samples * jnp.log(jnp.pi)
            + 0.5 * (self.alpha_w - self.num_variables + 2 * all_parents + 1) * jnp.log(self.t)
        )

    def init(self, rng_key: chex.PRNGKey, dummy_state: DAGEnvState) -> BGeScoreParams:
        """
        Initialize the likelihood.
        Args:
        - rng_key: chex.PRNGKey, random key
        - dummy_state: DAGEnvState, shape [1, ...], a dummy state

        Returns:
        - Precomputed BGeScore parameters
        """
        return self.BGeScoreParams(r_matrix=self.r_matrix, log_gamma_term=self.log_gamma_term)

    def _local_score(
        self, variables: chex.Array, parents: chex.Array, likelihood_params: BGeScoreParams
    ) -> TLogReward:
        r_matrix = likelihood_params.r_matrix
        log_gamma_term = likelihood_params.log_gamma_term

        def _logdet(array: chex.Array, mask: chex.Array) -> chex.Array:
            mask = mask[:, None, :] * mask[:, :, None]
            array = mask * array + (1.0 - mask) * jnp.eye(self.num_variables)
            _, logdet = jnp.linalg.slogdet(array)
            return logdet

        num_parents = jnp.sum(parents, axis=1)  # (num_graphs,)
        arange = jnp.arange(parents.shape[0])
        parents_and_variable = parents.at[arange, variables].set(True)

        factor = self.num_samples + self.alpha_w - self.num_variables + num_parents

        log_term_r = 0.5 * factor * _logdet(r_matrix, parents) - 0.5 * (factor + 1) * _logdet(
            r_matrix, parents_and_variable
        )

        return log_gamma_term[num_parents] + log_term_r

init(rng_key, dummy_state)

Initialize the likelihood. Args: - rng_key: chex.PRNGKey, random key - dummy_state: DAGEnvState, shape [1, ...], a dummy state

Returns: - Precomputed BGeScore parameters

Source code in gfnx/reward/dag_likelihood.py
212
213
214
215
216
217
218
219
220
221
222
def init(self, rng_key: chex.PRNGKey, dummy_state: DAGEnvState) -> BGeScoreParams:
    """
    Initialize the likelihood.
    Args:
    - rng_key: chex.PRNGKey, random key
    - dummy_state: DAGEnvState, shape [1, ...], a dummy state

    Returns:
    - Precomputed BGeScore parameters
    """
    return self.BGeScoreParams(r_matrix=self.r_matrix, log_gamma_term=self.log_gamma_term)

BaseDAGLikelihood

Source code in gfnx/reward/dag_likelihood.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class BaseDAGLikelihood:
    def init(self, rng_key: chex.PRNGKey, dummy_state: DAGEnvState) -> TRewardParams:
        """Initialize the likelihood. Default implementation returns None.

        Args:
        - rng_key: chex.PRNGKey, random key
        - dummy_state: DAGEnvState, shape [1, ...], a dummy state

        """
        return None

    def log_prob(self, state: DAGEnvState, env_params: DAGEnvParams) -> TLogReward:
        """Computes the log-likelihood of the data given the state - graph G:

            log P(D | G) = sum_j LocalScore(X_j | Pa_G(X_j))

        Args:
        - state: DAGEnvState, shape [B, ...], batch of states
        - env_params: DAGEnvParams, params of environment,
          always includes reward params

        Returns:
        - TLogReward, shape [B], batch of log-likelihoods
        """
        num_graphs, num_variables = state.adjacency_matrix.shape[:2]
        adjacency_matrix = state.adjacency_matrix.transpose(0, 2, 1)
        parents = adjacency_matrix.reshape(-1, num_variables)
        variables = jnp.tile(jnp.arange(num_variables), num_graphs)
        likelihood_params = env_params.reward_params.likelihood_params

        local_scores = self._local_score(variables, parents, likelihood_params)
        local_scores = local_scores.reshape(num_graphs, num_variables)
        return jnp.sum(local_scores, axis=1)  # [B]

    def delta_score(
        self,
        state: DAGEnvState,
        action: TAction,
        next_state: DAGEnvState,
        env_params: DAGEnvParams,
    ) -> TLogReward:
        """Computes the delta-score for adding an edge X_i -> X_j to some grpah
        G, for a specific choice of local score. The delta-score is given by:

            LocalScore(X_j | Pa_G(X_j) U X_i) - LocalScore(X_j | Pa_G(X_j))

        Args:
        - state: DAGEnvState, shape [B, ...], batch of states
        - action: DAGEnvAction, shape [B], batch of actions
        - next_state: DAGEnvState, shape [B, ...], batch of next states
        - env_params: DAGEnvParams, params of environment,
            always includes reward params

        Returns:
        - TLogReward, shape [B], batch of delta-scores
        """
        arange = jnp.arange(state.is_pad.shape[0])  # [B]
        _source, target = jnp.divmod(action, env_params.num_variables)
        parents = state.adjacency_matrix[arange, :, target]  # [B, num_variables]
        next_parents = next_state.adjacency_matrix[arange, :, target]  # [B, num_variables]
        likelihood_params = env_params.reward_params.likelihood_params
        return self._local_score(target, next_parents, likelihood_params) - self._local_score(
            target, parents, likelihood_params
        )

    def _local_score(
        self,
        variables: chex.Array,
        parents: chex.Array,
        likelihood_params: TRewardParams,
    ) -> TLogReward:
        """Computes the local score LocalScore(X_j | Pa_G(X_j)).

        Args:
        - variables: chex.Array, shape [B], batch of variables
        - parents: chex.Array, shape [B, num_variables], batched mask of parents
        - likelihood_params: TRewardParams, params of likelihood

        Returns:
        - TLogReward, shape [B], batch of local scores
        """
        raise NotImplementedError

delta_score(state, action, next_state, env_params)

Computes the delta-score for adding an edge X_i -> X_j to some grpah G, for a specific choice of local score. The delta-score is given by:

LocalScore(X_j | Pa_G(X_j) U X_i) - LocalScore(X_j | Pa_G(X_j))

Args: - state: DAGEnvState, shape [B, ...], batch of states - action: DAGEnvAction, shape [B], batch of actions - next_state: DAGEnvState, shape [B, ...], batch of next states - env_params: DAGEnvParams, params of environment, always includes reward params

Returns: - TLogReward, shape [B], batch of delta-scores

Source code in gfnx/reward/dag_likelihood.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def delta_score(
    self,
    state: DAGEnvState,
    action: TAction,
    next_state: DAGEnvState,
    env_params: DAGEnvParams,
) -> TLogReward:
    """Computes the delta-score for adding an edge X_i -> X_j to some grpah
    G, for a specific choice of local score. The delta-score is given by:

        LocalScore(X_j | Pa_G(X_j) U X_i) - LocalScore(X_j | Pa_G(X_j))

    Args:
    - state: DAGEnvState, shape [B, ...], batch of states
    - action: DAGEnvAction, shape [B], batch of actions
    - next_state: DAGEnvState, shape [B, ...], batch of next states
    - env_params: DAGEnvParams, params of environment,
        always includes reward params

    Returns:
    - TLogReward, shape [B], batch of delta-scores
    """
    arange = jnp.arange(state.is_pad.shape[0])  # [B]
    _source, target = jnp.divmod(action, env_params.num_variables)
    parents = state.adjacency_matrix[arange, :, target]  # [B, num_variables]
    next_parents = next_state.adjacency_matrix[arange, :, target]  # [B, num_variables]
    likelihood_params = env_params.reward_params.likelihood_params
    return self._local_score(target, next_parents, likelihood_params) - self._local_score(
        target, parents, likelihood_params
    )

init(rng_key, dummy_state)

Initialize the likelihood. Default implementation returns None.

Args: - rng_key: chex.PRNGKey, random key - dummy_state: DAGEnvState, shape [1, ...], a dummy state

Source code in gfnx/reward/dag_likelihood.py
12
13
14
15
16
17
18
19
20
def init(self, rng_key: chex.PRNGKey, dummy_state: DAGEnvState) -> TRewardParams:
    """Initialize the likelihood. Default implementation returns None.

    Args:
    - rng_key: chex.PRNGKey, random key
    - dummy_state: DAGEnvState, shape [1, ...], a dummy state

    """
    return None

log_prob(state, env_params)

Computes the log-likelihood of the data given the state - graph G:

log P(D | G) = sum_j LocalScore(X_j | Pa_G(X_j))

Args: - state: DAGEnvState, shape [B, ...], batch of states - env_params: DAGEnvParams, params of environment, always includes reward params

Returns: - TLogReward, shape [B], batch of log-likelihoods

Source code in gfnx/reward/dag_likelihood.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def log_prob(self, state: DAGEnvState, env_params: DAGEnvParams) -> TLogReward:
    """Computes the log-likelihood of the data given the state - graph G:

        log P(D | G) = sum_j LocalScore(X_j | Pa_G(X_j))

    Args:
    - state: DAGEnvState, shape [B, ...], batch of states
    - env_params: DAGEnvParams, params of environment,
      always includes reward params

    Returns:
    - TLogReward, shape [B], batch of log-likelihoods
    """
    num_graphs, num_variables = state.adjacency_matrix.shape[:2]
    adjacency_matrix = state.adjacency_matrix.transpose(0, 2, 1)
    parents = adjacency_matrix.reshape(-1, num_variables)
    variables = jnp.tile(jnp.arange(num_variables), num_graphs)
    likelihood_params = env_params.reward_params.likelihood_params

    local_scores = self._local_score(variables, parents, likelihood_params)
    local_scores = local_scores.reshape(num_graphs, num_variables)
    return jnp.sum(local_scores, axis=1)  # [B]

LinearGaussianScore

Bases: BaseDAGLikelihood

Source code in gfnx/reward/dag_likelihood.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class LinearGaussianScore(BaseDAGLikelihood):
    @chex.dataclass(frozen=True)
    class LinearGaussianScoreParams:
        data: chex.Array

    def __init__(
        self,
        data: chex.Array,
        prior_mean: float = 0.0,
        prior_scale: float = 1.0,
        obs_scale: float = math.sqrt(0.1),
    ):
        self.data = data
        self.prior_mean = prior_mean
        self.prior_scale = prior_scale
        self.obs_scale = obs_scale

    def init(self, rng_key: chex.PRNGKey, dummy_state: DAGEnvState) -> LinearGaussianScoreParams:
        """
        Initialize the likelihood.
        Args:
        - rng_key: chex.PRNGKey, random key
        - dummy_state: DAGEnvState, shape [1, ...], a dummy state

         Returns:
        - Loaded samples as a LinearGaussianScoreParams
        """
        return self.LinearGaussianScoreParams(data=self.data)

    def _local_score(
        self,
        variables: chex.Array,
        parents: chex.Array,
        likelihood_params: LinearGaussianScoreParams,
    ) -> TLogReward:
        data = likelihood_params.data
        num_samples, num_variables = data.shape
        masked_data = data * parents[:, jnp.newaxis]

        means = self.prior_mean * jnp.sum(masked_data, axis=2)
        diffs = (data[:, variables].T - means) / self.obs_scale
        y_matrix = self.prior_scale * jnp.matmul(diffs[:, None], masked_data)
        y = jnp.squeeze(y_matrix, axis=1)
        sigma_matrix = (self.obs_scale**2) * jnp.eye(num_variables) + (
            self.prior_scale**2
        ) * jnp.matmul(masked_data.transpose(0, 2, 1), masked_data)
        term1 = jnp.sum(diffs**2, axis=1)
        term2 = -jnp.sum(y * jnp.linalg.solve(sigma_matrix, y[..., None])[..., 0], axis=1)
        _, term3 = jnp.linalg.slogdet(sigma_matrix)
        term4 = 2 * (num_samples - num_variables) * jnp.log(self.obs_scale)
        term5 = num_samples * jnp.log(2 * jnp.pi)

        return -0.5 * (term1 + term2 - term3 - term4 - term5)

init(rng_key, dummy_state)

Initialize the likelihood. Args: - rng_key: chex.PRNGKey, random key - dummy_state: DAGEnvState, shape [1, ...], a dummy state

Returns: - Loaded samples as a LinearGaussianScoreParams

Source code in gfnx/reward/dag_likelihood.py
128
129
130
131
132
133
134
135
136
137
138
def init(self, rng_key: chex.PRNGKey, dummy_state: DAGEnvState) -> LinearGaussianScoreParams:
    """
    Initialize the likelihood.
    Args:
    - rng_key: chex.PRNGKey, random key
    - dummy_state: DAGEnvState, shape [1, ...], a dummy state

     Returns:
    - Loaded samples as a LinearGaussianScoreParams
    """
    return self.LinearGaussianScoreParams(data=self.data)

Final reward for enviornment

DAGRewardModule

Bases: BaseRewardModule[DAGEnvState, DAGEnvParams]

Reward module for directed acyclic graph (DAG) structures. The reward is defined as the product of a prior over DAGs and a likelihood of data given the DAG.

Source code in gfnx/reward/dag.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@chex.dataclass
class DAGRewardModule(BaseRewardModule[DAGEnvState, DAGEnvParams]):
    '''
    Reward module for directed acyclic graph (DAG) structures.
    The reward is defined as the product of a prior over DAGs and a likelihood
    of data given the DAG.
    '''
    prior: BaseDAGPrior
    likelihood: BaseDAGLikelihood

    def init(self, rng_key: chex.PRNGKey, dummy_state: DAGEnvState) -> DAGRewardParams:
        _, prior_key, likelihood_key = jax.random.split(rng_key, 3)
        return DAGRewardParams(
            prior_params=self.prior.init(prior_key, dummy_state),
            likelihood_params=self.likelihood.init(likelihood_key, dummy_state),
        )

    def reward(self, state: DAGEnvState, env_params: DAGEnvParams) -> TReward:
        return jnp.exp(self.log_reward(state, env_params))

    def log_reward(self, state: DAGEnvState, env_params: DAGEnvParams) -> TLogReward:
        return self.likelihood.log_prob(state, env_params) + self.prior.log_prob(state, env_params)

    def delta_score(
        self,
        state: DAGEnvState,
        action: TAction,
        next_state: DAGEnvState,
        env_params: DAGEnvParams,
    ) -> TLogReward:
        return self.prior.delta_score(
            state, action, next_state, env_params
        ) + self.likelihood.delta_score(state, action, next_state, env_params)