Branch data Line data Source code
1 : : /*
2 : : * Copyright 2022 Garena Online Private Limited
3 : : *
4 : : * Licensed under the Apache License, Version 2.0 (the "License");
5 : : * you may not use this file except in compliance with the License.
6 : : * You may obtain a copy of the License at
7 : : *
8 : : * http://www.apache.org/licenses/LICENSE-2.0
9 : : *
10 : : * Unless required by applicable law or agreed to in writing, software
11 : : * distributed under the License is distributed on an "AS IS" BASIS,
12 : : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 : : * See the License for the specific language governing permissions and
14 : : * limitations under the License.
15 : : */
16 : :
17 : : #ifndef ENVPOOL_BOX2D_LUNAR_LANDER_CONTINUOUS_H_
18 : : #define ENVPOOL_BOX2D_LUNAR_LANDER_CONTINUOUS_H_
19 : :
20 : : #include <algorithm>
21 : : #include <array>
22 : :
23 : : #include "envpool/box2d/lunar_lander_env.h"
24 : : #include "envpool/core/async_envpool.h"
25 : : #include "envpool/core/env.h"
26 : :
27 : : namespace box2d {
28 : :
29 : : class LunarLanderContinuousEnvFns {
30 : : public:
31 : : static decltype(auto) DefaultConfig() {
32 : : return MakeDict("reward_threshold"_.Bind(200.0));
33 : : }
34 : : template <typename Config>
35 : 23 : static decltype(auto) StateSpec(const Config& conf) {
36 : : #ifdef ENVPOOL_TEST
37 [ + - ]: 46 : return MakeDict("obs"_.Bind(Spec<float>({8})),
38 [ + - ]: 46 : "info:sky_polys"_.Bind(Spec<float>({10, 4, 2})),
39 [ + - ]: 46 : "info:lander_state"_.Bind(Spec<float>({7})),
40 [ + - ]: 46 : "info:leg_states"_.Bind(Spec<float>({2, 7})),
41 [ + - ]: 46 : "info:ground_contact"_.Bind(Spec<float>({2})),
42 [ + - ]: 46 : "info:prev_shaping"_.Bind(Spec<float>({-1})),
43 [ + - ]: 46 : "info:game_over"_.Bind(Spec<float>({-1})),
44 [ + - ]: 46 : "info:last_dispersion"_.Bind(Spec<float>({2})),
45 : 230 : "info:initial_force"_.Bind(Spec<float>({2})));
46 : : #else
47 : : return MakeDict("obs"_.Bind(Spec<float>({8})));
48 : : #endif
49 : : }
50 : : template <typename Config>
51 : 23 : static decltype(auto) ActionSpec(const Config& conf) {
52 : 46 : return MakeDict("action"_.Bind(Spec<float>({2}, {-1.0, 1.0})));
53 : : }
54 : : };
55 : :
56 : : using LunarLanderContinuousEnvSpec = EnvSpec<LunarLanderContinuousEnvFns>;
57 : :
58 : : class LunarLanderContinuousEnv : public Env<LunarLanderContinuousEnvSpec>,
59 : : public LunarLanderBox2dEnv {
60 : : public:
61 : 66 : LunarLanderContinuousEnv(const Spec& spec, int env_id)
62 : 66 : : Env<LunarLanderContinuousEnvSpec>(spec, env_id),
63 [ + - ]: 66 : LunarLanderBox2dEnv(true, spec.config["max_episode_steps"_]) {}
64 : :
65 : 72123 : bool IsDone() override { return done_; }
66 : :
67 : 309 : void Reset() override {
68 : 309 : LunarLanderReset(&gen_);
69 : 309 : WriteState();
70 : 309 : }
71 : :
72 : 35784 : void Step(const Action& action) override {
73 : 35784 : float action0 = action["action"_][0];
74 : 35792 : float action1 = action["action"_][1];
75 : 35776 : LunarLanderStep(&gen_, 0, action0, action1);
76 : 35778 : WriteState();
77 : 35796 : }
78 : :
79 : : private:
80 : 36092 : void WriteState() {
81 : 36092 : auto state = Allocate();
82 : : state["reward"_] = reward_;
83 : : state["obs"_].Assign(obs_.data(), obs_.size());
84 : : #ifdef ENVPOOL_TEST
85 [ + - ]: 36061 : auto sky_poly = SkyPolyState();
86 : : state["info:sky_polys"_].Assign(sky_poly.data(), sky_poly.size());
87 [ + - ]: 36105 : auto lander_state = BodyState(lander_);
88 : : state["info:lander_state"_].Assign(lander_state.data(),
89 : : lander_state.size());
90 : : std::array<float, 14> leg_states;
91 [ + + ]: 108189 : for (int i = 0; i < 2; ++i) {
92 [ + - ]: 72139 : auto leg_state = BodyState(legs_[i]);
93 : 72098 : std::copy(leg_state.begin(), leg_state.end(),
94 : 72098 : leg_states.begin() + i * leg_state.size());
95 : : }
96 : : state["info:leg_states"_].Assign(leg_states.data(), leg_states.size());
97 : : state["info:ground_contact"_].Assign(ground_contact_.data(),
98 : : ground_contact_.size());
99 : : state["info:prev_shaping"_] = prev_shaping_;
100 [ + + ]: 36050 : state["info:game_over"_] = done_ ? 1.0f : 0.0f;
101 : : state["info:last_dispersion"_].Assign(last_dispersion_.data(),
102 : : last_dispersion_.size());
103 : : state["info:initial_force"_].Assign(initial_force_.data(),
104 : : initial_force_.size());
105 : : #endif
106 : 72154 : }
107 : : };
108 : :
109 : : using LunarLanderContinuousEnvPool = AsyncEnvPool<LunarLanderContinuousEnv>;
110 : :
111 : : } // namespace box2d
112 : :
113 : : #endif // ENVPOOL_BOX2D_LUNAR_LANDER_CONTINUOUS_H_
|