Branch data Line data Source code
1 : : /*
2 : : * Copyright 2022 Garena Online Private Limited
3 : : *
4 : : * Licensed under the Apache License, Version 2.0 (the "License");
5 : : * you may not use this file except in compliance with the License.
6 : : * You may obtain a copy of the License at
7 : : *
8 : : * http://www.apache.org/licenses/LICENSE-2.0
9 : : *
10 : : * Unless required by applicable law or agreed to in writing, software
11 : : * distributed under the License is distributed on an "AS IS" BASIS,
12 : : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 : : * See the License for the specific language governing permissions and
14 : : * limitations under the License.
15 : : */
16 : :
17 : : #ifndef ENVPOOL_BOX2D_BIPEDAL_WALKER_H_
18 : : #define ENVPOOL_BOX2D_BIPEDAL_WALKER_H_
19 : :
20 : : #include <algorithm>
21 : : #include <array>
22 : : #include <vector>
23 : :
24 : : #include "envpool/box2d/bipedal_walker_env.h"
25 : : #include "envpool/core/async_envpool.h"
26 : : #include "envpool/core/env.h"
27 : :
28 : : namespace box2d {
29 : :
30 : : class BipedalWalkerEnvFns {
31 : : public:
32 : : static decltype(auto) DefaultConfig() {
33 : : return MakeDict("reward_threshold"_.Bind(300.0), "hardcore"_.Bind(false));
34 : : }
35 : : template <typename Config>
36 : 30 : static decltype(auto) StateSpec(const Config& conf) {
37 : : #ifdef ENVPOOL_TEST
38 [ + - ]: 60 : return MakeDict("obs"_.Bind(Spec<float>({24})),
39 [ + - ]: 60 : "info:scroll"_.Bind(Spec<float>({-1})),
40 [ + - ]: 60 : "info:path2"_.Bind(Spec<float>({199, 2, 2})),
41 : : "info:path4"_.Bind(
42 [ + - + - ]: 60 : Spec<Container<float>>({-1}, Spec<float>({-1, 4, 2}))),
43 [ + - ]: 60 : "info:path5"_.Bind(Spec<float>({1, 5, 2})),
44 [ + - ]: 60 : "info:cloud_poly"_.Bind(Spec<float>({10, 5, 2})),
45 [ + - ]: 60 : "info:hull_state"_.Bind(Spec<float>({7})),
46 [ + - ]: 60 : "info:leg_states"_.Bind(Spec<float>({4, 7})),
47 [ + - ]: 60 : "info:joint_states"_.Bind(Spec<float>({4, 4})),
48 [ + - ]: 60 : "info:ground_contact"_.Bind(Spec<float>({4})),
49 [ + - ]: 60 : "info:prev_shaping"_.Bind(Spec<float>({-1})),
50 [ + - ]: 60 : "info:game_over"_.Bind(Spec<float>({-1})),
51 [ + - ]: 450 : "info:initial_force"_.Bind(Spec<float>({-1})));
52 : : #else
53 : : return MakeDict("obs"_.Bind(Spec<float>({24})));
54 : : #endif
55 : : }
56 : : template <typename Config>
57 : 30 : static decltype(auto) ActionSpec(const Config& conf) {
58 : 60 : return MakeDict("action"_.Bind(Spec<float>({4}, {-1.0, 1.0})));
59 : : }
60 : : };
61 : :
62 : : using BipedalWalkerEnvSpec = EnvSpec<BipedalWalkerEnvFns>;
63 : :
64 : : class BipedalWalkerEnv : public Env<BipedalWalkerEnvSpec>,
65 : : public BipedalWalkerBox2dEnv {
66 : : public:
67 : 243 : BipedalWalkerEnv(const Spec& spec, int env_id)
68 : 243 : : Env<BipedalWalkerEnvSpec>(spec, env_id),
69 : : BipedalWalkerBox2dEnv(spec.config["hardcore"_],
70 [ + - ]: 243 : spec.config["max_episode_steps"_]) {}
71 : :
72 : 343363 : bool IsDone() override { return done_; }
73 : :
74 : 294 : void Reset() override {
75 : 294 : BipedalWalkerReset(&gen_);
76 : 296 : WriteState();
77 : 296 : }
78 : :
79 : 171134 : void Step(const Action& action) override {
80 [ + - + - : 342576 : BipedalWalkerStep(&gen_, action["action"_][0], action["action"_][1],
+ - + - ]
81 : : action["action"_][2], action["action"_][3]);
82 : 171463 : WriteState();
83 : 171468 : }
84 : :
85 : : private:
86 : 171758 : void WriteState() {
87 : 171758 : auto state = Allocate();
88 : : state["reward"_] = reward_;
89 : : state["obs"_].Assign(obs_.data(), obs_.size());
90 : : #ifdef ENVPOOL_TEST
91 : : state["info:scroll"_] = scroll_;
92 : : state["info:path2"_].Assign(terrain_edge_path2_.data(),
93 : : terrain_edge_path2_.size());
94 : : state["info:path5"_].Assign(hull_path5_.data(), hull_path5_.size());
95 [ + - ]: 171088 : auto cloud_poly = CloudPolyState();
96 : : state["info:cloud_poly"_].Assign(cloud_poly.data(), cloud_poly.size());
97 [ + - ]: 171717 : auto hull_state = BodyState(hull_);
98 : : state["info:hull_state"_].Assign(hull_state.data(), hull_state.size());
99 : : std::array<float, 28> leg_states;
100 [ + + ]: 856660 : for (int i = 0; i < 4; ++i) {
101 [ + - ]: 684964 : auto leg_state = BodyState(legs_[i]);
102 : 685048 : std::copy(leg_state.begin(), leg_state.end(),
103 : 685048 : leg_states.begin() + i * leg_state.size());
104 : : }
105 : : state["info:leg_states"_].Assign(leg_states.data(), leg_states.size());
106 : : std::array<float, 16> joint_states;
107 [ + + ]: 855700 : for (int i = 0; i < 4; ++i) {
108 [ + - ]: 684783 : joint_states[i * 4 + 0] = joints_[i]->GetJointAngle();
109 [ + - ]: 683863 : joint_states[i * 4 + 1] = joints_[i]->GetJointSpeed();
110 : 684004 : joint_states[i * 4 + 2] = joints_[i]->GetMotorSpeed();
111 : 684004 : joint_states[i * 4 + 3] = joints_[i]->GetMaxMotorTorque();
112 : : }
113 : : state["info:joint_states"_].Assign(joint_states.data(),
114 : : joint_states.size());
115 : : state["info:ground_contact"_].Assign(ground_contact_.data(),
116 : : ground_contact_.size());
117 : : state["info:prev_shaping"_] = prev_shaping_;
118 [ + + ]: 170917 : state["info:game_over"_] = done_ ? 1.0f : 0.0f;
119 : : state["info:initial_force"_] = initial_force_;
120 : :
121 : : Container<float>& path4 = state["info:path4"_];
122 : : std::vector<float> path4_data;
123 [ + - ]: 170917 : path4_data.reserve(terrain_poly_path4_.size() + leg_path4_.size());
124 [ + - ]: 171518 : path4_data.insert(path4_data.end(), terrain_poly_path4_.begin(),
125 : : terrain_poly_path4_.end());
126 : 171344 : path4_data.insert(path4_data.end(), leg_path4_.begin(), leg_path4_.end());
127 : 343173 : auto* array = new TArray<float>(::Spec<float>(
128 [ + - + - ]: 342651 : std::vector<int>{static_cast<int>(path4_data.size()) / 8, 4, 2}));
129 : : array->Assign(path4_data.data(), path4_data.size());
130 : : path4.reset(array);
131 : : #endif
132 : 342877 : }
133 : : };
134 : :
135 : : using BipedalWalkerEnvPool = AsyncEnvPool<BipedalWalkerEnv>;
136 : :
137 : : } // namespace box2d
138 : :
139 : : #endif // ENVPOOL_BOX2D_BIPEDAL_WALKER_H_
|