Skip to content

Commit 3e0468a

Browse files
authored
Update quantization to force gpu usage for blockwise8 (#3256)
Fixes # . ### Description Account for QA finding and [bug from bitsandbytes ](bitsandbytes-foundation/bitsandbytes#1540) Also add info to supported precisions ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated.
1 parent 36a40a2 commit 3e0468a

File tree

7 files changed

+71
-69
lines changed

7 files changed

+71
-69
lines changed

examples/advanced/llm_hf/sft_job.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
2020
from nvflare.app_common.workflows.fedavg import FedAvg
2121
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
22-
from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor
23-
from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor
22+
from nvflare.app_opt.pt.quantization.dequantizer import ModelDequantizer
23+
from nvflare.app_opt.pt.quantization.quantizer import ModelQuantizer
2424
from nvflare.job_config.script_runner import ScriptRunner
2525

2626

@@ -67,10 +67,10 @@ def main():
6767

6868
if args.quantize_mode:
6969
# If using quantization, add quantize filters.
70-
quantizor = ModelQuantizor(quantization_type=args.quantize_mode)
71-
dequantizor = ModelDequantizor()
72-
job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
73-
job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
70+
quantizer = ModelQuantizer(quantization_type=args.quantize_mode)
71+
dequantizer = ModelDequantizer()
72+
job.to(quantizer, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
73+
job.to(dequantizer, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
7474

7575
# Define the model persistor and send to server
7676
# First send the model to the server
@@ -106,8 +106,8 @@ def main():
106106
job.to(runner, site_name, tasks=["train"])
107107

108108
if args.quantize_mode:
109-
job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
110-
job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
109+
job.to(quantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
110+
job.to(dequantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
111111

112112
# Export the job
113113
print("job_dir=", job_dir)

examples/tutorials/self-paced-training/part-4_advanced_federated_learning/chapter-8_federated_LLM_training/08.2_llm_sft/sft_job.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
2020
from nvflare.app_common.workflows.fedavg import FedAvg
2121
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
22-
from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor
23-
from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor
22+
from nvflare.app_opt.pt.quantization.dequantizer import ModelDequantizer
23+
from nvflare.app_opt.pt.quantization.quantizer import ModelQuantizer
2424
from nvflare.job_config.script_runner import ScriptRunner
2525

2626

@@ -67,10 +67,10 @@ def main():
6767

6868
if args.quantize_mode:
6969
# If using quantization, add quantize filters.
70-
quantizor = ModelQuantizor(quantization_type=args.quantize_mode)
71-
dequantizor = ModelDequantizor()
72-
job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
73-
job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
70+
quantizer = ModelQuantizer(quantization_type=args.quantize_mode)
71+
dequantizer = ModelDequantizer()
72+
job.to(quantizer, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
73+
job.to(dequantizer, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
7474

7575
# Define the model persistor and send to server
7676
# First send the model to the server
@@ -106,8 +106,8 @@ def main():
106106
job.to(runner, site_name, tasks=["train"])
107107

108108
if args.quantize_mode:
109-
job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
110-
job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
109+
job.to(quantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
110+
job.to(dequantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
111111

112112
# Export the job
113113
print("job_dir=", job_dir)

examples/tutorials/self-paced-training/part-4_advanced_federated_learning/chapter-8_federated_LLM_training/08.3_llm_peft/peft_job.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
2020
from nvflare.app_common.workflows.fedavg import FedAvg
2121
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
22-
from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor
23-
from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor
22+
from nvflare.app_opt.pt.quantization.dequantizer import ModelDequantizer
23+
from nvflare.app_opt.pt.quantization.quantizer import ModelQuantizer
2424
from nvflare.job_config.script_runner import ScriptRunner
2525

2626

@@ -67,10 +67,10 @@ def main():
6767

6868
if args.quantize_mode:
6969
# If using quantization, add quantize filters.
70-
quantizor = ModelQuantizor(quantization_type=args.quantize_mode)
71-
dequantizor = ModelDequantizor()
72-
job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
73-
job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
70+
quantizer = ModelQuantizer(quantization_type=args.quantize_mode)
71+
dequantizer = ModelDequantizer()
72+
job.to(quantizer, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
73+
job.to(dequantizer, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
7474

7575
# Define the model persistor and send to server
7676
# First send the model to the server
@@ -106,8 +106,8 @@ def main():
106106
job.to(runner, site_name, tasks=["train"])
107107

108108
if args.quantize_mode:
109-
job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
110-
job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
109+
job.to(quantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
110+
job.to(dequantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
111111

112112
# Export the job
113113
print("job_dir=", job_dir)

examples/tutorials/self-paced-training/part-4_advanced_federated_learning/chapter-8_federated_LLM_training/08.4_llm_quantization/sft_job.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
2020
from nvflare.app_common.workflows.fedavg import FedAvg
2121
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
22-
from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor
23-
from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor
22+
from nvflare.app_opt.pt.quantization.dequantizer import ModelDequantizer
23+
from nvflare.app_opt.pt.quantization.quantizer import ModelQuantizer
2424
from nvflare.job_config.script_runner import ScriptRunner
2525

2626

@@ -67,10 +67,10 @@ def main():
6767

6868
if args.quantize_mode:
6969
# If using quantization, add quantize filters.
70-
quantizor = ModelQuantizor(quantization_type=args.quantize_mode)
71-
dequantizor = ModelDequantizor()
72-
job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
73-
job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
70+
quantizer = ModelQuantizer(quantization_type=args.quantize_mode)
71+
dequantizer = ModelDequantizer()
72+
job.to(quantizer, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
73+
job.to(dequantizer, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
7474

7575
# Define the model persistor and send to server
7676
# First send the model to the server
@@ -106,8 +106,8 @@ def main():
106106
job.to(runner, site_name, tasks=["train"])
107107

108108
if args.quantize_mode:
109-
job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
110-
job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
109+
job.to(quantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
110+
job.to(dequantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
111111

112112
# Export the job
113113
print("job_dir=", job_dir)

nvflare/app_opt/pt/quantization/constant.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# Supported Input Data Type
16+
# Message quantization is mainly for reducing the message that can be
17+
# significantly large, e.g. LLMs. Thus, the supported input data types
18+
# we consider are common ones during LLM training, including fp32, fp16, and bf16.
1519
DATA_TYPE = [
16-
"FLOAT64",
1720
"FLOAT32",
1821
"FLOAT16",
1922
"BFLOAT16",
20-
"UINT8",
21-
"INT8",
2223
]
2324

25+
# Supported Quantization Type to reduce the above input data types
26+
# The quantization types are mainly for reducing the model size,
27+
# Hence, we support 16-, 8-, and 4-bits quantization.
28+
# Note that 8- and 4-bits quantization needs GPU support.
2429
QUANTIZATION_TYPE = [
2530
"FLOAT16",
2631
"BLOCKWISE8",

nvflare/app_opt/pt/quantization/dequantizor.py renamed to nvflare/app_opt/pt/quantization/dequantizer.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from nvflare.app_opt.pt.quantization.constant import QUANTIZATION_TYPE
2727

2828

29-
class ModelDequantizor(DXOFilter):
29+
class ModelDequantizer(DXOFilter):
3030
def __init__(self):
3131
"""Filter to dequantize Shareable object to recover from quantization
3232
@@ -84,17 +84,18 @@ def dequantization(
8484
params[param_name] = values
8585
elif quantization_type in ["blockwise8", "float4", "normfloat4"]:
8686
# use bitsandbytes to dequantize the values
87+
# need GPU for general support
8788
# extract quantization state
8889
if quantization_type == "blockwise8":
8990
if source_data_format == "numpy":
9091
# first convert numpy array to tensor if numpy
91-
quantized = torch.as_tensor(values)
92-
absmax = torch.as_tensor(quant_state[param_name]["absmax"])
93-
code = torch.as_tensor(quant_state[param_name]["code"])
92+
quantized = torch.as_tensor(values).cuda()
93+
absmax = torch.as_tensor(quant_state[param_name]["absmax"]).cuda()
94+
code = torch.as_tensor(quant_state[param_name]["code"]).cuda()
9495
elif source_data_format == "torch":
95-
quantized = values
96-
absmax = quant_state[param_name]["absmax"]
97-
code = quant_state[param_name]["code"]
96+
quantized = values.cuda()
97+
absmax = quant_state[param_name]["absmax"].cuda()
98+
code = quant_state[param_name]["code"].cuda()
9899
# de-quanitze
99100
dequantized = dequantize_blockwise(quantized, absmax=absmax, code=code)
100101
else:
@@ -125,6 +126,7 @@ def dequantization(
125126
dequantized = dequantize_4bit(quantized, quantize_state, quant_type="fp4")
126127
else:
127128
dequantized = dequantize_4bit(quantized, quantize_state, quant_type="nf4")
129+
128130
if source_data_format == "numpy":
129131
params[param_name] = dequantized.cpu().numpy()
130132
elif source_data_format == "torch":
@@ -135,16 +137,12 @@ def dequantization(
135137
# convert back to original data type
136138
if source_data_type == "float32":
137139
params[param_name] = params[param_name].astype(np.float32)
138-
elif source_data_type == "float64":
139-
params[param_name] = params[param_name].astype(np.float64)
140140
elif source_data_type == "float16":
141141
params[param_name] = params[param_name].astype(np.float16)
142142
elif source_data_format == "torch":
143143
# convert back to original data type
144144
if source_data_type == "float32":
145145
params[param_name] = params[param_name].float()
146-
elif source_data_type == "float64":
147-
params[param_name] = params[param_name].double()
148146
elif source_data_type == "float16":
149147
params[param_name] = params[param_name].half()
150148
elif source_data_type == "bfloat16":

nvflare/app_opt/pt/quantization/quantizor.py renamed to nvflare/app_opt/pt/quantization/quantizer.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from nvflare.app_opt.pt.quantization.constant import DATA_TYPE, QUANTIZATION_TYPE
2727

2828

29-
class ModelQuantizor(DXOFilter):
29+
class ModelQuantizer(DXOFilter):
3030
def __init__(
3131
self,
3232
quantization_type="float16",
@@ -120,41 +120,39 @@ def quantization(self, params: dict, fl_ctx: FLContext):
120120
elif self.quantization_type in ["blockwise8", "float4", "normfloat4"]:
121121
# use bitsandbytes to quantize the values
122122
# input is a tensor, output is a tuple of (quantized tensor, quantized_state)
123-
if self.quantization_type == "blockwise8":
124-
if source_data_format == "numpy":
125-
# if numpy, first convert numpy array to tensor
126-
values_tensor = torch.as_tensor(values)
127-
elif source_data_format == "torch":
128-
values_tensor = values
129123

130-
# then quantize the tensor
124+
# CPU has limited support for 8- and 4-bits quantization
125+
# For general purpose, here we use GPU
126+
if source_data_format == "numpy":
127+
# if numpy, first convert numpy array to tensor, need to use GPU
128+
values_tensor = torch.as_tensor(values).cuda()
129+
elif source_data_format == "torch":
130+
# if torch, directly use the tensor, need to use GPU
131+
values_tensor = values.cuda()
132+
133+
if self.quantization_type == "blockwise8":
134+
# quantize the tensor
131135
quantized, quantized_state = quantize_blockwise(values_tensor)
132136
# add the quantization state and values, keep source data format
133137
if source_data_format == "numpy":
134-
quant_state[param_name]["absmax"] = quantized_state.absmax.numpy()
135-
quant_state[param_name]["code"] = quantized_state.code.numpy()
136-
values = quantized.numpy()
138+
quant_state[param_name]["absmax"] = quantized_state.absmax.cpu().numpy()
139+
quant_state[param_name]["code"] = quantized_state.code.cpu().numpy()
140+
values = quantized.cpu().numpy()
137141
elif source_data_format == "torch":
138-
quant_state[param_name]["absmax"] = quantized_state.absmax
139-
quant_state[param_name]["code"] = quantized_state.code
140-
values = quantized
142+
quant_state[param_name]["absmax"] = quantized_state.absmax.cpu()
143+
quant_state[param_name]["code"] = quantized_state.code.cpu()
144+
values = quantized.cpu()
141145
n_bytes_meta += quant_state[param_name]["absmax"].nbytes
142146
n_bytes_meta += quant_state[param_name]["code"].nbytes
143147
else:
144-
if source_data_format == "numpy":
145-
# if numpy, first convert numpy array to tensor, need to use GPU
146-
values_tensor = torch.as_tensor(values).cuda()
147-
elif source_data_format == "torch":
148-
# if torch, directly use the tensor, need to use GPU
149-
values_tensor = values.cuda()
150148
# then quantize the tensor
151149
if self.quantization_type == "float4":
152150
quantized, quantized_state = quantize_4bit(values_tensor, quant_type="fp4")
153151
else:
154152
quantized, quantized_state = quantize_4bit(values_tensor, quant_type="nf4")
155153
# add the quantization state and values, keep source data format
156154
quantized_state = quantized_state.as_dict()
157-
155+
# prepared the message
158156
for state_name, state in quantized_state.items():
159157
if isinstance(state, torch.Tensor):
160158
if source_data_format == "numpy":
@@ -171,6 +169,7 @@ def quantization(self, params: dict, fl_ctx: FLContext):
171169
values = quantized.cpu().numpy()
172170
elif source_data_format == "torch":
173171
values = quantized.cpu()
172+
174173
params[param_name] = values
175174
n_bytes_after += params[param_name].nbytes
176175

@@ -203,8 +202,8 @@ def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Unio
203202
# thus the subsequent communications to the rest of clients will no longer need to apply quantization
204203
# This will not apply to client job, since the client job will be 1-1 and quantization applies to each client
205204
# Potentially:
206-
# If clients talks to each other, it will also be 1-N and same rule applies
207-
# If 1-N server-client filters can be different (Filter_1 applies to server-client_subset_1, etc.), then
205+
# - If clients talks to each other, it will also be 1-N and same rule applies
206+
# - If 1-N server-client filters can be different (Filter_1 applies to server-client_subset_1, etc.), then
208207
# a deep copy of the server data should be made by filter before applying a different filter
209208

210209
# quantized_flag None if does not exist in meta

0 commit comments

Comments
 (0)