Branch data Line data Source code
1 : : # Copyright 2021 Garena Online Private Limited
2 : : #
3 : : # Licensed under the Apache License, Version 2.0 (the "License");
4 : : # you may not use this file except in compliance with the License.
5 : : # You may obtain a copy of the License at
6 : : #
7 : : # http://www.apache.org/licenses/LICENSE-2.0
8 : : #
9 : : # Unless required by applicable law or agreed to in writing, software
10 : : # distributed under the License is distributed on an "AS IS" BASIS,
11 : : # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : : # See the License for the specific language governing permissions and
13 : : # limitations under the License.
14 : 1 : """Test EnvPool by well-trained RL agents."""
15 : :
16 : 1 : import os
17 : 1 : import sys
18 : 1 : from typing import Any, cast
19 : :
20 : 1 : import numpy as np
21 : 1 : import torch
22 : 1 : from absl import logging
23 : 1 : from absl.testing import absltest
24 : 1 : from tianshou.data import Batch
25 : 1 : from tianshou.policy import QRDQNPolicy
26 : :
27 : 1 : import envpool.atari.registration # noqa: F401
28 : 1 : from envpool.atari.atari_network import QRDQN
29 : 1 : from envpool.registration import make_gym
30 : :
31 : : # try:
32 : : # import cv2
33 : : # except ImportError:
34 : : # cv2 = None
35 : :
36 : :
37 : 1 : class _AtariPretrainTest(absltest.TestCase):
38 : 1 : def eval_qrdqn(
39 : : self,
40 : : task: str,
41 : : resume_path: str,
42 : : num_envs: int = 10,
43 : : seed: int = 0,
44 : : target_reward: float = 0.0,
45 : : delta: float | None = None,
46 : : ) -> None:
47 : 1 : env = make_gym(task.capitalize() + "-v5", num_envs=num_envs, seed=seed)
48 : 1 : state_shape = env.observation_space.shape
49 : 1 : action_shape = env.action_space.n
50 : 1 : device = "cuda" if torch.cuda.is_available() else "cpu"
51 : 1 : np.random.seed(seed)
52 : 1 : torch.manual_seed(seed)
53 : 1 : logging.info(state_shape)
54 : 1 : net = QRDQN(*state_shape, action_shape, 200, device) # type: ignore
55 : 1 : optim = torch.optim.Adam(net.parameters(), lr=1e-4)
56 : 1 : policy = QRDQNPolicy(
57 : : net, optim, 0.99, 200, 3, target_update_freq=500
58 : : ).to(device)
59 : 1 : policy.load_state_dict(torch.load(resume_path, map_location=device))
60 : 1 : policy.eval()
61 : 1 : ids = cast(Any, np.arange(num_envs))
62 : 1 : reward = np.zeros(num_envs)
63 : 1 : obs, _ = env.reset()
64 [ + ]: 1 : for _ in range(25000):
65 [ + ]: 1 : if np.random.rand() < 5e-3:
66 [ + ]: 1 : act = np.random.randint(action_shape, size=len(ids))
67 : : else:
68 [ + ]: 1 : act = policy(Batch(obs=obs, info={})).act
69 : 1 : obs, rew, terminated, truncated, info = env.step(act, ids)
70 : 1 : done = np.logical_or(terminated, truncated)
71 : 1 : ids = cast(Any, np.asarray(info["env_id"]))
72 : 1 : reward[ids] += rew
73 : 1 : obs = obs[~done]
74 : 1 : ids = ids[~done]
75 : 1 : if len(ids) == 0:
76 [ + ]: 1 : break
77 : : # if cv2 is not None:
78 : : # obs_all = np.zeros((84, 84 * num_envs, 3), np.uint8)
79 : : # for i, j in enumerate(ids):
80 : : # obs_all[:, 84 * j:84 * (j + 1)] = obs[i, 1:].transpose(1, 2, 0)
81 : : # cv2.imwrite(f"/tmp/{task}-{t}.png", obs_all)
82 : :
83 [ + ]: 1 : rew = reward.mean()
84 : 1 : logging.info(f"Mean reward of {task}: {rew}")
85 : 1 : if delta is None:
86 [ + ]: 1 : self.assertAlmostEqual(rew, target_reward)
87 : : else:
88 [ # ]: 0 : self.assertAlmostEqual(rew, target_reward, delta=delta)
89 : :
90 : 1 : def test_pong(self) -> None:
91 : 1 : model_path = os.path.join("envpool", "atari", "policy-pong.pth")
92 : 1 : self.assertTrue(os.path.exists(model_path))
93 : 1 : if sys.platform == "darwin":
94 : : # Apple Silicon torch inference lands a tick below the historical
95 : : # Linux golden, but remains stable across reruns.
96 [ # ]: 0 : self.eval_qrdqn("pong", model_path, target_reward=20.5, delta=0.5)
97 : : else:
98 [ + ]: 1 : self.eval_qrdqn("pong", model_path, target_reward=20.6)
99 : :
100 : 1 : def test_breakout(self) -> None:
101 : 1 : model_path = os.path.join("envpool", "atari", "policy-breakout.pth")
102 : 1 : self.assertTrue(os.path.exists(model_path))
103 : 1 : if sys.platform == "darwin":
104 [ # ]: 0 : self.eval_qrdqn(
105 : : "breakout",
106 : : model_path,
107 : : target_reward=365.1,
108 : : delta=5.0,
109 : : )
110 : : else:
111 [ + ]: 1 : self.eval_qrdqn("breakout", model_path, target_reward=367.8)
112 : :
113 : :
114 : 1 : if __name__ == "__main__":
115 [ + ]: 1 : absltest.main()
|