LCOV - code coverage report
Current view: top level - envpool/minigrid/impl - babyai_core.cc (source / functions) Coverage Total Hit
Test: EnvPool coverage report Lines: 94.3 % 228 215
Test Date: 2026-04-07 20:03:58 Functions: 100.0 % 24 24
Legend: Lines: hit not hit | Branches: + taken - not taken # not executed Branches: 57.8 % 277 160

             Branch data     Line data    Source code
       1                 :             : // Copyright 2026 Garena Online Private Limited
       2                 :             : //
       3                 :             : // Licensed under the Apache License, Version 2.0 (the "License");
       4                 :             : // you may not use this file except in compliance with the License.
       5                 :             : // You may obtain a copy of the License at
       6                 :             : //
       7                 :             : //      http://www.apache.org/licenses/LICENSE-2.0
       8                 :             : //
       9                 :             : // Unless required by applicable law or agreed to in writing, software
      10                 :             : // distributed under the License is distributed on an "AS IS" BASIS,
      11                 :             : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      12                 :             : // See the License for the specific language governing permissions and
      13                 :             : // limitations under the License.
      14                 :             : 
      15                 :             : #include "envpool/minigrid/impl/babyai_core.h"
      16                 :             : 
      17                 :             : #include <cstdlib>
      18                 :             : #include <memory>
      19                 :             : #include <optional>
      20                 :             : #include <string>
      21                 :             : #include <utility>
      22                 :             : #include <vector>
      23                 :             : 
      24                 :             : #include "absl/log/log.h"
      25                 :             : 
      26                 :             : namespace minigrid {
      27                 :             : 
      28                 :         275 : BabyAIRejectSampling::BabyAIRejectSampling(const std::string& msg)
      29                 :         275 :     : std::runtime_error(msg) {}
      30                 :             : 
      31                 :        1333 : BabyAILevelTask::BabyAILevelTask(std::string env_name, int room_size,
      32                 :             :                                  int num_rows, int num_cols, int max_steps,
      33                 :        1333 :                                  int mission_bytes)
      34                 :             :     : RoomGridTask(std::move(env_name), room_size, num_rows, num_cols,
      35                 :             :                    max_steps, 7, mission_bytes),
      36         [ +  - ]:        2664 :       fixed_max_steps_(max_steps > 0) {}  // NOLINT(whitespace/indent_namespace)
      37                 :             : 
      38                 :          13 : std::pair<Pos, std::pair<Type, Color>> BabyAILevelTask::AddExistingObject(
      39                 :             :     int i, int j, const WorldObj& obj) {
      40                 :          13 :   const Room& room = GetRoom(i, j);
      41         [ +  - ]:          13 :   Pos pos = PlaceObj(
      42         [ +  - ]:          13 :       obj, room.top.first, room.top.second, room.size.first, room.size.second,
      43                 :             :       [&](const Pos& pos_candidate) {
      44                 :          13 :         return std::abs(agent_pos_.first - pos_candidate.first) +
      45                 :          13 :                    std::abs(agent_pos_.second - pos_candidate.second) <
      46                 :          13 :                2;
      47                 :             :       },
      48                 :             :       1000);
      49                 :          13 :   GetRoom(i, j).objs.emplace_back(obj.GetType(), obj.GetColor());
      50                 :          13 :   return {pos, {obj.GetType(), obj.GetColor()}};
      51                 :             : }
      52                 :             : 
      53                 :             : std::vector<std::pair<Pos, std::pair<Type, Color>>>
      54                 :        2062 : BabyAILevelTask::AddDistractorsOrReject(int i, int j, int num_distractors,
      55                 :             :                                         bool all_unique) {
      56                 :             :   try {
      57         [ +  - ]:        2062 :     return AddDistractors(i, j, num_distractors, all_unique);
      58         [ #  # ]:           0 :   } catch (const std::runtime_error&) {
      59   [ #  #  #  # ]:           0 :     throw BabyAIRejectSampling("failed to add distractors");
      60                 :           0 :   }
      61                 :             : }
      62                 :             : 
      63                 :         721 : void BabyAILevelTask::ConnectAllOrReject(
      64                 :             :     const std::vector<Color>& door_colors) {
      65         [ -  + ]:         721 :   if (!TryConnectAll(door_colors)) {
      66   [ #  #  #  # ]:           0 :     throw BabyAIRejectSampling("connect_all failed");
      67                 :             :   }
      68                 :         721 : }
      69                 :             : 
      70                 :         872 : void BabyAILevelTask::CheckObjsReachableOrReject() const {
      71         [ +  + ]:         872 :   if (!CheckObjsReachable()) {
      72   [ +  -  +  - ]:         420 :     throw BabyAIRejectSampling("unreachable object");
      73                 :             :   }
      74                 :         664 : }
      75                 :             : 
      76                 :          26 : void BabyAILevelTask::OpenAllDoors() {
      77         [ +  + ]:         104 :   for (int i = 0; i < num_cols_; ++i) {
      78         [ +  + ]:         312 :     for (int j = 0; j < num_rows_; ++j) {
      79                 :         234 :       const Room& room = GetRoom(i, j);
      80         [ +  + ]:        1170 :       for (int door_idx = 0; door_idx < 4; ++door_idx) {
      81         [ +  + ]:         936 :         if (!room.connected[door_idx]) {
      82                 :         476 :           continue;
      83                 :             :         }
      84                 :         460 :         const Pos pos = room.door_pos[door_idx];
      85                 :         460 :         WorldObj& obj = Cell(pos.first, pos.second);
      86         [ +  - ]:         460 :         if (obj.GetType() == kDoor) {
      87                 :             :           obj.SetDoorLocked(false);
      88                 :             :           obj.SetDoorOpen(true);
      89                 :             :         }
      90                 :             :       }
      91                 :             :     }
      92                 :             :   }
      93                 :          26 : }
      94                 :             : 
      95                 :        1603 : void BabyAILevelTask::GenGrid() {
      96                 :        1603 :   std::string last_error = "unknown";
      97         [ +  - ]:        1876 :   for (int retry = 0; retry < 1000; ++retry) {
      98         [ +  - ]:        1876 :     RoomGridTask::GenGrid();
      99         [ +  + ]:        1879 :     locked_room_ = nullptr;
     100                 :             :     instrs_.reset();
     101                 :             :     try {
     102         [ +  + ]:        1879 :       GenMission();
     103   [ +  -  -  - ]:        1638 :       CHECK(instrs_ != nullptr);
     104                 :        1603 :       std::vector<Color> locked_colors;
     105         [ +  + ]:        4995 :       for (int i = 0; i < num_cols_; ++i) {
     106         [ +  + ]:       10886 :         for (int j = 0; j < num_rows_; ++j) {
     107         [ +  - ]:        7529 :           const Room& room = GetRoom(i, j);
     108         [ +  + ]:       37632 :           for (int door_idx = 0; door_idx < 4; ++door_idx) {
     109         [ +  + ]:       30102 :             if (!room.connected[door_idx]) {
     110                 :       19951 :               continue;
     111                 :             :             }
     112                 :       10151 :             const Pos pos = room.door_pos[door_idx];
     113         [ +  - ]:       10151 :             const WorldObj door = GetCell(pos.first, pos.second);
     114   [ +  +  +  + ]:       10152 :             if (door.GetType() == kDoor && door.GetDoorLocked()) {
     115   [ +  -  -  - ]:         844 :               locked_colors.push_back(door.GetColor());
     116                 :             :             }
     117                 :             :           }
     118                 :             :         }
     119                 :             :       }
     120   [ +  -  +  + ]:        1639 :       instrs_->Validate(*this, locked_colors, UnblockingEnabled());
     121   [ +  -  +  - ]:        3235 :       SetMission(instrs_->Surface(*this));
     122         [ +  - ]:        1598 :       instrs_->ResetVerifier(*this);
     123         [ +  + ]:        1604 :       if (!fixed_max_steps_) {
     124         [ +  - ]:        1166 :         max_steps_ = instrs_->NumNavsNeeded() * room_size_ * room_size_ *
     125                 :        1164 :                      num_rows_ * num_cols_;
     126                 :             :       }
     127         [ +  - ]:        1602 :       AfterResetVerifier();
     128                 :        1600 :       return;
     129      [ -  +  - ]:         275 :     } catch (const BabyAIRejectSampling& e) {
     130                 :         275 :       last_error = e.what();
     131                 :         275 :     } catch (const std::runtime_error& e) {
     132                 :           0 :       last_error = e.what();
     133                 :           0 :     }
     134                 :             :   }
     135         [ #  # ]:           0 :   throw std::runtime_error("BabyAI mission generation failed for " + env_name_ +
     136         [ #  # ]:           0 :                            ": " + last_error);
     137                 :             : }
     138                 :             : 
     139                 :       58733 : void BabyAILevelTask::AfterStep(Act act, const WorldObj& /*pre_fwd*/,
     140                 :             :                                 const Pos& /*fwd_pos*/,
     141                 :             :                                 const WorldObj& pre_carrying, float* reward,
     142                 :             :                                 bool* terminated) {
     143         [ +  + ]:       58733 :   if (act == kDrop) {
     144                 :        7267 :     instrs_->UpdateObjPoss(*this);
     145                 :             :   }
     146                 :       58734 :   BabyAIStatus status = instrs_->Verify(*this, act, pre_carrying);
     147         [ +  + ]:       58633 :   if (status == BabyAIStatus::kSuccess) {
     148                 :         108 :     *reward = SuccessReward();
     149                 :         108 :     *terminated = true;
     150         [ +  + ]:       58525 :   } else if (status == BabyAIStatus::kFailure) {
     151                 :          25 :     *reward = 0.0f;
     152                 :          25 :     *terminated = true;
     153                 :             :   }
     154                 :       58633 : }
     155                 :             : 
     156                 :          77 : void BabyAILevelTask::AddLockedRoom() {
     157                 :             :   while (true) {
     158                 :         110 :     int i = RandInt(0, num_cols_);
     159                 :         110 :     int j = RandInt(0, num_rows_);
     160                 :         110 :     int door_idx = RandInt(0, 4);
     161                 :         110 :     Room& room = GetRoom(i, j);
     162         [ +  + ]:         110 :     if (!room.has_neighbor[door_idx]) {
     163                 :          33 :       continue;
     164                 :             :     }
     165                 :          77 :     locked_room_ = &room;
     166                 :          77 :     Pos door_pos = AddDoor(i, j, door_idx, kUnassigned, true);
     167                 :         154 :     Color door_color = GetCell(door_pos.first, door_pos.second).GetColor();
     168                 :             :     while (true) {
     169                 :          81 :       int key_i = RandInt(0, num_cols_);
     170                 :          81 :       int key_j = RandInt(0, num_rows_);
     171         [ +  + ]:          81 :       if (key_i == i && key_j == j) {
     172                 :             :         continue;
     173                 :             :       }
     174                 :          77 :       AddObject(key_i, key_j, kKey, door_color);
     175                 :          76 :       return;
     176                 :           4 :     }
     177                 :             :   }
     178                 :             : }
     179                 :             : 
     180                 :         164 : BabyAILevelGenTask::BabyAILevelGenTask(const BabyAITaskConfig& config)
     181                 :         164 :     : BabyAILevelTask(config.env_name, config.room_size, config.num_rows,
     182                 :         164 :                       config.num_cols, config.max_steps, config.mission_bytes),
     183                 :         164 :       num_dists_(config.num_dists),
     184                 :         164 :       locked_room_prob_(config.locked_room_prob),
     185                 :         164 :       locations_(config.locations),
     186                 :         164 :       unblocking_(config.unblocking),
     187   [ +  -  +  - ]:         328 :       implicit_unlock_(config.implicit_unlock) {
     188         [ +  - ]:         164 :   action_kinds_ = SplitKinds(config.action_kinds);
     189         [ +  - ]:         162 :   instr_kinds_ = SplitKinds(config.instr_kinds);
     190                 :         164 : }
     191                 :             : 
     192                 :         326 : std::vector<std::string> BabyAILevelGenTask::SplitKinds(
     193                 :             :     const std::string& csv) const {
     194                 :             :   std::vector<std::string> kinds;
     195                 :             :   std::string cur;
     196         [ +  + ]:        5113 :   for (char ch : csv) {
     197         [ +  + ]:        4786 :     if (ch == ',') {
     198         [ +  - ]:         552 :       if (!cur.empty()) {
     199         [ +  - ]:         553 :         kinds.push_back(cur);
     200                 :             :         cur.clear();
     201                 :             :       }
     202                 :         551 :       continue;
     203                 :             :     }
     204         [ +  - ]:        4234 :     cur.push_back(ch);
     205                 :             :   }
     206         [ +  - ]:         327 :   if (!cur.empty()) {
     207         [ +  - ]:         328 :     kinds.push_back(cur);
     208                 :             :   }
     209                 :         327 :   return kinds;
     210                 :           0 : }
     211                 :             : 
     212                 :         319 : BabyAIObjDesc BabyAILevelGenTask::RandObj(const std::vector<Type>& types,
     213                 :             :                                           const std::vector<Color>& colors,
     214                 :             :                                           int max_tries) {
     215         [ +  - ]:         863 :   for (int num_tries = 0; num_tries <= max_tries; ++num_tries) {
     216                 :         863 :     std::optional<Color> color;
     217         [ +  + ]:         863 :     if (RandInt(0, static_cast<int>(colors.size()) + 1) > 0) {
     218                 :         734 :       color = RandElem(colors);
     219                 :             :     }
     220                 :         863 :     std::optional<Type> type = RandElem(types);
     221                 :             :     BabyAILoc loc = BabyAILoc::kNone;
     222   [ +  +  +  + ]:         863 :     if (locations_ && RandBool()) {
     223         [ +  - ]:         334 :       loc = RandElem(std::vector<BabyAILoc>{BabyAILoc::kLeft, BabyAILoc::kRight,
     224                 :             :                                             BabyAILoc::kFront,
     225                 :             :                                             BabyAILoc::kBehind});
     226                 :             :     }
     227                 :         863 :     BabyAIObjDesc desc(type, color, loc);
     228         [ +  - ]:         863 :     desc.FindMatchingObjs(*this);
     229         [ +  + ]:         863 :     if (desc.ObjUids().empty()) {
     230                 :         544 :       continue;
     231                 :             :     }
     232   [ +  +  +  + ]:         319 :     if (!implicit_unlock_ && locked_room_ != nullptr) {
     233                 :             :       bool has_unlocked_match = false;
     234         [ +  - ]:          94 :       for (const Pos& pos : desc.ObjPoss()) {
     235         [ +  + ]:          94 :         if (!locked_room_->PosInside(pos.first, pos.second)) {
     236                 :             :           has_unlocked_match = true;
     237                 :             :           break;
     238                 :             :         }
     239                 :             :       }
     240         [ -  + ]:          82 :       if (!has_unlocked_match) {
     241                 :           0 :         continue;
     242                 :             :       }
     243                 :             :     }
     244                 :         319 :     return desc;
     245                 :         863 :   }
     246   [ #  #  #  # ]:           0 :   throw BabyAIRejectSampling("failed to find suitable object");
     247                 :             : }
     248                 :             : 
     249                 :         280 : std::unique_ptr<BabyAIInstr> BabyAILevelGenTask::RandActionInstr(
     250                 :             :     const std::vector<std::string>& action_kinds) {
     251                 :         280 :   const std::string& action = RandElem(action_kinds);
     252         [ +  + ]:         280 :   if (action == "goto") {
     253   [ +  -  +  -  :         111 :     return std::make_unique<BabyAIGoToInstr>(RandObj());
                   +  - ]
     254                 :             :   }
     255         [ +  + ]:         169 :   if (action == "pickup") {
     256         [ +  - ]:          87 :     return std::make_unique<BabyAIPickupInstr>(
     257   [ +  -  +  - ]:         174 :         RandObj(std::vector<Type>{kKey, kBall, kBox}));
     258                 :             :   }
     259         [ +  + ]:          82 :   if (action == "open") {
     260   [ +  -  +  -  :          43 :     return std::make_unique<BabyAIOpenInstr>(RandObj(std::vector<Type>{kDoor}));
                   +  - ]
     261                 :             :   }
     262         [ +  - ]:          39 :   if (action == "putnext") {
     263         [ +  - ]:          39 :     return std::make_unique<BabyAIPutNextInstr>(
     264   [ +  -  +  -  :          78 :         RandObj(std::vector<Type>{kKey, kBall, kBox}), RandObj());
          +  -  +  -  +  
                      - ]
     265                 :             :   }
     266         [ #  # ]:           0 :   LOG(FATAL) << "Unknown BabyAI action kind: " << action;
     267                 :             :   return nullptr;
     268                 :             : }
     269                 :             : 
     270                 :         377 : std::unique_ptr<BabyAIInstr> BabyAILevelGenTask::RandInstr(
     271                 :             :     const std::vector<std::string>& action_kinds,
     272                 :             :     const std::vector<std::string>& instr_kinds, int depth) {
     273                 :         377 :   const std::string& kind = RandElem(instr_kinds);
     274         [ +  + ]:         377 :   if (kind == "action") {
     275                 :         280 :     return RandActionInstr(action_kinds);
     276                 :             :   }
     277         [ +  + ]:          97 :   if (kind == "and") {
     278         [ +  - ]:          72 :     return std::make_unique<BabyAIAndInstr>(
     279   [ +  -  +  -  :         216 :         RandInstr(action_kinds, std::vector<std::string>{"action"}, depth + 1),
             +  +  -  - ]
     280   [ +  -  +  -  :         216 :         RandInstr(action_kinds, std::vector<std::string>{"action"}, depth + 1));
             +  +  -  - ]
     281                 :             :   }
     282         [ +  - ]:          25 :   if (kind == "seq") {
     283                 :             :     auto instr_a = RandInstr(
     284   [ +  -  +  -  :          75 :         action_kinds, std::vector<std::string>{"action", "and"}, depth + 1);
             +  +  -  - ]
     285                 :             :     auto instr_b = RandInstr(
     286   [ +  -  +  -  :          75 :         action_kinds, std::vector<std::string>{"action", "and"}, depth + 1);
          +  +  -  -  -  
                      - ]
     287         [ +  + ]:          25 :     if (RandBool()) {
     288         [ +  - ]:          10 :       return std::make_unique<BabyAIBeforeInstr>(std::move(instr_a),
     289                 :             :                                                  std::move(instr_b));
     290                 :             :     }
     291   [ +  -  -  - ]:          15 :     return std::make_unique<BabyAIAfterInstr>(std::move(instr_a),
     292                 :             :                                               std::move(instr_b));
     293                 :             :   }
     294         [ #  # ]:           0 :   LOG(FATAL) << "Unknown BabyAI instruction kind: " << kind;
     295                 :             :   return nullptr;
     296   [ +  -  -  +  :         363 : }
          +  -  +  -  +  
          -  +  -  +  -  
          -  -  -  -  -  
                      - ]
     297                 :             : 
     298                 :         232 : void BabyAILevelGenTask::GenMission() {
     299         [ +  + ]:         232 :   if (RandFloat(0.0f, 1.0f) < locked_room_prob_) {
     300                 :          77 :     AddLockedRoom();
     301                 :             :   }
     302         [ +  - ]:         232 :   ConnectAllOrReject();
     303                 :         232 :   AddDistractorsOrReject(-1, -1, num_dists_, false);
     304                 :             :   while (true) {
     305                 :         246 :     PlaceAgentInRoom();
     306   [ +  +  +  + ]:         492 :     if (locked_room_ == nullptr ||
     307         [ +  + ]:          91 :         &RoomFromPos(agent_pos_.first, agent_pos_.second) != locked_room_) {
     308                 :             :       break;
     309                 :             :     }
     310                 :             :   }
     311         [ +  + ]:         232 :   if (!unblocking_) {
     312                 :          99 :     CheckObjsReachableOrReject();
     313                 :             :   }
     314                 :         183 :   instrs_ = RandInstr(action_kinds_, instr_kinds_);
     315                 :         183 : }
     316                 :             : 
     317                 :         274 : std::pair<Pos, std::pair<Type, Color>> ObjAtDoor(const BabyAILevelTask& env,
     318                 :             :                                                  const Pos& pos) {
     319                 :         274 :   const WorldObj door = env.CellAt(pos.first, pos.second);
     320         [ -  + ]:         275 :   return {pos, {door.GetType(), door.GetColor()}};
     321                 :             : }
     322                 :             : 
     323                 :          26 : BabyAITaskConfig MakeGoToSeqConfig(BabyAITaskConfig config) {
     324                 :          26 :   config.action_kinds = "goto";
     325                 :          25 :   config.locked_room_prob = 0.0f;
     326                 :          25 :   config.locations = false;
     327                 :          25 :   config.unblocking = false;
     328                 :          25 :   return config;
     329                 :             : }
     330                 :             : 
     331                 :          21 : BabyAITaskConfig MakePickupLocConfig(BabyAITaskConfig config) {
     332                 :          21 :   config.action_kinds = "pickup";
     333                 :          21 :   config.instr_kinds = "action";
     334                 :          21 :   config.num_rows = 1;
     335                 :          21 :   config.num_cols = 1;
     336                 :          21 :   config.num_dists = 8;
     337                 :          21 :   config.locked_room_prob = 0.0f;
     338                 :          21 :   config.locations = true;
     339                 :          21 :   config.unblocking = false;
     340                 :          21 :   return config;
     341                 :             : }
     342                 :             : 
     343                 :          26 : BabyAITaskConfig MakeSynthConfig(BabyAITaskConfig config) {
     344                 :          26 :   config.instr_kinds = "action";
     345                 :          26 :   config.locations = false;
     346                 :          26 :   config.unblocking = true;
     347                 :          26 :   config.implicit_unlock = false;
     348                 :          26 :   return config;
     349                 :             : }
     350                 :             : 
     351                 :          13 : BabyAITaskConfig MakeSynthLocConfig(BabyAITaskConfig config) {
     352                 :          13 :   config.instr_kinds = "action";
     353                 :          13 :   config.locations = true;
     354                 :          13 :   config.unblocking = true;
     355                 :          13 :   config.implicit_unlock = false;
     356                 :          13 :   return config;
     357                 :             : }
     358                 :             : 
     359                 :          21 : BabyAITaskConfig MakeSynthSeqConfig(BabyAITaskConfig config) {
     360                 :          21 :   config.locations = true;
     361                 :          21 :   config.unblocking = true;
     362                 :          21 :   config.implicit_unlock = false;
     363                 :          21 :   return config;
     364                 :             : }
     365                 :             : 
     366                 :          21 : BabyAITaskConfig MakeMiniBossConfig(BabyAITaskConfig config) {
     367                 :          21 :   config.num_cols = 2;
     368                 :          21 :   config.num_rows = 2;
     369                 :          21 :   config.room_size = 5;
     370                 :          21 :   config.num_dists = 7;
     371                 :          21 :   config.locked_room_prob = 0.25f;
     372                 :          21 :   return config;
     373                 :             : }
     374                 :             : 
     375                 :          13 : BabyAITaskConfig MakeBossNoUnlockConfig(BabyAITaskConfig config) {
     376                 :          13 :   config.locked_room_prob = 0.0f;
     377                 :          13 :   config.implicit_unlock = false;
     378                 :          13 :   return config;
     379                 :             : }
     380                 :             : 
     381                 :             : }  // namespace minigrid
        

Generated by: LCOV version 2.0-1