Skip to content

Commit d5fa2ac

Browse files
authored
Improve Windows ETW callback registration and fix issues (#24877)
### Description - `EtwRegistrationManager`. Make sure all fields initialized by a constructor - Register a callback object instead of a pointer to it. Store it in the map with a session unique key. - Register `ML_Ort_Provider_Etw_Callback` once for all the sessions. The first session registers, the last one to go away removes the callback to Log all sessions. For this we make callbacks ref-counted inside the map they are stored in. This is done to prevent a deadlock where `active_sessions_mutex_` and `callback_mutex_` are acquired from different threads in a different order. - Create a registration guard to remove callbacks in case `InferenceSession` constructor does not finish. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> This PR is inspired by #24773. Current code exhibits multiple issues. - `EtwRegistrationManager` constructor does not initialize all of the fields including the `InitializationStatus`. - Global callback object is registered and re-created by every session. Customers sometimes run thousands of models in the same sessions which results in a quadratic ETW costs. The callback object is destroyed and recreated every time a session is created. - There is a chance that InferenceSession constructor does not finish, and the callback would remain registered. This may result in intermittent hard to diagnose bugs. - `active_sessions_lock_` and `callback` lock are not acquired/released in the same order by different threads which is a classic deadlock scenario.
1 parent 4a3b63f commit d5fa2ac

15 files changed

+368
-294
lines changed

onnxruntime/core/framework/execution_provider.cc

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33
#include "core/framework/execution_provider.h"
4+
#include "core/framework/execution_providers.h"
45

56
#include "core/graph/graph_viewer.h"
67
#include "core/framework/compute_capability.h"
@@ -9,6 +10,8 @@
910
#include "core/framework/murmurhash3.h"
1011
#include "core/framework/op_kernel.h"
1112

13+
#include <stdint.h>
14+
1215
namespace onnxruntime {
1316

1417
std::vector<std::unique_ptr<ComputeCapability>>
@@ -37,4 +40,99 @@ common::Status IExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
3740
}
3841

3942
#endif
43+
44+
ExecutionProviders::ExecutionProviders() {
45+
#ifdef _WIN32
46+
// Register callback for ETW capture state (rundown)
47+
etw_callback_key_ = "ExecutionProviders_rundown_";
48+
etw_callback_key_.append(std::to_string(reinterpret_cast<uintptr_t>(this)));
49+
WindowsTelemetry::RegisterInternalCallback(
50+
etw_callback_key_,
51+
[this](LPCGUID SourceId,
52+
ULONG IsEnabled,
53+
UCHAR Level,
54+
ULONGLONG MatchAnyKeyword,
55+
ULONGLONG MatchAllKeyword,
56+
PEVENT_FILTER_DESCRIPTOR FilterData,
57+
PVOID CallbackContext) { this->EtwProvidersCallback(SourceId, IsEnabled, Level,
58+
MatchAnyKeyword, MatchAllKeyword,
59+
FilterData, CallbackContext); });
60+
#endif
61+
}
62+
63+
ExecutionProviders::~ExecutionProviders() {
64+
#ifdef _WIN32
65+
WindowsTelemetry::UnregisterInternalCallback(etw_callback_key_);
66+
#endif
67+
}
68+
69+
common::Status ExecutionProviders::Add(const std::string& provider_id,
70+
const std::shared_ptr<IExecutionProvider>& p_exec_provider) {
71+
// make sure there are no issues before we change any internal data structures
72+
if (provider_idx_map_.find(provider_id) != provider_idx_map_.end()) {
73+
auto status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Provider ", provider_id, " has already been registered.");
74+
LOGS_DEFAULT(ERROR) << status.ErrorMessage();
75+
return status;
76+
}
77+
78+
// index that provider will have after insertion
79+
auto new_provider_idx = exec_providers_.size();
80+
81+
ORT_IGNORE_RETURN_VALUE(provider_idx_map_.insert({provider_id, new_provider_idx}));
82+
83+
// update execution provider options
84+
auto providerOptions = p_exec_provider->GetProviderOptions();
85+
exec_provider_options_[provider_id] = providerOptions;
86+
87+
#ifdef _WIN32
88+
LogProviderOptions(provider_id, providerOptions, false);
89+
#endif
90+
91+
exec_provider_ids_.push_back(provider_id);
92+
exec_providers_.push_back(p_exec_provider);
93+
return Status::OK();
94+
}
95+
96+
#ifdef _WIN32
97+
void ExecutionProviders::EtwProvidersCallback(LPCGUID /* SourceId */,
98+
ULONG IsEnabled,
99+
UCHAR /* Level */,
100+
ULONGLONG MatchAnyKeyword,
101+
ULONGLONG /* MatchAllKeyword */,
102+
PEVENT_FILTER_DESCRIPTOR /* FilterData */,
103+
PVOID /* CallbackContext */) {
104+
// Check if this callback is for capturing state
105+
if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) &&
106+
((MatchAnyKeyword & static_cast<ULONGLONG>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) {
107+
for (size_t i = 0; i < exec_providers_.size(); ++i) {
108+
const auto& provider_id = exec_provider_ids_[i];
109+
110+
auto it = exec_provider_options_.find(provider_id);
111+
if (it != exec_provider_options_.end()) {
112+
const auto& options = it->second;
113+
114+
LogProviderOptions(provider_id, options, true);
115+
}
116+
}
117+
}
118+
}
119+
120+
void ExecutionProviders::LogProviderOptions(const std::string& provider_id,
121+
const ProviderOptions& providerOptions,
122+
bool captureState) {
123+
for (const auto& config_pair : providerOptions) {
124+
TraceLoggingWrite(
125+
telemetry_provider_handle,
126+
"ProviderOptions",
127+
TraceLoggingKeyword(static_cast<uint64_t>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
128+
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
129+
TraceLoggingString(provider_id.c_str(), "ProviderId"),
130+
TraceLoggingString(config_pair.first.c_str(), "Key"),
131+
TraceLoggingString(config_pair.second.c_str(), "Value"),
132+
TraceLoggingBool(captureState, "isCaptureState"));
133+
}
134+
}
135+
136+
#endif
137+
40138
} // namespace onnxruntime

onnxruntime/core/framework/execution_providers.h

Lines changed: 13 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -26,91 +26,24 @@ Class for managing lookup of the execution providers in a session.
2626
*/
2727
class ExecutionProviders {
2828
public:
29-
ExecutionProviders() {
30-
#ifdef _WIN32
31-
// Register callback for ETW capture state (rundown)
32-
etw_callback_ = onnxruntime::WindowsTelemetry::EtwInternalCallback(
33-
[this](
34-
LPCGUID SourceId,
35-
ULONG IsEnabled,
36-
UCHAR Level,
37-
ULONGLONG MatchAnyKeyword,
38-
ULONGLONG MatchAllKeyword,
39-
PEVENT_FILTER_DESCRIPTOR FilterData,
40-
PVOID CallbackContext) {
41-
(void)SourceId;
42-
(void)Level;
43-
(void)MatchAnyKeyword;
44-
(void)MatchAllKeyword;
45-
(void)FilterData;
46-
(void)CallbackContext;
47-
48-
// Check if this callback is for capturing state
49-
if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) &&
50-
((MatchAnyKeyword & static_cast<ULONGLONG>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) {
51-
for (size_t i = 0; i < exec_providers_.size(); ++i) {
52-
const auto& provider_id = exec_provider_ids_[i];
53-
54-
auto it = exec_provider_options_.find(provider_id);
55-
if (it != exec_provider_options_.end()) {
56-
const auto& options = it->second;
57-
58-
LogProviderOptions(provider_id, options, true);
59-
}
60-
}
61-
}
62-
});
63-
WindowsTelemetry::RegisterInternalCallback(etw_callback_);
64-
#endif
65-
}
66-
67-
~ExecutionProviders() {
68-
#ifdef _WIN32
69-
WindowsTelemetry ::UnregisterInternalCallback(etw_callback_);
70-
#endif
71-
}
29+
ExecutionProviders();
7230

73-
common::Status
74-
Add(const std::string& provider_id, const std::shared_ptr<IExecutionProvider>& p_exec_provider) {
75-
// make sure there are no issues before we change any internal data structures
76-
if (provider_idx_map_.find(provider_id) != provider_idx_map_.end()) {
77-
auto status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Provider ", provider_id, " has already been registered.");
78-
LOGS_DEFAULT(ERROR) << status.ErrorMessage();
79-
return status;
80-
}
31+
~ExecutionProviders();
8132

82-
// index that provider will have after insertion
83-
auto new_provider_idx = exec_providers_.size();
84-
85-
ORT_IGNORE_RETURN_VALUE(provider_idx_map_.insert({provider_id, new_provider_idx}));
86-
87-
// update execution provider options
88-
auto providerOptions = p_exec_provider->GetProviderOptions();
89-
exec_provider_options_[provider_id] = providerOptions;
33+
common::Status Add(const std::string& provider_id, const std::shared_ptr<IExecutionProvider>& p_exec_provider);
9034

9135
#ifdef _WIN32
92-
LogProviderOptions(provider_id, providerOptions, false);
93-
#endif
9436

95-
exec_provider_ids_.push_back(provider_id);
96-
exec_providers_.push_back(p_exec_provider);
97-
return Status::OK();
98-
}
37+
void EtwProvidersCallback(LPCGUID SourceId,
38+
ULONG IsEnabled,
39+
UCHAR Level,
40+
ULONGLONG MatchAnyKeyword,
41+
ULONGLONG MatchAllKeyword,
42+
PEVENT_FILTER_DESCRIPTOR FilterData,
43+
PVOID CallbackContext);
9944

100-
#ifdef _WIN32
101-
void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions, bool captureState) {
102-
for (const auto& config_pair : providerOptions) {
103-
TraceLoggingWrite(
104-
telemetry_provider_handle,
105-
"ProviderOptions",
106-
TraceLoggingKeyword(static_cast<uint64_t>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
107-
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
108-
TraceLoggingString(provider_id.c_str(), "ProviderId"),
109-
TraceLoggingString(config_pair.first.c_str(), "Key"),
110-
TraceLoggingString(config_pair.second.c_str(), "Value"),
111-
TraceLoggingBool(captureState, "isCaptureState"));
112-
}
113-
}
45+
void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions,
46+
bool captureState);
11447
#endif
11548

11649
const IExecutionProvider* Get(const onnxruntime::Node& node) const {
@@ -169,7 +102,7 @@ class ExecutionProviders {
169102
bool cpu_execution_provider_was_implicitly_added_ = false;
170103

171104
#ifdef _WIN32
172-
WindowsTelemetry::EtwInternalCallback etw_callback_;
105+
std::string etw_callback_key_;
173106
#endif
174107
};
175108
} // namespace onnxruntime

onnxruntime/core/platform/windows/logging/etw_sink.cc

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,15 @@ HRESULT EtwRegistrationManager::Status() const {
106106
return etw_status_;
107107
}
108108

109-
void EtwRegistrationManager::RegisterInternalCallback(const EtwInternalCallback& callback) {
109+
void EtwRegistrationManager::RegisterInternalCallback(const std::string& cb_key, EtwInternalCallback callback) {
110110
std::lock_guard<std::mutex> lock(callbacks_mutex_);
111-
callbacks_.push_back(&callback);
111+
[[maybe_unused]] auto result = callbacks_.emplace(cb_key, std::move(callback));
112+
assert(result.second);
112113
}
113114

114-
void EtwRegistrationManager::UnregisterInternalCallback(const EtwInternalCallback& callback) {
115+
void EtwRegistrationManager::UnregisterInternalCallback(const std::string& cb_key) {
115116
std::lock_guard<std::mutex> lock(callbacks_mutex_);
116-
auto new_end = std::remove_if(callbacks_.begin(), callbacks_.end(),
117-
[&callback](const EtwInternalCallback* ptr) {
118-
return ptr == &callback;
119-
});
120-
callbacks_.erase(new_end, callbacks_.end());
117+
callbacks_.erase(cb_key);
121118
}
122119

123120
void NTAPI EtwRegistrationManager::ORT_TL_EtwEnableCallback(
@@ -138,21 +135,12 @@ void NTAPI EtwRegistrationManager::ORT_TL_EtwEnableCallback(
138135
manager.InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext);
139136
}
140137

141-
EtwRegistrationManager::~EtwRegistrationManager() {
142-
std::lock_guard<std::mutex> lock(callbacks_mutex_);
143-
callbacks_.clear();
144-
if (initialization_status_ == InitializationStatus::Initialized ||
145-
initialization_status_ == InitializationStatus::Initializing) {
146-
std::lock_guard<std::mutex> init_lock(init_mutex_);
147-
assert(initialization_status_ != InitializationStatus::Initializing);
148-
if (initialization_status_ == InitializationStatus::Initialized) {
149-
::TraceLoggingUnregister(etw_provider_handle);
150-
initialization_status_ = InitializationStatus::NotInitialized;
151-
}
152-
}
153-
}
154-
155-
EtwRegistrationManager::EtwRegistrationManager() {
138+
EtwRegistrationManager::EtwRegistrationManager()
139+
: initialization_status_(InitializationStatus::NotInitialized),
140+
is_enabled_(false),
141+
level_(),
142+
keyword_(0),
143+
etw_status_(S_OK) {
156144
}
157145

158146
void EtwRegistrationManager::LazyInitialize() {
@@ -173,6 +161,13 @@ void EtwRegistrationManager::LazyInitialize() {
173161
}
174162
}
175163

164+
EtwRegistrationManager::~EtwRegistrationManager() {
165+
if (initialization_status_ == InitializationStatus::Initialized) {
166+
::TraceLoggingUnregister(etw_provider_handle);
167+
initialization_status_ = InitializationStatus::NotInitialized;
168+
}
169+
}
170+
176171
void EtwRegistrationManager::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword,
177172
ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData,
178173
PVOID CallbackContext) {
@@ -182,10 +177,9 @@ void EtwRegistrationManager::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled,
182177
}
183178

184179
std::lock_guard<std::mutex> lock(callbacks_mutex_);
185-
for (const auto& callback : callbacks_) {
186-
if (callback != nullptr) {
187-
(*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext);
188-
}
180+
for (const auto& entry : callbacks_) {
181+
const auto& cb = entry.second;
182+
cb(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext);
189183
}
190184
}
191185

onnxruntime/core/platform/windows/logging/etw_sink.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <atomic>
2121
#include <iostream>
2222
#include <string>
23+
#include <unordered_map>
2324
#include <vector>
2425

2526
#include "core/common/logging/capture.h"
@@ -77,9 +78,9 @@ class EtwRegistrationManager {
7778
// Get the ETW registration status
7879
HRESULT Status() const;
7980

80-
void RegisterInternalCallback(const EtwInternalCallback& callback);
81+
void RegisterInternalCallback(const std::string& cb_key, EtwInternalCallback callback);
8182

82-
void UnregisterInternalCallback(const EtwInternalCallback& callback);
83+
void UnregisterInternalCallback(const std::string& cb_key);
8384

8485
private:
8586
EtwRegistrationManager();
@@ -100,11 +101,11 @@ class EtwRegistrationManager {
100101
_In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData,
101102
_In_opt_ PVOID CallbackContext);
102103

103-
std::vector<const EtwInternalCallback*> callbacks_;
104+
std::mutex init_mutex_;
105+
std::atomic<InitializationStatus> initialization_status_ = InitializationStatus::NotInitialized;
106+
std::unordered_map<std::string, EtwInternalCallback> callbacks_;
104107
std::mutex callbacks_mutex_;
105108
mutable std::mutex provider_change_mutex_;
106-
std::mutex init_mutex_;
107-
InitializationStatus initialization_status_ = InitializationStatus::NotInitialized;
108109
bool is_enabled_;
109110
UCHAR level_;
110111
ULONGLONG keyword_;
@@ -133,8 +134,8 @@ class EtwRegistrationManager {
133134
Severity MapLevelToSeverity() { return Severity::kFATAL; }
134135
uint64_t Keyword() const { return 0; }
135136
HRESULT Status() const { return 0; }
136-
void RegisterInternalCallback(const EtwInternalCallback& callback) {}
137-
void UnregisterInternalCallback(const EtwInternalCallback& callback) {}
137+
void RegisterInternalCallback(const std::string& cb_key, EtwInternalCallback callback) {}
138+
void UnregisterInternalCallback(const std::string& cb_key) {}
138139

139140
private:
140141
EtwRegistrationManager() = default;

0 commit comments

Comments
 (0)