Skip to content

Reward API Reference

Bit Sequence reward

Reward functions used for hypergrid environment

BitseqRewardModule

Bases: BaseRewardModule[BitseqEnvState, BitseqEnvParams]

Source code in gfnx/reward/bitseq.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class BitseqRewardModule(BaseRewardModule[BitseqEnvState, BitseqEnvParams]):
    def __init__(
        self,
        sentence_len: int = 120,
        k: int = 8,
        mode_set_size: int = 60,
        reward_exponent: float = 1.0,
    ):
        """
        General reward function for bitseqs
        """
        self.block_len = 8
        self.block_set = jnp.array(
            [
                [0, 0, 0, 0, 0, 0, 0, 0],
                [1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 1, 1, 1],
                [0, 0, 1, 1, 1, 1, 0, 0],
            ],
            dtype=jnp.bool,
        )
        self.sentence_len = sentence_len
        self.k = k
        assert sentence_len % self.block_len == 0
        self.mode_set_size = mode_set_size
        self.reward_exponent = reward_exponent

    def init(self, rng_key: chex.PRNGKey, dummy_state: BitseqEnvState) -> TRewardParams:
        return {
            "mode_set": construct_mode_set(
                self.sentence_len,
                self.block_len,
                self.block_set,
                self.mode_set_size,
                rng_key,
            )
        }

    def _mode_set_distance(self, s: chex.Array, mode_set: chex.Array):
        distances = jax.vmap(lambda ms: hamming_distance(s, ms))(mode_set)
        return jnp.min(distances)

    def log_reward(self, state: BitseqEnvState, env_params: BitseqEnvParams) -> TLogReward:
        def single_log_reward(tokens: chex.Array, reward_params: TRewardParams):
            bitseq = detokenize(tokens, self.k)
            mode_dist = self._mode_set_distance(bitseq, reward_params["mode_set"])
            return -self.reward_exponent * mode_dist.astype(jnp.float32) / bitseq.shape[0]

        return jax.vmap(single_log_reward, in_axes=(0, None))(
            state.tokens, env_params.reward_params
        )

    def reward(self, state: BitseqEnvState, env_params: BitseqEnvParams) -> TReward:
        return jnp.exp(self.log_reward(state, env_params))

__init__(sentence_len=120, k=8, mode_set_size=60, reward_exponent=1.0)

General reward function for bitseqs

Source code in gfnx/reward/bitseq.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def __init__(
    self,
    sentence_len: int = 120,
    k: int = 8,
    mode_set_size: int = 60,
    reward_exponent: float = 1.0,
):
    """
    General reward function for bitseqs
    """
    self.block_len = 8
    self.block_set = jnp.array(
        [
            [0, 0, 0, 0, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 1],
            [1, 1, 1, 1, 0, 0, 0, 0],
            [0, 0, 0, 0, 1, 1, 1, 1],
            [0, 0, 1, 1, 1, 1, 0, 0],
        ],
        dtype=jnp.bool,
    )
    self.sentence_len = sentence_len
    self.k = k
    assert sentence_len % self.block_len == 0
    self.mode_set_size = mode_set_size
    self.reward_exponent = reward_exponent

TFBind-8 reward

Reward functions used for TFBind-8 environment.

TFBind8RewardModule

Bases: BaseRewardModule[TFBind8EnvState, TFBind8EnvParams]

