Skip to content

Commit 694df0d

Browse files
chr1sj0nesjax authors
authored andcommitted
Revert change to next_power_of_two.
Reverts 168f30a PiperOrigin-RevId: 615722408
1 parent 4a35c12 commit 694df0d

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

jax/_src/pallas/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Pallas utility functions."""
16+
import math
1617
from jax import lax
1718
from jax._src import core as jax_core
1819
from jax._src.util import split_list
@@ -44,7 +45,9 @@ def strides_from_shape(shape: tuple[int, ...]) -> tuple[int, ...]:
4445

4546

4647
def next_power_of_2(x: int) -> int:
47-
return 2**x.bit_length()
48+
if x == 0:
49+
return 1
50+
return int(2 ** math.ceil(math.log2(x)))
4851

4952

5053
def pattern_match_scan_to_fori_loop(

0 commit comments

Comments
 (0)