Skip to content

Phylogenetic Trees Environment

This environment follows the PhyloGFN formulation of Zhou et al. (2024): we represent the task of inferring a rooted binary phylogenetic tree as a sequential decision process. Each node corresponds to a candidate tree topology built by repeatedly merging species, making the environment a natural stress test for GFlowNet objectives on structured combinatorial spaces.

Intuition

  • State – a forest of partially built trees. The initial state \(s_0\) contains \(n\) singleton trees, one per species; intermediate states keep track of the current forest plus the merge history required to compute rewards.
  • Action – pick two roots from the forest and combine them under a common ancestor. After \(n-1\) merges you obtain a single rooted binary tree spanning all species and the episode terminates.
  • Observation – one-hot Fitch features derived from the binary-encoded DNA (or RNA) sequences that appear at each root. These features are suitable for transformer-style policies and match the setup in the original PhyloGFN work.

Reward structure

For a terminating tree \(T\) with parsimony score \(M(T)\) (the minimum number of mutations required to explain the observed species), the raw reward follows a Gibbs distribution:

$$ R(T) = \exp\left(- \frac{M(T)}{\alpha} \right), $$ where \(\alpha\) is a temperature hyperparameter. For training stability we recenter this expression using a dataset-specific constant \(C\):

$$ R(T) = \exp\left(\frac{C - M(T)}{\alpha}\right). $$ This keeps rewards in a numerically friendly range while preserving the ranking over tree topologies. Following Deleu et al. (2024) we set \(\alpha = 4\) and choose \(C\) per dataset:

  • DS1: \(C = 5800\)
  • DS2: \(C = 8000\)
  • DS3: \(C = 8800\)
  • DS4: \(C = 3500\)
  • DS5: \(C = 2300\)
  • DS6: \(C = 2300\)
  • DS7: \(C = 12500\)
  • DS8: \(C = 2800\)

You can supply your own scaling by constructing PhyloTreeRewardModule with custom C and scale values.

Loading datasets and building the environment

Datasets DS1–DS8 ship with the library in JSON form. Use gfnx.utils.get_phylo_initialization_args to retrieve the encoded sequences and reward parameters:

from pathlib import Path

import jax
import gfnx
from gfnx.utils import get_phylo_initialization_args

data_dir = Path("path/to/phylo_datasets")
env_kwargs, reward_kwargs = get_phylo_initialization_args("DS1", data_dir)

reward = gfnx.PhyloTreeRewardModule(**reward_kwargs)
env = gfnx.PhyloTreeEnvironment(reward_module=reward, **env_kwargs)
params = env.init(jax.random.PRNGKey(0))

obs, state = env.reset(num_envs=1, env_params=params)

Just like other GFNX environments, PhyloTreeEnvironment is fully vectorised: set num_envs > 1 to roll out multiple forests in parallel. When a trajectory terminates the returned log_reward corresponds to the expression above.

API references:

References

  • Zhou, M. et al. (2024). PhyloGFN: Phylogenetic inference with generative flow networks. The Twelfth International Conference on Learning Representations (ICLR).
  • Deleu, T. et al. (2024). Discrete Probabilistic Inference as Control in Multi-path Environments. Proceedings of the 40th Conference on Uncertainty in Artificial Intelligence (UAI).