Source code in gfnx/reward/tfbind.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class TFBind8RewardModule(BaseRewardModule[TFBind8EnvState, TFBind8EnvParams]):
    def __init__(
        self,
        nchar: int = 4,
        max_length: int = 8,
        min_reward: float = 1e-3,
        reward_exponent: float = 3.0,
        reward_scale: float = 10.0,
    ):
        """
        TODO: Add description
        """
        self.nchar = nchar
        self.max_length = max_length
        self.min_reward = min_reward
        self.reward_exponent = reward_exponent
        self.reward_scale = reward_scale

    def init(self, rng_key: chex.PRNGKey, dummy_state: TFBind8EnvState) -> None:
        # Make a full loop to get the values for all possible states

        # Generate all possible values of characters
        values = list(range(self.nchar))
        # Generate all possible arrays
        all_states = np.array(list(itertools.product(values, repeat=self.max_length)))

        # Source: https://github.com/maxwshen/gflownet/blob/main/datasets/tfbind8/tfbind8-exact-v0-all.pkl
        with open("proxy/weights/tfbind/tfbind8-exact-v0-all.pkl", "rb") as f:
            oracle_d = pickle.load(f)
        oracle = {tuple(x): float(y[0]) for x, y in zip(oracle_d["x"], oracle_d["y"])}

        values_raw = jnp.array([oracle[tuple(state)] for state in all_states])

        # Normalization as in https://github.com/maxwshen/gflownet/blob/main/exps/tfbind8/tfbind8_oracle.py
        values = jnp.pow(values_raw, self.reward_exponent)
        values = values * self.reward_scale / values.max()
        values = jnp.clip(values, min=self.min_reward)
        return {"rewards": values}  # Dict with all possible values

    def reward(self, state: TFBind8EnvState, env_params: TFBind8EnvParams) -> TReward:
        tokens = state.tokens
        powers_array = jnp.array([
            self.nchar ** (self.max_length - i - 1) for i in range(self.max_length)
        ])
        indices = jnp.sum(tokens * powers_array, axis=-1)
        return jnp.take_along_axis(
            env_params.reward_params["rewards"],
            indices,
            axis=0,
            mode="fill",
            fill_value=self.min_reward,
        )

    def log_reward(self, state: TFBind8EnvState, env_params: TFBind8EnvParams) -> TLogReward:
        return jnp.log(self.reward(state, env_params))

__init__(nchar=4, max_length=8, min_reward=0.001, reward_exponent=3.0, reward_scale=10.0)

TODO: Add description

Source code in gfnx/reward/tfbind.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def __init__(
    self,
    nchar: int = 4,
    max_length: int = 8,
    min_reward: float = 1e-3,
    reward_exponent: float = 3.0,
    reward_scale: float = 10.0,
):
    """
    TODO: Add description
    """
    self.nchar = nchar
    self.max_length = max_length
    self.min_reward = min_reward
    self.reward_exponent = reward_exponent
    self.reward_scale = reward_scale

QM9 Small reward

Reward functions used for QM9Small environment.

QM9SmallRewardModule

Bases: BaseRewardModule[QM9SmallEnvState, QM9SmallEnvParams]

Source code in gfnx/reward/qm9_small.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class QM9SmallRewardModule(
    BaseRewardModule[QM9SmallEnvState, QM9SmallEnvParams]
):
    def __init__(
        self,
        nchar: int = 11,
        max_length: int = 5,
        min_reward: float = 1e-3,
        reward_exponent: float = 5.0,
        reward_scale: float = 100.0,
    ):
        """
        TODO: Add description
        """
        self.nchar = nchar
        self.max_length = max_length
        self.min_reward = min_reward
        self.reward_exponent = reward_exponent
        self.reward_scale = reward_scale

    def init(
        self, rng_key: chex.PRNGKey, dummy_state: QM9SmallEnvState
    ) -> None:
        # Source: https://github.com/maxwshen/gflownet/blob/main/datasets/qm9str/block_qm9str_v1_s5.pkl
        with open('proxy/weights/qm9_small/block_qm9str_v1_s5.pkl', 'rb') as f:
            oracle_d = pickle.load(f)
        oracle = {tuple(x): float(y) for x, y in oracle_d.items()}

        values_raw = jnp.array(list(oracle.values()))

        # Normalization as in https://github.com/maxwshen/gflownet/blob/main/exps/qm9str/qm9str.py
        values = jnp.clip(values_raw, min=self.min_reward)
        values = jnp.pow(values, self.reward_exponent)
        values = (values * self.reward_scale / values.max())
        return {"rewards": values} 

    def reward(
        self, state: QM9SmallEnvState, env_params: QM9SmallEnvParams
    ) -> TReward:
        tokens = state.tokens
        powers_array = jnp.array([
            self.nchar ** (self.max_length - i - 1)
            for i in range(self.max_length)
        ])
        indices = jnp.sum(tokens * powers_array, axis=-1)
        return jnp.take_along_axis(
            env_params.reward_params["rewards"],
            indices,
            axis=0,
            mode="fill",
            fill_value=self.min_reward,
        )

    def log_reward(
        self, state: QM9SmallEnvState, env_params: QM9SmallEnvParams
    ) -> TLogReward:
        return jnp.log(self.reward(state, env_params))

