Branch data Line data Source code
1 : : # Copyright 2026 Garena Online Private Limited
2 : : #
3 : : # Licensed under the Apache License, Version 2.0 (the "License");
4 : : # you may not use this file except in compliance with the License.
5 : : # You may obtain a copy of the License at
6 : : #
7 : : # http://www.apache.org/licenses/LICENSE-2.0
8 : : #
9 : : # Unless required by applicable law or agreed to in writing, software
10 : : # distributed under the License is distributed on an "AS IS" BASIS,
11 : : # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : : # See the License for the specific language governing permissions and
13 : : # limitations under the License.
14 : 7 : """MiniGrid env in EnvPool."""
15 : :
16 : 7 : from __future__ import annotations
17 : :
18 : 7 : from typing import Any, cast
19 : :
20 : 7 : import numpy as np
21 : :
22 : 7 : from envpool.python.api import py_env
23 : :
24 : 7 : from .minigrid_envpool import (
25 : : _MiniGridDebugState,
26 : : _MiniGridEnvPool,
27 : : _MiniGridEnvSpec,
28 : : )
29 : :
30 : :
31 : 7 : def _decode_mission_row(row: np.ndarray) -> str:
32 : 2 : mission = np.asarray(row, dtype=np.uint8).reshape(-1)
33 : 2 : zero = np.flatnonzero(mission == 0)
34 : 2 : end = int(zero[0]) if zero.size else int(mission.shape[0])
35 : 2 : return mission[:end].tobytes().decode("utf-8")
36 : :
37 : :
38 : 7 : def decode_mission(mission: np.ndarray) -> str | np.ndarray:
39 : : """Decode the fixed-size mission byte buffer returned by the C++ backend."""
40 : 2 : arr = np.asarray(mission, dtype=np.uint8)
41 : 2 : if arr.ndim == 1:
42 [ # ]: 0 : return _decode_mission_row(arr)
43 [ + ]: 2 : return np.asarray([_decode_mission_row(row) for row in arr], dtype=object)
44 : :
45 : :
46 : 7 : def _normalize_env_ids(
47 : : env_ids: np.ndarray | list[int] | None, num_envs: int
48 : : ) -> np.ndarray:
49 : 0 : if env_ids is None:
50 [ # ]: 0 : return np.arange(num_envs, dtype=np.int32)
51 [ # ]: 0 : return np.asarray(env_ids, dtype=np.int32)
52 : :
53 : :
54 : 7 : (
55 : : MiniGridEnvSpec,
56 : : MiniGridDMEnvPool,
57 : : MiniGridGymEnvPool,
58 : : MiniGridGymnasiumEnvPool,
59 : : ) = py_env(_MiniGridEnvSpec, _MiniGridEnvPool)
60 : :
61 : 7 : cast(Any, MiniGridEnvSpec).decode_mission = staticmethod(decode_mission)
62 : :
63 : :
64 : 7 : def _debug_states(
65 : : self: Any, env_ids: np.ndarray | list[int] | None = None
66 : : ) -> list[Any]:
67 : 0 : env_ids = _normalize_env_ids(env_ids, self.config["num_envs"])
68 : 0 : return self._debug_states(env_ids)
69 : :
70 : :
71 [ + ]: 7 : for _env_cls in (
72 : : MiniGridDMEnvPool,
73 : : MiniGridGymEnvPool,
74 : : MiniGridGymnasiumEnvPool,
75 : : ):
76 [ + ]: 7 : cast(Any, _env_cls).decode_mission = staticmethod(decode_mission)
77 : 7 : cast(Any, _env_cls).debug_states = _debug_states
78 : :
79 : :
80 [ + ]: 7 : __all__ = [
81 : : "MiniGridEnvSpec",
82 : : "MiniGridDMEnvPool",
83 : : "MiniGridGymEnvPool",
84 : : "MiniGridGymnasiumEnvPool",
85 : : "_MiniGridDebugState",
86 : : "decode_mission",
87 : : ]
|