LCOV - code coverage report
Current view: top level - envpool/python - xla_template.py (source / functions) Coverage Total Hit
Test: EnvPool coverage report Lines: 100.0 % 36 36
Test Date: 2026-04-07 08:10:29 Functions: - 0 0
Legend: Lines: hit not hit | Branches: + taken - not taken # not executed Branches: 63.6 % 11 7

             Branch data     Line data    Source code
       1                 :             : # Copyright 2022 Garena Online Private Limited
       2                 :             : #
       3                 :             : # Licensed under the Apache License, Version 2.0 (the "License");
       4                 :             : # you may not use this file except in compliance with the License.
       5                 :             : # You may obtain a copy of the License at
       6                 :             : #
       7                 :             : #      http://www.apache.org/licenses/LICENSE-2.0
       8                 :             : #
       9                 :             : # Unless required by applicable law or agreed to in writing, software
      10                 :             : # distributed under the License is distributed on an "AS IS" BASIS,
      11                 :             : # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      12                 :             : # See the License for the specific language governing permissions and
      13                 :             : # limitations under the License.
      14                 :          35 : """xla template on python side."""
      15                 :             : 
      16                 :          35 : import sys
      17                 :          35 : from collections import namedtuple
      18                 :          35 : from typing import Any, Callable
      19                 :             : 
      20                 :          35 : import numpy as np
      21                 :          35 : from jax import ShapeDtypeStruct, dtypes, ffi
      22                 :             : 
      23                 :             : 
      24                 :          35 : def _normalize_specs(
      25                 :             :     specs: tuple[tuple[Any, list[int]], ...],
      26                 :             : ) -> tuple[tuple[tuple[int, ...], Any], ...]:
      27            [ + ]:           1 :     return tuple(
      28                 :             :         (tuple(shape), dtypes.canonicalize_dtype(dtype))
      29                 :             :         for dtype, shape in specs
      30                 :             :     )
      31                 :             : 
      32                 :             : 
      33                 :          35 : def _shape_dtype_struct(shape: tuple[int, ...], dtype: Any) -> ShapeDtypeStruct:
      34                 :           1 :     return ShapeDtypeStruct(shape, dtype)
      35                 :             : 
      36                 :             : 
      37                 :          35 : def _layout(shape: tuple[int, ...]) -> tuple[int, ...]:
      38                 :           1 :     return tuple(range(len(shape)))
      39                 :             : 
      40                 :             : 
      41                 :          35 : def _make_xla_function(
      42                 :             :     obj: Any,
      43                 :             :     handle: bytes,
      44                 :             :     name: str,
      45                 :             :     specs: tuple[tuple[Any, ...], tuple[Any, ...]],
      46                 :             :     capsules: tuple[Any, Any],
      47                 :             : ) -> Callable:
      48                 :           1 :     in_specs, out_specs = specs
      49                 :           1 :     in_specs = _normalize_specs(in_specs)
      50                 :           1 :     out_specs = _normalize_specs(out_specs)
      51                 :           1 :     cpu_capsule, gpu_capsule = capsules
      52                 :           1 :     call_target_name = f"{type(obj).__name__}_{id(obj)}_{name}"
      53                 :           1 :     ffi.register_ffi_target(
      54                 :             :         call_target_name,
      55                 :             :         cpu_capsule,
      56                 :             :         platform="cpu",
      57                 :             :         api_version=1,
      58                 :             :     )
      59                 :           1 :     if gpu_capsule is not None:
      60            [ # ]:           1 :         ffi.register_ffi_target(
      61                 :             :             call_target_name,
      62                 :             :             gpu_capsule,
      63                 :             :             platform="gpu",
      64                 :             :             api_version=1,
      65                 :             :         )
      66            [ # ]:           1 :     result_specs = tuple(_shape_dtype_struct(*spec) for spec in out_specs)
      67            [ # ]:           1 :     xla_func = ffi.ffi_call(
      68                 :             :         call_target_name,
      69                 :             :         result_specs if len(result_specs) > 1 else result_specs[0],
      70                 :             :         has_side_effect=True,
      71                 :             :         input_layouts=tuple(_layout(shape) for shape, _ in in_specs),
      72                 :             :         output_layouts=(
      73                 :             :             tuple(_layout(shape) for shape, _ in out_specs)
      74                 :             :             if len(out_specs) > 1
      75                 :             :             else _layout(out_specs[0][0])
      76                 :             :         ),
      77                 :             :         input_output_aliases={0: 0},
      78                 :             :     )
      79            [ # ]:           1 :     handle_value = int.from_bytes(handle, byteorder=sys.byteorder, signed=False)
      80                 :             : 
      81                 :           1 :     def call(*args: Any) -> Any:
      82                 :           1 :         return xla_func(*args, handle=handle_value)
      83                 :             : 
      84                 :           1 :     return call
      85                 :             : 
      86                 :             : 
      87                 :          35 : def make_xla(obj: Any) -> Any:
      88                 :             :     """Return callables that can be jitted in a namedtuple.
      89                 :             : 
      90                 :             :     Args:
      91                 :             :       obj: The object that has a `_xla` function.
      92                 :             :         All instances of envpool has a `_xla` function that returns
      93                 :             :         the necessary information for creating jittable send/recv functions.
      94                 :             : 
      95                 :             :     Returns:
      96                 :             :       XlaFunctions: A namedtuple, the first element is a handle
      97                 :             :         representing `obj`. The rest of the elements are jittable functions.
      98                 :             :     """
      99                 :           1 :     xla_native = obj._xla()
     100                 :           1 :     method_names = []
     101                 :           1 :     methods = []
     102                 :           1 :     for name, (handle, specs, capsules) in xla_native:
     103            [ + ]:           1 :         method_names.append(name)
     104                 :           1 :         methods.append(_make_xla_function(obj, handle, name, specs, capsules))
     105            [ + ]:           1 :     XlaFunctions = namedtuple(  # type: ignore
     106                 :             :         "XlaFunctions", ["handle", *method_names]
     107                 :             :     )
     108                 :           1 :     return XlaFunctions(  # type: ignore
     109                 :             :         np.frombuffer(handle, dtype=np.uint8), *methods
     110                 :             :     )
        

Generated by: LCOV version 2.0-1