LCOV - code coverage report
Current view: top level - envpool/atari - api_test.py (source / functions) Coverage Total Hit
Test: EnvPool coverage report Lines: 100.0 % 255 255
Test Date: 2026-04-07 08:10:29 Functions: - 0 0
Legend: Lines: hit not hit | Branches: + taken - not taken # not executed Branches: - 0 0

             Branch data     Line data    Source code
       1                 :             : # Copyright 2021 Garena Online Private Limited
       2                 :             : #
       3                 :             : # Licensed under the Apache License, Version 2.0 (the "License");
       4                 :             : # you may not use this file except in compliance with the License.
       5                 :             : # You may obtain a copy of the License at
       6                 :             : #
       7                 :             : #      http://www.apache.org/licenses/LICENSE-2.0
       8                 :             : #
       9                 :             : # Unless required by applicable law or agreed to in writing, software
      10                 :             : # distributed under the License is distributed on an "AS IS" BASIS,
      11                 :             : # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      12                 :             : # See the License for the specific language governing permissions and
      13                 :             : # limitations under the License.
      14                 :           1 : """Unit test for atari envpool and speed benchmark."""
      15                 :             : 
      16                 :           1 : from typing import no_type_check
      17                 :             : 
      18                 :           1 : import dm_env
      19                 :           1 : import gymnasium as gym
      20                 :           1 : import numpy as np
      21                 :           1 : from absl import logging
      22                 :           1 : from absl.testing import absltest
      23                 :             : 
      24                 :           1 : import envpool.atari.registration  # noqa: F401
      25                 :           1 : from envpool.atari import AtariEnvSpec
      26                 :           1 : from envpool.registration import make_dm, make_gym, make_spec
      27                 :             : 
      28                 :             : 
      29                 :           1 : class _SpecTest(absltest.TestCase):
      30            [ + ]:           1 :     @no_type_check
      31                 :           1 :     def test_spec(self) -> None:
      32                 :           1 :         action_nums = {"pong": 6, "breakout": 4}
      33                 :           1 :         for task in ["pong", "breakout"]:
      34            [ + ]:           1 :             action_num = action_nums[task]
      35                 :           1 :             spec = make_spec(task.capitalize() + "-v5")
      36                 :           1 :             logging.info(spec)
      37                 :           1 :             self.assertEqual(
      38                 :             :                 spec.action_array_spec["action"].maximum + 1, action_num
      39                 :             :             )
      40                 :             :             # check dm spec
      41                 :           1 :             dm_obs_spec = spec.observation_spec().obs
      42                 :           1 :             dm_act_spec = spec.action_spec()
      43                 :           1 :             self.assertEqual(len(spec.action_array_spec), 3)
      44                 :           1 :             self.assertIsInstance(dm_obs_spec, dm_env.specs.BoundedArray)
      45                 :           1 :             self.assertEqual(dm_obs_spec.dtype, np.uint8)
      46                 :           1 :             self.assertEqual(dm_obs_spec.maximum, 255)
      47                 :           1 :             self.assertIsInstance(dm_act_spec, dm_env.specs.DiscreteArray)
      48                 :           1 :             self.assertEqual(dm_act_spec.num_values, action_num)
      49                 :             :             # check gym space
      50                 :           1 :             gym_obs_space: gym.spaces.Box = spec.observation_space
      51                 :           1 :             gym_act_space: gym.spaces.Discrete = spec.action_space
      52                 :           1 :             self.assertEqual(len(spec.action_array_spec), 3)
      53                 :           1 :             self.assertIsInstance(gym_obs_space, gym.spaces.Box)
      54                 :           1 :             self.assertEqual(gym_obs_space.dtype, np.uint8)
      55                 :           1 :             np.testing.assert_allclose(gym_obs_space.high, 255)
      56                 :           1 :             self.assertIsInstance(gym_act_space, gym.spaces.Discrete)
      57                 :           1 :             self.assertEqual(gym_act_space.n, action_num)
      58                 :             : 
      59            [ + ]:           1 :     def test_seed_warning(self) -> None:
      60                 :           1 :         num_envs = 4
      61                 :           1 :         env = make_dm("Pong-v5", num_envs=num_envs)
      62                 :           1 :         with self.assertWarns(UserWarning):
      63            [ + ]:           1 :             env.seed(1)
      64            [ + ]:           1 :         env = make_gym("Pong-v5", num_envs=num_envs)
      65                 :           1 :         with self.assertWarns(UserWarning):
      66            [ + ]:           1 :             env.seed()
      67                 :             : 
      68                 :           1 :     def test_invalid_batch_size(self) -> None:
      69                 :           1 :         num_envs = 4
      70                 :           1 :         batch_size = 5
      71                 :           1 :         config = AtariEnvSpec.gen_config(
      72                 :             :             task="pong", num_envs=num_envs, batch_size=batch_size
      73                 :             :         )
      74                 :           1 :         self.assertRaises(ValueError, AtariEnvSpec, config)
      75                 :             : 
      76                 :           1 :     def test_mode_and_difficulty_defaults_and_overrides(self) -> None:
      77                 :           1 :         spec = make_spec("Breakout-v5")
      78                 :           1 :         self.assertEqual(spec.config.mode, -1)
      79                 :           1 :         self.assertEqual(spec.config.difficulty, -1)
      80                 :             : 
      81                 :           1 :         spec = make_spec("Breakout-v5", mode=4, difficulty=1)
      82                 :           1 :         self.assertEqual(spec.config.mode, 4)
      83                 :           1 :         self.assertEqual(spec.config.difficulty, 1)
      84                 :             : 
      85                 :           1 :     def test_metadata(self) -> None:
      86                 :           1 :         num_envs = 4
      87                 :           1 :         env = make_gym("Pong-v5", num_envs=num_envs)
      88                 :           1 :         self.assertEqual(len(env), num_envs)
      89                 :           1 :         self.assertFalse(env.is_async)
      90                 :           1 :         num_envs = 8
      91                 :           1 :         batch_size = 4
      92                 :           1 :         env = make_gym("Pong-v5", num_envs=num_envs, batch_size=batch_size)
      93                 :           1 :         self.assertEqual(len(env), num_envs)
      94                 :           1 :         self.assertTrue(env.is_async)
      95                 :           1 :         self.assertIsNone(env.spec.reward_threshold)
      96                 :             : 
      97                 :             : 
      98                 :           1 : class _DMSyncTest(absltest.TestCase):
      99            [ + ]:           1 :     @no_type_check
     100                 :           1 :     def test_spec(self) -> None:
     101                 :           1 :         action_nums = {"pong": 6, "breakout": 4}
     102                 :           1 :         for task in ["pong", "breakout"]:
     103            [ + ]:           1 :             action_num = action_nums[task]
     104                 :           1 :             env = make_dm(task.capitalize() + "-v5")
     105                 :           1 :             self.assertIsInstance(env, dm_env.Environment)
     106                 :           1 :             logging.info(env)
     107                 :             :             # check dm spec
     108                 :           1 :             dm_obs_spec = env.observation_spec().obs
     109                 :           1 :             dm_act_spec = env.action_spec()
     110                 :           1 :             self.assertIsInstance(dm_obs_spec, dm_env.specs.BoundedArray)
     111                 :           1 :             self.assertEqual(dm_obs_spec.dtype, np.uint8)
     112                 :           1 :             self.assertEqual(dm_obs_spec.maximum, 255)
     113                 :           1 :             self.assertIsInstance(dm_act_spec, dm_env.specs.DiscreteArray)
     114                 :           1 :             self.assertEqual(dm_act_spec.num_values, action_num)
     115                 :             : 
     116            [ + ]:           1 :     def test_lowlevel_step(self) -> None:
     117                 :           1 :         num_envs = 4
     118                 :           1 :         env = make_dm("Pong-v5", num_envs=num_envs)
     119                 :           1 :         logging.info(env)
     120                 :           1 :         env.async_reset()
     121                 :           1 :         ts: dm_env.TimeStep = env.recv()
     122                 :             :         # check ts structure
     123                 :           1 :         self.assertTrue(np.all(ts.first()))
     124                 :           1 :         np.testing.assert_allclose(ts.step_type.shape, (num_envs,))
     125                 :           1 :         np.testing.assert_allclose(ts.reward.shape, (num_envs,))
     126                 :           1 :         self.assertEqual(ts.reward.dtype, np.float32)
     127                 :           1 :         np.testing.assert_allclose(ts.discount.shape, (num_envs,))
     128                 :           1 :         self.assertEqual(ts.discount.dtype, np.float32)
     129                 :           1 :         np.testing.assert_allclose(
     130                 :             :             ts.observation.obs.shape, (num_envs, 4, 84, 84)
     131                 :             :         )
     132                 :           1 :         self.assertEqual(ts.observation.obs.dtype, np.uint8)
     133                 :           1 :         np.testing.assert_allclose(ts.observation.lives.shape, (num_envs,))
     134                 :           1 :         self.assertEqual(ts.observation.lives.dtype, np.int32)
     135                 :           1 :         np.testing.assert_allclose(ts.observation.env_id, np.arange(num_envs))
     136                 :           1 :         self.assertEqual(ts.observation.env_id.dtype, np.int32)
     137                 :           1 :         np.testing.assert_allclose(
     138                 :             :             ts.observation.players.env_id.shape, (num_envs,)
     139                 :             :         )
     140                 :           1 :         self.assertEqual(ts.observation.players.env_id.dtype, np.int32)
     141                 :           1 :         action = {
     142                 :             :             "env_id": np.arange(num_envs),
     143                 :             :             "players.env_id": np.arange(num_envs),
     144                 :             :             "action": np.ones(num_envs, int),
     145                 :             :         }
     146                 :             :         # because in c++ side we define action is int32 instead of int64
     147                 :           1 :         self.assertRaises(RuntimeError, env.send, action)
     148                 :           1 :         action = {
     149                 :             :             "env_id": np.arange(num_envs, dtype=np.int32),
     150                 :             :             "players.env_id": np.arange(num_envs, dtype=np.int32),
     151                 :             :             "action": np.ones(num_envs, np.int32),
     152                 :             :         }
     153                 :           1 :         env.send(action)
     154                 :           1 :         ts1: dm_env.TimeStep = env.recv()
     155                 :           1 :         self.assertTrue(np.all(ts1.mid()))
     156                 :           1 :         action = np.ones(num_envs)
     157                 :           1 :         env.send(action)
     158                 :           1 :         ts2: dm_env.TimeStep = env.recv()
     159                 :           1 :         self.assertTrue(np.all(ts2.mid()))
     160                 :           1 :         while np.all(ts2.mid()):
     161            [ + ]:           1 :             env.send(np.random.randint(6, size=num_envs))
     162                 :           1 :             ts2 = env.recv()
     163            [ + ]:           1 :         env.send(np.random.randint(6, size=num_envs))
     164                 :           1 :         tsp1: dm_env.TimeStep = env.recv()
     165                 :           1 :         index = np.where(ts2.last())
     166                 :           1 :         np.testing.assert_allclose(ts2.discount[index], 0)
     167                 :           1 :         np.testing.assert_allclose(tsp1.step_type[index], dm_env.StepType.FIRST)
     168                 :           1 :         np.testing.assert_allclose(tsp1.discount[index], 1)
     169                 :             : 
     170                 :           1 :     def test_highlevel_step(self) -> None:
     171                 :           1 :         num_envs = 4
     172                 :             :         # defender game hangs infinitely in gym.make("Defender-v0")
     173                 :           1 :         env = make_dm("Defender-v5", num_envs=num_envs)
     174                 :           1 :         logging.info(env)
     175                 :           1 :         ts: dm_env.TimeStep = env.reset()
     176                 :             :         # check ts structure
     177                 :           1 :         self.assertTrue(np.all(ts.first()))
     178                 :           1 :         np.testing.assert_allclose(ts.step_type.shape, (num_envs,))
     179                 :           1 :         np.testing.assert_allclose(ts.reward.shape, (num_envs,))
     180                 :           1 :         self.assertEqual(ts.reward.dtype, np.float32)
     181                 :           1 :         np.testing.assert_allclose(ts.discount.shape, (num_envs,))
     182                 :           1 :         self.assertEqual(ts.discount.dtype, np.float32)
     183                 :           1 :         np.testing.assert_allclose(
     184                 :             :             ts.observation.obs.shape, (num_envs, 4, 84, 84)
     185                 :             :         )
     186                 :           1 :         self.assertEqual(ts.observation.obs.dtype, np.uint8)
     187                 :           1 :         np.testing.assert_allclose(ts.observation.lives.shape, (num_envs,))
     188                 :           1 :         self.assertEqual(ts.observation.lives.dtype, np.int32)
     189                 :           1 :         np.testing.assert_allclose(ts.observation.env_id, np.arange(num_envs))
     190                 :           1 :         self.assertEqual(ts.observation.env_id.dtype, np.int32)
     191                 :           1 :         np.testing.assert_allclose(
     192                 :             :             ts.observation.players.env_id.shape, (num_envs,)
     193                 :             :         )
     194                 :           1 :         self.assertEqual(ts.observation.players.env_id.dtype, np.int32)
     195                 :           1 :         action = {
     196                 :             :             "env_id": np.arange(num_envs),
     197                 :             :             "players.env_id": np.arange(num_envs),
     198                 :             :             "action": np.ones(num_envs, int),
     199                 :             :         }
     200                 :             :         # because in c++ side we define action is int32 instead of int64
     201                 :           1 :         self.assertRaises(RuntimeError, env.step, action)
     202                 :           1 :         action = {
     203                 :             :             "env_id": np.arange(num_envs, dtype=np.int32),
     204                 :             :             "players.env_id": np.arange(num_envs, dtype=np.int32),
     205                 :             :             "action": np.ones(num_envs, np.int32),
     206                 :             :         }
     207                 :           1 :         ts1: dm_env.TimeStep = env.step(action)
     208                 :           1 :         self.assertTrue(np.all(ts1.mid()))
     209                 :           1 :         action = np.ones(num_envs)
     210                 :           1 :         ts2: dm_env.TimeStep = env.step(action)
     211                 :           1 :         self.assertTrue(np.all(ts2.mid()))
     212                 :           1 :         while np.all(ts2.mid()):
     213            [ + ]:           1 :             ts2 = env.step(np.random.randint(18, size=num_envs))
     214            [ + ]:           1 :         tsp1: dm_env.TimeStep = env.step(np.random.randint(18, size=num_envs))
     215                 :           1 :         index = np.where(ts2.last())
     216                 :           1 :         np.testing.assert_allclose(ts2.discount[index], 0)
     217                 :           1 :         np.testing.assert_allclose(tsp1.step_type[index], dm_env.StepType.FIRST)
     218                 :           1 :         np.testing.assert_allclose(tsp1.discount[index], 1)
     219                 :             : 
     220                 :             : 
     221                 :           1 : class _GymSyncTest(absltest.TestCase):
     222            [ + ]:           1 :     @no_type_check
     223                 :           1 :     def test_spec(self) -> None:
     224                 :           1 :         action_nums = {"pong": 6, "breakout": 4}
     225                 :           1 :         for task in ["pong", "breakout"]:
     226            [ + ]:           1 :             action_num = action_nums[task]
     227                 :           1 :             env = make_gym(task.capitalize() + "-v5")
     228                 :           1 :             self.assertIsInstance(env, gym.Env)
     229                 :           1 :             logging.info(env)
     230                 :             :             # check gym space
     231                 :           1 :             gym_obs_space: gym.spaces.Box = env.observation_space
     232                 :           1 :             gym_act_space: gym.spaces.Discrete = env.action_space
     233                 :           1 :             self.assertEqual(len(env.spec.action_array_spec), 3)
     234                 :           1 :             self.assertIsInstance(gym_obs_space, gym.spaces.Box)
     235                 :           1 :             self.assertEqual(gym_obs_space.dtype, np.uint8)
     236                 :           1 :             np.testing.assert_allclose(gym_obs_space.high, 255)
     237                 :           1 :             self.assertIsInstance(gym_act_space, gym.spaces.Discrete)
     238                 :           1 :             self.assertEqual(gym_act_space.n, action_num)
     239                 :             :             # Issue 207
     240                 :           1 :             gym_act_space.seed(1)
     241                 :           1 :             action0 = gym_act_space.sample()
     242                 :           1 :             gym_act_space.seed(1)
     243                 :           1 :             action1 = gym_act_space.sample()
     244                 :           1 :             self.assertEqual(action0, action1)
     245                 :           1 :             env.action_space.seed(2)
     246                 :           1 :             action2 = env.action_space.sample()
     247                 :           1 :             env.action_space.seed(2)
     248                 :           1 :             action3 = env.action_space.sample()
     249                 :           1 :             self.assertEqual(action2, action3)
     250                 :             : 
     251            [ + ]:           1 :     def test_lowlevel_step(self) -> None:
     252                 :           1 :         num_envs = 4
     253                 :           1 :         env = make_gym("Breakout-v5", num_envs=num_envs)
     254                 :           1 :         self.assertTrue(isinstance(env, gym.Env))
     255                 :           1 :         logging.info(env)
     256                 :           1 :         env.async_reset()
     257                 :           1 :         obs, rew, terminated, truncated, info = env.recv()
     258                 :           1 :         done = np.logical_or(terminated, truncated)
     259                 :             :         # check shape
     260                 :           1 :         self.assertIsInstance(obs, np.ndarray)
     261                 :           1 :         self.assertEqual(obs.dtype, np.uint8)
     262                 :           1 :         np.testing.assert_allclose(rew.shape, (num_envs,))
     263                 :           1 :         self.assertEqual(rew.dtype, np.float32)
     264                 :           1 :         np.testing.assert_allclose(done.shape, (num_envs,))
     265                 :           1 :         self.assertEqual(done.dtype, np.bool_)
     266                 :           1 :         self.assertEqual(terminated.dtype, np.bool_)
     267                 :           1 :         self.assertEqual(truncated.dtype, np.bool_)
     268                 :           1 :         self.assertIsInstance(info, dict)
     269                 :           1 :         self.assertEqual(len(info), 7)
     270                 :           1 :         self.assertEqual(info["env_id"].dtype, np.int32)
     271                 :           1 :         self.assertEqual(info["lives"].dtype, np.int32)
     272                 :           1 :         self.assertEqual(info["ram"].dtype, np.uint8)
     273                 :           1 :         self.assertEqual(info["players"]["env_id"].dtype, np.int32)
     274                 :           1 :         np.testing.assert_allclose(info["env_id"], np.arange(num_envs))
     275                 :           1 :         np.testing.assert_allclose(info["lives"].shape, (num_envs,))
     276                 :           1 :         np.testing.assert_allclose(info["ram"].shape, (num_envs, 128))
     277                 :           1 :         np.testing.assert_allclose(info["players"]["env_id"].shape, (num_envs,))
     278                 :           1 :         np.testing.assert_allclose(truncated.shape, (num_envs,))
     279                 :           1 :         while not np.any(done):
     280            [ + ]:           1 :             env.send(np.random.randint(6, size=num_envs))
     281                 :           1 :             obs, rew, terminated, truncated, info = env.recv()
     282                 :           1 :             done = np.logical_or(terminated, truncated)
     283            [ + ]:           1 :         env.send(np.random.randint(6, size=num_envs))
     284                 :           1 :         obs1, rew1, terminated1, truncated1, info1 = env.recv()
     285                 :           1 :         done1 = np.logical_or(terminated1, truncated1)
     286                 :           1 :         index = np.where(done)[0]
     287                 :           1 :         self.assertTrue(np.all(~done1[index]))
     288                 :             : 
     289                 :           1 :     def test_highlevel_step(self) -> None:
     290                 :           1 :         num_envs = 4
     291                 :           1 :         env = make_gym("Pong-v5", num_envs=num_envs)
     292                 :           1 :         self.assertTrue(isinstance(env, gym.Env))
     293                 :           1 :         logging.info(env)
     294                 :           1 :         obs, _ = env.reset()
     295                 :             :         # check shape
     296                 :           1 :         self.assertIsInstance(obs, np.ndarray)
     297                 :           1 :         self.assertEqual(obs.dtype, np.uint8)
     298                 :           1 :         obs, rew, terminated, truncated, info = env.step(
     299                 :             :             np.random.randint(6, size=num_envs)
     300                 :             :         )
     301                 :           1 :         done = np.logical_or(terminated, truncated)
     302                 :           1 :         self.assertIsInstance(obs, np.ndarray)
     303                 :           1 :         self.assertEqual(obs.dtype, np.uint8)
     304                 :           1 :         np.testing.assert_allclose(rew.shape, (num_envs,))
     305                 :           1 :         self.assertEqual(rew.dtype, np.float32)
     306                 :           1 :         np.testing.assert_allclose(done.shape, (num_envs,))
     307                 :           1 :         self.assertEqual(done.dtype, np.bool_)
     308                 :           1 :         self.assertIsInstance(info, dict)
     309                 :           1 :         self.assertEqual(len(info), 7)
     310                 :           1 :         self.assertEqual(info["env_id"].dtype, np.int32)
     311                 :           1 :         self.assertEqual(info["lives"].dtype, np.int32)
     312                 :           1 :         self.assertEqual(info["ram"].dtype, np.uint8)
     313                 :           1 :         self.assertEqual(info["players"]["env_id"].dtype, np.int32)
     314                 :           1 :         self.assertEqual(truncated.dtype, np.bool_)
     315                 :           1 :         np.testing.assert_allclose(info["env_id"], np.arange(num_envs))
     316                 :           1 :         np.testing.assert_allclose(info["lives"].shape, (num_envs,))
     317                 :           1 :         np.testing.assert_allclose(info["ram"].shape, (num_envs, 128))
     318                 :           1 :         np.testing.assert_allclose(info["players"]["env_id"].shape, (num_envs,))
     319                 :           1 :         np.testing.assert_allclose(truncated.shape, (num_envs,))
     320                 :           1 :         while not np.any(done):
     321            [ + ]:           1 :             obs, rew, terminated, truncated, info = env.step(
     322                 :             :                 np.random.randint(6, size=num_envs)
     323                 :             :             )
     324                 :           1 :             done = np.logical_or(terminated, truncated)
     325            [ + ]:           1 :         obs1, rew1, terminated1, truncated1, info1 = env.step(
     326                 :             :             np.random.randint(6, size=num_envs)
     327                 :             :         )
     328                 :           1 :         done1 = np.logical_or(terminated1, truncated1)
     329                 :           1 :         index = np.where(done)[0]
     330                 :           1 :         self.assertTrue(np.all(~done1[index]))
     331                 :             : 
     332                 :             : 
     333                 :           1 : if __name__ == "__main__":
     334            [ + ]:           1 :     absltest.main()
        

Generated by: LCOV version 2.0-1