forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_dnc.py
83 lines (67 loc) · 2.37 KB
/
test_dnc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gymnasium as gym
import unittest
import ray
from ray import air
from ray import tune
from ray.rllib.algorithms.a2c import A2CConfig
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.examples.models.neural_computer import DNCMemory
from ray.rllib.utils.framework import try_import_torch
torch, _ = try_import_torch()
class TestDNC(unittest.TestCase):
stop = {
"episode_reward_mean": 100.0,
"timesteps_total": 10000000,
}
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4, ignore_reinit_error=True)
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_pack_unpack(self):
d = DNCMemory(gym.spaces.Discrete(1), gym.spaces.Discrete(1), 1, {}, "")
# Add batch dim
packed_state = [m.unsqueeze(0) for m in d.get_initial_state()]
[m.random_() for m in packed_state]
original_packed = [m.clone() for m in packed_state]
B, T = packed_state[0].shape[:2]
unpacked = d.unpack_state(packed_state)
packed = d.pack_state(*unpacked)
self.assertTrue(len(packed) > 0)
self.assertEqual(len(packed), len(original_packed))
for m_idx in range(len(packed)):
self.assertTrue(torch.all(packed[m_idx] == original_packed[m_idx]))
def test_dnc_learning(self):
ModelCatalog.register_custom_model("dnc", DNCMemory)
config = (
A2CConfig()
.environment(StatelessCartPole)
.framework("torch")
.rollouts(num_envs_per_worker=5, num_rollout_workers=1)
.training(
gamma=0.99,
lr=0.01,
entropy_coeff=0.0005,
vf_loss_coeff=1e-5,
model={
"custom_model": "dnc",
"max_seq_len": 64,
"custom_model_config": {
"nr_cells": 10,
"cell_size": 8,
},
},
)
.resources(num_cpus_per_worker=2.0)
)
tune.Tuner(
"A2C",
param_space=config,
run_config=air.RunConfig(stop=self.stop, verbose=1),
).fit()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))