@@ -401,23 +401,6 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
401
401
return torch .tensor (data , dtype = torch .float32 )
402
402
403
403
404
- @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
405
- def create_quantile_map (A , total_bits = 8 ):
406
- q = estimate_quantiles (A , num_quantiles = 2 ** total_bits - 1 )
407
- q = q .tolist ()
408
- q .append (0 )
409
-
410
- gap = 256 - len (q )
411
- for i in range (gap ):
412
- q .append (0 )
413
-
414
- q .sort ()
415
-
416
- q = Tensor (q )
417
- q = q / q .abs ().max ()
418
- return q
419
-
420
-
421
404
def is_on_gpu (tensors : Iterable [Optional [torch .Tensor ]]):
422
405
"""Verifies that the input tensors are all on the same device.
423
406
@@ -474,74 +457,6 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
474
457
return ct .c_void_p (A .data_ptr ())
475
458
476
459
477
- @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
478
- def estimate_quantiles (
479
- A : Tensor ,
480
- out : Optional [torch .Tensor ] = None ,
481
- offset : float = 1 / 512 ,
482
- num_quantiles = 256 ,
483
- ) -> Tensor :
484
- """
485
- Estimates 256 equidistant quantiles on the input tensor eCDF.
486
-
487
- Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
488
- via the eCDF of the input tensor `A`. This is a fast but approximate algorithm
489
- and the extreme quantiles close to 0 and 1 have high variance / large estimation
490
- errors. These large errors can be avoided by using the offset variable which trims
491
- the distribution. The default offset value of 1/512 ensures minimum entropy encoding -- it
492
- trims 1/512 = 0.2% from each side of the distrivution. An offset value of 0.01 to 0.02
493
- usually has a much lower error but is not a minimum entropy encoding. Given an offset
494
- of 0.02 equidistance points in the range [0.02, 0.98] are used for the quantiles.
495
-
496
- Parameters
497
- ----------
498
- A : torch.Tensor
499
- The input tensor. Any shape.
500
- out : torch.Tensor
501
- Tensor with the 256 estimated quantiles.
502
- offset : float
503
- The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
504
- num_quantiles : int
505
- The number of equally spaced quantiles.
506
-
507
- Returns
508
- -------
509
- torch.Tensor:
510
- The 256 quantiles in float32 datatype.
511
- """
512
- if A .numel () < 256 :
513
- raise NotImplementedError (
514
- f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only { A .numel ()} values." ,
515
- )
516
- if num_quantiles > 256 :
517
- raise NotImplementedError (
518
- f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={ num_quantiles } " ,
519
- )
520
- if num_quantiles < 256 and offset == 1 / (512 ):
521
- # override default arguments
522
- offset = 1 / (2 * num_quantiles )
523
-
524
- if out is None :
525
- out = torch .zeros ((256 ,), dtype = torch .float32 , device = A .device )
526
-
527
- with _cuda_device_of (A ):
528
- is_on_gpu ([A , out ])
529
-
530
- if A .dtype == torch .float32 :
531
- lib .cestimate_quantiles_fp32 (get_ptr (A ), get_ptr (out ), ct .c_float (offset ), ct .c_int (A .numel ()))
532
- elif A .dtype == torch .float16 :
533
- lib .cestimate_quantiles_fp16 (get_ptr (A ), get_ptr (out ), ct .c_float (offset ), ct .c_int (A .numel ()))
534
- else :
535
- raise NotImplementedError (f"Not supported data type { A .dtype } " )
536
-
537
- if num_quantiles < 256 :
538
- step = round (256 / num_quantiles )
539
- idx = torch .linspace (0 , 255 , num_quantiles ).long ().to (A .device )
540
- out = out [idx ]
541
-
542
- return out
543
-
544
-
545
460
class QuantState :
546
461
"""container for quantization state components to work with Params4bit and similar classes"""
547
462
@@ -1601,25 +1516,6 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
1601
1516
return current_gnorm , clip_value , gnorm_scale
1602
1517
1603
1518
1604
- @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
1605
- def histogram_scatter_add_2d (histogram : Tensor , index1 : Tensor , index2 : Tensor , source : Tensor ):
1606
- assert len (histogram .shape ) == 2
1607
- assert histogram .dtype == torch .float32
1608
- assert source .dtype == torch .float32
1609
- assert index1 .dtype == torch .int32
1610
- assert index2 .dtype == torch .int32
1611
-
1612
- assert histogram .device .type == "cuda"
1613
- assert index1 .device .type == "cuda"
1614
- assert index2 .device .type == "cuda"
1615
- assert source .device .type == "cuda"
1616
-
1617
- maxdim1 = ct .c_int32 (histogram .shape [0 ])
1618
- n = ct .c_int32 (index1 .numel ())
1619
- is_on_gpu ([histogram , index1 , index2 , source ])
1620
- lib .chistogram_scatter_add_2d (get_ptr (histogram ), get_ptr (index1 ), get_ptr (index2 ), get_ptr (source ), maxdim1 , n )
1621
-
1622
-
1623
1519
def check_matmul (A , B , out , transposed_A , transposed_B , expected_type = torch .int8 ):
1624
1520
if not torch .cuda .is_initialized ():
1625
1521
torch .cuda .init ()
@@ -2426,118 +2322,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
2426
2322
C = 127.0
2427
2323
2428
2324
2429
- @deprecated (
2430
- "This function is deprecated and will be removed in a future release. "
2431
- "Consider using `int8_vectorwise_quant` instead." ,
2432
- category = FutureWarning ,
2433
- )
2434
- def vectorwise_quant (x , dim = 1 , quant_type = "vector" ):
2435
- if quant_type == "linear" :
2436
- max1 = torch .abs (x ).max ().float ()
2437
- xq = torch .round (x / max1 * 127 ).to (torch .int8 )
2438
- return xq , max1
2439
- elif quant_type in ["vector" , "row" ]:
2440
- max1 = torch .amax (torch .abs (x ), dim = dim , keepdim = True )
2441
- xq = torch .round (x * (C / max1 )).to (torch .int8 )
2442
- return xq , max1
2443
- elif quant_type == "zeropoint" :
2444
- dtype = x .dtype
2445
- x = x .float ()
2446
- dyna = x .max () - x .min ()
2447
- if dyna == 0 :
2448
- dyna = 1
2449
- qx = 255.0 / dyna
2450
- minx = x .min ()
2451
- zpx = torch .round (minx * qx )
2452
- x = torch .round (qx * x - zpx ) + zpx
2453
- return x , qx
2454
- elif quant_type in ["vector-zeropoint" , "row-zeropoint" ]:
2455
- dtype = x .dtype
2456
- x = x .float ()
2457
- dyna = torch .amax (x , dim = dim , keepdim = True ) - torch .amin (x , dim = dim , keepdim = True )
2458
- dyna [dyna == 0 ] = 1
2459
- qx = 255.0 / dyna
2460
- minx = torch .amin (x , dim = dim , keepdim = True )
2461
- zpx = torch .round (minx * qx )
2462
- x = torch .round (qx * x - zpx ) + zpx
2463
- return x , qx
2464
- elif quant_type == "truncated-vector" :
2465
- with torch .no_grad ():
2466
- absx = torch .abs (x )
2467
- max1 = torch .amax (absx , dim = dim , keepdim = True )
2468
- max1 = max1 * 0.7
2469
- idx = absx > max1 .expand_as (absx )
2470
- sign = torch .sign (x [idx ])
2471
- x [idx ] = max1 .expand_as (absx )[idx ] * sign
2472
- xq = torch .round (x / max1 * C ).to (torch .int8 )
2473
- return xq , max1
2474
- else :
2475
- return None
2476
-
2477
-
2478
- @deprecated (
2479
- "This function is deprecated and will be removed in a future release." ,
2480
- category = FutureWarning ,
2481
- )
2482
- def vectorwise_mm_dequant (xq , S1 , S2 , dtype = torch .half , quant_type = "vector" ):
2483
- if quant_type == "linear" :
2484
- norm = S1 * S2 / (C * C )
2485
- # double cast needed to prevent overflows
2486
- return (xq .float () * norm ).to (dtype )
2487
- elif quant_type == "zeropoint" :
2488
- norm = 1.0 / (S1 * S2 )
2489
- return (xq .float () * norm ).to (dtype )
2490
- elif quant_type == "row-zeropoint" :
2491
- norm = 1.0 / (S1 * S2 )
2492
- x = xq .float ()
2493
- if len (S1 .shape ) == 3 and len (x .shape ) == 2 :
2494
- S1 = S1 .squeeze (0 )
2495
- if len (S2 .shape ) == 3 and len (x .shape ) == 2 :
2496
- S2 = S2 .squeeze (0 )
2497
- if len (S1 .shape ) == 2 :
2498
- x *= norm
2499
- else :
2500
- x *= norm
2501
- return x .to (dtype )
2502
- elif quant_type == "vector-zeropoint" :
2503
- x = xq .float ()
2504
- if len (S1 .shape ) == 3 and len (x .shape ) == 2 :
2505
- S1 = S1 .squeeze (0 )
2506
- if len (S2 .shape ) == 3 and len (x .shape ) == 2 :
2507
- S2 = S2 .squeeze (0 )
2508
- if len (S1 .shape ) == 2 :
2509
- x *= 1.0 / S1
2510
- else :
2511
- x *= 1.0 / S1
2512
- x *= 1.0 / S2 .t ()
2513
- return x .to (dtype )
2514
- elif quant_type == "row" :
2515
- x = xq .float ()
2516
- if len (S1 .shape ) == 3 and len (x .shape ) == 2 :
2517
- S1 = S1 .squeeze (0 )
2518
- if len (S2 .shape ) == 3 and len (x .shape ) == 2 :
2519
- S2 = S2 .squeeze (0 )
2520
- if len (S1 .shape ) == 2 :
2521
- x *= S1 * S2 / (C * C )
2522
- else :
2523
- x *= S1 * S2 / (C * C )
2524
- return x .to (dtype )
2525
- elif quant_type in ["truncated-vector" , "vector" ]:
2526
- x = xq .float ()
2527
- if len (S1 .shape ) == 3 and len (x .shape ) == 2 :
2528
- S1 = S1 .squeeze (0 )
2529
- if len (S2 .shape ) == 3 and len (x .shape ) == 2 :
2530
- S2 = S2 .squeeze (0 )
2531
- if len (S1 .shape ) == 2 :
2532
- x *= S1 / C
2533
- else :
2534
- x *= S1 / C
2535
- x *= S2 / C
2536
- return x .to (dtype )
2537
- else :
2538
- return None
2539
-
2540
-
2541
2325
def _enable_ipex_fusion (linear : torch .nn .Module , x : torch .Tensor ):
2542
2326
quant_state = linear .weight .quant_state
2543
2327
0 commit comments