Skip to content

Commit 64bd95d

Browse files
chr1sj0nesjax authors
authored andcommitted
Use int.bit_length() in next_power_of_2.
- Added docstring explaining its behaviour. - Check for negative inputs. See https://docs.python.org/3/library/stdtypes.html#int.bit_length. PiperOrigin-RevId: 615731376
1 parent 694df0d commit 64bd95d

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

jax/_src/pallas/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Pallas utility functions."""
16-
import math
16+
1717
from jax import lax
1818
from jax._src import core as jax_core
1919
from jax._src.util import split_list
@@ -45,9 +45,10 @@ def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]:
4545

4646

4747
def next_power_of_2(x: int) -> int:
48-
if x == 0:
49-
return 1
50-
return int(2 ** math.ceil(math.log2(x)))
48+
"""Returns the next power of two greater than or equal to `x`."""
49+
if x < 0:
50+
raise ValueError("`next_power_of_2` requires a non-negative integer.")
51+
return 1 if x == 0 else 2 ** (x - 1).bit_length()
5152

5253

5354
def pattern_match_scan_to_fori_loop(

0 commit comments

Comments
 (0)