Branch data Line data Source code
1 : : # Copyright 2021 Garena Online Private Limited
2 : : #
3 : : # Licensed under the Apache License, Version 2.0 (the "License");
4 : : # you may not use this file except in compliance with the License.
5 : : # You may obtain a copy of the License at
6 : : #
7 : : # http://www.apache.org/licenses/LICENSE-2.0
8 : : #
9 : : # Unless required by applicable law or agreed to in writing, software
10 : : # distributed under the License is distributed on an "AS IS" BASIS,
11 : : # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : : # See the License for the specific language governing permissions and
13 : : # limitations under the License.
14 : 1 : """Unit test for atari envpool and speed benchmark."""
15 : :
16 : 1 : from typing import no_type_check
17 : :
18 : 1 : import dm_env
19 : 1 : import gymnasium as gym
20 : 1 : import numpy as np
21 : 1 : from absl import logging
22 : 1 : from absl.testing import absltest
23 : :
24 : 1 : import envpool.atari.registration # noqa: F401
25 : 1 : from envpool.atari import AtariEnvSpec
26 : 1 : from envpool.registration import make_dm, make_gym, make_spec
27 : :
28 : :
29 : 1 : class _SpecTest(absltest.TestCase):
30 [ + ]: 1 : @no_type_check
31 : 1 : def test_spec(self) -> None:
32 : 1 : action_nums = {"pong": 6, "breakout": 4}
33 : 1 : for task in ["pong", "breakout"]:
34 [ + ]: 1 : action_num = action_nums[task]
35 : 1 : spec = make_spec(task.capitalize() + "-v5")
36 : 1 : logging.info(spec)
37 : 1 : self.assertEqual(
38 : : spec.action_array_spec["action"].maximum + 1, action_num
39 : : )
40 : : # check dm spec
41 : 1 : dm_obs_spec = spec.observation_spec().obs
42 : 1 : dm_act_spec = spec.action_spec()
43 : 1 : self.assertEqual(len(spec.action_array_spec), 3)
44 : 1 : self.assertIsInstance(dm_obs_spec, dm_env.specs.BoundedArray)
45 : 1 : self.assertEqual(dm_obs_spec.dtype, np.uint8)
46 : 1 : self.assertEqual(dm_obs_spec.maximum, 255)
47 : 1 : self.assertIsInstance(dm_act_spec, dm_env.specs.DiscreteArray)
48 : 1 : self.assertEqual(dm_act_spec.num_values, action_num)
49 : : # check gym space
50 : 1 : gym_obs_space: gym.spaces.Box = spec.observation_space
51 : 1 : gym_act_space: gym.spaces.Discrete = spec.action_space
52 : 1 : self.assertEqual(len(spec.action_array_spec), 3)
53 : 1 : self.assertIsInstance(gym_obs_space, gym.spaces.Box)
54 : 1 : self.assertEqual(gym_obs_space.dtype, np.uint8)
55 : 1 : np.testing.assert_allclose(gym_obs_space.high, 255)
56 : 1 : self.assertIsInstance(gym_act_space, gym.spaces.Discrete)
57 : 1 : self.assertEqual(gym_act_space.n, action_num)
58 : :
59 [ + ]: 1 : def test_seed_warning(self) -> None:
60 : 1 : num_envs = 4
61 : 1 : env = make_dm("Pong-v5", num_envs=num_envs)
62 : 1 : with self.assertWarns(UserWarning):
63 [ + ]: 1 : env.seed(1)
64 [ + ]: 1 : env = make_gym("Pong-v5", num_envs=num_envs)
65 : 1 : with self.assertWarns(UserWarning):
66 [ + ]: 1 : env.seed()
67 : :
68 : 1 : def test_invalid_batch_size(self) -> None:
69 : 1 : num_envs = 4
70 : 1 : batch_size = 5
71 : 1 : config = AtariEnvSpec.gen_config(
72 : : task="pong", num_envs=num_envs, batch_size=batch_size
73 : : )
74 : 1 : self.assertRaises(ValueError, AtariEnvSpec, config)
75 : :
76 : 1 : def test_mode_and_difficulty_defaults_and_overrides(self) -> None:
77 : 1 : spec = make_spec("Breakout-v5")
78 : 1 : self.assertEqual(spec.config.mode, -1)
79 : 1 : self.assertEqual(spec.config.difficulty, -1)
80 : :
81 : 1 : spec = make_spec("Breakout-v5", mode=4, difficulty=1)
82 : 1 : self.assertEqual(spec.config.mode, 4)
83 : 1 : self.assertEqual(spec.config.difficulty, 1)
84 : :
85 : 1 : def test_metadata(self) -> None:
86 : 1 : num_envs = 4
87 : 1 : env = make_gym("Pong-v5", num_envs=num_envs)
88 : 1 : self.assertEqual(len(env), num_envs)
89 : 1 : self.assertFalse(env.is_async)
90 : 1 : num_envs = 8
91 : 1 : batch_size = 4
92 : 1 : env = make_gym("Pong-v5", num_envs=num_envs, batch_size=batch_size)
93 : 1 : self.assertEqual(len(env), num_envs)
94 : 1 : self.assertTrue(env.is_async)
95 : 1 : self.assertIsNone(env.spec.reward_threshold)
96 : :
97 : :
98 : 1 : class _DMSyncTest(absltest.TestCase):
99 [ + ]: 1 : @no_type_check
100 : 1 : def test_spec(self) -> None:
101 : 1 : action_nums = {"pong": 6, "breakout": 4}
102 : 1 : for task in ["pong", "breakout"]:
103 [ + ]: 1 : action_num = action_nums[task]
104 : 1 : env = make_dm(task.capitalize() + "-v5")
105 : 1 : self.assertIsInstance(env, dm_env.Environment)
106 : 1 : logging.info(env)
107 : : # check dm spec
108 : 1 : dm_obs_spec = env.observation_spec().obs
109 : 1 : dm_act_spec = env.action_spec()
110 : 1 : self.assertIsInstance(dm_obs_spec, dm_env.specs.BoundedArray)
111 : 1 : self.assertEqual(dm_obs_spec.dtype, np.uint8)
112 : 1 : self.assertEqual(dm_obs_spec.maximum, 255)
113 : 1 : self.assertIsInstance(dm_act_spec, dm_env.specs.DiscreteArray)
114 : 1 : self.assertEqual(dm_act_spec.num_values, action_num)
115 : :
116 [ + ]: 1 : def test_lowlevel_step(self) -> None:
117 : 1 : num_envs = 4
118 : 1 : env = make_dm("Pong-v5", num_envs=num_envs)
119 : 1 : logging.info(env)
120 : 1 : env.async_reset()
121 : 1 : ts: dm_env.TimeStep = env.recv()
122 : : # check ts structure
123 : 1 : self.assertTrue(np.all(ts.first()))
124 : 1 : np.testing.assert_allclose(ts.step_type.shape, (num_envs,))
125 : 1 : np.testing.assert_allclose(ts.reward.shape, (num_envs,))
126 : 1 : self.assertEqual(ts.reward.dtype, np.float32)
127 : 1 : np.testing.assert_allclose(ts.discount.shape, (num_envs,))
128 : 1 : self.assertEqual(ts.discount.dtype, np.float32)
129 : 1 : np.testing.assert_allclose(
130 : : ts.observation.obs.shape, (num_envs, 4, 84, 84)
131 : : )
132 : 1 : self.assertEqual(ts.observation.obs.dtype, np.uint8)
133 : 1 : np.testing.assert_allclose(ts.observation.lives.shape, (num_envs,))
134 : 1 : self.assertEqual(ts.observation.lives.dtype, np.int32)
135 : 1 : np.testing.assert_allclose(ts.observation.env_id, np.arange(num_envs))
136 : 1 : self.assertEqual(ts.observation.env_id.dtype, np.int32)
137 : 1 : np.testing.assert_allclose(
138 : : ts.observation.players.env_id.shape, (num_envs,)
139 : : )
140 : 1 : self.assertEqual(ts.observation.players.env_id.dtype, np.int32)
141 : 1 : action = {
142 : : "env_id": np.arange(num_envs),
143 : : "players.env_id": np.arange(num_envs),
144 : : "action": np.ones(num_envs, int),
145 : : }
146 : : # because in c++ side we define action is int32 instead of int64
147 : 1 : self.assertRaises(RuntimeError, env.send, action)
148 : 1 : action = {
149 : : "env_id": np.arange(num_envs, dtype=np.int32),
150 : : "players.env_id": np.arange(num_envs, dtype=np.int32),
151 : : "action": np.ones(num_envs, np.int32),
152 : : }
153 : 1 : env.send(action)
154 : 1 : ts1: dm_env.TimeStep = env.recv()
155 : 1 : self.assertTrue(np.all(ts1.mid()))
156 : 1 : action = np.ones(num_envs)
157 : 1 : env.send(action)
158 : 1 : ts2: dm_env.TimeStep = env.recv()
159 : 1 : self.assertTrue(np.all(ts2.mid()))
160 : 1 : while np.all(ts2.mid()):
161 [ + ]: 1 : env.send(np.random.randint(6, size=num_envs))
162 : 1 : ts2 = env.recv()
163 [ + ]: 1 : env.send(np.random.randint(6, size=num_envs))
164 : 1 : tsp1: dm_env.TimeStep = env.recv()
165 : 1 : index = np.where(ts2.last())
166 : 1 : np.testing.assert_allclose(ts2.discount[index], 0)
167 : 1 : np.testing.assert_allclose(tsp1.step_type[index], dm_env.StepType.FIRST)
168 : 1 : np.testing.assert_allclose(tsp1.discount[index], 1)
169 : :
170 : 1 : def test_highlevel_step(self) -> None:
171 : 1 : num_envs = 4
172 : : # defender game hangs infinitely in gym.make("Defender-v0")
173 : 1 : env = make_dm("Defender-v5", num_envs=num_envs)
174 : 1 : logging.info(env)
175 : 1 : ts: dm_env.TimeStep = env.reset()
176 : : # check ts structure
177 : 1 : self.assertTrue(np.all(ts.first()))
178 : 1 : np.testing.assert_allclose(ts.step_type.shape, (num_envs,))
179 : 1 : np.testing.assert_allclose(ts.reward.shape, (num_envs,))
180 : 1 : self.assertEqual(ts.reward.dtype, np.float32)
181 : 1 : np.testing.assert_allclose(ts.discount.shape, (num_envs,))
182 : 1 : self.assertEqual(ts.discount.dtype, np.float32)
183 : 1 : np.testing.assert_allclose(
184 : : ts.observation.obs.shape, (num_envs, 4, 84, 84)
185 : : )
186 : 1 : self.assertEqual(ts.observation.obs.dtype, np.uint8)
187 : 1 : np.testing.assert_allclose(ts.observation.lives.shape, (num_envs,))
188 : 1 : self.assertEqual(ts.observation.lives.dtype, np.int32)
189 : 1 : np.testing.assert_allclose(ts.observation.env_id, np.arange(num_envs))
190 : 1 : self.assertEqual(ts.observation.env_id.dtype, np.int32)
191 : 1 : np.testing.assert_allclose(
192 : : ts.observation.players.env_id.shape, (num_envs,)
193 : : )
194 : 1 : self.assertEqual(ts.observation.players.env_id.dtype, np.int32)
195 : 1 : action = {
196 : : "env_id": np.arange(num_envs),
197 : : "players.env_id": np.arange(num_envs),
198 : : "action": np.ones(num_envs, int),
199 : : }
200 : : # because in c++ side we define action is int32 instead of int64
201 : 1 : self.assertRaises(RuntimeError, env.step, action)
202 : 1 : action = {
203 : : "env_id": np.arange(num_envs, dtype=np.int32),
204 : : "players.env_id": np.arange(num_envs, dtype=np.int32),
205 : : "action": np.ones(num_envs, np.int32),
206 : : }
207 : 1 : ts1: dm_env.TimeStep = env.step(action)
208 : 1 : self.assertTrue(np.all(ts1.mid()))
209 : 1 : action = np.ones(num_envs)
210 : 1 : ts2: dm_env.TimeStep = env.step(action)
211 : 1 : self.assertTrue(np.all(ts2.mid()))
212 : 1 : while np.all(ts2.mid()):
213 [ + ]: 1 : ts2 = env.step(np.random.randint(18, size=num_envs))
214 [ + ]: 1 : tsp1: dm_env.TimeStep = env.step(np.random.randint(18, size=num_envs))
215 : 1 : index = np.where(ts2.last())
216 : 1 : np.testing.assert_allclose(ts2.discount[index], 0)
217 : 1 : np.testing.assert_allclose(tsp1.step_type[index], dm_env.StepType.FIRST)
218 : 1 : np.testing.assert_allclose(tsp1.discount[index], 1)
219 : :
220 : :
221 : 1 : class _GymSyncTest(absltest.TestCase):
222 [ + ]: 1 : @no_type_check
223 : 1 : def test_spec(self) -> None:
224 : 1 : action_nums = {"pong": 6, "breakout": 4}
225 : 1 : for task in ["pong", "breakout"]:
226 [ + ]: 1 : action_num = action_nums[task]
227 : 1 : env = make_gym(task.capitalize() + "-v5")
228 : 1 : self.assertIsInstance(env, gym.Env)
229 : 1 : logging.info(env)
230 : : # check gym space
231 : 1 : gym_obs_space: gym.spaces.Box = env.observation_space
232 : 1 : gym_act_space: gym.spaces.Discrete = env.action_space
233 : 1 : self.assertEqual(len(env.spec.action_array_spec), 3)
234 : 1 : self.assertIsInstance(gym_obs_space, gym.spaces.Box)
235 : 1 : self.assertEqual(gym_obs_space.dtype, np.uint8)
236 : 1 : np.testing.assert_allclose(gym_obs_space.high, 255)
237 : 1 : self.assertIsInstance(gym_act_space, gym.spaces.Discrete)
238 : 1 : self.assertEqual(gym_act_space.n, action_num)
239 : : # Issue 207
240 : 1 : gym_act_space.seed(1)
241 : 1 : action0 = gym_act_space.sample()
242 : 1 : gym_act_space.seed(1)
243 : 1 : action1 = gym_act_space.sample()
244 : 1 : self.assertEqual(action0, action1)
245 : 1 : env.action_space.seed(2)
246 : 1 : action2 = env.action_space.sample()
247 : 1 : env.action_space.seed(2)
248 : 1 : action3 = env.action_space.sample()
249 : 1 : self.assertEqual(action2, action3)
250 : :
251 [ + ]: 1 : def test_lowlevel_step(self) -> None:
252 : 1 : num_envs = 4
253 : 1 : env = make_gym("Breakout-v5", num_envs=num_envs)
254 : 1 : self.assertTrue(isinstance(env, gym.Env))
255 : 1 : logging.info(env)
256 : 1 : env.async_reset()
257 : 1 : obs, rew, terminated, truncated, info = env.recv()
258 : 1 : done = np.logical_or(terminated, truncated)
259 : : # check shape
260 : 1 : self.assertIsInstance(obs, np.ndarray)
261 : 1 : self.assertEqual(obs.dtype, np.uint8)
262 : 1 : np.testing.assert_allclose(rew.shape, (num_envs,))
263 : 1 : self.assertEqual(rew.dtype, np.float32)
264 : 1 : np.testing.assert_allclose(done.shape, (num_envs,))
265 : 1 : self.assertEqual(done.dtype, np.bool_)
266 : 1 : self.assertEqual(terminated.dtype, np.bool_)
267 : 1 : self.assertEqual(truncated.dtype, np.bool_)
268 : 1 : self.assertIsInstance(info, dict)
269 : 1 : self.assertEqual(len(info), 7)
270 : 1 : self.assertEqual(info["env_id"].dtype, np.int32)
271 : 1 : self.assertEqual(info["lives"].dtype, np.int32)
272 : 1 : self.assertEqual(info["ram"].dtype, np.uint8)
273 : 1 : self.assertEqual(info["players"]["env_id"].dtype, np.int32)
274 : 1 : np.testing.assert_allclose(info["env_id"], np.arange(num_envs))
275 : 1 : np.testing.assert_allclose(info["lives"].shape, (num_envs,))
276 : 1 : np.testing.assert_allclose(info["ram"].shape, (num_envs, 128))
277 : 1 : np.testing.assert_allclose(info["players"]["env_id"].shape, (num_envs,))
278 : 1 : np.testing.assert_allclose(truncated.shape, (num_envs,))
279 : 1 : while not np.any(done):
280 [ + ]: 1 : env.send(np.random.randint(6, size=num_envs))
281 : 1 : obs, rew, terminated, truncated, info = env.recv()
282 : 1 : done = np.logical_or(terminated, truncated)
283 [ + ]: 1 : env.send(np.random.randint(6, size=num_envs))
284 : 1 : obs1, rew1, terminated1, truncated1, info1 = env.recv()
285 : 1 : done1 = np.logical_or(terminated1, truncated1)
286 : 1 : index = np.where(done)[0]
287 : 1 : self.assertTrue(np.all(~done1[index]))
288 : :
289 : 1 : def test_highlevel_step(self) -> None:
290 : 1 : num_envs = 4
291 : 1 : env = make_gym("Pong-v5", num_envs=num_envs)
292 : 1 : self.assertTrue(isinstance(env, gym.Env))
293 : 1 : logging.info(env)
294 : 1 : obs, _ = env.reset()
295 : : # check shape
296 : 1 : self.assertIsInstance(obs, np.ndarray)
297 : 1 : self.assertEqual(obs.dtype, np.uint8)
298 : 1 : obs, rew, terminated, truncated, info = env.step(
299 : : np.random.randint(6, size=num_envs)
300 : : )
301 : 1 : done = np.logical_or(terminated, truncated)
302 : 1 : self.assertIsInstance(obs, np.ndarray)
303 : 1 : self.assertEqual(obs.dtype, np.uint8)
304 : 1 : np.testing.assert_allclose(rew.shape, (num_envs,))
305 : 1 : self.assertEqual(rew.dtype, np.float32)
306 : 1 : np.testing.assert_allclose(done.shape, (num_envs,))
307 : 1 : self.assertEqual(done.dtype, np.bool_)
308 : 1 : self.assertIsInstance(info, dict)
309 : 1 : self.assertEqual(len(info), 7)
310 : 1 : self.assertEqual(info["env_id"].dtype, np.int32)
311 : 1 : self.assertEqual(info["lives"].dtype, np.int32)
312 : 1 : self.assertEqual(info["ram"].dtype, np.uint8)
313 : 1 : self.assertEqual(info["players"]["env_id"].dtype, np.int32)
314 : 1 : self.assertEqual(truncated.dtype, np.bool_)
315 : 1 : np.testing.assert_allclose(info["env_id"], np.arange(num_envs))
316 : 1 : np.testing.assert_allclose(info["lives"].shape, (num_envs,))
317 : 1 : np.testing.assert_allclose(info["ram"].shape, (num_envs, 128))
318 : 1 : np.testing.assert_allclose(info["players"]["env_id"].shape, (num_envs,))
319 : 1 : np.testing.assert_allclose(truncated.shape, (num_envs,))
320 : 1 : while not np.any(done):
321 [ + ]: 1 : obs, rew, terminated, truncated, info = env.step(
322 : : np.random.randint(6, size=num_envs)
323 : : )
324 : 1 : done = np.logical_or(terminated, truncated)
325 [ + ]: 1 : obs1, rew1, terminated1, truncated1, info1 = env.step(
326 : : np.random.randint(6, size=num_envs)
327 : : )
328 : 1 : done1 = np.logical_or(terminated1, truncated1)
329 : 1 : index = np.where(done)[0]
330 : 1 : self.assertTrue(np.all(~done1[index]))
331 : :
332 : :
333 : 1 : if __name__ == "__main__":
334 [ + ]: 1 : absltest.main()
|