Branch data Line data Source code
1 : : /*
2 : : * Copyright 2022 Garena Online Private Limited
3 : : *
4 : : * Licensed under the Apache License, Version 2.0 (the "License");
5 : : * you may not use this file except in compliance with the License.
6 : : * You may obtain a copy of the License at
7 : : *
8 : : * http://www.apache.org/licenses/LICENSE-2.0
9 : : *
10 : : * Unless required by applicable law or agreed to in writing, software
11 : : * distributed under the License is distributed on an "AS IS" BASIS,
12 : : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 : : * See the License for the specific language governing permissions and
14 : : * limitations under the License.
15 : : */
16 : :
17 : : #ifndef ENVPOOL_BOX2D_LUNAR_LANDER_DISCRETE_H_
18 : : #define ENVPOOL_BOX2D_LUNAR_LANDER_DISCRETE_H_
19 : :
20 : : #include <algorithm>
21 : : #include <array>
22 : :
23 : : #include "envpool/box2d/lunar_lander_env.h"
24 : : #include "envpool/core/async_envpool.h"
25 : : #include "envpool/core/env.h"
26 : :
27 : : namespace box2d {
28 : :
29 : : class LunarLanderDiscreteEnvFns {
30 : : public:
31 : : static decltype(auto) DefaultConfig() {
32 : : return MakeDict("reward_threshold"_.Bind(200.0));
33 : : }
34 : : template <typename Config>
35 : 26 : static decltype(auto) StateSpec(const Config& conf) {
36 : : #ifdef ENVPOOL_TEST
37 [ + - ]: 52 : return MakeDict("obs"_.Bind(Spec<float>({8})),
38 [ + - ]: 52 : "info:sky_polys"_.Bind(Spec<float>({10, 4, 2})),
39 [ + - ]: 52 : "info:lander_state"_.Bind(Spec<float>({7})),
40 [ + - ]: 52 : "info:leg_states"_.Bind(Spec<float>({2, 7})),
41 [ + - ]: 52 : "info:ground_contact"_.Bind(Spec<float>({2})),
42 [ + - ]: 52 : "info:prev_shaping"_.Bind(Spec<float>({-1})),
43 [ + - ]: 52 : "info:game_over"_.Bind(Spec<float>({-1})),
44 [ + - ]: 52 : "info:last_dispersion"_.Bind(Spec<float>({2})),
45 : 260 : "info:initial_force"_.Bind(Spec<float>({2})));
46 : : #else
47 : : return MakeDict("obs"_.Bind(Spec<float>({8})));
48 : : #endif
49 : : }
50 : : template <typename Config>
51 : 26 : static decltype(auto) ActionSpec(const Config& conf) {
52 : 52 : return MakeDict("action"_.Bind(Spec<int>({-1}, {0, 3})));
53 : : }
54 : : };
55 : :
56 : : using LunarLanderDiscreteEnvSpec = EnvSpec<LunarLanderDiscreteEnvFns>;
57 : :
58 : : class LunarLanderDiscreteEnv : public Env<LunarLanderDiscreteEnvSpec>,
59 : : public LunarLanderBox2dEnv {
60 : : public:
61 : 70 : LunarLanderDiscreteEnv(const Spec& spec, int env_id)
62 : 70 : : Env<LunarLanderDiscreteEnvSpec>(spec, env_id),
63 [ + - ]: 70 : LunarLanderBox2dEnv(false, spec.config["max_episode_steps"_]) {}
64 : :
65 : 73681 : bool IsDone() override { return done_; }
66 : :
67 : 357 : void Reset() override {
68 : 357 : LunarLanderReset(&gen_);
69 : 356 : WriteState();
70 : 357 : }
71 : :
72 : 36522 : void Step(const Action& action) override {
73 : 36522 : int act = action["action"_];
74 : 36522 : LunarLanderStep(&gen_, act, 0, 0);
75 : 36509 : WriteState();
76 : 36534 : }
77 : :
78 : : private:
79 : 36868 : void WriteState() {
80 : 36868 : auto state = Allocate();
81 : : state["reward"_] = reward_;
82 : : state["obs"_].Assign(obs_.data(), obs_.size());
83 : : #ifdef ENVPOOL_TEST
84 [ + - ]: 36858 : auto sky_poly = SkyPolyState();
85 : : state["info:sky_polys"_].Assign(sky_poly.data(), sky_poly.size());
86 [ + - ]: 36877 : auto lander_state = BodyState(lander_);
87 : : state["info:lander_state"_].Assign(lander_state.data(),
88 : : lander_state.size());
89 : : std::array<float, 14> leg_states;
90 [ + + ]: 110556 : for (int i = 0; i < 2; ++i) {
91 [ + - ]: 73713 : auto leg_state = BodyState(legs_[i]);
92 : 73693 : std::copy(leg_state.begin(), leg_state.end(),
93 : 73693 : leg_states.begin() + i * leg_state.size());
94 : : }
95 : : state["info:leg_states"_].Assign(leg_states.data(), leg_states.size());
96 : : state["info:ground_contact"_].Assign(ground_contact_.data(),
97 : : ground_contact_.size());
98 : : state["info:prev_shaping"_] = prev_shaping_;
99 [ + + ]: 36843 : state["info:game_over"_] = done_ ? 1.0f : 0.0f;
100 : : state["info:last_dispersion"_].Assign(last_dispersion_.data(),
101 : : last_dispersion_.size());
102 : : state["info:initial_force"_].Assign(initial_force_.data(),
103 : : initial_force_.size());
104 : : #endif
105 : 73732 : }
106 : : };
107 : :
108 : : using LunarLanderDiscreteEnvPool = AsyncEnvPool<LunarLanderDiscreteEnv>;
109 : :
110 : : } // namespace box2d
111 : :
112 : : #endif // ENVPOOL_BOX2D_LUNAR_LANDER_DISCRETE_H_
|