Branch data Line data Source code
1 : : /*
2 : : * Copyright 2022 Garena Online Private Limited
3 : : *
4 : : * Licensed under the Apache License, Version 2.0 (the "License");
5 : : * you may not use this file except in compliance with the License.
6 : : * You may obtain a copy of the License at
7 : : *
8 : : * http://www.apache.org/licenses/LICENSE-2.0
9 : : *
10 : : * Unless required by applicable law or agreed to in writing, software
11 : : * distributed under the License is distributed on an "AS IS" BASIS,
12 : : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 : : * See the License for the specific language governing permissions and
14 : : * limitations under the License.
15 : : */
16 : :
17 : : #ifndef ENVPOOL_BOX2D_LUNAR_LANDER_CONTINUOUS_H_
18 : : #define ENVPOOL_BOX2D_LUNAR_LANDER_CONTINUOUS_H_
19 : :
20 : : #include "envpool/box2d/lunar_lander_env.h"
21 : : #include "envpool/core/async_envpool.h"
22 : : #include "envpool/core/env.h"
23 : :
24 : : namespace box2d {
25 : :
26 : : class LunarLanderContinuousEnvFns {
27 : : public:
28 : : static decltype(auto) DefaultConfig() {
29 : : return MakeDict("reward_threshold"_.Bind(200.0));
30 : : }
31 : : template <typename Config>
32 : 16 : static decltype(auto) StateSpec(const Config& conf) {
33 : 32 : return MakeDict("obs"_.Bind(Spec<float>({8})));
34 : : }
35 : : template <typename Config>
36 : 16 : static decltype(auto) ActionSpec(const Config& conf) {
37 : 32 : return MakeDict("action"_.Bind(Spec<float>({2}, {-1.0, 1.0})));
38 : : }
39 : : };
40 : :
41 : : using LunarLanderContinuousEnvSpec = EnvSpec<LunarLanderContinuousEnvFns>;
42 : :
43 : : class LunarLanderContinuousEnv : public Env<LunarLanderContinuousEnvSpec>,
44 : : public LunarLanderBox2dEnv {
45 : : public:
46 : 52 : LunarLanderContinuousEnv(const Spec& spec, int env_id)
47 : 52 : : Env<LunarLanderContinuousEnvSpec>(spec, env_id),
48 [ + - ]: 52 : LunarLanderBox2dEnv(true, spec.config["max_episode_steps"_]) {}
49 : :
50 : 145452 : bool IsDone() override { return done_; }
51 : :
52 : 619 : void Reset() override {
53 : 619 : LunarLanderReset(&gen_);
54 : 619 : WriteState();
55 : 619 : }
56 : :
57 : 72068 : void Step(const Action& action) override {
58 : 72119 : float action0 = action["action"_][0];
59 : 72124 : float action1 = action["action"_][1];
60 : 72094 : LunarLanderStep(&gen_, 0, action0, action1);
61 : 72102 : WriteState();
62 : 72132 : }
63 : :
64 : : private:
65 : 72753 : void WriteState() {
66 : 72753 : auto state = Allocate();
67 : : state["reward"_] = reward_;
68 : : state["obs"_].Assign(obs_.data(), obs_.size());
69 : 72748 : }
70 : : };
71 : :
72 : : using LunarLanderContinuousEnvPool = AsyncEnvPool<LunarLanderContinuousEnv>;
73 : :
74 : : } // namespace box2d
75 : :
76 : : #endif // ENVPOOL_BOX2D_LUNAR_LANDER_CONTINUOUS_H_
|