LCOV - code coverage report
Current view: top level - envpool/box2d - box2d_render_test.py (source / functions) Coverage Total Hit
Test: EnvPool coverage report Lines: 82.6 % 149 123
Test Date: 2026-04-15 02:10:58 Functions: - 0 0
Legend: Lines: hit not hit | Branches: + taken - not taken # not executed Branches: - 0 0

             Branch data     Line data    Source code
       1                 :             : # Copyright 2026 Garena Online Private Limited
       2                 :             : #
       3                 :             : # Licensed under the Apache License, Version 2.0 (the "License");
       4                 :             : # you may not use this file except in compliance with the License.
       5                 :             : # You may obtain a copy of the License at
       6                 :             : #
       7                 :             : #      http://www.apache.org/licenses/LICENSE-2.0
       8                 :             : #
       9                 :             : # Unless required by applicable law or agreed to in writing, software
      10                 :             : # distributed under the License is distributed on an "AS IS" BASIS,
      11                 :             : # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      12                 :             : # See the License for the specific language governing permissions and
      13                 :             : # limitations under the License.
      14                 :           1 : """Render tests for Box2D environments."""
      15                 :             : 
      16                 :           1 : import importlib.machinery
      17                 :           1 : import importlib.util
      18                 :           1 : import re
      19                 :           1 : import sys
      20                 :           1 : import types
      21                 :           1 : from pathlib import Path
      22                 :           1 : from typing import Any, cast
      23                 :             : 
      24                 :           1 : import gymnasium as gym
      25                 :           1 : import numpy as np
      26                 :           1 : from absl.testing import absltest
      27                 :             : 
      28                 :           1 : import envpool.box2d.registration  # noqa: F401
      29                 :           1 : from envpool.registration import make_gym
      30                 :             : 
      31                 :           1 : _RENDER_STEPS = 3
      32                 :           1 : _TASK_IDS = (
      33                 :             :     "CarRacing-v2",
      34                 :             :     "CarRacing-v3",
      35                 :             :     "BipedalWalker-v3",
      36                 :             :     "BipedalWalkerHardcore-v3",
      37                 :             :     "LunarLander-v2",
      38                 :             :     "LunarLander-v3",
      39                 :             :     "LunarLanderContinuous-v2",
      40                 :             :     "LunarLanderContinuous-v3",
      41                 :             : )
      42                 :           1 : _OFFICIAL_RENDER_TASK_IDS = (
      43                 :             :     "CarRacing-v3",
      44                 :             :     "BipedalWalker-v3",
      45                 :             :     "BipedalWalkerHardcore-v3",
      46                 :             :     "LunarLander-v3",
      47                 :             :     "LunarLanderContinuous-v3",
      48                 :             : )
      49                 :           1 : _BOX2D_SWIGCONSTANT_RE = re.compile(r"_Box2D\.(\w+_swigconstant)\(")
      50                 :             : 
      51                 :             : 
      52                 :           1 : def _patch_box2d_swigconstant_shims(module: Any, pathname: str) -> None:
      53                 :           0 :     wrapper_path = Path(pathname).with_name("Box2D.py")
      54                 :           0 :     try:
      55                 :           0 :         names = set(_BOX2D_SWIGCONSTANT_RE.findall(wrapper_path.read_text()))
      56                 :           0 :     except OSError:
      57                 :           0 :         return
      58       [ # ][ # ]:           0 :     for attr in names:
      59            [ # ]:           0 :         if not hasattr(module, attr):
      60            [ # ]:           0 :             setattr(module, attr, lambda _target, _attr=attr: None)
      61                 :             : 
      62                 :             : 
      63                 :           1 : def _install_imp_compat() -> None:
      64                 :           1 :     try:
      65                 :           1 :         import imp  # noqa: F401
      66                 :             : 
      67                 :           0 :         return
      68                 :           1 :     except ModuleNotFoundError:
      69                 :           1 :         pass
      70                 :             : 
      71                 :           1 :     compat_imp: Any = types.ModuleType("imp")
      72                 :           1 :     compat_imp.C_EXTENSION = 3
      73                 :             : 
      74                 :           1 :     def find_module(
      75                 :             :         name: str, path: Any = None
      76                 :             :     ) -> tuple[Any, str, tuple[str, str, int]]:
      77                 :           0 :         spec = importlib.machinery.PathFinder.find_spec(name, path)
      78                 :           0 :         if spec is None or spec.origin is None:
      79            [ # ]:           0 :             raise ImportError(name)
      80            [ # ]:           0 :         return (
      81                 :             :             open(spec.origin, "rb"),
      82                 :             :             spec.origin,
      83                 :             :             ("", "rb", compat_imp.C_EXTENSION),
      84                 :             :         )
      85                 :             : 
      86                 :           1 :     def load_module(
      87                 :             :         name: str, file: Any, pathname: str, description: Any
      88                 :             :     ) -> Any:
      89                 :           0 :         del file, description
      90                 :           0 :         module = sys.modules.get(name)
      91                 :           0 :         if module is not None:
      92            [ # ]:           0 :             return module
      93            [ # ]:           0 :         spec = importlib.util.spec_from_file_location(name, pathname)
      94                 :           0 :         if spec is None or spec.loader is None:
      95            [ # ]:           0 :             raise ImportError(pathname)
      96            [ # ]:           0 :         module = importlib.util.module_from_spec(spec)
      97                 :           0 :         sys.modules[name] = module
      98                 :           0 :         spec.loader.exec_module(module)
      99                 :           0 :         if name == "_Box2D":
     100            [ # ]:           0 :             _patch_box2d_swigconstant_shims(module, pathname)
     101            [ # ]:           0 :         return module
     102                 :             : 
     103                 :           1 :     compat_imp.find_module = find_module
     104                 :           1 :     compat_imp.load_module = load_module
     105                 :           1 :     sys.modules["imp"] = compat_imp
     106                 :             : 
     107                 :             : 
     108                 :           1 : _install_imp_compat()
     109                 :             : 
     110                 :             : 
     111                 :           1 : def _render_array(env: Any, env_ids: Any = None) -> np.ndarray:
     112                 :           1 :     frame = env.render(env_ids=env_ids)
     113                 :           1 :     assert frame is not None
     114                 :           1 :     return cast(np.ndarray, frame)
     115                 :             : 
     116                 :             : 
     117                 :           1 : def _make_oracle_env(task_id: str) -> gym.Env[Any, Any]:
     118                 :           1 :     return gym.make(task_id, render_mode="rgb_array")
     119                 :             : 
     120                 :             : 
     121                 :           1 : def _zero_action(space: Any, num_envs: int) -> np.ndarray:
     122                 :           1 :     sample = np.asarray(space.sample())
     123                 :           1 :     zero = np.zeros_like(sample)
     124                 :           1 :     if sample.ndim == 0:
     125            [ + ]:           1 :         return np.full((num_envs,), zero.item(), dtype=sample.dtype)
     126            [ + ]:           1 :     return np.repeat(zero[np.newaxis, ...], num_envs, axis=0)
     127                 :             : 
     128                 :             : 
     129                 :           1 : class Box2DRenderTest(absltest.TestCase):
     130                 :           1 :     """Render regression tests for Box2D environments."""
     131                 :             : 
     132                 :           1 :     def _assert_batch_consistent_render(self, task_id: str) -> None:
     133                 :           1 :         env = make_gym(
     134                 :             :             task_id,
     135                 :             :             num_envs=2,
     136                 :             :             render_mode="rgb_array",
     137                 :             :             render_width=64,
     138                 :             :             render_height=48,
     139                 :             :         )
     140                 :           1 :         try:
     141                 :           1 :             env.reset()
     142            [ + ]:           1 :             for step_idx in range(_RENDER_STEPS):
     143            [ + ]:           1 :                 frame0 = _render_array(env)
     144                 :           1 :                 frame1 = _render_array(env, env_ids=1)
     145                 :           1 :                 frames = _render_array(env, env_ids=[0, 1])
     146                 :           1 :                 frame0_again = _render_array(env)
     147                 :           1 :                 self.assertEqual(frame0.shape, (1, 48, 64, 3))
     148                 :           1 :                 self.assertEqual(frame1.shape, (1, 48, 64, 3))
     149                 :           1 :                 self.assertEqual(frames.shape, (2, 48, 64, 3))
     150                 :           1 :                 self.assertEqual(frame0.dtype, np.uint8)
     151                 :           1 :                 self.assertEqual(frames.dtype, np.uint8)
     152                 :           1 :                 np.testing.assert_array_equal(frame0[0], frames[0])
     153                 :           1 :                 np.testing.assert_array_equal(frame1[0], frames[1])
     154                 :           1 :                 np.testing.assert_array_equal(frame0, frame0_again)
     155                 :           1 :                 if step_idx + 1 < _RENDER_STEPS:
     156            [ + ]:           1 :                     env.step(_zero_action(env.action_space, 2))
     157                 :             :         finally:
     158            [ + ]:           1 :             env.close()
     159                 :             : 
     160                 :           1 :     def test_render_succeeds_for_multiple_steps_for_all_tasks(self) -> None:
     161                 :             :         """Every Box2D task should render repeatedly across several steps."""
     162            [ + ]:           1 :         for task_id in _TASK_IDS:
     163            [ + ]:           1 :             with self.subTest(task_id=task_id):
     164            [ + ]:           1 :                 env = make_gym(
     165                 :             :                     task_id,
     166                 :             :                     num_envs=1,
     167                 :             :                     seed=0,
     168                 :             :                     render_mode="rgb_array",
     169                 :             :                 )
     170                 :           1 :                 try:
     171                 :           1 :                     env.reset()
     172            [ + ]:           1 :                     for step_idx in range(_RENDER_STEPS):
     173            [ + ]:           1 :                         frame = _render_array(env)[0]
     174                 :           1 :                         frame_again = _render_array(env)[0]
     175                 :           1 :                         self.assertEqual(frame.dtype, np.uint8)
     176                 :           1 :                         self.assertEqual(frame.ndim, 3)
     177                 :           1 :                         self.assertEqual(frame.shape[-1], 3)
     178                 :           1 :                         np.testing.assert_array_equal(frame, frame_again)
     179                 :           1 :                         self.assertGreater(
     180                 :             :                             int(frame.max()) - int(frame.min()), 0
     181                 :             :                         )
     182                 :           1 :                         if step_idx + 1 < _RENDER_STEPS:
     183            [ + ]:           1 :                             env.step(_zero_action(env.action_space, 1))
     184                 :             :                 finally:
     185            [ + ]:           1 :                     env.close()
     186                 :             : 
     187                 :           1 :     def test_car_racing_render(self) -> None:
     188                 :             :         """CarRacing should support consistent batched rendering."""
     189                 :           1 :         self._assert_batch_consistent_render("CarRacing-v3")
     190                 :             : 
     191                 :           1 :     def test_bipedal_walker_render(self) -> None:
     192                 :             :         """BipedalWalker should support consistent batched rendering."""
     193                 :           1 :         self._assert_batch_consistent_render("BipedalWalker-v3")
     194                 :             : 
     195                 :           1 :     def test_bipedal_walker_render_matches_official_ground_profile(
     196                 :             :         self,
     197                 :             :     ) -> None:
     198                 :             :         """BipedalWalker renders should preserve the official ground profile."""
     199            [ + ]:           1 :         for task_id in ("BipedalWalker-v3", "BipedalWalkerHardcore-v3"):
     200            [ + ]:           1 :             with self.subTest(task_id=task_id):
     201            [ + ]:           1 :                 env = make_gym(
     202                 :             :                     task_id,
     203                 :             :                     num_envs=1,
     204                 :             :                     seed=0,
     205                 :             :                     render_mode="rgb_array",
     206                 :             :                     render_width=600,
     207                 :             :                     render_height=400,
     208                 :             :                 )
     209                 :           1 :                 oracle = _make_oracle_env(task_id)
     210                 :           1 :                 try:
     211                 :           1 :                     env.reset()
     212                 :           1 :                     oracle.reset(seed=0)
     213                 :           1 :                     frame = _render_array(env)[0].astype(np.int16)
     214                 :           1 :                     expected = np.asarray(oracle.render(), dtype=np.int16)
     215                 :           1 :                     lower_band = slice(-96, None)
     216                 :           1 :                     diff = np.abs(
     217                 :             :                         frame[lower_band] - expected[lower_band]
     218                 :             :                     ).mean()
     219                 :           1 :                     agent_crop = np.abs(
     220                 :             :                         frame[100:320, :220] - expected[100:320, :220]
     221                 :             :                     ).mean()
     222                 :           1 :                     self.assertLess(diff, 40.0)
     223                 :           1 :                     self.assertLess(agent_crop, 15.0)
     224                 :             :                 finally:
     225                 :           1 :                     env.close()
     226                 :           1 :                     oracle.close()
     227                 :             : 
     228                 :           1 :     def test_render_matches_official_first_frame(self) -> None:
     229                 :             :         """All Box2D reset renders should stay close to Gymnasium."""
     230                 :           1 :         thresholds = {
     231                 :             :             "BipedalWalker-v3": 40.0,
     232                 :             :             "BipedalWalkerHardcore-v3": 40.0,
     233                 :             :             "CarRacing-v3": 15.0,
     234                 :             :             "LunarLander-v3": 30.0,
     235                 :             :             "LunarLanderContinuous-v3": 30.0,
     236                 :             :         }
     237                 :           1 :         self.assertEqual(
     238                 :             :             tuple(sorted(thresholds)),
     239                 :             :             tuple(sorted(_OFFICIAL_RENDER_TASK_IDS)),
     240                 :             :         )
     241            [ + ]:           1 :         for task_id in _OFFICIAL_RENDER_TASK_IDS:
     242            [ + ]:           1 :             with self.subTest(task_id=task_id):
     243            [ + ]:           1 :                 env = make_gym(
     244                 :             :                     task_id,
     245                 :             :                     num_envs=1,
     246                 :             :                     seed=0,
     247                 :             :                     render_mode="rgb_array",
     248                 :             :                 )
     249                 :           1 :                 oracle = _make_oracle_env(task_id)
     250                 :           1 :                 try:
     251                 :           1 :                     env.reset()
     252                 :           1 :                     oracle.reset(seed=0)
     253                 :           1 :                     frame = _render_array(env)[0].astype(np.int16)
     254                 :           1 :                     expected = np.asarray(oracle.render(), dtype=np.int16)
     255                 :           1 :                     self.assertEqual(frame.shape, expected.shape)
     256                 :           1 :                     self.assertLess(
     257                 :             :                         np.abs(frame - expected).mean(),
     258                 :             :                         thresholds[task_id],
     259                 :             :                     )
     260                 :             :                 finally:
     261                 :           1 :                     env.close()
     262                 :           1 :                     oracle.close()
     263                 :             : 
     264                 :           1 :     def test_lunar_lander_render(self) -> None:
     265                 :             :         """LunarLander should support consistent batched rendering."""
     266                 :           1 :         self._assert_batch_consistent_render("LunarLander-v3")
     267                 :             : 
     268                 :             : 
     269                 :           1 : if __name__ == "__main__":
     270            [ + ]:           1 :     absltest.main()
        

Generated by: LCOV version 2.0-1