1- // Copyright 2018-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+ // Copyright 2018-2026 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22//
33// Redistribution and use in source and binary forms, with or without
44// modification, are permitted provided that the following conditions
@@ -1072,7 +1072,7 @@ AutoCompleteBackendFields(
10721072
10731073 // Trying to fill the 'backend', 'default_model_filename' field.
10741074
1075- // TensorFlow
1075+ // TensorFlow <-- TensorFlow backend is deprecated -->
10761076 // For TF backend, the platform is required
10771077 if (config->platform ().empty ()) {
10781078 // Check 'backend', 'default_model_filename', and the actual directory
@@ -1199,33 +1199,39 @@ AutoCompleteBackendFields(
11991199
12001200 // PyTorch
12011201 if (config->backend ().empty ()) {
1202- if ((config->platform () == kPyTorchLibTorchPlatform ) ||
1203- (config->default_model_filename () == kPyTorchLibTorchFilename )) {
1202+ // Torch JIT interface
1203+ if (config->platform () == kPyTorchLibTorchPlatform ||
1204+ config->default_model_filename () == kPyTorchLibTorchFilename ) {
12041205 config->set_backend (kPyTorchBackend );
1205- } else if (
1206- config->platform ().empty () &&
1207- config->default_model_filename ().empty () && has_version) {
1208- bool is_dir = false ;
1209- if (version_dir_content.find (kPyTorchLibTorchFilename ) !=
1210- version_dir_content.end ()) {
1211- RETURN_IF_ERROR (IsDirectory (
1212- JoinPath ({version_path, kPyTorchLibTorchFilename }), &is_dir));
1213- if (!is_dir) {
1214- config->set_backend (kPyTorchBackend );
1215- }
1206+ } else
1207+ // Torch AOTI interface
1208+ if (config->platform () == kPyTorchAotiPlatform ||
1209+ config->default_model_filename () == kPyTorchAotiFilename ) {
1210+ config->set_backend (kPyTorchAotiBackend );
1211+ } else if (
1212+ config->platform ().empty () &&
1213+ config->default_model_filename ().empty () && has_version) {
1214+ bool is_dir = false ;
1215+
1216+ // Torch JIT interface
1217+ if (version_dir_content.find (kPyTorchLibTorchFilename ) !=
1218+ version_dir_content.end ()) {
1219+ RETURN_IF_ERROR (IsDirectory (
1220+ JoinPath ({version_path, kPyTorchLibTorchFilename }), &is_dir));
1221+ if (!is_dir) {
1222+ config->set_backend (kPyTorchBackend );
1223+ }
1224+ } else
1225+ // Torch AOTI interface
1226+ if (version_dir_content.find (kPyTorchAotiFilename ) !=
1227+ version_dir_content.end ()) {
1228+ RETURN_IF_ERROR (IsDirectory (
1229+ JoinPath ({version_path, kPyTorchAotiFilename }), &is_dir));
1230+ if (!is_dir) {
1231+ config->set_backend (kPyTorchAotiBackend );
1232+ }
1233+ }
12161234 }
1217- }
1218- }
1219- if (config->backend () == kPyTorchBackend ) {
1220- if (config->platform ().empty ()) {
1221- // do not introduce new platforms, new runtimes may ignore this field.
1222- config->set_platform (kPyTorchLibTorchPlatform );
1223- }
1224- if (config->runtime () != kPythonFilename &&
1225- config->default_model_filename ().empty ()) {
1226- config->set_default_model_filename (kPyTorchLibTorchFilename );
1227- }
1228- return Status::Success;
12291235 }
12301236
12311237 // Python
@@ -2369,6 +2375,10 @@ GetBackendTypeFromPlatform(const std::string& platform_name)
23692375 return BackendType::BACKEND_TYPE_PYTORCH;
23702376 }
23712377
2378+ if (platform_name == kPyTorchAotiPlatform ) {
2379+ return BackendType::BACKEND_TYPE_TORCHAOTI;
2380+ }
2381+
23722382 return BackendType::BACKEND_TYPE_UNKNOWN;
23732383}
23742384
@@ -2395,6 +2405,10 @@ GetBackendType(const std::string& backend_name)
23952405 return BackendType::BACKEND_TYPE_PYTORCH;
23962406 }
23972407
2408+ if (backend_name == kPyTorchAotiBackend ) {
2409+ return BackendType::BACKEND_TYPE_TORCHAOTI;
2410+ }
2411+
23982412 return BackendType::BACKEND_TYPE_UNKNOWN;
23992413}
24002414
0 commit comments