Skip to content

Bayesian Networks Structure Learning

This environment models Bayesian structure learning with Generative Flow Networks, following Deleu et al. (2022). We search over directed acyclic graphs (DAGs) that connect variables \(\{X_1, \dots, X_d\}\) and learn a sampler whose terminal distribution matches the posterior \(P(G \mid \mathcal{D})\) for a dataset \(\mathcal{D}\). The environment exposes the same incremental edge-building process used in classic score-based structure learning, making it a strong benchmark for objectives that must respect structural constraints and exploit factorised rewards.

Intuition

  • State – the current DAG under construction together with the transitive closure of its transpose. The initial state is the empty graph. Every intermediate state already corresponds to a valid (possibly sparse) DAG.
  • Action – choose an ordered pair of nodes \((i, j)\) and add the edge \(X_i \rightarrow X_j\), unless doing so would create a cycle. A dedicated stop action terminates the trajectory and yields the current graph.
  • Observation – the adjacency matrix of the partially built DAG. Policies typically consume this matrix directly or embed it with a graph neural network.

Because any state can be treated as terminal, the environment naturally supports Modified Detailed Balance (MDB) training as well as more standard objectives.

Reward structure and modularity

The terminal reward is proportional to the posterior

$$ \log R(G) = \log P(\mathcal{D} \mid G) + \log P(G), $$ where we assume i.i.d., fully observed samples and a uniform prior over structures. GFNX ships two choices of marginal likelihoods \(P(\mathcal{D} \mid G)\):

  • LinearGaussianScore implements the linear-Gaussian model of Deleu et al. (2022), using data stored on disk.
  • BGeScore implements the Bayesian Gaussian equivalent score of Geiger and Heckerman (1994).

Both scores, as well as the uniform prior, satisfy the modularity property:

$$ \log R(G) = \sum_{j=1}^{d} \mathrm{LocalScore}\left(X_j \mid \mathrm{Pa}_{G}(X_j)\right) $$ When an edge \(X_i \rightarrow X_j\) is added, only the term for \(X_j\) changes, so the delta score

$$ \Delta_j = \mathrm{LocalScore}\left(X_j \mid \mathrm{Pa}_G(X_j) \cup {X_i}\right) - \mathrm{LocalScore}\left(X_j \mid \mathrm{Pa}_G(X_j)\right) $$ is cheap to evaluate. MDB relies on this delta to update flows efficiently, so the environment exposes it via the reward module.

Dataset generation

For experiments we follow the linear-Gaussian setting: ground-truth DAGs \(G^*\) are sampled from an Erdős–Rényi model with expected in-degree \(1\). Given \(G^*\), synthetic data are produced by ancestral sampling with weights \(w_{ij} \sim \mathcal{N}(0, 1)\) and fixed noise variance \(\sigma_j^2 = 0.1\). Each dataset contains \(100\) samples and we generate \(20\) such datasets (and associated true graphs) for evaluation. Utilities in gfnx.utils.dag make it easy to replicate this pipeline or load pre-generated CSV files.

Efficient action masks

To enforce acyclicity online, the environment maintains two matrices per state: the current adjacency matrix and the transitive closure of the transpose DAG. When an edge \((u \rightarrow v)\) is added, the closure is updated with a single outer-product operation that marks every node that can now reach \(u\) or be reached from \(v\). The resulting mask excludes (i) existing edges and (ii) edges that would complete a cycle, providing \(O(d^2)\) updates without resorting to repeated graph traversals.

Quickstart

from pathlib import Path

import jax
import gfnx
from gfnx.reward.dag import DAGRewardModule
from gfnx.reward.dag_likelihood import LinearGaussianScore
from gfnx.reward.dag_prior import UniformDAGPrior

num_variables = 5
data_path = Path("datasets/dag/train_data.csv")

likelihood = LinearGaussianScore(data_path=str(data_path))
prior = UniformDAGPrior(num_variables=num_variables)
reward = DAGRewardModule(prior=prior, likelihood=likelihood)
env = gfnx.DAGEnvironment(reward_module=reward, num_variables=num_variables)
params = env.init(jax.random.PRNGKey(0))

obs, state = env.reset(num_envs=1, env_params=params)

Set num_envs > 1 to explore multiple graphs in parallel. For num_variables < 6 the environment can enumerate every DAG.

API references:

References

  • Deleu, T. et al. (2022). Bayesian Structure Learning with Generative Flow Networks. Proceedings of the 38th Conference on Uncertainty in Artificial Intelligence (UAI).