Skip to content

Commit 2948a80

Browse files
author
jax authors
committed
Merge pull request #20217 from froystig:jex-primitives
PiperOrigin-RevId: 624323399
2 parents 0017320 + 65034b3 commit 2948a80

File tree

3 files changed

+240
-3
lines changed

3 files changed

+240
-3
lines changed

jax/extend/BUILD

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
load(
1616
"//jaxlib:jax.bzl",
17+
"py_library_providing_imports_info",
1718
"pytype_strict_library",
1819
)
1920

@@ -34,10 +35,15 @@ pytype_strict_library(
3435
],
3536
)
3637

37-
pytype_strict_library(
38+
py_library_providing_imports_info(
3839
name = "core",
39-
srcs = ["core.py"],
40-
deps = ["//jax:abstract_arrays"],
40+
srcs = glob(["core/**/*.py"]),
41+
deps = [
42+
"//jax",
43+
"//jax:abstract_arrays",
44+
"//jax:ad_util",
45+
"//jax:core",
46+
],
4147
)
4248

4349
pytype_strict_library(

jax/extend/core.py renamed to jax/extend/core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,5 @@
1818
from jax._src.abstract_arrays import (
1919
array_types as array_types
2020
)
21+
22+
from . import primitives as primitives

jax/extend/core/primitives.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Note: import <name> as <name> is required for names to be exported.
16+
# See PEP 484 & https://github.com/google/jax/issues/7570
17+
18+
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
19+
20+
from jax._src.core import (
21+
call_p as call_p,
22+
closed_call_p as closed_call_p
23+
)
24+
25+
from jax._src.custom_derivatives import (
26+
custom_jvp_call_p as custom_jvp_call_p,
27+
custom_jvp_call_jaxpr_p as custom_jvp_call_jaxpr_p,
28+
custom_vjp_call_p as custom_vjp_call_p,
29+
custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p,
30+
)
31+
32+
from jax._src.dispatch import device_put_p as device_put_p
33+
34+
from jax._src.interpreters.ad import (
35+
add_jaxvals_p as add_jaxvals_p,
36+
custom_lin_p as custom_lin_p,
37+
zeros_like_p as zeros_like_p,
38+
)
39+
40+
from jax._src.interpreters.pxla import xla_pmap_p as xla_pmap_p
41+
42+
from jax._src.lax.lax import (
43+
abs_p as abs_p,
44+
acos_p as acos_p,
45+
acosh_p as acosh_p,
46+
add_p as add_p,
47+
after_all_p as after_all_p,
48+
and_p as and_p,
49+
argmax_p as argmax_p,
50+
argmin_p as argmin_p,
51+
asin_p as asin_p,
52+
asinh_p as asinh_p,
53+
atan_p as atan_p,
54+
atan2_p as atan2_p,
55+
atanh_p as atanh_p,
56+
bitcast_convert_type_p as bitcast_convert_type_p,
57+
broadcast_in_dim_p as broadcast_in_dim_p,
58+
cbrt_p as cbrt_p,
59+
ceil_p as ceil_p,
60+
clamp_p as clamp_p,
61+
clz_p as clz_p,
62+
complex_p as complex_p,
63+
concatenate_p as concatenate_p,
64+
conj_p as conj_p,
65+
convert_element_type_p as convert_element_type_p,
66+
copy_p as copy_p,
67+
cos_p as cos_p,
68+
cosh_p as cosh_p,
69+
create_token_p as create_token_p,
70+
div_p as div_p,
71+
dot_general_p as dot_general_p,
72+
eq_p as eq_p,
73+
eq_to_p as eq_to_p,
74+
exp_p as exp_p,
75+
exp2_p as exp2_p,
76+
expm1_p as expm1_p,
77+
floor_p as floor_p,
78+
ge_p as ge_p,
79+
gt_p as gt_p,
80+
imag_p as imag_p,
81+
infeed_p as infeed_p,
82+
integer_pow_p as integer_pow_p,
83+
iota_p as iota_p,
84+
is_finite_p as is_finite_p,
85+
le_p as le_p,
86+
le_to_p as le_to_p,
87+
log1p_p as log1p_p,
88+
log_p as log_p,
89+
logistic_p as logistic_p,
90+
lt_p as lt_p,
91+
lt_to_p as lt_to_p,
92+
max_p as max_p,
93+
min_p as min_p,
94+
mul_p as mul_p,
95+
ne_p as ne_p,
96+
neg_p as neg_p,
97+
nextafter_p as nextafter_p,
98+
not_p as not_p,
99+
or_p as or_p,
100+
outfeed_p as outfeed_p,
101+
pad_p as pad_p,
102+
population_count_p as population_count_p,
103+
pow_p as pow_p,
104+
real_p as real_p,
105+
reduce_and_p as reduce_and_p,
106+
reduce_max_p as reduce_max_p,
107+
reduce_min_p as reduce_min_p,
108+
reduce_or_p as reduce_or_p,
109+
reduce_p as reduce_p,
110+
reduce_precision_p as reduce_precision_p,
111+
reduce_prod_p as reduce_prod_p,
112+
reduce_sum_p as reduce_sum_p,
113+
reduce_xor_p as reduce_xor_p,
114+
rem_p as rem_p,
115+
reshape_p as reshape_p,
116+
rev_p as rev_p,
117+
rng_bit_generator_p as rng_bit_generator_p,
118+
rng_uniform_p as rng_uniform_p,
119+
round_p as round_p,
120+
rsqrt_p as rsqrt_p,
121+
select_n_p as select_n_p,
122+
shift_left_p as shift_left_p,
123+
shift_right_arithmetic_p as shift_right_arithmetic_p,
124+
shift_right_logical_p as shift_right_logical_p,
125+
sign_p as sign_p,
126+
sin_p as sin_p,
127+
sinh_p as sinh_p,
128+
sort_p as sort_p,
129+
sqrt_p as sqrt_p,
130+
squeeze_p as squeeze_p,
131+
sub_p as sub_p,
132+
tan_p as tan_p,
133+
tanh_p as tanh_p,
134+
top_k_p as top_k_p,
135+
transpose_p as transpose_p,
136+
xor_p as xor_p,
137+
)
138+
139+
from jax._src.lax.special import (
140+
bessel_i0e_p as bessel_i0e_p,
141+
bessel_i1e_p as bessel_i1e_p,
142+
digamma_p as digamma_p,
143+
erfc_p as erfc_p,
144+
erf_inv_p as erf_inv_p,
145+
erf_p as erf_p,
146+
igammac_p as igammac_p,
147+
igamma_grad_a_p as igamma_grad_a_p,
148+
igamma_p as igamma_p,
149+
lgamma_p as lgamma_p,
150+
polygamma_p as polygamma_p,
151+
random_gamma_grad_p as random_gamma_grad_p,
152+
regularized_incomplete_beta_p as regularized_incomplete_beta_p,
153+
zeta_p as zeta_p,
154+
)
155+
156+
from jax._src.lax.slicing import (
157+
dynamic_slice_p as dynamic_slice_p,
158+
dynamic_update_slice_p as dynamic_update_slice_p,
159+
gather_p as gather_p,
160+
scatter_add_p as scatter_add_p,
161+
scatter_max_p as scatter_max_p,
162+
scatter_min_p as scatter_min_p,
163+
scatter_mul_p as scatter_mul_p,
164+
scatter_p as scatter_p,
165+
slice_p as slice_p,
166+
)
167+
168+
from jax._src.lax.convolution import (
169+
conv_general_dilated_p as conv_general_dilated_p,
170+
)
171+
172+
from jax._src.lax.windowed_reductions import (
173+
reduce_window_max_p as reduce_window_max_p,
174+
reduce_window_min_p as reduce_window_min_p,
175+
reduce_window_p as reduce_window_p,
176+
reduce_window_sum_p as reduce_window_sum_p,
177+
select_and_gather_add_p as select_and_gather_add_p,
178+
select_and_scatter_p as select_and_scatter_p,
179+
select_and_scatter_add_p as select_and_scatter_add_p,
180+
)
181+
182+
from jax._src.lax.control_flow import (
183+
cond_p as cond_p,
184+
cumlogsumexp_p as cumlogsumexp_p,
185+
cummax_p as cummax_p,
186+
cummin_p as cummin_p,
187+
cumprod_p as cumprod_p,
188+
cumsum_p as cumsum_p,
189+
linear_solve_p as linear_solve_p,
190+
scan_p as scan_p,
191+
while_p as while_p,
192+
)
193+
194+
from jax._src.lax.fft import (
195+
fft_p as fft_p,
196+
)
197+
198+
from jax._src.lax.parallel import (
199+
all_gather_p as all_gather_p,
200+
all_to_all_p as all_to_all_p,
201+
axis_index_p as axis_index_p,
202+
pmax_p as pmax_p,
203+
pmin_p as pmin_p,
204+
ppermute_p as ppermute_p,
205+
psum_p as psum_p,
206+
)
207+
208+
from jax._src.lax.ann import (
209+
approx_top_k_p as approx_top_k_p
210+
)
211+
212+
from jax._src.lax.linalg import (
213+
cholesky_p as cholesky_p,
214+
eig_p as eig_p,
215+
eigh_p as eigh_p,
216+
hessenberg_p as hessenberg_p,
217+
lu_p as lu_p,
218+
householder_product_p as householder_product_p,
219+
qr_p as qr_p,
220+
svd_p as svd_p,
221+
triangular_solve_p as triangular_solve_p,
222+
tridiagonal_p as tridiagonal_p,
223+
tridiagonal_solve_p as tridiagonal_solve_p,
224+
schur_p as schur_p,
225+
)
226+
227+
from jax._src.pjit import sharding_constraint_p as sharding_constraint_p
228+
from jax._src.prng import threefry2x32_p as threefry2x32_p
229+
from jax._src.random import random_gamma_p as random_gamma_p

0 commit comments

Comments
 (0)