Bases: BaseRewardModule[PhyloTreeEnvState, PhyloTreeEnvParams]
Reward module for phylogenetic trees using exponential reward function.
R(x) = exp((offset - total_mutations) / scale)
Source code in gfnx/reward/phylogenetic_tree.py
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83 | class PhyloTreeRewardModule(BaseRewardModule[PhyloTreeEnvState, PhyloTreeEnvParams]):
"""
Reward module for phylogenetic trees using exponential reward function.
R(x) = exp((offset - total_mutations) / scale)
"""
def __init__(self, num_nodes: int, scale: float = 1.0, C: float = 0.0):
self.num_nodes = num_nodes
self.scale = scale
self.C = C
# TODO: check delta score in original paper
self._offset = (C / scale) / num_nodes
def init(self, rng_key: chex.PRNGKey, dummy_state: PhyloTreeEnvState) -> TRewardParams:
"""Initialize reward parameters"""
return {} # No parameters for this reward
def _get_mutations(
self,
state: PhyloTreeEnvState,
) -> chex.Array:
"""Compute total mutations in the tree"""
def compute_mutations(carry, node_idx):
is_tree = jnp.logical_and(
state.to_leaf[node_idx] != -1, # the node is constructed
state.to_leaf[node_idx] != node_idx, # the node is not a leaf
)
mutations = jnp.where(
is_tree,
jnp.sum(
state.sequences[state.left_child[node_idx]]
& state.sequences[state.right_child[node_idx]]
== 0
),
0.0,
)
return carry + mutations, None
return jax.lax.scan(compute_mutations, 0.0, jnp.arange(2 * self.num_nodes - 1))[0]
def delta_score(self, state: PhyloTreeEnvState) -> TReward:
"""Compute delta score"""
def _single_delta_score(state):
mutations = jnp.sum(
state.sequences[state.left_child[state.length - 1]]
& state.sequences[state.right_child[state.length - 1]]
== 0
)
return (self.C / self.num_nodes - mutations) / self.scale
return jax.vmap(_single_delta_score)(state)
def log_reward(
self,
state: PhyloTreeEnvState,
env_params: PhyloTreeEnvParams,
) -> TLogReward:
"""Compute log reward: (C - total_mutations) / scale"""
def _single_log_reward(state):
total_mutations = self._get_mutations(state)
return (self.C - total_mutations) / self.scale
return jax.vmap(_single_log_reward)(state)
def reward(
self,
state: PhyloTreeEnvState,
env_params: PhyloTreeEnvParams,
) -> TReward:
"""Compute reward: exp((C - total_mutations) / scale)"""
return jnp.exp(self.log_reward(state, env_params))
|
delta_score(state)
Compute delta score
Source code in gfnx/reward/phylogenetic_tree.py
51
52
53
54
55
56
57
58
59
60
61
62 | def delta_score(self, state: PhyloTreeEnvState) -> TReward:
"""Compute delta score"""
def _single_delta_score(state):
mutations = jnp.sum(
state.sequences[state.left_child[state.length - 1]]
& state.sequences[state.right_child[state.length - 1]]
== 0
)
return (self.C / self.num_nodes - mutations) / self.scale
return jax.vmap(_single_delta_score)(state)
|
init(rng_key, dummy_state)
Initialize reward parameters
Source code in gfnx/reward/phylogenetic_tree.py
| def init(self, rng_key: chex.PRNGKey, dummy_state: PhyloTreeEnvState) -> TRewardParams:
"""Initialize reward parameters"""
return {} # No parameters for this reward
|
log_reward(state, env_params)
Compute log reward: (C - total_mutations) / scale
Source code in gfnx/reward/phylogenetic_tree.py
64
65
66
67
68
69
70
71
72
73
74
75 | def log_reward(
self,
state: PhyloTreeEnvState,
env_params: PhyloTreeEnvParams,
) -> TLogReward:
"""Compute log reward: (C - total_mutations) / scale"""
def _single_log_reward(state):
total_mutations = self._get_mutations(state)
return (self.C - total_mutations) / self.scale
return jax.vmap(_single_log_reward)(state)
|
reward(state, env_params)
Compute reward: exp((C - total_mutations) / scale)
Source code in gfnx/reward/phylogenetic_tree.py
| def reward(
self,
state: PhyloTreeEnvState,
env_params: PhyloTreeEnvParams,
) -> TReward:
"""Compute reward: exp((C - total_mutations) / scale)"""
return jnp.exp(self.log_reward(state, env_params))
|