@@ -484,3 +484,47 @@ def test_quantized_linear_sidecar_patches(
484
484
output_linear_patched = linear_layer_custom (input )
485
485
output_quantized_patched = quantized_linear_layer_custom (input )
486
486
assert torch .allclose (output_linear_patched , output_quantized_patched , rtol = 0.2 , atol = 0.2 )
487
+
488
+
489
+ @parameterize_cuda_and_mps
490
+ def test_quantized_linear_sidecar_patches_with_autocast_from_cpu_to_device (
491
+ device : str ,
492
+ quantized_linear_layer_under_test : tuple [torch .nn .Module , torch .nn .Module ],
493
+ patch_under_test : PatchUnderTest ,
494
+ ):
495
+ """Test that the output of a linear layer with sidecar patches is the same when the layer is on the device and
496
+ when the layer is on the CPU and the patches are autocasted to the device.
497
+ """
498
+ patches , input = patch_under_test
499
+
500
+ _ , quantized_linear_layer = quantized_linear_layer_under_test
501
+
502
+ # Move everything to the device.
503
+ layer_to_device_via_state_dict (quantized_linear_layer , device )
504
+ input = input .to (torch .device (device ))
505
+
506
+ # Wrap the quantized linear layer in a custom layer and add the patch to it.
507
+ quantized_linear_layer_custom = wrap_single_custom_layer (quantized_linear_layer )
508
+ for patch , weight in patches :
509
+ patch .to (torch .device (device ))
510
+ quantized_linear_layer_custom .add_patch (patch , weight )
511
+
512
+ # Run inference with the custom layer on the device.
513
+ expected_output = quantized_linear_layer_custom (input )
514
+
515
+ # Move the custom layer to the CPU.
516
+ layer_to_device_via_state_dict (quantized_linear_layer_custom , "cpu" )
517
+
518
+ # Move the patches to the CPU.
519
+ quantized_linear_layer_custom .clear_patches ()
520
+ for patch , weight in patches :
521
+ patch .to (torch .device ("cpu" ))
522
+ quantized_linear_layer_custom .add_patch (patch , weight )
523
+
524
+ # Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to
525
+ # the device.
526
+ autocast_output = quantized_linear_layer_custom (input )
527
+ assert autocast_output .device .type == device
528
+
529
+ # Assert that the outputs with and without autocasting are the same.
530
+ assert torch .allclose (expected_output , autocast_output , atol = 1e-6 )
0 commit comments