Skip to content

Commit 1572769

Browse files
authored
Merge pull request #2402 from JosuaRieder/fix_inference_csv_export
disable abbreviating csv inference output with ellipses
2 parents 53c3c89 + fc0609b commit 1572769

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

inference.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import time
1313
from contextlib import suppress
1414
from functools import partial
15+
from sys import maxsize
1516

1617
import numpy as np
1718
import pandas as pd
@@ -104,6 +105,8 @@
104105
help='use Native AMP for mixed precision training')
105106
parser.add_argument('--amp-dtype', default='float16', type=str,
106107
help='lower precision AMP dtype (default: float16)')
108+
parser.add_argument('--model-dtype', default=None, type=str,
109+
help='Model dtype override (non-AMP) (default: float32)')
107110
parser.add_argument('--fuser', default='', type=str,
108111
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
109112
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
@@ -160,9 +163,15 @@ def main():
160163

161164
device = torch.device(args.device)
162165

166+
model_dtype = None
167+
if args.model_dtype:
168+
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
169+
model_dtype = getattr(torch, args.model_dtype)
170+
163171
# resolve AMP arguments based on PyTorch / Apex availability
164172
amp_autocast = suppress
165173
if args.amp:
174+
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
166175
assert args.amp_dtype in ('float16', 'bfloat16')
167176
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
168177
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
@@ -200,7 +209,7 @@ def main():
200209
if args.test_pool:
201210
model, test_time_pool = apply_test_time_pool(model, data_config)
202211

203-
model = model.to(device)
212+
model = model.to(device=device, dtype=model_dtype)
204213
model.eval()
205214
if args.channels_last:
206215
model = model.to(memory_format=torch.channels_last)
@@ -236,6 +245,7 @@ def main():
236245
use_prefetcher=True,
237246
num_workers=workers,
238247
device=device,
248+
img_dtype=model_dtype or torch.float32,
239249
**data_config,
240250
)
241251

@@ -279,7 +289,7 @@ def main():
279289
np_labels = to_label(np_indices)
280290
all_labels.append(np_labels)
281291

282-
all_outputs.append(output.cpu().numpy())
292+
all_outputs.append(output.float().cpu().numpy())
283293

284294
# measure elapsed time
285295
batch_time.update(time.time() - end)
@@ -343,6 +353,7 @@ def main():
343353

344354

345355
def save_results(df, results_filename, results_format='csv', filename_col='filename'):
356+
np.set_printoptions(threshold=maxsize)
346357
results_filename += _FMT_EXT[results_format]
347358
if results_format == 'parquet':
348359
df.set_index(filename_col).to_parquet(results_filename)

0 commit comments

Comments
 (0)