|
| 1 | +# Copyright 2024 The JAX Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# Note: import <name> as <name> is required for names to be exported. |
| 16 | +# See PEP 484 & https://github.com/google/jax/issues/7570 |
| 17 | + |
| 18 | +from jax._src.ad_util import stop_gradient_p as stop_gradient_p |
| 19 | + |
| 20 | +from jax._src.core import ( |
| 21 | + call_p as call_p, |
| 22 | + closed_call_p as closed_call_p |
| 23 | +) |
| 24 | + |
| 25 | +from jax._src.custom_derivatives import ( |
| 26 | + custom_jvp_call_p as custom_jvp_call_p, |
| 27 | + custom_jvp_call_jaxpr_p as custom_jvp_call_jaxpr_p, |
| 28 | + custom_vjp_call_p as custom_vjp_call_p, |
| 29 | + custom_vjp_call_jaxpr_p as custom_vjp_call_jaxpr_p, |
| 30 | +) |
| 31 | + |
| 32 | +from jax._src.dispatch import device_put_p as device_put_p |
| 33 | + |
| 34 | +from jax._src.interpreters.ad import ( |
| 35 | + add_jaxvals_p as add_jaxvals_p, |
| 36 | + custom_lin_p as custom_lin_p, |
| 37 | + zeros_like_p as zeros_like_p, |
| 38 | +) |
| 39 | + |
| 40 | +from jax._src.interpreters.pxla import xla_pmap_p as xla_pmap_p |
| 41 | + |
| 42 | +from jax._src.lax.lax import ( |
| 43 | + abs_p as abs_p, |
| 44 | + acos_p as acos_p, |
| 45 | + acosh_p as acosh_p, |
| 46 | + add_p as add_p, |
| 47 | + after_all_p as after_all_p, |
| 48 | + and_p as and_p, |
| 49 | + argmax_p as argmax_p, |
| 50 | + argmin_p as argmin_p, |
| 51 | + asin_p as asin_p, |
| 52 | + asinh_p as asinh_p, |
| 53 | + atan_p as atan_p, |
| 54 | + atan2_p as atan2_p, |
| 55 | + atanh_p as atanh_p, |
| 56 | + bitcast_convert_type_p as bitcast_convert_type_p, |
| 57 | + broadcast_in_dim_p as broadcast_in_dim_p, |
| 58 | + cbrt_p as cbrt_p, |
| 59 | + ceil_p as ceil_p, |
| 60 | + clamp_p as clamp_p, |
| 61 | + clz_p as clz_p, |
| 62 | + complex_p as complex_p, |
| 63 | + concatenate_p as concatenate_p, |
| 64 | + conj_p as conj_p, |
| 65 | + convert_element_type_p as convert_element_type_p, |
| 66 | + copy_p as copy_p, |
| 67 | + cos_p as cos_p, |
| 68 | + cosh_p as cosh_p, |
| 69 | + create_token_p as create_token_p, |
| 70 | + div_p as div_p, |
| 71 | + dot_general_p as dot_general_p, |
| 72 | + eq_p as eq_p, |
| 73 | + eq_to_p as eq_to_p, |
| 74 | + exp_p as exp_p, |
| 75 | + exp2_p as exp2_p, |
| 76 | + expm1_p as expm1_p, |
| 77 | + floor_p as floor_p, |
| 78 | + ge_p as ge_p, |
| 79 | + gt_p as gt_p, |
| 80 | + imag_p as imag_p, |
| 81 | + infeed_p as infeed_p, |
| 82 | + integer_pow_p as integer_pow_p, |
| 83 | + iota_p as iota_p, |
| 84 | + is_finite_p as is_finite_p, |
| 85 | + le_p as le_p, |
| 86 | + le_to_p as le_to_p, |
| 87 | + log1p_p as log1p_p, |
| 88 | + log_p as log_p, |
| 89 | + logistic_p as logistic_p, |
| 90 | + lt_p as lt_p, |
| 91 | + lt_to_p as lt_to_p, |
| 92 | + max_p as max_p, |
| 93 | + min_p as min_p, |
| 94 | + mul_p as mul_p, |
| 95 | + ne_p as ne_p, |
| 96 | + neg_p as neg_p, |
| 97 | + nextafter_p as nextafter_p, |
| 98 | + not_p as not_p, |
| 99 | + or_p as or_p, |
| 100 | + outfeed_p as outfeed_p, |
| 101 | + pad_p as pad_p, |
| 102 | + population_count_p as population_count_p, |
| 103 | + pow_p as pow_p, |
| 104 | + real_p as real_p, |
| 105 | + reduce_and_p as reduce_and_p, |
| 106 | + reduce_max_p as reduce_max_p, |
| 107 | + reduce_min_p as reduce_min_p, |
| 108 | + reduce_or_p as reduce_or_p, |
| 109 | + reduce_p as reduce_p, |
| 110 | + reduce_precision_p as reduce_precision_p, |
| 111 | + reduce_prod_p as reduce_prod_p, |
| 112 | + reduce_sum_p as reduce_sum_p, |
| 113 | + reduce_xor_p as reduce_xor_p, |
| 114 | + rem_p as rem_p, |
| 115 | + reshape_p as reshape_p, |
| 116 | + rev_p as rev_p, |
| 117 | + rng_bit_generator_p as rng_bit_generator_p, |
| 118 | + rng_uniform_p as rng_uniform_p, |
| 119 | + round_p as round_p, |
| 120 | + rsqrt_p as rsqrt_p, |
| 121 | + select_n_p as select_n_p, |
| 122 | + shift_left_p as shift_left_p, |
| 123 | + shift_right_arithmetic_p as shift_right_arithmetic_p, |
| 124 | + shift_right_logical_p as shift_right_logical_p, |
| 125 | + sign_p as sign_p, |
| 126 | + sin_p as sin_p, |
| 127 | + sinh_p as sinh_p, |
| 128 | + sort_p as sort_p, |
| 129 | + sqrt_p as sqrt_p, |
| 130 | + squeeze_p as squeeze_p, |
| 131 | + sub_p as sub_p, |
| 132 | + tan_p as tan_p, |
| 133 | + tanh_p as tanh_p, |
| 134 | + top_k_p as top_k_p, |
| 135 | + transpose_p as transpose_p, |
| 136 | + xor_p as xor_p, |
| 137 | +) |
| 138 | + |
| 139 | +from jax._src.lax.special import ( |
| 140 | + bessel_i0e_p as bessel_i0e_p, |
| 141 | + bessel_i1e_p as bessel_i1e_p, |
| 142 | + digamma_p as digamma_p, |
| 143 | + erfc_p as erfc_p, |
| 144 | + erf_inv_p as erf_inv_p, |
| 145 | + erf_p as erf_p, |
| 146 | + igammac_p as igammac_p, |
| 147 | + igamma_grad_a_p as igamma_grad_a_p, |
| 148 | + igamma_p as igamma_p, |
| 149 | + lgamma_p as lgamma_p, |
| 150 | + polygamma_p as polygamma_p, |
| 151 | + random_gamma_grad_p as random_gamma_grad_p, |
| 152 | + regularized_incomplete_beta_p as regularized_incomplete_beta_p, |
| 153 | + zeta_p as zeta_p, |
| 154 | +) |
| 155 | + |
| 156 | +from jax._src.lax.slicing import ( |
| 157 | + dynamic_slice_p as dynamic_slice_p, |
| 158 | + dynamic_update_slice_p as dynamic_update_slice_p, |
| 159 | + gather_p as gather_p, |
| 160 | + scatter_add_p as scatter_add_p, |
| 161 | + scatter_max_p as scatter_max_p, |
| 162 | + scatter_min_p as scatter_min_p, |
| 163 | + scatter_mul_p as scatter_mul_p, |
| 164 | + scatter_p as scatter_p, |
| 165 | + slice_p as slice_p, |
| 166 | +) |
| 167 | + |
| 168 | +from jax._src.lax.convolution import ( |
| 169 | + conv_general_dilated_p as conv_general_dilated_p, |
| 170 | +) |
| 171 | + |
| 172 | +from jax._src.lax.windowed_reductions import ( |
| 173 | + reduce_window_max_p as reduce_window_max_p, |
| 174 | + reduce_window_min_p as reduce_window_min_p, |
| 175 | + reduce_window_p as reduce_window_p, |
| 176 | + reduce_window_sum_p as reduce_window_sum_p, |
| 177 | + select_and_gather_add_p as select_and_gather_add_p, |
| 178 | + select_and_scatter_p as select_and_scatter_p, |
| 179 | + select_and_scatter_add_p as select_and_scatter_add_p, |
| 180 | +) |
| 181 | + |
| 182 | +from jax._src.lax.control_flow import ( |
| 183 | + cond_p as cond_p, |
| 184 | + cumlogsumexp_p as cumlogsumexp_p, |
| 185 | + cummax_p as cummax_p, |
| 186 | + cummin_p as cummin_p, |
| 187 | + cumprod_p as cumprod_p, |
| 188 | + cumsum_p as cumsum_p, |
| 189 | + linear_solve_p as linear_solve_p, |
| 190 | + scan_p as scan_p, |
| 191 | + while_p as while_p, |
| 192 | +) |
| 193 | + |
| 194 | +from jax._src.lax.fft import ( |
| 195 | + fft_p as fft_p, |
| 196 | +) |
| 197 | + |
| 198 | +from jax._src.lax.parallel import ( |
| 199 | + all_gather_p as all_gather_p, |
| 200 | + all_to_all_p as all_to_all_p, |
| 201 | + axis_index_p as axis_index_p, |
| 202 | + pmax_p as pmax_p, |
| 203 | + pmin_p as pmin_p, |
| 204 | + ppermute_p as ppermute_p, |
| 205 | + psum_p as psum_p, |
| 206 | +) |
| 207 | + |
| 208 | +from jax._src.lax.ann import ( |
| 209 | + approx_top_k_p as approx_top_k_p |
| 210 | +) |
| 211 | + |
| 212 | +from jax._src.lax.linalg import ( |
| 213 | + cholesky_p as cholesky_p, |
| 214 | + eig_p as eig_p, |
| 215 | + eigh_p as eigh_p, |
| 216 | + hessenberg_p as hessenberg_p, |
| 217 | + lu_p as lu_p, |
| 218 | + householder_product_p as householder_product_p, |
| 219 | + qr_p as qr_p, |
| 220 | + svd_p as svd_p, |
| 221 | + triangular_solve_p as triangular_solve_p, |
| 222 | + tridiagonal_p as tridiagonal_p, |
| 223 | + tridiagonal_solve_p as tridiagonal_solve_p, |
| 224 | + schur_p as schur_p, |
| 225 | +) |
| 226 | + |
| 227 | +from jax._src.pjit import sharding_constraint_p as sharding_constraint_p |
| 228 | +from jax._src.prng import threefry2x32_p as threefry2x32_p |
| 229 | +from jax._src.random import random_gamma_p as random_gamma_p |
0 commit comments