Branch data Line data Source code
1 : : // Copyright 2026 Garena Online Private Limited
2 : : //
3 : : // Licensed under the Apache License, Version 2.0 (the "License");
4 : : // you may not use this file except in compliance with the License.
5 : : // You may obtain a copy of the License at
6 : : //
7 : : // http://www.apache.org/licenses/LICENSE-2.0
8 : : //
9 : : // Unless required by applicable law or agreed to in writing, software
10 : : // distributed under the License is distributed on an "AS IS" BASIS,
11 : : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : : // See the License for the specific language governing permissions and
13 : : // limitations under the License.
14 : :
15 : : #include "envpool/minigrid/impl/babyai_core.h"
16 : :
17 : : #include <cstdlib>
18 : : #include <memory>
19 : : #include <optional>
20 : : #include <string>
21 : : #include <utility>
22 : : #include <vector>
23 : :
24 : : #include "absl/log/log.h"
25 : :
26 : : namespace minigrid {
27 : :
28 : 275 : BabyAIRejectSampling::BabyAIRejectSampling(const std::string& msg)
29 : 275 : : std::runtime_error(msg) {}
30 : :
31 : 1333 : BabyAILevelTask::BabyAILevelTask(std::string env_name, int room_size,
32 : : int num_rows, int num_cols, int max_steps,
33 : 1333 : int mission_bytes)
34 : : : RoomGridTask(std::move(env_name), room_size, num_rows, num_cols,
35 : : max_steps, 7, mission_bytes),
36 [ + - ]: 2664 : fixed_max_steps_(max_steps > 0) {} // NOLINT(whitespace/indent_namespace)
37 : :
38 : 13 : std::pair<Pos, std::pair<Type, Color>> BabyAILevelTask::AddExistingObject(
39 : : int i, int j, const WorldObj& obj) {
40 : 13 : const Room& room = GetRoom(i, j);
41 [ + - ]: 13 : Pos pos = PlaceObj(
42 [ + - ]: 13 : obj, room.top.first, room.top.second, room.size.first, room.size.second,
43 : : [&](const Pos& pos_candidate) {
44 : 13 : return std::abs(agent_pos_.first - pos_candidate.first) +
45 : 13 : std::abs(agent_pos_.second - pos_candidate.second) <
46 : 13 : 2;
47 : : },
48 : : 1000);
49 : 13 : GetRoom(i, j).objs.emplace_back(obj.GetType(), obj.GetColor());
50 : 13 : return {pos, {obj.GetType(), obj.GetColor()}};
51 : : }
52 : :
53 : : std::vector<std::pair<Pos, std::pair<Type, Color>>>
54 : 2062 : BabyAILevelTask::AddDistractorsOrReject(int i, int j, int num_distractors,
55 : : bool all_unique) {
56 : : try {
57 [ + - ]: 2062 : return AddDistractors(i, j, num_distractors, all_unique);
58 [ # # ]: 0 : } catch (const std::runtime_error&) {
59 [ # # # # ]: 0 : throw BabyAIRejectSampling("failed to add distractors");
60 : 0 : }
61 : : }
62 : :
63 : 721 : void BabyAILevelTask::ConnectAllOrReject(
64 : : const std::vector<Color>& door_colors) {
65 [ - + ]: 721 : if (!TryConnectAll(door_colors)) {
66 [ # # # # ]: 0 : throw BabyAIRejectSampling("connect_all failed");
67 : : }
68 : 721 : }
69 : :
70 : 872 : void BabyAILevelTask::CheckObjsReachableOrReject() const {
71 [ + + ]: 872 : if (!CheckObjsReachable()) {
72 [ + - + - ]: 420 : throw BabyAIRejectSampling("unreachable object");
73 : : }
74 : 664 : }
75 : :
76 : 26 : void BabyAILevelTask::OpenAllDoors() {
77 [ + + ]: 104 : for (int i = 0; i < num_cols_; ++i) {
78 [ + + ]: 312 : for (int j = 0; j < num_rows_; ++j) {
79 : 234 : const Room& room = GetRoom(i, j);
80 [ + + ]: 1170 : for (int door_idx = 0; door_idx < 4; ++door_idx) {
81 [ + + ]: 936 : if (!room.connected[door_idx]) {
82 : 476 : continue;
83 : : }
84 : 460 : const Pos pos = room.door_pos[door_idx];
85 : 460 : WorldObj& obj = Cell(pos.first, pos.second);
86 [ + - ]: 460 : if (obj.GetType() == kDoor) {
87 : : obj.SetDoorLocked(false);
88 : : obj.SetDoorOpen(true);
89 : : }
90 : : }
91 : : }
92 : : }
93 : 26 : }
94 : :
95 : 1603 : void BabyAILevelTask::GenGrid() {
96 : 1603 : std::string last_error = "unknown";
97 [ + - ]: 1876 : for (int retry = 0; retry < 1000; ++retry) {
98 [ + - ]: 1876 : RoomGridTask::GenGrid();
99 [ + + ]: 1879 : locked_room_ = nullptr;
100 : : instrs_.reset();
101 : : try {
102 [ + + ]: 1879 : GenMission();
103 [ + - - - ]: 1638 : CHECK(instrs_ != nullptr);
104 : 1603 : std::vector<Color> locked_colors;
105 [ + + ]: 4995 : for (int i = 0; i < num_cols_; ++i) {
106 [ + + ]: 10886 : for (int j = 0; j < num_rows_; ++j) {
107 [ + - ]: 7529 : const Room& room = GetRoom(i, j);
108 [ + + ]: 37632 : for (int door_idx = 0; door_idx < 4; ++door_idx) {
109 [ + + ]: 30102 : if (!room.connected[door_idx]) {
110 : 19951 : continue;
111 : : }
112 : 10151 : const Pos pos = room.door_pos[door_idx];
113 [ + - ]: 10151 : const WorldObj door = GetCell(pos.first, pos.second);
114 [ + + + + ]: 10152 : if (door.GetType() == kDoor && door.GetDoorLocked()) {
115 [ + - - - ]: 844 : locked_colors.push_back(door.GetColor());
116 : : }
117 : : }
118 : : }
119 : : }
120 [ + - + + ]: 1639 : instrs_->Validate(*this, locked_colors, UnblockingEnabled());
121 [ + - + - ]: 3235 : SetMission(instrs_->Surface(*this));
122 [ + - ]: 1598 : instrs_->ResetVerifier(*this);
123 [ + + ]: 1604 : if (!fixed_max_steps_) {
124 [ + - ]: 1166 : max_steps_ = instrs_->NumNavsNeeded() * room_size_ * room_size_ *
125 : 1164 : num_rows_ * num_cols_;
126 : : }
127 [ + - ]: 1602 : AfterResetVerifier();
128 : 1600 : return;
129 [ - + - ]: 275 : } catch (const BabyAIRejectSampling& e) {
130 : 275 : last_error = e.what();
131 : 275 : } catch (const std::runtime_error& e) {
132 : 0 : last_error = e.what();
133 : 0 : }
134 : : }
135 [ # # ]: 0 : throw std::runtime_error("BabyAI mission generation failed for " + env_name_ +
136 [ # # ]: 0 : ": " + last_error);
137 : : }
138 : :
139 : 58733 : void BabyAILevelTask::AfterStep(Act act, const WorldObj& /*pre_fwd*/,
140 : : const Pos& /*fwd_pos*/,
141 : : const WorldObj& pre_carrying, float* reward,
142 : : bool* terminated) {
143 [ + + ]: 58733 : if (act == kDrop) {
144 : 7267 : instrs_->UpdateObjPoss(*this);
145 : : }
146 : 58734 : BabyAIStatus status = instrs_->Verify(*this, act, pre_carrying);
147 [ + + ]: 58633 : if (status == BabyAIStatus::kSuccess) {
148 : 108 : *reward = SuccessReward();
149 : 108 : *terminated = true;
150 [ + + ]: 58525 : } else if (status == BabyAIStatus::kFailure) {
151 : 25 : *reward = 0.0f;
152 : 25 : *terminated = true;
153 : : }
154 : 58633 : }
155 : :
156 : 77 : void BabyAILevelTask::AddLockedRoom() {
157 : : while (true) {
158 : 110 : int i = RandInt(0, num_cols_);
159 : 110 : int j = RandInt(0, num_rows_);
160 : 110 : int door_idx = RandInt(0, 4);
161 : 110 : Room& room = GetRoom(i, j);
162 [ + + ]: 110 : if (!room.has_neighbor[door_idx]) {
163 : 33 : continue;
164 : : }
165 : 77 : locked_room_ = &room;
166 : 77 : Pos door_pos = AddDoor(i, j, door_idx, kUnassigned, true);
167 : 154 : Color door_color = GetCell(door_pos.first, door_pos.second).GetColor();
168 : : while (true) {
169 : 81 : int key_i = RandInt(0, num_cols_);
170 : 81 : int key_j = RandInt(0, num_rows_);
171 [ + + ]: 81 : if (key_i == i && key_j == j) {
172 : : continue;
173 : : }
174 : 77 : AddObject(key_i, key_j, kKey, door_color);
175 : 76 : return;
176 : 4 : }
177 : : }
178 : : }
179 : :
180 : 164 : BabyAILevelGenTask::BabyAILevelGenTask(const BabyAITaskConfig& config)
181 : 164 : : BabyAILevelTask(config.env_name, config.room_size, config.num_rows,
182 : 164 : config.num_cols, config.max_steps, config.mission_bytes),
183 : 164 : num_dists_(config.num_dists),
184 : 164 : locked_room_prob_(config.locked_room_prob),
185 : 164 : locations_(config.locations),
186 : 164 : unblocking_(config.unblocking),
187 [ + - + - ]: 328 : implicit_unlock_(config.implicit_unlock) {
188 [ + - ]: 164 : action_kinds_ = SplitKinds(config.action_kinds);
189 [ + - ]: 162 : instr_kinds_ = SplitKinds(config.instr_kinds);
190 : 164 : }
191 : :
192 : 326 : std::vector<std::string> BabyAILevelGenTask::SplitKinds(
193 : : const std::string& csv) const {
194 : : std::vector<std::string> kinds;
195 : : std::string cur;
196 [ + + ]: 5113 : for (char ch : csv) {
197 [ + + ]: 4786 : if (ch == ',') {
198 [ + - ]: 552 : if (!cur.empty()) {
199 [ + - ]: 553 : kinds.push_back(cur);
200 : : cur.clear();
201 : : }
202 : 551 : continue;
203 : : }
204 [ + - ]: 4234 : cur.push_back(ch);
205 : : }
206 [ + - ]: 327 : if (!cur.empty()) {
207 [ + - ]: 328 : kinds.push_back(cur);
208 : : }
209 : 327 : return kinds;
210 : 0 : }
211 : :
212 : 319 : BabyAIObjDesc BabyAILevelGenTask::RandObj(const std::vector<Type>& types,
213 : : const std::vector<Color>& colors,
214 : : int max_tries) {
215 [ + - ]: 863 : for (int num_tries = 0; num_tries <= max_tries; ++num_tries) {
216 : 863 : std::optional<Color> color;
217 [ + + ]: 863 : if (RandInt(0, static_cast<int>(colors.size()) + 1) > 0) {
218 : 734 : color = RandElem(colors);
219 : : }
220 : 863 : std::optional<Type> type = RandElem(types);
221 : : BabyAILoc loc = BabyAILoc::kNone;
222 [ + + + + ]: 863 : if (locations_ && RandBool()) {
223 [ + - ]: 334 : loc = RandElem(std::vector<BabyAILoc>{BabyAILoc::kLeft, BabyAILoc::kRight,
224 : : BabyAILoc::kFront,
225 : : BabyAILoc::kBehind});
226 : : }
227 : 863 : BabyAIObjDesc desc(type, color, loc);
228 [ + - ]: 863 : desc.FindMatchingObjs(*this);
229 [ + + ]: 863 : if (desc.ObjUids().empty()) {
230 : 544 : continue;
231 : : }
232 [ + + + + ]: 319 : if (!implicit_unlock_ && locked_room_ != nullptr) {
233 : : bool has_unlocked_match = false;
234 [ + - ]: 94 : for (const Pos& pos : desc.ObjPoss()) {
235 [ + + ]: 94 : if (!locked_room_->PosInside(pos.first, pos.second)) {
236 : : has_unlocked_match = true;
237 : : break;
238 : : }
239 : : }
240 [ - + ]: 82 : if (!has_unlocked_match) {
241 : 0 : continue;
242 : : }
243 : : }
244 : 319 : return desc;
245 : 863 : }
246 [ # # # # ]: 0 : throw BabyAIRejectSampling("failed to find suitable object");
247 : : }
248 : :
249 : 280 : std::unique_ptr<BabyAIInstr> BabyAILevelGenTask::RandActionInstr(
250 : : const std::vector<std::string>& action_kinds) {
251 : 280 : const std::string& action = RandElem(action_kinds);
252 [ + + ]: 280 : if (action == "goto") {
253 [ + - + - : 111 : return std::make_unique<BabyAIGoToInstr>(RandObj());
+ - ]
254 : : }
255 [ + + ]: 169 : if (action == "pickup") {
256 [ + - ]: 87 : return std::make_unique<BabyAIPickupInstr>(
257 [ + - + - ]: 174 : RandObj(std::vector<Type>{kKey, kBall, kBox}));
258 : : }
259 [ + + ]: 82 : if (action == "open") {
260 [ + - + - : 43 : return std::make_unique<BabyAIOpenInstr>(RandObj(std::vector<Type>{kDoor}));
+ - ]
261 : : }
262 [ + - ]: 39 : if (action == "putnext") {
263 [ + - ]: 39 : return std::make_unique<BabyAIPutNextInstr>(
264 [ + - + - : 78 : RandObj(std::vector<Type>{kKey, kBall, kBox}), RandObj());
+ - + - +
- ]
265 : : }
266 [ # # ]: 0 : LOG(FATAL) << "Unknown BabyAI action kind: " << action;
267 : : return nullptr;
268 : : }
269 : :
270 : 377 : std::unique_ptr<BabyAIInstr> BabyAILevelGenTask::RandInstr(
271 : : const std::vector<std::string>& action_kinds,
272 : : const std::vector<std::string>& instr_kinds, int depth) {
273 : 377 : const std::string& kind = RandElem(instr_kinds);
274 [ + + ]: 377 : if (kind == "action") {
275 : 280 : return RandActionInstr(action_kinds);
276 : : }
277 [ + + ]: 97 : if (kind == "and") {
278 [ + - ]: 72 : return std::make_unique<BabyAIAndInstr>(
279 [ + - + - : 216 : RandInstr(action_kinds, std::vector<std::string>{"action"}, depth + 1),
+ + - - ]
280 [ + - + - : 216 : RandInstr(action_kinds, std::vector<std::string>{"action"}, depth + 1));
+ + - - ]
281 : : }
282 [ + - ]: 25 : if (kind == "seq") {
283 : : auto instr_a = RandInstr(
284 [ + - + - : 75 : action_kinds, std::vector<std::string>{"action", "and"}, depth + 1);
+ + - - ]
285 : : auto instr_b = RandInstr(
286 [ + - + - : 75 : action_kinds, std::vector<std::string>{"action", "and"}, depth + 1);
+ + - - -
- ]
287 [ + + ]: 25 : if (RandBool()) {
288 [ + - ]: 10 : return std::make_unique<BabyAIBeforeInstr>(std::move(instr_a),
289 : : std::move(instr_b));
290 : : }
291 [ + - - - ]: 15 : return std::make_unique<BabyAIAfterInstr>(std::move(instr_a),
292 : : std::move(instr_b));
293 : : }
294 [ # # ]: 0 : LOG(FATAL) << "Unknown BabyAI instruction kind: " << kind;
295 : : return nullptr;
296 [ + - - + : 363 : }
+ - + - +
- + - + -
- - - - -
- ]
297 : :
298 : 232 : void BabyAILevelGenTask::GenMission() {
299 [ + + ]: 232 : if (RandFloat(0.0f, 1.0f) < locked_room_prob_) {
300 : 77 : AddLockedRoom();
301 : : }
302 [ + - ]: 232 : ConnectAllOrReject();
303 : 232 : AddDistractorsOrReject(-1, -1, num_dists_, false);
304 : : while (true) {
305 : 246 : PlaceAgentInRoom();
306 [ + + + + ]: 492 : if (locked_room_ == nullptr ||
307 [ + + ]: 91 : &RoomFromPos(agent_pos_.first, agent_pos_.second) != locked_room_) {
308 : : break;
309 : : }
310 : : }
311 [ + + ]: 232 : if (!unblocking_) {
312 : 99 : CheckObjsReachableOrReject();
313 : : }
314 : 183 : instrs_ = RandInstr(action_kinds_, instr_kinds_);
315 : 183 : }
316 : :
317 : 274 : std::pair<Pos, std::pair<Type, Color>> ObjAtDoor(const BabyAILevelTask& env,
318 : : const Pos& pos) {
319 : 274 : const WorldObj door = env.CellAt(pos.first, pos.second);
320 [ - + ]: 275 : return {pos, {door.GetType(), door.GetColor()}};
321 : : }
322 : :
323 : 26 : BabyAITaskConfig MakeGoToSeqConfig(BabyAITaskConfig config) {
324 : 26 : config.action_kinds = "goto";
325 : 25 : config.locked_room_prob = 0.0f;
326 : 25 : config.locations = false;
327 : 25 : config.unblocking = false;
328 : 25 : return config;
329 : : }
330 : :
331 : 21 : BabyAITaskConfig MakePickupLocConfig(BabyAITaskConfig config) {
332 : 21 : config.action_kinds = "pickup";
333 : 21 : config.instr_kinds = "action";
334 : 21 : config.num_rows = 1;
335 : 21 : config.num_cols = 1;
336 : 21 : config.num_dists = 8;
337 : 21 : config.locked_room_prob = 0.0f;
338 : 21 : config.locations = true;
339 : 21 : config.unblocking = false;
340 : 21 : return config;
341 : : }
342 : :
343 : 26 : BabyAITaskConfig MakeSynthConfig(BabyAITaskConfig config) {
344 : 26 : config.instr_kinds = "action";
345 : 26 : config.locations = false;
346 : 26 : config.unblocking = true;
347 : 26 : config.implicit_unlock = false;
348 : 26 : return config;
349 : : }
350 : :
351 : 13 : BabyAITaskConfig MakeSynthLocConfig(BabyAITaskConfig config) {
352 : 13 : config.instr_kinds = "action";
353 : 13 : config.locations = true;
354 : 13 : config.unblocking = true;
355 : 13 : config.implicit_unlock = false;
356 : 13 : return config;
357 : : }
358 : :
359 : 21 : BabyAITaskConfig MakeSynthSeqConfig(BabyAITaskConfig config) {
360 : 21 : config.locations = true;
361 : 21 : config.unblocking = true;
362 : 21 : config.implicit_unlock = false;
363 : 21 : return config;
364 : : }
365 : :
366 : 21 : BabyAITaskConfig MakeMiniBossConfig(BabyAITaskConfig config) {
367 : 21 : config.num_cols = 2;
368 : 21 : config.num_rows = 2;
369 : 21 : config.room_size = 5;
370 : 21 : config.num_dists = 7;
371 : 21 : config.locked_room_prob = 0.25f;
372 : 21 : return config;
373 : : }
374 : :
375 : 13 : BabyAITaskConfig MakeBossNoUnlockConfig(BabyAITaskConfig config) {
376 : 13 : config.locked_room_prob = 0.0f;
377 : 13 : config.implicit_unlock = false;
378 : 13 : return config;
379 : : }
380 : :
381 : : } // namespace minigrid
|