__init__(nchar=11, max_length=5, min_reward=0.001, reward_exponent=5.0, reward_scale=100.0)

TODO: Add description

Source code in gfnx/reward/qm9_small.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def __init__(
    self,
    nchar: int = 11,
    max_length: int = 5,
    min_reward: float = 1e-3,
    reward_exponent: float = 5.0,
    reward_scale: float = 100.0,
):
    """
    TODO: Add description
    """
    self.nchar = nchar
    self.max_length = max_length
    self.min_reward = min_reward
    self.reward_exponent = reward_exponent
    self.reward_scale = reward_scale

AMP reward

EqxProxyAMPRewardModule

Bases: BaseRewardModule[AMPEnvState, AMPEnvParams]

Source code in gfnx/reward/amp.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class EqxProxyAMPRewardModule(BaseRewardModule[AMPEnvState, AMPEnvParams]):
    def __init__(
        self,
        proxy_config_path: str,
        pretrained_proxy_path: str,
        reward_exponent: float = 1.0,
        min_reward: float = 1e-6,
    ):
        """
        Proxy reward model for amp
        """
        # Load config to a proxy model
        import omegaconf

        cfg = omegaconf.OmegaConf.load(proxy_config_path)
        self.network_params = omegaconf.OmegaConf.to_container(cfg.network)
        self.pretrained_proxy_path = pretrained_proxy_path
        if not os.path.isabs(self.pretrained_proxy_path):
            # Assume that the path is relative to the root of the project
            module_path = os.path.abspath(
                os.path.join(os.path.dirname(__file__), "..", "..", "..")
            )
            self.pretrained_proxy_path = os.path.join(module_path, self.pretrained_proxy_path)

        self.reward_exponent = reward_exponent
        self.min_reward = min_reward

    def init(self, rng_key: chex.PRNGKey, dummy_state: AMPEnvState) -> TRewardParams:
        # Lazy imports to avoid importing equinox in the main module
        import equinox as eqx
        import orbax.checkpoint as ocp

        from ..networks.reward_models import EqxTransformerRewardModel

        ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
        model = EqxTransformerRewardModel(
            encoder_params={
                "pad_id": len(PROTEINS_FULL_ALPHABET) - 1,
                **self.network_params["encoder_params"],
            },
            output_dim=1,
            key=rng_key,
        )

        abstract_model = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, model)
        model = ckptr.restore(self.pretrained_proxy_path, abstract_model)
        model_params, model_static = eqx.partition(model, eqx.is_array)
        self.model_static = model_static
        self.offset = model.offset

        return {"model_params": model_params}

    def log_reward(self, state: AMPEnvState, env_params: AMPEnvParams) -> TLogReward:
        return jnp.log(self.reward(state, env_params))

    def reward(self, state: AMPEnvState, env_params: AMPEnvParams) -> TReward:
        # Lazy imports to avoid importing equinox in the main module
        import equinox as eqx

        model = eqx.combine(env_params.reward_params["model_params"], self.model_static)
        reward = jax.vmap(lambda x: model(x, enable_dropout=False, key=None))(state.tokens)
        reward = jnp.clip(
            jnp.pow(jax.nn.sigmoid(reward), self.reward_exponent),
            min=self.min_reward,
        ).squeeze(axis=-1)
        chex.assert_shape(reward, (state.tokens.shape[0],))  # [B]
        return reward

__init__(proxy_config_path, pretrained_proxy_path, reward_exponent=1.0, min_reward=1e-06)

Proxy reward model for amp

Source code in gfnx/reward/amp.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def __init__(
    self,
    proxy_config_path: str,
    pretrained_proxy_path: str,
    reward_exponent: float = 1.0,
    min_reward: float = 1e-6,
):
    """
    Proxy reward model for amp
    """
    # Load config to a proxy model
    import omegaconf

    cfg = omegaconf.OmegaConf.load(proxy_config_path)
    self.network_params = omegaconf.OmegaConf.to_container(cfg.network)
    self.pretrained_proxy_path = pretrained_proxy_path
    if not os.path.isabs(self.pretrained_proxy_path):
        # Assume that the path is relative to the root of the project
        module_path = os.path.abspath(
            os.path.join(os.path.dirname(__file__), "..", "..", "..")
        )
        self.pretrained_proxy_path = os.path.join(module_path, self.pretrained_proxy_path)

    self.reward_exponent = reward_exponent
    self.min_reward = min_reward

