Branch data Line data Source code
1 : : /*
2 : : * Copyright 2022 Garena Online Private Limited
3 : : *
4 : : * Licensed under the Apache License, Version 2.0 (the "License");
5 : : * you may not use this file except in compliance with the License.
6 : : * You may obtain a copy of the License at
7 : : *
8 : : * http://www.apache.org/licenses/LICENSE-2.0
9 : : *
10 : : * Unless required by applicable law or agreed to in writing, software
11 : : * distributed under the License is distributed on an "AS IS" BASIS,
12 : : * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 : : * See the License for the specific language governing permissions and
14 : : * limitations under the License.
15 : : */
16 : :
17 : : #ifndef ENVPOOL_BOX2D_LUNAR_LANDER_DISCRETE_H_
18 : : #define ENVPOOL_BOX2D_LUNAR_LANDER_DISCRETE_H_
19 : :
20 : : #include "envpool/box2d/lunar_lander_env.h"
21 : : #include "envpool/core/async_envpool.h"
22 : : #include "envpool/core/env.h"
23 : :
24 : : namespace box2d {
25 : :
26 : : class LunarLanderDiscreteEnvFns {
27 : : public:
28 : : static decltype(auto) DefaultConfig() {
29 : : return MakeDict("reward_threshold"_.Bind(200.0));
30 : : }
31 : : template <typename Config>
32 : 19 : static decltype(auto) StateSpec(const Config& conf) {
33 : 38 : return MakeDict("obs"_.Bind(Spec<float>({8})));
34 : : }
35 : : template <typename Config>
36 : 19 : static decltype(auto) ActionSpec(const Config& conf) {
37 : 38 : return MakeDict("action"_.Bind(Spec<int>({-1}, {0, 3})));
38 : : }
39 : : };
40 : :
41 : : using LunarLanderDiscreteEnvSpec = EnvSpec<LunarLanderDiscreteEnvFns>;
42 : :
43 : : class LunarLanderDiscreteEnv : public Env<LunarLanderDiscreteEnvSpec>,
44 : : public LunarLanderBox2dEnv {
45 : : public:
46 : 56 : LunarLanderDiscreteEnv(const Spec& spec, int env_id)
47 : 56 : : Env<LunarLanderDiscreteEnvSpec>(spec, env_id),
48 [ + - ]: 56 : LunarLanderBox2dEnv(false, spec.config["max_episode_steps"_]) {}
49 : :
50 : 149194 : bool IsDone() override { return done_; }
51 : :
52 : 718 : void Reset() override {
53 : 718 : LunarLanderReset(&gen_);
54 : 719 : WriteState();
55 : 719 : }
56 : :
57 : 73830 : void Step(const Action& action) override {
58 : 73830 : int act = action["action"_];
59 : 73830 : LunarLanderStep(&gen_, act, 0, 0);
60 : 73855 : WriteState();
61 : 73881 : }
62 : :
63 : : private:
64 : 74584 : void WriteState() {
65 : 74584 : auto state = Allocate();
66 : : state["reward"_] = reward_;
67 : : state["obs"_].Assign(obs_.data(), obs_.size());
68 : 74606 : }
69 : : };
70 : :
71 : : using LunarLanderDiscreteEnvPool = AsyncEnvPool<LunarLanderDiscreteEnv>;
72 : :
73 : : } // namespace box2d
74 : :
75 : : #endif // ENVPOOL_BOX2D_LUNAR_LANDER_DISCRETE_H_
|