Branch data Line data Source code
1 : : # Copyright 2026 Garena Online Private Limited
2 : : #
3 : : # Licensed under the Apache License, Version 2.0 (the "License");
4 : : # you may not use this file except in compliance with the License.
5 : : # You may obtain a copy of the License at
6 : : #
7 : : # http://www.apache.org/licenses/LICENSE-2.0
8 : : #
9 : : # Unless required by applicable law or agreed to in writing, software
10 : : # distributed under the License is distributed on an "AS IS" BASIS,
11 : : # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : : # See the License for the specific language governing permissions and
13 : : # limitations under the License.
14 : 1 : """Render tests for Box2D environments."""
15 : :
16 : 1 : import importlib.machinery
17 : 1 : import importlib.util
18 : 1 : import re
19 : 1 : import sys
20 : 1 : import types
21 : 1 : from pathlib import Path
22 : 1 : from typing import Any, cast
23 : :
24 : 1 : import gymnasium as gym
25 : 1 : import numpy as np
26 : 1 : from absl.testing import absltest
27 : :
28 : 1 : import envpool.box2d.registration # noqa: F401
29 : 1 : from envpool.registration import make_gym
30 : :
31 : 1 : _RENDER_STEPS = 3
32 : 1 : _TASK_IDS = (
33 : : "CarRacing-v2",
34 : : "CarRacing-v3",
35 : : "BipedalWalker-v3",
36 : : "BipedalWalkerHardcore-v3",
37 : : "LunarLander-v2",
38 : : "LunarLander-v3",
39 : : "LunarLanderContinuous-v2",
40 : : "LunarLanderContinuous-v3",
41 : : )
42 : 1 : _OFFICIAL_RENDER_TASK_IDS = (
43 : : "CarRacing-v3",
44 : : "BipedalWalker-v3",
45 : : "BipedalWalkerHardcore-v3",
46 : : "LunarLander-v3",
47 : : "LunarLanderContinuous-v3",
48 : : )
49 : 1 : _BOX2D_SWIGCONSTANT_RE = re.compile(r"_Box2D\.(\w+_swigconstant)\(")
50 : :
51 : :
52 : 1 : def _patch_box2d_swigconstant_shims(module: Any, pathname: str) -> None:
53 : 0 : wrapper_path = Path(pathname).with_name("Box2D.py")
54 : 0 : try:
55 : 0 : names = set(_BOX2D_SWIGCONSTANT_RE.findall(wrapper_path.read_text()))
56 : 0 : except OSError:
57 : 0 : return
58 [ # ][ # ]: 0 : for attr in names:
59 [ # ]: 0 : if not hasattr(module, attr):
60 [ # ]: 0 : setattr(module, attr, lambda _target, _attr=attr: None)
61 : :
62 : :
63 : 1 : def _install_imp_compat() -> None:
64 : 1 : try:
65 : 1 : import imp # noqa: F401
66 : :
67 : 0 : return
68 : 1 : except ModuleNotFoundError:
69 : 1 : pass
70 : :
71 : 1 : compat_imp: Any = types.ModuleType("imp")
72 : 1 : compat_imp.C_EXTENSION = 3
73 : :
74 : 1 : def find_module(
75 : : name: str, path: Any = None
76 : : ) -> tuple[Any, str, tuple[str, str, int]]:
77 : 0 : spec = importlib.machinery.PathFinder.find_spec(name, path)
78 : 0 : if spec is None or spec.origin is None:
79 [ # ]: 0 : raise ImportError(name)
80 [ # ]: 0 : return (
81 : : open(spec.origin, "rb"),
82 : : spec.origin,
83 : : ("", "rb", compat_imp.C_EXTENSION),
84 : : )
85 : :
86 : 1 : def load_module(
87 : : name: str, file: Any, pathname: str, description: Any
88 : : ) -> Any:
89 : 0 : del file, description
90 : 0 : module = sys.modules.get(name)
91 : 0 : if module is not None:
92 [ # ]: 0 : return module
93 [ # ]: 0 : spec = importlib.util.spec_from_file_location(name, pathname)
94 : 0 : if spec is None or spec.loader is None:
95 [ # ]: 0 : raise ImportError(pathname)
96 [ # ]: 0 : module = importlib.util.module_from_spec(spec)
97 : 0 : sys.modules[name] = module
98 : 0 : spec.loader.exec_module(module)
99 : 0 : if name == "_Box2D":
100 [ # ]: 0 : _patch_box2d_swigconstant_shims(module, pathname)
101 [ # ]: 0 : return module
102 : :
103 : 1 : compat_imp.find_module = find_module
104 : 1 : compat_imp.load_module = load_module
105 : 1 : sys.modules["imp"] = compat_imp
106 : :
107 : :
108 : 1 : _install_imp_compat()
109 : :
110 : :
111 : 1 : def _render_array(env: Any, env_ids: Any = None) -> np.ndarray:
112 : 1 : frame = env.render(env_ids=env_ids)
113 : 1 : assert frame is not None
114 : 1 : return cast(np.ndarray, frame)
115 : :
116 : :
117 : 1 : def _make_oracle_env(task_id: str) -> gym.Env[Any, Any]:
118 : 1 : return gym.make(task_id, render_mode="rgb_array")
119 : :
120 : :
121 : 1 : def _zero_action(space: Any, num_envs: int) -> np.ndarray:
122 : 1 : sample = np.asarray(space.sample())
123 : 1 : zero = np.zeros_like(sample)
124 : 1 : if sample.ndim == 0:
125 [ + ]: 1 : return np.full((num_envs,), zero.item(), dtype=sample.dtype)
126 [ + ]: 1 : return np.repeat(zero[np.newaxis, ...], num_envs, axis=0)
127 : :
128 : :
129 : 1 : class Box2DRenderTest(absltest.TestCase):
130 : 1 : """Render regression tests for Box2D environments."""
131 : :
132 : 1 : def _assert_batch_consistent_render(self, task_id: str) -> None:
133 : 1 : env = make_gym(
134 : : task_id,
135 : : num_envs=2,
136 : : render_mode="rgb_array",
137 : : render_width=64,
138 : : render_height=48,
139 : : )
140 : 1 : try:
141 : 1 : env.reset()
142 [ + ]: 1 : for step_idx in range(_RENDER_STEPS):
143 [ + ]: 1 : frame0 = _render_array(env)
144 : 1 : frame1 = _render_array(env, env_ids=1)
145 : 1 : frames = _render_array(env, env_ids=[0, 1])
146 : 1 : frame0_again = _render_array(env)
147 : 1 : self.assertEqual(frame0.shape, (1, 48, 64, 3))
148 : 1 : self.assertEqual(frame1.shape, (1, 48, 64, 3))
149 : 1 : self.assertEqual(frames.shape, (2, 48, 64, 3))
150 : 1 : self.assertEqual(frame0.dtype, np.uint8)
151 : 1 : self.assertEqual(frames.dtype, np.uint8)
152 : 1 : np.testing.assert_array_equal(frame0[0], frames[0])
153 : 1 : np.testing.assert_array_equal(frame1[0], frames[1])
154 : 1 : np.testing.assert_array_equal(frame0, frame0_again)
155 : 1 : if step_idx + 1 < _RENDER_STEPS:
156 [ + ]: 1 : env.step(_zero_action(env.action_space, 2))
157 : : finally:
158 [ + ]: 1 : env.close()
159 : :
160 : 1 : def test_render_succeeds_for_multiple_steps_for_all_tasks(self) -> None:
161 : : """Every Box2D task should render repeatedly across several steps."""
162 [ + ]: 1 : for task_id in _TASK_IDS:
163 [ + ]: 1 : with self.subTest(task_id=task_id):
164 [ + ]: 1 : env = make_gym(
165 : : task_id,
166 : : num_envs=1,
167 : : seed=0,
168 : : render_mode="rgb_array",
169 : : )
170 : 1 : try:
171 : 1 : env.reset()
172 [ + ]: 1 : for step_idx in range(_RENDER_STEPS):
173 [ + ]: 1 : frame = _render_array(env)[0]
174 : 1 : frame_again = _render_array(env)[0]
175 : 1 : self.assertEqual(frame.dtype, np.uint8)
176 : 1 : self.assertEqual(frame.ndim, 3)
177 : 1 : self.assertEqual(frame.shape[-1], 3)
178 : 1 : np.testing.assert_array_equal(frame, frame_again)
179 : 1 : self.assertGreater(
180 : : int(frame.max()) - int(frame.min()), 0
181 : : )
182 : 1 : if step_idx + 1 < _RENDER_STEPS:
183 [ + ]: 1 : env.step(_zero_action(env.action_space, 1))
184 : : finally:
185 [ + ]: 1 : env.close()
186 : :
187 : 1 : def test_car_racing_render(self) -> None:
188 : : """CarRacing should support consistent batched rendering."""
189 : 1 : self._assert_batch_consistent_render("CarRacing-v3")
190 : :
191 : 1 : def test_bipedal_walker_render(self) -> None:
192 : : """BipedalWalker should support consistent batched rendering."""
193 : 1 : self._assert_batch_consistent_render("BipedalWalker-v3")
194 : :
195 : 1 : def test_bipedal_walker_render_matches_official_ground_profile(
196 : : self,
197 : : ) -> None:
198 : : """BipedalWalker renders should preserve the official ground profile."""
199 [ + ]: 1 : for task_id in ("BipedalWalker-v3", "BipedalWalkerHardcore-v3"):
200 [ + ]: 1 : with self.subTest(task_id=task_id):
201 [ + ]: 1 : env = make_gym(
202 : : task_id,
203 : : num_envs=1,
204 : : seed=0,
205 : : render_mode="rgb_array",
206 : : render_width=600,
207 : : render_height=400,
208 : : )
209 : 1 : oracle = _make_oracle_env(task_id)
210 : 1 : try:
211 : 1 : env.reset()
212 : 1 : oracle.reset(seed=0)
213 : 1 : frame = _render_array(env)[0].astype(np.int16)
214 : 1 : expected = np.asarray(oracle.render(), dtype=np.int16)
215 : 1 : lower_band = slice(-96, None)
216 : 1 : diff = np.abs(
217 : : frame[lower_band] - expected[lower_band]
218 : : ).mean()
219 : 1 : agent_crop = np.abs(
220 : : frame[100:320, :220] - expected[100:320, :220]
221 : : ).mean()
222 : 1 : self.assertLess(diff, 40.0)
223 : 1 : self.assertLess(agent_crop, 15.0)
224 : : finally:
225 : 1 : env.close()
226 : 1 : oracle.close()
227 : :
228 : 1 : def test_render_matches_official_first_frame(self) -> None:
229 : : """All Box2D reset renders should stay close to Gymnasium."""
230 : 1 : thresholds = {
231 : : "BipedalWalker-v3": 40.0,
232 : : "BipedalWalkerHardcore-v3": 40.0,
233 : : "CarRacing-v3": 15.0,
234 : : "LunarLander-v3": 30.0,
235 : : "LunarLanderContinuous-v3": 30.0,
236 : : }
237 : 1 : self.assertEqual(
238 : : tuple(sorted(thresholds)),
239 : : tuple(sorted(_OFFICIAL_RENDER_TASK_IDS)),
240 : : )
241 [ + ]: 1 : for task_id in _OFFICIAL_RENDER_TASK_IDS:
242 [ + ]: 1 : with self.subTest(task_id=task_id):
243 [ + ]: 1 : env = make_gym(
244 : : task_id,
245 : : num_envs=1,
246 : : seed=0,
247 : : render_mode="rgb_array",
248 : : )
249 : 1 : oracle = _make_oracle_env(task_id)
250 : 1 : try:
251 : 1 : env.reset()
252 : 1 : oracle.reset(seed=0)
253 : 1 : frame = _render_array(env)[0].astype(np.int16)
254 : 1 : expected = np.asarray(oracle.render(), dtype=np.int16)
255 : 1 : self.assertEqual(frame.shape, expected.shape)
256 : 1 : self.assertLess(
257 : : np.abs(frame - expected).mean(),
258 : : thresholds[task_id],
259 : : )
260 : : finally:
261 : 1 : env.close()
262 : 1 : oracle.close()
263 : :
264 : 1 : def test_lunar_lander_render(self) -> None:
265 : : """LunarLander should support consistent batched rendering."""
266 : 1 : self._assert_batch_consistent_render("LunarLander-v3")
267 : :
268 : :
269 : 1 : if __name__ == "__main__":
270 [ + ]: 1 : absltest.main()
|