|
105 | 105 | help='use Native AMP for mixed precision training')
|
106 | 106 | parser.add_argument('--amp-dtype', default='float16', type=str,
|
107 | 107 | help='lower precision AMP dtype (default: float16)')
|
| 108 | +parser.add_argument('--model-dtype', default=None, type=str, |
| 109 | + help='Model dtype override (non-AMP) (default: float32)') |
108 | 110 | parser.add_argument('--fuser', default='', type=str,
|
109 | 111 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
110 | 112 | parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
@@ -161,9 +163,15 @@ def main():
|
161 | 163 |
|
162 | 164 | device = torch.device(args.device)
|
163 | 165 |
|
| 166 | + model_dtype = None |
| 167 | + if args.model_dtype: |
| 168 | + assert args.model_dtype in ('float32', 'float16', 'bfloat16') |
| 169 | + model_dtype = getattr(torch, args.model_dtype) |
| 170 | + |
164 | 171 | # resolve AMP arguments based on PyTorch / Apex availability
|
165 | 172 | amp_autocast = suppress
|
166 | 173 | if args.amp:
|
| 174 | + assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP' |
167 | 175 | assert args.amp_dtype in ('float16', 'bfloat16')
|
168 | 176 | amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
|
169 | 177 | amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
@@ -201,7 +209,7 @@ def main():
|
201 | 209 | if args.test_pool:
|
202 | 210 | model, test_time_pool = apply_test_time_pool(model, data_config)
|
203 | 211 |
|
204 |
| - model = model.to(device) |
| 212 | + model = model.to(device=device, dtype=model_dtype) |
205 | 213 | model.eval()
|
206 | 214 | if args.channels_last:
|
207 | 215 | model = model.to(memory_format=torch.channels_last)
|
@@ -237,6 +245,7 @@ def main():
|
237 | 245 | use_prefetcher=True,
|
238 | 246 | num_workers=workers,
|
239 | 247 | device=device,
|
| 248 | + img_dtype=model_dtype or torch.float32, |
240 | 249 | **data_config,
|
241 | 250 | )
|
242 | 251 |
|
@@ -280,7 +289,7 @@ def main():
|
280 | 289 | np_labels = to_label(np_indices)
|
281 | 290 | all_labels.append(np_labels)
|
282 | 291 |
|
283 |
| - all_outputs.append(output.cpu().numpy()) |
| 292 | + all_outputs.append(output.float().cpu().numpy()) |
284 | 293 |
|
285 | 294 | # measure elapsed time
|
286 | 295 | batch_time.update(time.time() - end)
|
|
0 commit comments