Skip to content

Commit dabbddc

Browse files
authored
feat: Add Support for Torch AOTI (#477)
This change adds support for the PyTorch backend to handle Torch AOTI models based on the specified backend and platform. Prior to this change the pytorch backend only supported the "pytorch_libtorch" platform which assumed default model file name of "model.pt". After this change, the pytorch backend with support a new platform "torch_aoti" which assumes a default model file name of "model.pt2".
1 parent 66f09f8 commit dabbddc

File tree

3 files changed

+49
-30
lines changed

3 files changed

+49
-30
lines changed

src/constants.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2018-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -57,6 +57,10 @@ constexpr char kPyTorchLibTorchPlatform[] = "pytorch_libtorch";
5757
constexpr char kPyTorchLibTorchFilename[] = "model.pt";
5858
constexpr char kPyTorchBackend[] = "pytorch";
5959

60+
constexpr char kPyTorchAotiPlatform[] = "torch_aoti";
61+
constexpr char kPyTorchAotiFilename[] = "model.pt2";
62+
constexpr char kPyTorchAotiBackend[] = "pytorch";
63+
6064
constexpr char kPythonFilename[] = "model.py";
6165
constexpr char kPythonBackend[] = "python";
6266

src/model_config_utils.cc

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2018-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -1072,7 +1072,7 @@ AutoCompleteBackendFields(
10721072

10731073
// Trying to fill the 'backend', 'default_model_filename' field.
10741074

1075-
// TensorFlow
1075+
// TensorFlow <-- TensorFlow backend is deprecated -->
10761076
// For TF backend, the platform is required
10771077
if (config->platform().empty()) {
10781078
// Check 'backend', 'default_model_filename', and the actual directory
@@ -1199,33 +1199,39 @@ AutoCompleteBackendFields(
11991199

12001200
// PyTorch
12011201
if (config->backend().empty()) {
1202-
if ((config->platform() == kPyTorchLibTorchPlatform) ||
1203-
(config->default_model_filename() == kPyTorchLibTorchFilename)) {
1202+
// Torch JIT interface
1203+
if (config->platform() == kPyTorchLibTorchPlatform ||
1204+
config->default_model_filename() == kPyTorchLibTorchFilename) {
12041205
config->set_backend(kPyTorchBackend);
1205-
} else if (
1206-
config->platform().empty() &&
1207-
config->default_model_filename().empty() && has_version) {
1208-
bool is_dir = false;
1209-
if (version_dir_content.find(kPyTorchLibTorchFilename) !=
1210-
version_dir_content.end()) {
1211-
RETURN_IF_ERROR(IsDirectory(
1212-
JoinPath({version_path, kPyTorchLibTorchFilename}), &is_dir));
1213-
if (!is_dir) {
1214-
config->set_backend(kPyTorchBackend);
1215-
}
1206+
} else
1207+
// Torch AOTI interface
1208+
if (config->platform() == kPyTorchAotiPlatform ||
1209+
config->default_model_filename() == kPyTorchAotiFilename) {
1210+
config->set_backend(kPyTorchAotiBackend);
1211+
} else if (
1212+
config->platform().empty() &&
1213+
config->default_model_filename().empty() && has_version) {
1214+
bool is_dir = false;
1215+
1216+
// Torch JIT interface
1217+
if (version_dir_content.find(kPyTorchLibTorchFilename) !=
1218+
version_dir_content.end()) {
1219+
RETURN_IF_ERROR(IsDirectory(
1220+
JoinPath({version_path, kPyTorchLibTorchFilename}), &is_dir));
1221+
if (!is_dir) {
1222+
config->set_backend(kPyTorchBackend);
1223+
}
1224+
} else
1225+
// Torch AOTI interface
1226+
if (version_dir_content.find(kPyTorchAotiFilename) !=
1227+
version_dir_content.end()) {
1228+
RETURN_IF_ERROR(IsDirectory(
1229+
JoinPath({version_path, kPyTorchAotiFilename}), &is_dir));
1230+
if (!is_dir) {
1231+
config->set_backend(kPyTorchAotiBackend);
1232+
}
1233+
}
12161234
}
1217-
}
1218-
}
1219-
if (config->backend() == kPyTorchBackend) {
1220-
if (config->platform().empty()) {
1221-
// do not introduce new platforms, new runtimes may ignore this field.
1222-
config->set_platform(kPyTorchLibTorchPlatform);
1223-
}
1224-
if (config->runtime() != kPythonFilename &&
1225-
config->default_model_filename().empty()) {
1226-
config->set_default_model_filename(kPyTorchLibTorchFilename);
1227-
}
1228-
return Status::Success;
12291235
}
12301236

12311237
// Python
@@ -2369,6 +2375,10 @@ GetBackendTypeFromPlatform(const std::string& platform_name)
23692375
return BackendType::BACKEND_TYPE_PYTORCH;
23702376
}
23712377

2378+
if (platform_name == kPyTorchAotiPlatform) {
2379+
return BackendType::BACKEND_TYPE_TORCHAOTI;
2380+
}
2381+
23722382
return BackendType::BACKEND_TYPE_UNKNOWN;
23732383
}
23742384

@@ -2395,6 +2405,10 @@ GetBackendType(const std::string& backend_name)
23952405
return BackendType::BACKEND_TYPE_PYTORCH;
23962406
}
23972407

2408+
if (backend_name == kPyTorchAotiBackend) {
2409+
return BackendType::BACKEND_TYPE_TORCHAOTI;
2410+
}
2411+
23982412
return BackendType::BACKEND_TYPE_UNKNOWN;
23992413
}
24002414

src/model_config_utils.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2018-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -39,7 +39,8 @@ enum BackendType {
3939
BACKEND_TYPE_TENSORRT = 1,
4040
BACKEND_TYPE_TENSORFLOW = 2,
4141
BACKEND_TYPE_ONNXRUNTIME = 3,
42-
BACKEND_TYPE_PYTORCH = 4
42+
BACKEND_TYPE_PYTORCH = 4,
43+
BACKEND_TYPE_TORCHAOTI = 4, // Torch AOTI uses the same backend as PyTorch
4344
};
4445

4546
// Get version of a model from the path containing the model

0 commit comments

Comments
 (0)