GFP reward

EqxProxyGFPRewardModule

Bases: BaseRewardModule[GFPEnvState, GFPEnvParams]

Source code in gfnx/reward/gfp.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class EqxProxyGFPRewardModule(BaseRewardModule[GFPEnvState, GFPEnvParams]):
    def __init__(
        self,
        proxy_config_path: str,
        pretrained_proxy_path: str,
        reward_exponent: float = 1.0,
        min_reward: float = 1e-6,
    ):
        """
        Proxy reward model for amp
        """
        # Load config to a proxy model
        import omegaconf

        cfg = omegaconf.OmegaConf.load(proxy_config_path)
        self.network_params = omegaconf.OmegaConf.to_container(cfg.network)
        self.pretrained_proxy_path = pretrained_proxy_path
        if not os.path.isabs(self.pretrained_proxy_path):
            # Assume that the path is relative to the root of the project
            module_path = os.path.abspath(
                os.path.join(os.path.dirname(__file__), "..", "..", "..")
            )
            self.pretrained_proxy_path = os.path.join(module_path, self.pretrained_proxy_path)

        self.reward_exponent = reward_exponent
        self.min_reward = min_reward

    def init(self, rng_key: chex.PRNGKey, dummy_state: GFPEnvState) -> TRewardParams:
        # Lazy imports to avoid importing equinox in the main module
        import equinox as eqx
        import orbax.checkpoint as ocp

        from ..networks.reward_models import EqxTransformerRewardModel

        ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
        model = EqxTransformerRewardModel(
            encoder_params={
                "pad_id": len(PROTEINS_FULL_ALPHABET) - 1,
                **self.network_params["encoder_params"],
            },
            output_dim=1,
            key=rng_key,
        )

        abstract_model = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, model)
        model = ckptr.restore(self.pretrained_proxy_path, abstract_model)
        model_params, model_static = eqx.partition(model, eqx.is_array)
        self.model_static = model_static
        self.offset = model.offset

        return {"model_params": model_params}

    def log_reward(self, state: GFPEnvState, env_params: GFPEnvParams) -> TLogReward:
        return jnp.log(self.reward(state, env_params))

    def reward(self, state: GFPEnvState, env_params: GFPEnvParams) -> TReward:
        # Lazy imports to avoid importing equinox in the main module
        import equinox as eqx

        model = eqx.combine(env_params.reward_params["model_params"], self.model_static)
        reward = jax.vmap(lambda x: model(x, enable_dropout=False, key=None))(state.tokens)
        reward = jnp.clip(reward + self.offset, min=self.min_reward)
        reward = reward.squeeze(axis=-1)
        chex.assert_shape(reward, (state.tokens.shape[0],))  # [B]
        return reward

__init__(proxy_config_path, pretrained_proxy_path, reward_exponent=1.0, min_reward=1e-06)

Proxy reward model for amp

Source code in gfnx/reward/gfp.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def __init__(
    self,
    proxy_config_path: str,
    pretrained_proxy_path: str,
    reward_exponent: float = 1.0,
    min_reward: float = 1e-6,
):
    """
    Proxy reward model for amp
    """
    # Load config to a proxy model
    import omegaconf

    cfg = omegaconf.OmegaConf.load(proxy_config_path)
    self.network_params = omegaconf.OmegaConf.to_container(cfg.network)
    self.pretrained_proxy_path = pretrained_proxy_path
    if not os.path.isabs(self.pretrained_proxy_path):
        # Assume that the path is relative to the root of the project
        module_path = os.path.abspath(
            os.path.join(os.path.dirname(__file__), "..", "..", "..")
        )
        self.pretrained_proxy_path = os.path.join(module_path, self.pretrained_proxy_path)

    self.reward_exponent = reward_exponent
    self.min_reward = min_reward