Branch data Line data Source code
1 : : # Copyright 2026 Garena Online Private Limited
2 : : #
3 : : # Licensed under the Apache License, Version 2.0 (the "License");
4 : : # you may not use this file except in compliance with the License.
5 : : # You may obtain a copy of the License at
6 : : #
7 : : # http://www.apache.org/licenses/LICENSE-2.0
8 : : #
9 : : # Unless required by applicable law or agreed to in writing, software
10 : : # distributed under the License is distributed on an "AS IS" BASIS,
11 : : # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : : # See the License for the specific language governing permissions and
13 : : # limitations under the License.
14 : 1 : """Render tests for Box2D environments."""
15 : :
16 : 1 : import importlib.machinery
17 : 1 : import importlib.util
18 : 1 : import re
19 : 1 : import sys
20 : 1 : import types
21 : 1 : from pathlib import Path
22 : 1 : from typing import Any, cast
23 : :
24 : 1 : import gymnasium as gym
25 : 1 : import numpy as np
26 : 1 : from absl.testing import absltest
27 : :
28 : 1 : import envpool.box2d.registration # noqa: F401
29 : 1 : from envpool.registration import make_gym
30 : :
31 : 1 : _RENDER_STEPS = 3
32 : 1 : _TASK_IDS = (
33 : : "CarRacing-v2",
34 : : "CarRacing-v3",
35 : : "BipedalWalker-v3",
36 : : "BipedalWalkerHardcore-v3",
37 : : "LunarLander-v2",
38 : : "LunarLander-v3",
39 : : "LunarLanderContinuous-v2",
40 : : "LunarLanderContinuous-v3",
41 : : )
42 : 1 : _BOX2D_SWIGCONSTANT_RE = re.compile(r"_Box2D\.(\w+_swigconstant)\(")
43 : :
44 : :
45 : 1 : def _patch_box2d_swigconstant_shims(module: Any, pathname: str) -> None:
46 : 0 : wrapper_path = Path(pathname).with_name("Box2D.py")
47 : 0 : try:
48 : 0 : names = set(_BOX2D_SWIGCONSTANT_RE.findall(wrapper_path.read_text()))
49 : 0 : except OSError:
50 : 0 : return
51 [ # ][ # ]: 0 : for attr in names:
52 [ # ]: 0 : if not hasattr(module, attr):
53 [ # ]: 0 : setattr(module, attr, lambda _target, _attr=attr: None)
54 : :
55 : :
56 : 1 : def _install_imp_compat() -> None:
57 : 1 : try:
58 : 1 : import imp # noqa: F401
59 : :
60 : 0 : return
61 : 1 : except ModuleNotFoundError:
62 : 1 : pass
63 : :
64 : 1 : compat_imp: Any = types.ModuleType("imp")
65 : 1 : compat_imp.C_EXTENSION = 3
66 : :
67 : 1 : def find_module(
68 : : name: str, path: Any = None
69 : : ) -> tuple[Any, str, tuple[str, str, int]]:
70 : 0 : spec = importlib.machinery.PathFinder.find_spec(name, path)
71 : 0 : if spec is None or spec.origin is None:
72 [ # ]: 0 : raise ImportError(name)
73 [ # ]: 0 : return (
74 : : open(spec.origin, "rb"),
75 : : spec.origin,
76 : : ("", "rb", compat_imp.C_EXTENSION),
77 : : )
78 : :
79 : 1 : def load_module(
80 : : name: str, file: Any, pathname: str, description: Any
81 : : ) -> Any:
82 : 0 : del file, description
83 : 0 : module = sys.modules.get(name)
84 : 0 : if module is not None:
85 [ # ]: 0 : return module
86 [ # ]: 0 : spec = importlib.util.spec_from_file_location(name, pathname)
87 : 0 : if spec is None or spec.loader is None:
88 [ # ]: 0 : raise ImportError(pathname)
89 [ # ]: 0 : module = importlib.util.module_from_spec(spec)
90 : 0 : sys.modules[name] = module
91 : 0 : spec.loader.exec_module(module)
92 : 0 : if name == "_Box2D":
93 [ # ]: 0 : _patch_box2d_swigconstant_shims(module, pathname)
94 [ # ]: 0 : return module
95 : :
96 : 1 : compat_imp.find_module = find_module
97 : 1 : compat_imp.load_module = load_module
98 : 1 : sys.modules["imp"] = compat_imp
99 : :
100 : :
101 : 1 : _install_imp_compat()
102 : :
103 : :
104 : 1 : def _render_array(env: Any, env_ids: Any = None) -> np.ndarray:
105 : 1 : frame = env.render(env_ids=env_ids)
106 : 1 : assert frame is not None
107 : 1 : return cast(np.ndarray, frame)
108 : :
109 : :
110 : 1 : def _make_oracle_env(task_id: str) -> gym.Env[Any, Any]:
111 : 1 : return gym.make(task_id, render_mode="rgb_array")
112 : :
113 : :
114 : 1 : def _zero_action(space: Any, num_envs: int) -> np.ndarray:
115 : 1 : sample = np.asarray(space.sample())
116 : 1 : zero = np.zeros_like(sample)
117 : 1 : if sample.ndim == 0:
118 [ + ]: 1 : return np.full((num_envs,), zero.item(), dtype=sample.dtype)
119 [ + ]: 1 : return np.repeat(zero[np.newaxis, ...], num_envs, axis=0)
120 : :
121 : :
122 : 1 : class Box2DRenderTest(absltest.TestCase):
123 : 1 : """Render regression tests for Box2D environments."""
124 : :
125 : 1 : def _assert_batch_consistent_render(self, task_id: str) -> None:
126 : 1 : env = make_gym(
127 : : task_id,
128 : : num_envs=2,
129 : : render_mode="rgb_array",
130 : : render_width=64,
131 : : render_height=48,
132 : : )
133 : 1 : try:
134 : 1 : env.reset()
135 [ + ]: 1 : for step_idx in range(_RENDER_STEPS):
136 [ + ]: 1 : frame0 = _render_array(env)
137 : 1 : frame1 = _render_array(env, env_ids=1)
138 : 1 : frames = _render_array(env, env_ids=[0, 1])
139 : 1 : frame0_again = _render_array(env)
140 : 1 : self.assertEqual(frame0.shape, (1, 48, 64, 3))
141 : 1 : self.assertEqual(frame1.shape, (1, 48, 64, 3))
142 : 1 : self.assertEqual(frames.shape, (2, 48, 64, 3))
143 : 1 : self.assertEqual(frame0.dtype, np.uint8)
144 : 1 : self.assertEqual(frames.dtype, np.uint8)
145 : 1 : np.testing.assert_array_equal(frame0[0], frames[0])
146 : 1 : np.testing.assert_array_equal(frame1[0], frames[1])
147 : 1 : np.testing.assert_array_equal(frame0, frame0_again)
148 : 1 : if step_idx + 1 < _RENDER_STEPS:
149 [ + ]: 1 : env.step(_zero_action(env.action_space, 2))
150 : : finally:
151 [ + ]: 1 : env.close()
152 : :
153 : 1 : def test_render_succeeds_for_multiple_steps_for_all_tasks(self) -> None:
154 : : """Every Box2D task should render repeatedly across several steps."""
155 [ + ]: 1 : for task_id in _TASK_IDS:
156 [ + ]: 1 : with self.subTest(task_id=task_id):
157 [ + ]: 1 : env = make_gym(
158 : : task_id,
159 : : num_envs=1,
160 : : seed=0,
161 : : render_mode="rgb_array",
162 : : )
163 : 1 : try:
164 : 1 : env.reset()
165 [ + ]: 1 : for step_idx in range(_RENDER_STEPS):
166 [ + ]: 1 : frame = _render_array(env)[0]
167 : 1 : frame_again = _render_array(env)[0]
168 : 1 : self.assertEqual(frame.dtype, np.uint8)
169 : 1 : self.assertEqual(frame.ndim, 3)
170 : 1 : self.assertEqual(frame.shape[-1], 3)
171 : 1 : np.testing.assert_array_equal(frame, frame_again)
172 : 1 : self.assertGreater(
173 : : int(frame.max()) - int(frame.min()), 0
174 : : )
175 : 1 : if step_idx + 1 < _RENDER_STEPS:
176 [ + ]: 1 : env.step(_zero_action(env.action_space, 1))
177 : : finally:
178 [ + ]: 1 : env.close()
179 : :
180 : 1 : def test_car_racing_render(self) -> None:
181 : : """CarRacing should support consistent batched rendering."""
182 : 1 : self._assert_batch_consistent_render("CarRacing-v3")
183 : :
184 : 1 : def test_bipedal_walker_render(self) -> None:
185 : : """BipedalWalker should support consistent batched rendering."""
186 : 1 : self._assert_batch_consistent_render("BipedalWalker-v3")
187 : :
188 : 1 : def test_bipedal_walker_render_matches_official_ground_profile(
189 : : self,
190 : : ) -> None:
191 : : """BipedalWalker renders should preserve the official ground profile."""
192 [ + ]: 1 : for task_id in ("BipedalWalker-v3", "BipedalWalkerHardcore-v3"):
193 [ + ]: 1 : with self.subTest(task_id=task_id):
194 [ + ]: 1 : env = make_gym(
195 : : task_id,
196 : : num_envs=1,
197 : : seed=0,
198 : : render_mode="rgb_array",
199 : : render_width=600,
200 : : render_height=400,
201 : : )
202 : 1 : oracle = _make_oracle_env(task_id)
203 : 1 : try:
204 : 1 : env.reset()
205 : 1 : oracle.reset(seed=0)
206 : 1 : frame = _render_array(env)[0].astype(np.int16)
207 : 1 : expected = np.asarray(oracle.render(), dtype=np.int16)
208 : 1 : lower_band = slice(-96, None)
209 : 1 : diff = np.abs(
210 : : frame[lower_band] - expected[lower_band]
211 : : ).mean()
212 : 1 : agent_crop = np.abs(
213 : : frame[100:320, :220] - expected[100:320, :220]
214 : : ).mean()
215 : 1 : self.assertLess(diff, 40.0)
216 : 1 : self.assertLess(agent_crop, 15.0)
217 : : finally:
218 : 1 : env.close()
219 : 1 : oracle.close()
220 : :
221 : 1 : def test_render_matches_official_first_frame(self) -> None:
222 : : """Representative Box2D renders should stay close to Gymnasium."""
223 : 1 : thresholds = {
224 : : "CarRacing-v3": 15.0,
225 : : "LunarLander-v3": 30.0,
226 : : "LunarLanderContinuous-v3": 30.0,
227 : : }
228 [ + ]: 1 : for task_id, threshold in thresholds.items():
229 [ + ]: 1 : with self.subTest(task_id=task_id):
230 [ + ]: 1 : env = make_gym(
231 : : task_id,
232 : : num_envs=1,
233 : : seed=0,
234 : : render_mode="rgb_array",
235 : : )
236 : 1 : oracle = _make_oracle_env(task_id)
237 : 1 : try:
238 : 1 : env.reset()
239 : 1 : oracle.reset(seed=0)
240 : 1 : frame = _render_array(env)[0].astype(np.int16)
241 : 1 : expected = np.asarray(oracle.render(), dtype=np.int16)
242 : 1 : self.assertEqual(frame.shape, expected.shape)
243 : 1 : self.assertLess(np.abs(frame - expected).mean(), threshold)
244 : : finally:
245 : 1 : env.close()
246 : 1 : oracle.close()
247 : :
248 : 1 : def test_lunar_lander_render(self) -> None:
249 : : """LunarLander should support consistent batched rendering."""
250 : 1 : self._assert_batch_consistent_render("LunarLander-v3")
251 : :
252 : :
253 : 1 : if __name__ == "__main__":
254 [ + ]: 1 : absltest.main()
|