Skip to content

Commit a83a999

Browse files
committed
Reduce peak memory used for unit tests.
1 parent f8a6acc commit a83a999

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def linear_8bit_lt_layer():
4040
def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: InvokeLinear8bitLt):
4141
"""Test CustomInvokeLinear8bitLt inference with all weights on the GPU."""
4242
# Run inference on the original layer.
43-
x = torch.randn(10, 32).to("cuda")
43+
x = torch.randn(1, 32).to("cuda")
4444
y_quantized = linear_8bit_lt_layer(x)
4545

4646
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
@@ -54,7 +54,7 @@ def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer:
5454
def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: InvokeLinear8bitLt):
5555
"""Test CustomInvokeLinear8bitLt inference with all weights on the CPU (streaming to the GPU)."""
5656
# Run inference on the original layer.
57-
x = torch.randn(10, 32).to("cuda")
57+
x = torch.randn(1, 32).to("cuda")
5858
y_quantized = linear_8bit_lt_layer(x)
5959

6060
# Copy the state dict to the CPU and reload it.
@@ -98,7 +98,7 @@ def linear_nf4_layer():
9898
def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLinearNF4):
9999
"""Test CustomInvokeLinearNF4 inference with all weights on the GPU."""
100100
# Run inference on the original layer.
101-
x = torch.randn(10, 32).to("cuda")
101+
x = torch.randn(1, 32).to("cuda")
102102
y_quantized = linear_nf4_layer(x)
103103

104104
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
@@ -112,7 +112,7 @@ def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLi
112112
def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLinearNF4):
113113
"""Test CustomInvokeLinearNF4 inference with all weights on the CPU (streaming to the GPU)."""
114114
# Run inference on the original layer.
115-
x = torch.randn(10, 32).to(device="cuda")
115+
x = torch.randn(1, 32).to(device="cuda")
116116
y_quantized = linear_nf4_layer(x)
117117

118118
# Copy the state dict to the CPU and reload it.

tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_torch_module_autocast_linear_layer(device: torch.device, model: torch.n
5757
assert all(p.device.type == "cpu" for p in model.parameters())
5858

5959
# Run inference on the CPU.
60-
x = torch.randn(10, 32, device="cpu")
60+
x = torch.randn(1, 32, device="cpu")
6161
expected = model(x)
6262
assert expected.device.type == "cpu"
6363

@@ -103,7 +103,7 @@ def test_torch_module_autocast_bnb_llm_int8_linear_layer():
103103
assert model.linear.weight.SCB is not None
104104

105105
# Run inference on the GPU.
106-
x = torch.randn(10, 32)
106+
x = torch.randn(1, 32)
107107
expected = model(x.to("cuda"))
108108
assert expected.device.type == "cuda"
109109

tests/backend/quantization/test_bnb_llm_int8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_invoke_linear_8bit_lt_quantization():
3333
assert quantized_layer.weight.CB.dtype == torch.int8
3434

3535
# Run inference on both the original and quantized layers.
36-
x = torch.randn(10, 32)
36+
x = torch.randn(1, 32)
3737
y = orig_layer(x)
3838
y_quantized = quantized_layer(x.to("cuda"))
3939
assert y.shape == y_quantized.shape
@@ -53,7 +53,7 @@ def test_invoke_linear_8bit_lt_state_dict_roundtrip():
5353
orig_layer_state_dict = orig_layer.state_dict()
5454

5555
# Run inference on the original layer.
56-
x = torch.randn(10, 32)
56+
x = torch.randn(1, 32)
5757
y = orig_layer(x)
5858

5959
# Prepare a quantized InvokeLinear8bitLt layer.

0 commit comments

Comments
 (0)