|
12 | 12 | import time
|
13 | 13 | from contextlib import suppress
|
14 | 14 | from functools import partial
|
| 15 | +from sys import maxsize |
15 | 16 |
|
16 | 17 | import numpy as np
|
17 | 18 | import pandas as pd
|
|
104 | 105 | help='use Native AMP for mixed precision training')
|
105 | 106 | parser.add_argument('--amp-dtype', default='float16', type=str,
|
106 | 107 | help='lower precision AMP dtype (default: float16)')
|
| 108 | +parser.add_argument('--model-dtype', default=None, type=str, |
| 109 | + help='Model dtype override (non-AMP) (default: float32)') |
107 | 110 | parser.add_argument('--fuser', default='', type=str,
|
108 | 111 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
109 | 112 | parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
@@ -160,9 +163,15 @@ def main():
|
160 | 163 |
|
161 | 164 | device = torch.device(args.device)
|
162 | 165 |
|
| 166 | + model_dtype = None |
| 167 | + if args.model_dtype: |
| 168 | + assert args.model_dtype in ('float32', 'float16', 'bfloat16') |
| 169 | + model_dtype = getattr(torch, args.model_dtype) |
| 170 | + |
163 | 171 | # resolve AMP arguments based on PyTorch / Apex availability
|
164 | 172 | amp_autocast = suppress
|
165 | 173 | if args.amp:
|
| 174 | + assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP' |
166 | 175 | assert args.amp_dtype in ('float16', 'bfloat16')
|
167 | 176 | amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
|
168 | 177 | amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
@@ -200,7 +209,7 @@ def main():
|
200 | 209 | if args.test_pool:
|
201 | 210 | model, test_time_pool = apply_test_time_pool(model, data_config)
|
202 | 211 |
|
203 |
| - model = model.to(device) |
| 212 | + model = model.to(device=device, dtype=model_dtype) |
204 | 213 | model.eval()
|
205 | 214 | if args.channels_last:
|
206 | 215 | model = model.to(memory_format=torch.channels_last)
|
@@ -236,6 +245,7 @@ def main():
|
236 | 245 | use_prefetcher=True,
|
237 | 246 | num_workers=workers,
|
238 | 247 | device=device,
|
| 248 | + img_dtype=model_dtype or torch.float32, |
239 | 249 | **data_config,
|
240 | 250 | )
|
241 | 251 |
|
@@ -279,7 +289,7 @@ def main():
|
279 | 289 | np_labels = to_label(np_indices)
|
280 | 290 | all_labels.append(np_labels)
|
281 | 291 |
|
282 |
| - all_outputs.append(output.cpu().numpy()) |
| 292 | + all_outputs.append(output.float().cpu().numpy()) |
283 | 293 |
|
284 | 294 | # measure elapsed time
|
285 | 295 | batch_time.update(time.time() - end)
|
@@ -343,6 +353,7 @@ def main():
|
343 | 353 |
|
344 | 354 |
|
345 | 355 | def save_results(df, results_filename, results_format='csv', filename_col='filename'):
|
| 356 | + np.set_printoptions(threshold=maxsize) |
346 | 357 | results_filename += _FMT_EXT[results_format]
|
347 | 358 | if results_format == 'parquet':
|
348 | 359 | df.set_index(filename_col).to_parquet(results_filename)
|
|
0 commit comments