Branch data Line data Source code
1 : : // Copyright 2026 Garena Online Private Limited
2 : : //
3 : : // Licensed under the Apache License, Version 2.0 (the "License");
4 : : // you may not use this file except in compliance with the License.
5 : : // You may obtain a copy of the License at
6 : : //
7 : : // http://www.apache.org/licenses/LICENSE-2.0
8 : : //
9 : : // Unless required by applicable law or agreed to in writing, software
10 : : // distributed under the License is distributed on an "AS IS" BASIS,
11 : : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : : // See the License for the specific language governing permissions and
13 : : // limitations under the License.
14 : :
15 : : #include "envpool/minigrid/impl/babyai_core.h"
16 : :
17 : : #include <cstdlib>
18 : : #include <memory>
19 : : #include <optional>
20 : : #include <string>
21 : : #include <utility>
22 : : #include <vector>
23 : :
24 : : #include "absl/log/log.h"
25 : :
26 : : namespace minigrid {
27 : :
28 : 390 : BabyAIRejectSampling::BabyAIRejectSampling(const std::string& msg)
29 : 390 : : std::runtime_error(msg) {}
30 : :
31 : 1332 : BabyAILevelTask::BabyAILevelTask(std::string env_name, int room_size,
32 : : int num_rows, int num_cols, int max_steps,
33 : 1332 : int mission_bytes)
34 : : : RoomGridTask(std::move(env_name), room_size, num_rows, num_cols,
35 : : max_steps, 7, mission_bytes),
36 [ + - ]: 2660 : fixed_max_steps_(max_steps > 0) {} // NOLINT(whitespace/indent_namespace)
37 : :
38 : 13 : std::pair<Pos, std::pair<Type, Color>> BabyAILevelTask::AddExistingObject(
39 : : int i, int j, const WorldObj& obj) {
40 : 13 : const Room& room = GetRoom(i, j);
41 [ + - ]: 13 : Pos pos = PlaceObj(
42 [ + - ]: 13 : obj, room.top.first, room.top.second, room.size.first, room.size.second,
43 : : [&](const Pos& pos_candidate) {
44 : 13 : return std::abs(agent_pos_.first - pos_candidate.first) +
45 : 13 : std::abs(agent_pos_.second - pos_candidate.second) <
46 : 13 : 2;
47 : : },
48 : : 1000);
49 : 13 : GetRoom(i, j).objs.emplace_back(obj.GetType(), obj.GetColor());
50 : 13 : return {pos, {obj.GetType(), obj.GetColor()}};
51 : : }
52 : :
53 : : std::vector<std::pair<Pos, std::pair<Type, Color>>>
54 : 2157 : BabyAILevelTask::AddDistractorsOrReject(int i, int j, int num_distractors,
55 : : bool all_unique) {
56 : : try {
57 [ + - ]: 2157 : return AddDistractors(i, j, num_distractors, all_unique);
58 [ # # ]: 0 : } catch (const std::runtime_error&) {
59 [ # # # # ]: 0 : throw BabyAIRejectSampling("failed to add distractors");
60 : 0 : }
61 : : }
62 : :
63 : 849 : void BabyAILevelTask::ConnectAllOrReject(
64 : : const std::vector<Color>& door_colors) {
65 [ - + ]: 849 : if (!TryConnectAll(door_colors)) {
66 [ # # # # ]: 0 : throw BabyAIRejectSampling("connect_all failed");
67 : : }
68 : 849 : }
69 : :
70 : 945 : void BabyAILevelTask::CheckObjsReachableOrReject() const {
71 [ + + ]: 945 : if (!CheckObjsReachable()) {
72 [ + - + - ]: 628 : throw BabyAIRejectSampling("unreachable object");
73 : : }
74 : 631 : }
75 : :
76 : 32 : void BabyAILevelTask::OpenAllDoors() {
77 [ + + ]: 128 : for (int i = 0; i < num_cols_; ++i) {
78 [ + + ]: 384 : for (int j = 0; j < num_rows_; ++j) {
79 : 288 : const Room& room = GetRoom(i, j);
80 [ + + ]: 1440 : for (int door_idx = 0; door_idx < 4; ++door_idx) {
81 [ + + ]: 1152 : if (!room.connected[door_idx]) {
82 : 572 : continue;
83 : : }
84 : 580 : const Pos pos = room.door_pos[door_idx];
85 : 580 : WorldObj& obj = Cell(pos.first, pos.second);
86 [ + - ]: 580 : if (obj.GetType() == kDoor) {
87 : : obj.SetDoorLocked(false);
88 : : obj.SetDoorOpen(true);
89 : : }
90 : : }
91 : : }
92 : : }
93 : 32 : }
94 : :
95 : 1675 : void BabyAILevelTask::GenGrid() {
96 : 1675 : std::string last_error = "unknown";
97 [ + - ]: 2065 : for (int retry = 0; retry < 1000; ++retry) {
98 [ + - ]: 2065 : RoomGridTask::GenGrid();
99 [ + + ]: 2066 : locked_room_ = nullptr;
100 : : instrs_.reset();
101 : : try {
102 [ + + ]: 2066 : GenMission();
103 [ + - - - ]: 1710 : CHECK(instrs_ != nullptr);
104 : 1674 : std::vector<Color> locked_colors;
105 [ + + ]: 5562 : for (int i = 0; i < num_cols_; ++i) {
106 [ + + ]: 13119 : for (int j = 0; j < num_rows_; ++j) {
107 [ + - ]: 9267 : const Room& room = GetRoom(i, j);
108 [ + + ]: 46267 : for (int door_idx = 0; door_idx < 4; ++door_idx) {
109 [ + + ]: 37001 : if (!room.connected[door_idx]) {
110 : 24674 : continue;
111 : : }
112 : 12327 : const Pos pos = room.door_pos[door_idx];
113 [ + - ]: 12327 : const WorldObj door = GetCell(pos.first, pos.second);
114 [ + + + + ]: 12321 : if (door.GetType() == kDoor && door.GetDoorLocked()) {
115 [ + - - - ]: 1011 : locked_colors.push_back(door.GetColor());
116 : : }
117 : : }
118 : : }
119 : : }
120 [ + - + + ]: 1709 : instrs_->Validate(*this, locked_colors, UnblockingEnabled());
121 [ + - + - ]: 3381 : SetMission(instrs_->Surface(*this));
122 [ + - ]: 1671 : instrs_->ResetVerifier(*this);
123 [ + + ]: 1676 : if (!fixed_max_steps_) {
124 [ + - ]: 1128 : max_steps_ = instrs_->NumNavsNeeded() * room_size_ * room_size_ *
125 : 1127 : num_rows_ * num_cols_;
126 : : }
127 [ + - ]: 1675 : AfterResetVerifier();
128 : 1672 : return;
129 [ - + - ]: 390 : } catch (const BabyAIRejectSampling& e) {
130 : 390 : last_error = e.what();
131 : 390 : } catch (const std::runtime_error& e) {
132 : 0 : last_error = e.what();
133 : 0 : }
134 : : }
135 [ # # ]: 0 : throw std::runtime_error("BabyAI mission generation failed for " + env_name_ +
136 [ # # ]: 0 : ": " + last_error);
137 : : }
138 : :
139 : 367482 : void BabyAILevelTask::AfterStep(Act act, const WorldObj& /*pre_fwd*/,
140 : : const Pos& /*fwd_pos*/,
141 : : const WorldObj& pre_carrying, float* reward,
142 : : bool* terminated) {
143 [ + + ]: 367482 : if (act == kDrop) {
144 : 50337 : instrs_->UpdateObjPoss(*this);
145 : : }
146 : 367486 : BabyAIStatus status = instrs_->Verify(*this, act, pre_carrying);
147 [ + + ]: 367124 : if (status == BabyAIStatus::kSuccess) {
148 : 211 : *reward = SuccessReward();
149 : 211 : *terminated = true;
150 [ + + ]: 366913 : } else if (status == BabyAIStatus::kFailure) {
151 : 97 : *reward = 0.0f;
152 : 97 : *terminated = true;
153 : : }
154 : 367124 : }
155 : :
156 : 93 : void BabyAILevelTask::AddLockedRoom() {
157 : : while (true) {
158 : 140 : int i = RandInt(0, num_cols_);
159 : 139 : int j = RandInt(0, num_rows_);
160 : 140 : int door_idx = RandInt(0, 4);
161 : 140 : Room& room = GetRoom(i, j);
162 [ + + ]: 140 : if (!room.has_neighbor[door_idx]) {
163 : 47 : continue;
164 : : }
165 : 93 : locked_room_ = &room;
166 : 93 : Pos door_pos = AddDoor(i, j, door_idx, kUnassigned, true);
167 : 186 : Color door_color = GetCell(door_pos.first, door_pos.second).GetColor();
168 : : while (true) {
169 : 97 : int key_i = RandInt(0, num_cols_);
170 : 97 : int key_j = RandInt(0, num_rows_);
171 [ + + ]: 97 : if (key_i == i && key_j == j) {
172 : : continue;
173 : : }
174 : 93 : AddObject(key_i, key_j, kKey, door_color);
175 : 93 : return;
176 : 4 : }
177 : : }
178 : : }
179 : :
180 : 164 : BabyAILevelGenTask::BabyAILevelGenTask(const BabyAITaskConfig& config)
181 : 164 : : BabyAILevelTask(config.env_name, config.room_size, config.num_rows,
182 : 164 : config.num_cols, config.max_steps, config.mission_bytes),
183 : 164 : num_dists_(config.num_dists),
184 : 164 : locked_room_prob_(config.locked_room_prob),
185 : 164 : locations_(config.locations),
186 : 164 : unblocking_(config.unblocking),
187 [ + - + - ]: 328 : implicit_unlock_(config.implicit_unlock) {
188 [ + - ]: 164 : action_kinds_ = SplitKinds(config.action_kinds);
189 [ + - ]: 162 : instr_kinds_ = SplitKinds(config.instr_kinds);
190 : 164 : }
191 : :
192 : 327 : std::vector<std::string> BabyAILevelGenTask::SplitKinds(
193 : : const std::string& csv) const {
194 : : std::vector<std::string> kinds;
195 : : std::string cur;
196 [ + + ]: 5093 : for (char ch : csv) {
197 [ + + ]: 4768 : if (ch == ',') {
198 [ + - ]: 551 : if (!cur.empty()) {
199 [ + - ]: 554 : kinds.push_back(cur);
200 : : cur.clear();
201 : : }
202 : 551 : continue;
203 : : }
204 [ + - ]: 4217 : cur.push_back(ch);
205 : : }
206 [ + - ]: 325 : if (!cur.empty()) {
207 [ + - ]: 325 : kinds.push_back(cur);
208 : : }
209 : 326 : return kinds;
210 : 0 : }
211 : :
212 : 436 : BabyAIObjDesc BabyAILevelGenTask::RandObj(const std::vector<Type>& types,
213 : : const std::vector<Color>& colors,
214 : : int max_tries) {
215 [ + - ]: 1118 : for (int num_tries = 0; num_tries <= max_tries; ++num_tries) {
216 : 1118 : std::optional<Color> color;
217 [ + + ]: 1118 : if (RandInt(0, static_cast<int>(colors.size()) + 1) > 0) {
218 : 960 : color = RandElem(colors);
219 : : }
220 : 1118 : std::optional<Type> type = RandElem(types);
221 : : BabyAILoc loc = BabyAILoc::kNone;
222 [ + + + + ]: 1118 : if (locations_ && RandBool()) {
223 [ + - ]: 412 : loc = RandElem(std::vector<BabyAILoc>{BabyAILoc::kLeft, BabyAILoc::kRight,
224 : : BabyAILoc::kFront,
225 : : BabyAILoc::kBehind});
226 : : }
227 : 1118 : BabyAIObjDesc desc(type, color, loc);
228 [ + - ]: 1118 : desc.FindMatchingObjs(*this);
229 [ + + ]: 1118 : if (desc.ObjUids().empty()) {
230 : 680 : continue;
231 : : }
232 [ + + + + ]: 438 : if (!implicit_unlock_ && locked_room_ != nullptr) {
233 : : bool has_unlocked_match = false;
234 [ + + ]: 122 : for (const Pos& pos : desc.ObjPoss()) {
235 [ + + ]: 120 : if (!locked_room_->PosInside(pos.first, pos.second)) {
236 : : has_unlocked_match = true;
237 : : break;
238 : : }
239 : : }
240 [ + + ]: 108 : if (!has_unlocked_match) {
241 : 2 : continue;
242 : : }
243 : : }
244 : 436 : return desc;
245 : 1118 : }
246 [ # # # # ]: 0 : throw BabyAIRejectSampling("failed to find suitable object");
247 : : }
248 : :
249 : 379 : std::unique_ptr<BabyAIInstr> BabyAILevelGenTask::RandActionInstr(
250 : : const std::vector<std::string>& action_kinds) {
251 : 379 : const std::string& action = RandElem(action_kinds);
252 [ + + ]: 379 : if (action == "goto") {
253 [ + - + - : 171 : return std::make_unique<BabyAIGoToInstr>(RandObj());
+ - ]
254 : : }
255 [ + + ]: 208 : if (action == "pickup") {
256 [ + - ]: 102 : return std::make_unique<BabyAIPickupInstr>(
257 [ + - + - ]: 204 : RandObj(std::vector<Type>{kKey, kBall, kBox}));
258 : : }
259 [ + + ]: 106 : if (action == "open") {
260 [ + - + - : 49 : return std::make_unique<BabyAIOpenInstr>(RandObj(std::vector<Type>{kDoor}));
+ - ]
261 : : }
262 [ + - ]: 57 : if (action == "putnext") {
263 [ + - ]: 57 : return std::make_unique<BabyAIPutNextInstr>(
264 [ + - + - : 114 : RandObj(std::vector<Type>{kKey, kBall, kBox}), RandObj());
+ - + - +
- ]
265 : : }
266 [ # # ]: 0 : LOG(FATAL) << "Unknown BabyAI action kind: " << action;
267 : : return nullptr;
268 : : }
269 : :
270 : 520 : std::unique_ptr<BabyAIInstr> BabyAILevelGenTask::RandInstr(
271 : : const std::vector<std::string>& action_kinds,
272 : : const std::vector<std::string>& instr_kinds, int depth) {
273 : 520 : const std::string& kind = RandElem(instr_kinds);
274 [ + + ]: 520 : if (kind == "action") {
275 : 379 : return RandActionInstr(action_kinds);
276 : : }
277 [ + + ]: 141 : if (kind == "and") {
278 [ + - ]: 98 : return std::make_unique<BabyAIAndInstr>(
279 [ + - + - : 294 : RandInstr(action_kinds, std::vector<std::string>{"action"}, depth + 1),
+ + - - ]
280 [ + - + - : 294 : RandInstr(action_kinds, std::vector<std::string>{"action"}, depth + 1));
+ + - - ]
281 : : }
282 [ + - ]: 43 : if (kind == "seq") {
283 : : auto instr_a = RandInstr(
284 [ + - + - : 129 : action_kinds, std::vector<std::string>{"action", "and"}, depth + 1);
+ + - - ]
285 : : auto instr_b = RandInstr(
286 [ + - + - : 129 : action_kinds, std::vector<std::string>{"action", "and"}, depth + 1);
+ + - - -
- ]
287 [ + + ]: 43 : if (RandBool()) {
288 [ + - ]: 18 : return std::make_unique<BabyAIBeforeInstr>(std::move(instr_a),
289 : : std::move(instr_b));
290 : : }
291 [ + - - - ]: 25 : return std::make_unique<BabyAIAfterInstr>(std::move(instr_a),
292 : : std::move(instr_b));
293 : : }
294 [ # # ]: 0 : LOG(FATAL) << "Unknown BabyAI instruction kind: " << kind;
295 : : return nullptr;
296 [ + - - + : 521 : }
+ - + - +
- + - + -
- - - - -
- ]
297 : :
298 : 311 : void BabyAILevelGenTask::GenMission() {
299 [ + + ]: 311 : if (RandFloat(0.0f, 1.0f) < locked_room_prob_) {
300 : 93 : AddLockedRoom();
301 : : }
302 [ + - ]: 311 : ConnectAllOrReject();
303 : 311 : AddDistractorsOrReject(-1, -1, num_dists_, false);
304 : : while (true) {
305 : 325 : PlaceAgentInRoom();
306 [ + + + + ]: 650 : if (locked_room_ == nullptr ||
307 [ + + ]: 107 : &RoomFromPos(agent_pos_.first, agent_pos_.second) != locked_room_) {
308 : : break;
309 : : }
310 : : }
311 [ + + ]: 311 : if (!unblocking_) {
312 : 144 : CheckObjsReachableOrReject();
313 : : }
314 : 238 : instrs_ = RandInstr(action_kinds_, instr_kinds_);
315 : 238 : }
316 : :
317 : 647 : std::pair<Pos, std::pair<Type, Color>> ObjAtDoor(const BabyAILevelTask& env,
318 : : const Pos& pos) {
319 : 647 : const WorldObj door = env.CellAt(pos.first, pos.second);
320 [ - + ]: 646 : return {pos, {door.GetType(), door.GetColor()}};
321 : : }
322 : :
323 : 26 : BabyAITaskConfig MakeGoToSeqConfig(BabyAITaskConfig config) {
324 : 26 : config.action_kinds = "goto";
325 : 26 : config.locked_room_prob = 0.0f;
326 : 26 : config.locations = false;
327 : 26 : config.unblocking = false;
328 : 26 : return config;
329 : : }
330 : :
331 : 21 : BabyAITaskConfig MakePickupLocConfig(BabyAITaskConfig config) {
332 : 21 : config.action_kinds = "pickup";
333 : 21 : config.instr_kinds = "action";
334 : 21 : config.num_rows = 1;
335 : 21 : config.num_cols = 1;
336 : 21 : config.num_dists = 8;
337 : 21 : config.locked_room_prob = 0.0f;
338 : 21 : config.locations = true;
339 : 21 : config.unblocking = false;
340 : 21 : return config;
341 : : }
342 : :
343 : 26 : BabyAITaskConfig MakeSynthConfig(BabyAITaskConfig config) {
344 : 26 : config.instr_kinds = "action";
345 : 26 : config.locations = false;
346 : 26 : config.unblocking = true;
347 : 26 : config.implicit_unlock = false;
348 : 26 : return config;
349 : : }
350 : :
351 : 13 : BabyAITaskConfig MakeSynthLocConfig(BabyAITaskConfig config) {
352 : 13 : config.instr_kinds = "action";
353 : 13 : config.locations = true;
354 : 13 : config.unblocking = true;
355 : 13 : config.implicit_unlock = false;
356 : 13 : return config;
357 : : }
358 : :
359 : 21 : BabyAITaskConfig MakeSynthSeqConfig(BabyAITaskConfig config) {
360 : 21 : config.locations = true;
361 : 21 : config.unblocking = true;
362 : 21 : config.implicit_unlock = false;
363 : 21 : return config;
364 : : }
365 : :
366 : 21 : BabyAITaskConfig MakeMiniBossConfig(BabyAITaskConfig config) {
367 : 21 : config.num_cols = 2;
368 : 21 : config.num_rows = 2;
369 : 21 : config.room_size = 5;
370 : 21 : config.num_dists = 7;
371 : 21 : config.locked_room_prob = 0.25f;
372 : 21 : return config;
373 : : }
374 : :
375 : 13 : BabyAITaskConfig MakeBossNoUnlockConfig(BabyAITaskConfig config) {
376 : 13 : config.locked_room_prob = 0.0f;
377 : 13 : config.implicit_unlock = false;
378 : 13 : return config;
379 : : }
380 : :
381 : : } // namespace minigrid
|