Branch data Line data Source code
1 : : # Copyright 2022 Garena Online Private Limited
2 : : #
3 : : # Licensed under the Apache License, Version 2.0 (the "License");
4 : : # you may not use this file except in compliance with the License.
5 : : # You may obtain a copy of the License at
6 : : #
7 : : # http://www.apache.org/licenses/LICENSE-2.0
8 : : #
9 : : # Unless required by applicable law or agreed to in writing, software
10 : : # distributed under the License is distributed on an "AS IS" BASIS,
11 : : # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : : # See the License for the specific language governing permissions and
13 : : # limitations under the License.
14 : 35 : """xla template on python side."""
15 : :
16 : 35 : import sys
17 : 35 : from collections import namedtuple
18 : 35 : from typing import Any, Callable
19 : :
20 : 35 : import numpy as np
21 : 35 : from jax import ShapeDtypeStruct, dtypes, ffi
22 : :
23 : :
24 : 35 : def _normalize_specs(
25 : : specs: tuple[tuple[Any, list[int]], ...],
26 : : ) -> tuple[tuple[tuple[int, ...], Any], ...]:
27 [ + ]: 1 : return tuple(
28 : : (tuple(shape), dtypes.canonicalize_dtype(dtype))
29 : : for dtype, shape in specs
30 : : )
31 : :
32 : :
33 : 35 : def _shape_dtype_struct(shape: tuple[int, ...], dtype: Any) -> ShapeDtypeStruct:
34 : 1 : return ShapeDtypeStruct(shape, dtype)
35 : :
36 : :
37 : 35 : def _layout(shape: tuple[int, ...]) -> tuple[int, ...]:
38 : 1 : return tuple(range(len(shape)))
39 : :
40 : :
41 : 35 : def _make_xla_function(
42 : : obj: Any,
43 : : handle: bytes,
44 : : name: str,
45 : : specs: tuple[tuple[Any, ...], tuple[Any, ...]],
46 : : capsules: tuple[Any, Any],
47 : : ) -> Callable:
48 : 1 : in_specs, out_specs = specs
49 : 1 : in_specs = _normalize_specs(in_specs)
50 : 1 : out_specs = _normalize_specs(out_specs)
51 : 1 : cpu_capsule, gpu_capsule = capsules
52 : 1 : call_target_name = f"{type(obj).__name__}_{id(obj)}_{name}"
53 : 1 : ffi.register_ffi_target(
54 : : call_target_name,
55 : : cpu_capsule,
56 : : platform="cpu",
57 : : api_version=1,
58 : : )
59 : 1 : if gpu_capsule is not None:
60 [ # ]: 1 : ffi.register_ffi_target(
61 : : call_target_name,
62 : : gpu_capsule,
63 : : platform="gpu",
64 : : api_version=1,
65 : : )
66 [ # ]: 1 : result_specs = tuple(_shape_dtype_struct(*spec) for spec in out_specs)
67 [ # ]: 1 : xla_func = ffi.ffi_call(
68 : : call_target_name,
69 : : result_specs if len(result_specs) > 1 else result_specs[0],
70 : : has_side_effect=True,
71 : : input_layouts=tuple(_layout(shape) for shape, _ in in_specs),
72 : : output_layouts=(
73 : : tuple(_layout(shape) for shape, _ in out_specs)
74 : : if len(out_specs) > 1
75 : : else _layout(out_specs[0][0])
76 : : ),
77 : : input_output_aliases={0: 0},
78 : : )
79 [ # ]: 1 : handle_value = int.from_bytes(handle, byteorder=sys.byteorder, signed=False)
80 : :
81 : 1 : def call(*args: Any) -> Any:
82 : 1 : return xla_func(*args, handle=handle_value)
83 : :
84 : 1 : return call
85 : :
86 : :
87 : 35 : def make_xla(obj: Any) -> Any:
88 : : """Return callables that can be jitted in a namedtuple.
89 : :
90 : : Args:
91 : : obj: The object that has a `_xla` function.
92 : : All instances of envpool has a `_xla` function that returns
93 : : the necessary information for creating jittable send/recv functions.
94 : :
95 : : Returns:
96 : : XlaFunctions: A namedtuple, the first element is a handle
97 : : representing `obj`. The rest of the elements are jittable functions.
98 : : """
99 : 1 : xla_native = obj._xla()
100 : 1 : method_names = []
101 : 1 : methods = []
102 : 1 : for name, (handle, specs, capsules) in xla_native:
103 [ + ]: 1 : method_names.append(name)
104 : 1 : methods.append(_make_xla_function(obj, handle, name, specs, capsules))
105 [ + ]: 1 : XlaFunctions = namedtuple( # type: ignore
106 : : "XlaFunctions", ["handle", *method_names]
107 : : )
108 : 1 : return XlaFunctions( # type: ignore
109 : : np.frombuffer(handle, dtype=np.uint8), *methods
110 : : )
|