@@ -350,54 +350,30 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
350
350
351
351
if func is aten .slice .Tensor :
352
352
self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
353
- n_by_8 , k_by_inner_tiles , _ , _ = self .packed_weight .shape
354
- sz_dim1 , sz_dim0 , _ = self .scale_and_zero .shape
355
- data_len = self .shape [dim ]
356
- assert dim in [0 , 1 ], (
357
- f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run { func } , with dim={ dim } , that is not supported"
358
- )
359
-
360
- if dim == 0 :
361
- pw_len = n_by_8
362
- sz_len = sz_dim0
353
+ if dim in [0 , 1 ]:
354
+ int_data , scale , zero_point = self .get_plain ()
355
+ data_len = int_data .shape [dim ]
356
+ scale_len = scale .shape [dim ]
357
+ ratio = data_len / scale_len
358
+ start_scale = int (start / ratio )
359
+ end_scale = int (end / ratio )
360
+
361
+ int_data = aten .slice .Tensor (int_data , dim , start , end , step )
362
+ scale = aten .slice .Tensor (scale , dim , start_scale , end_scale , step )
363
+ zero_point = aten .slice .Tensor (
364
+ zero_point , dim , start_scale , end_scale , step
365
+ )
366
+ # this is to handle padding
367
+ int_data , scale , zero_point = self ._layout .post_process (
368
+ int_data , scale , zero_point , self .block_size
369
+ )
370
+ sliced = self .from_plain (int_data , scale , zero_point , self ._layout )
371
+ return return_and_correct_aliasing (func , args , kwargs , sliced )
363
372
else :
364
- pw_len = k_by_inner_tiles
365
- sz_len = sz_dim1
366
-
367
- if pw_len == 0 or sz_len == 0 :
368
- return return_and_correct_aliasing (
369
- func ,
370
- args ,
371
- kwargs ,
372
- TensorCoreTiledAQTTensorImpl (
373
- self .packed_weight ,
374
- self .scale_and_zero ,
375
- self .transposed ,
376
- self ._layout ,
377
- ),
373
+ raise NotImplementedError (
374
+ f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run { func } , with dim={ dim } , that is not supported"
378
375
)
379
376
380
- pw_ratio = data_len / pw_len
381
- start_pw = int (start / pw_ratio )
382
- end_pw = int (end / pw_ratio )
383
-
384
- sz_ratio = data_len / sz_len
385
- start_sz = int (start / sz_ratio )
386
- end_sz = int (end / sz_ratio )
387
-
388
- packed_weight = aten .slice (self .packed_weight , dim , start_pw , end_pw , step )
389
- scale_and_zero = aten .slice (
390
- self .scale_and_zero , 1 - dim , start_sz , end_sz , step
391
- )
392
- return return_and_correct_aliasing (
393
- func ,
394
- args ,
395
- kwargs ,
396
- TensorCoreTiledAQTTensorImpl (
397
- packed_weight , scale_and_zero , self .transposed , self ._layout
398
- ),
399
- )
400
-
401
377
raise NotImplementedError (
402
378
f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run { func } , this is not supported"
403
379
)
0 commit comments