Skip to content

Reward API Reference

PhyloTreeRewardModule

Bases: BaseRewardModule[PhyloTreeEnvState, PhyloTreeEnvParams]

Reward module for phylogenetic trees using exponential reward function. R(x) = exp((offset - total_mutations) / scale)

Source code in gfnx/reward/phylogenetic_tree.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class PhyloTreeRewardModule(BaseRewardModule[PhyloTreeEnvState, PhyloTreeEnvParams]):
    """
    Reward module for phylogenetic trees using exponential reward function.
    R(x) = exp((offset - total_mutations) / scale)
    """

    def __init__(self, num_nodes: int, scale: float = 1.0, C: float = 0.0):
        self.num_nodes = num_nodes
        self.scale = scale
        self.C = C

        # TODO: check delta score in original paper
        self._offset = (C / scale) / num_nodes

    def init(self, rng_key: chex.PRNGKey, dummy_state: PhyloTreeEnvState) -> TRewardParams:
        """Initialize reward parameters"""
        return {}  # No parameters for this reward

    def _get_mutations(
        self,
        state: PhyloTreeEnvState,
    ) -> chex.Array:
        """Compute total mutations in the tree"""

        def compute_mutations(carry, node_idx):
            is_tree = jnp.logical_and(
                state.to_leaf[node_idx] != -1,  # the node is constructed
                state.to_leaf[node_idx] != node_idx,  # the node is not a leaf
            )
            mutations = jnp.where(
                is_tree,
                jnp.sum(
                    state.sequences[state.left_child[node_idx]]
                    & state.sequences[state.right_child[node_idx]]
                    == 0
                ),
                0.0,
            )
            return carry + mutations, None

        return jax.lax.scan(compute_mutations, 0.0, jnp.arange(2 * self.num_nodes - 1))[0]

    def delta_score(self, state: PhyloTreeEnvState) -> TReward:
        """Compute delta score"""

        def _single_delta_score(state):
            mutations = jnp.sum(
                state.sequences[state.left_child[state.length - 1]]
                & state.sequences[state.right_child[state.length - 1]]
                == 0
            )
            return (self.C / self.num_nodes - mutations) / self.scale

        return jax.vmap(_single_delta_score)(state)

    def log_reward(
        self,
        state: PhyloTreeEnvState,
        env_params: PhyloTreeEnvParams,
    ) -> TLogReward:
        """Compute log reward: (C - total_mutations) / scale"""

        def _single_log_reward(state):
            total_mutations = self._get_mutations(state)
            return (self.C - total_mutations) / self.scale

        return jax.vmap(_single_log_reward)(state)

    def reward(
        self,
        state: PhyloTreeEnvState,
        env_params: PhyloTreeEnvParams,
    ) -> TReward:
        """Compute reward: exp((C - total_mutations) / scale)"""
        return jnp.exp(self.log_reward(state, env_params))

delta_score(state)

Compute delta score

Source code in gfnx/reward/phylogenetic_tree.py
51
52
53
54
55
56
57
58
59
60
61
62
def delta_score(self, state: PhyloTreeEnvState) -> TReward:
    """Compute delta score"""

    def _single_delta_score(state):
        mutations = jnp.sum(
            state.sequences[state.left_child[state.length - 1]]
            & state.sequences[state.right_child[state.length - 1]]
            == 0
        )
        return (self.C / self.num_nodes - mutations) / self.scale

    return jax.vmap(_single_delta_score)(state)

init(rng_key, dummy_state)

Initialize reward parameters

Source code in gfnx/reward/phylogenetic_tree.py
23
24
25
def init(self, rng_key: chex.PRNGKey, dummy_state: PhyloTreeEnvState) -> TRewardParams:
    """Initialize reward parameters"""
    return {}  # No parameters for this reward

log_reward(state, env_params)

Compute log reward: (C - total_mutations) / scale

Source code in gfnx/reward/phylogenetic_tree.py
64
65
66
67
68
69
70
71
72
73
74
75
def log_reward(
    self,
    state: PhyloTreeEnvState,
    env_params: PhyloTreeEnvParams,
) -> TLogReward:
    """Compute log reward: (C - total_mutations) / scale"""

    def _single_log_reward(state):
        total_mutations = self._get_mutations(state)
        return (self.C - total_mutations) / self.scale

    return jax.vmap(_single_log_reward)(state)

reward(state, env_params)

Compute reward: exp((C - total_mutations) / scale)

Source code in gfnx/reward/phylogenetic_tree.py
77
78
79
80
81
82
83
def reward(
    self,
    state: PhyloTreeEnvState,
    env_params: PhyloTreeEnvParams,
) -> TReward:
    """Compute reward: exp((C - total_mutations) / scale)"""
    return jnp.exp(self.log_reward(state, env_params))