LCOV - code coverage report
Current view: top level - envpool/atari - atari_pretrain_test.py (source / functions) Coverage Total Hit
Test: EnvPool coverage report Lines: 95.1 % 61 58
Test Date: 2026-04-07 08:10:29 Functions: - 0 0
Legend: Lines: hit not hit | Branches: + taken - not taken # not executed Branches: - 0 0

             Branch data     Line data    Source code
       1                 :             : # Copyright 2021 Garena Online Private Limited
       2                 :             : #
       3                 :             : # Licensed under the Apache License, Version 2.0 (the "License");
       4                 :             : # you may not use this file except in compliance with the License.
       5                 :             : # You may obtain a copy of the License at
       6                 :             : #
       7                 :             : #      http://www.apache.org/licenses/LICENSE-2.0
       8                 :             : #
       9                 :             : # Unless required by applicable law or agreed to in writing, software
      10                 :             : # distributed under the License is distributed on an "AS IS" BASIS,
      11                 :             : # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      12                 :             : # See the License for the specific language governing permissions and
      13                 :             : # limitations under the License.
      14                 :           1 : """Test EnvPool by well-trained RL agents."""
      15                 :             : 
      16                 :           1 : import os
      17                 :           1 : import sys
      18                 :           1 : from typing import Any, cast
      19                 :             : 
      20                 :           1 : import numpy as np
      21                 :           1 : import torch
      22                 :           1 : from absl import logging
      23                 :           1 : from absl.testing import absltest
      24                 :           1 : from tianshou.data import Batch
      25                 :           1 : from tianshou.policy import QRDQNPolicy
      26                 :             : 
      27                 :           1 : import envpool.atari.registration  # noqa: F401
      28                 :           1 : from envpool.atari.atari_network import QRDQN
      29                 :           1 : from envpool.registration import make_gym
      30                 :             : 
      31                 :             : # try:
      32                 :             : #   import cv2
      33                 :             : # except ImportError:
      34                 :             : #   cv2 = None
      35                 :             : 
      36                 :             : 
      37                 :           1 : class _AtariPretrainTest(absltest.TestCase):
      38                 :           1 :     def eval_qrdqn(
      39                 :             :         self,
      40                 :             :         task: str,
      41                 :             :         resume_path: str,
      42                 :             :         num_envs: int = 10,
      43                 :             :         seed: int = 0,
      44                 :             :         target_reward: float = 0.0,
      45                 :             :         delta: float | None = None,
      46                 :             :     ) -> None:
      47                 :           1 :         env = make_gym(task.capitalize() + "-v5", num_envs=num_envs, seed=seed)
      48                 :           1 :         state_shape = env.observation_space.shape
      49                 :           1 :         action_shape = env.action_space.n
      50                 :           1 :         device = "cuda" if torch.cuda.is_available() else "cpu"
      51                 :           1 :         np.random.seed(seed)
      52                 :           1 :         torch.manual_seed(seed)
      53                 :           1 :         logging.info(state_shape)
      54                 :           1 :         net = QRDQN(*state_shape, action_shape, 200, device)  # type: ignore
      55                 :           1 :         optim = torch.optim.Adam(net.parameters(), lr=1e-4)
      56                 :           1 :         policy = QRDQNPolicy(
      57                 :             :             net, optim, 0.99, 200, 3, target_update_freq=500
      58                 :             :         ).to(device)
      59                 :           1 :         policy.load_state_dict(torch.load(resume_path, map_location=device))
      60                 :           1 :         policy.eval()
      61                 :           1 :         ids = cast(Any, np.arange(num_envs))
      62                 :           1 :         reward = np.zeros(num_envs)
      63                 :           1 :         obs, _ = env.reset()
      64            [ + ]:           1 :         for _ in range(25000):
      65            [ + ]:           1 :             if np.random.rand() < 5e-3:
      66            [ + ]:           1 :                 act = np.random.randint(action_shape, size=len(ids))
      67                 :             :             else:
      68            [ + ]:           1 :                 act = policy(Batch(obs=obs, info={})).act
      69                 :           1 :             obs, rew, terminated, truncated, info = env.step(act, ids)
      70                 :           1 :             done = np.logical_or(terminated, truncated)
      71                 :           1 :             ids = cast(Any, np.asarray(info["env_id"]))
      72                 :           1 :             reward[ids] += rew
      73                 :           1 :             obs = obs[~done]
      74                 :           1 :             ids = ids[~done]
      75                 :           1 :             if len(ids) == 0:
      76            [ + ]:           1 :                 break
      77                 :             :             # if cv2 is not None:
      78                 :             :             #   obs_all = np.zeros((84, 84 * num_envs, 3), np.uint8)
      79                 :             :             #   for i, j in enumerate(ids):
      80                 :             :             #     obs_all[:, 84 * j:84 * (j + 1)] = obs[i, 1:].transpose(1, 2, 0)
      81                 :             :             #   cv2.imwrite(f"/tmp/{task}-{t}.png", obs_all)
      82                 :             : 
      83            [ + ]:           1 :         rew = reward.mean()
      84                 :           1 :         logging.info(f"Mean reward of {task}: {rew}")
      85                 :           1 :         if delta is None:
      86            [ + ]:           1 :             self.assertAlmostEqual(rew, target_reward)
      87                 :             :         else:
      88            [ # ]:           0 :             self.assertAlmostEqual(rew, target_reward, delta=delta)
      89                 :             : 
      90                 :           1 :     def test_pong(self) -> None:
      91                 :           1 :         model_path = os.path.join("envpool", "atari", "policy-pong.pth")
      92                 :           1 :         self.assertTrue(os.path.exists(model_path))
      93                 :           1 :         if sys.platform == "darwin":
      94                 :             :             # Apple Silicon torch inference lands a tick below the historical
      95                 :             :             # Linux golden, but remains stable across reruns.
      96            [ # ]:           0 :             self.eval_qrdqn("pong", model_path, target_reward=20.5, delta=0.5)
      97                 :             :         else:
      98            [ + ]:           1 :             self.eval_qrdqn("pong", model_path, target_reward=20.6)
      99                 :             : 
     100                 :           1 :     def test_breakout(self) -> None:
     101                 :           1 :         model_path = os.path.join("envpool", "atari", "policy-breakout.pth")
     102                 :           1 :         self.assertTrue(os.path.exists(model_path))
     103                 :           1 :         if sys.platform == "darwin":
     104            [ # ]:           0 :             self.eval_qrdqn(
     105                 :             :                 "breakout",
     106                 :             :                 model_path,
     107                 :             :                 target_reward=365.1,
     108                 :             :                 delta=5.0,
     109                 :             :             )
     110                 :             :         else:
     111            [ + ]:           1 :             self.eval_qrdqn("breakout", model_path, target_reward=367.8)
     112                 :             : 
     113                 :             : 
     114                 :           1 : if __name__ == "__main__":
     115            [ + ]:           1 :     absltest.main()
        

Generated by: LCOV version 2.0-1