Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/constants.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2018-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -57,6 +57,10 @@ constexpr char kPyTorchLibTorchPlatform[] = "pytorch_libtorch";
constexpr char kPyTorchLibTorchFilename[] = "model.pt";
constexpr char kPyTorchBackend[] = "pytorch";

constexpr char kPyTorchAotiPlatform[] = "torch_aoti";
constexpr char kPyTorchAotiFilename[] = "model.pt2";
constexpr char kPyTorchAotiBackend[] = "pytorch";

constexpr char kPythonFilename[] = "model.py";
constexpr char kPythonBackend[] = "python";

Expand Down
68 changes: 41 additions & 27 deletions src/model_config_utils.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2018-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -1072,7 +1072,7 @@ AutoCompleteBackendFields(

// Trying to fill the 'backend', 'default_model_filename' field.

// TensorFlow
// TensorFlow <-- TensorFlow backend is deprecated -->
// For TF backend, the platform is required
if (config->platform().empty()) {
// Check 'backend', 'default_model_filename', and the actual directory
Expand Down Expand Up @@ -1199,33 +1199,39 @@ AutoCompleteBackendFields(

// PyTorch
if (config->backend().empty()) {
if ((config->platform() == kPyTorchLibTorchPlatform) ||
(config->default_model_filename() == kPyTorchLibTorchFilename)) {
// Torch JIT interface
if (config->platform() == kPyTorchLibTorchPlatform ||
config->default_model_filename() == kPyTorchLibTorchFilename) {
config->set_backend(kPyTorchBackend);
} else if (
config->platform().empty() &&
config->default_model_filename().empty() && has_version) {
bool is_dir = false;
if (version_dir_content.find(kPyTorchLibTorchFilename) !=
version_dir_content.end()) {
RETURN_IF_ERROR(IsDirectory(
JoinPath({version_path, kPyTorchLibTorchFilename}), &is_dir));
if (!is_dir) {
config->set_backend(kPyTorchBackend);
}
} else
// Torch AOTI interface
if (config->platform() == kPyTorchAotiPlatform ||
config->default_model_filename() == kPyTorchAotiFilename) {
config->set_backend(kPyTorchAotiBackend);
} else if (
config->platform().empty() &&
config->default_model_filename().empty() && has_version) {
bool is_dir = false;

// Torch JIT interface
if (version_dir_content.find(kPyTorchLibTorchFilename) !=
version_dir_content.end()) {
RETURN_IF_ERROR(IsDirectory(
JoinPath({version_path, kPyTorchLibTorchFilename}), &is_dir));
if (!is_dir) {
config->set_backend(kPyTorchBackend);
}
} else
// Torch AOTI interface
if (version_dir_content.find(kPyTorchAotiFilename) !=
version_dir_content.end()) {
RETURN_IF_ERROR(IsDirectory(
JoinPath({version_path, kPyTorchAotiFilename}), &is_dir));
if (!is_dir) {
config->set_backend(kPyTorchAotiBackend);
}
}
}
}
}
if (config->backend() == kPyTorchBackend) {
if (config->platform().empty()) {
// do not introduce new platforms, new runtimes may ignore this field.
config->set_platform(kPyTorchLibTorchPlatform);
}
if (config->runtime() != kPythonFilename &&
config->default_model_filename().empty()) {
config->set_default_model_filename(kPyTorchLibTorchFilename);
}
return Status::Success;
}

// Python
Expand Down Expand Up @@ -2369,6 +2375,10 @@ GetBackendTypeFromPlatform(const std::string& platform_name)
return BackendType::BACKEND_TYPE_PYTORCH;
}

if (platform_name == kPyTorchAotiPlatform) {
return BackendType::BACKEND_TYPE_TORCHAOTI;
}

return BackendType::BACKEND_TYPE_UNKNOWN;
}

Expand All @@ -2395,6 +2405,10 @@ GetBackendType(const std::string& backend_name)
return BackendType::BACKEND_TYPE_PYTORCH;
}

if (backend_name == kPyTorchAotiBackend) {
return BackendType::BACKEND_TYPE_TORCHAOTI;
}

return BackendType::BACKEND_TYPE_UNKNOWN;
}

Expand Down
5 changes: 3 additions & 2 deletions src/model_config_utils.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2018-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -39,7 +39,8 @@ enum BackendType {
BACKEND_TYPE_TENSORRT = 1,
BACKEND_TYPE_TENSORFLOW = 2,
BACKEND_TYPE_ONNXRUNTIME = 3,
BACKEND_TYPE_PYTORCH = 4
BACKEND_TYPE_PYTORCH = 4,
BACKEND_TYPE_TORCHAOTI = 4, // Torch AOTI uses the same backend as PyTorch
};

// Get version of a model from the path containing the model
Expand Down
Loading