Skip to content

Commit 733e123

Browse files
lmproviders timeout as well
1 parent f101933 commit 733e123

File tree

5 files changed

+108
-49
lines changed

5 files changed

+108
-49
lines changed

src/cascadia/QueryExtension/AzureLLMProvider.cpp

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -146,35 +146,48 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
146146
try
147147
{
148148
const auto sendRequestOperation = _httpClient.SendRequestAsync(request);
149+
150+
// if the caller cancels this operation, make sure to cancel the http request as well
149151
auto cancellationToken{ co_await winrt::get_cancellation_token() };
150152
cancellationToken.callback([sendRequestOperation] {
151153
sendRequestOperation.Cancel();
152154
});
153-
const auto response{ co_await sendRequestOperation };
154-
// Parse out the suggestion from the response
155-
const auto string{ co_await response.Content().ReadAsStringAsync() };
156-
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
157-
if (jsonResult.HasKey(errorString))
158-
{
159-
const auto errorObject = jsonResult.GetNamedObject(errorString);
160-
message = errorObject.GetNamedString(messageString);
161-
errorType = ErrorTypes::FromProvider;
162-
}
163-
else
155+
156+
if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed)
164157
{
165-
if (_verifyModelIsValidHelper(jsonResult))
158+
// Parse out the suggestion from the response
159+
const auto response = sendRequestOperation.GetResults();
160+
const auto string{ co_await response.Content().ReadAsStringAsync() };
161+
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
162+
if (jsonResult.HasKey(errorString))
166163
{
167-
const auto choices = jsonResult.GetNamedArray(L"choices");
168-
const auto firstChoice = choices.GetAt(0).GetObject();
169-
const auto messageObject = firstChoice.GetNamedObject(messageString);
170-
message = messageObject.GetNamedString(contentString);
164+
const auto errorObject = jsonResult.GetNamedObject(errorString);
165+
message = errorObject.GetNamedString(messageString);
166+
errorType = ErrorTypes::FromProvider;
171167
}
172168
else
173169
{
174-
message = RS_(L"InvalidModelMessage");
175-
errorType = ErrorTypes::InvalidModel;
170+
if (_verifyModelIsValidHelper(jsonResult))
171+
{
172+
const auto choices = jsonResult.GetNamedArray(L"choices");
173+
const auto firstChoice = choices.GetAt(0).GetObject();
174+
const auto messageObject = firstChoice.GetNamedObject(messageString);
175+
message = messageObject.GetNamedString(contentString);
176+
}
177+
else
178+
{
179+
message = RS_(L"InvalidModelMessage");
180+
errorType = ErrorTypes::InvalidModel;
181+
}
176182
}
177183
}
184+
else
185+
{
186+
// if the http request takes too long, cancel the http request and return an error
187+
sendRequestOperation.Cancel();
188+
message = RS_(L"UnknownErrorMessage");
189+
errorType = ErrorTypes::Unknown;
190+
}
178191
}
179192
catch (...)
180193
{

src/cascadia/QueryExtension/ExtensionPalette.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
160160
if (_lmProvider)
161161
{
162162
const auto asyncOperation = _lmProvider.GetResponseAsync(promptCopy);
163-
if (asyncOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed)
163+
if (asyncOperation.wait_for(std::chrono::seconds(15)) == AsyncStatus::Completed)
164164
{
165165
result = asyncOperation.GetResults();
166166
}

src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,9 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
247247

248248
// Make sure we are on the background thread for the http request
249249
auto strongThis = get_strong();
250+
250251
co_await winrt::resume_background();
252+
auto cancellationToken{ co_await winrt::get_cancellation_token() };
251253

252254
for (bool refreshAttempted = false;;)
253255
{
@@ -276,24 +278,37 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
276278
};
277279

278280
// Send the request
279-
const auto jsonResultOperation = _SendRequestReturningJson(_endpointUri, requestContent, WWH::HttpMethod::Post());
280-
auto cancellationToken{ co_await winrt::get_cancellation_token() };
281-
cancellationToken.callback([jsonResultOperation] {
282-
jsonResultOperation.Cancel();
281+
const auto sendRequestOperation = _SendRequestReturningJson(_endpointUri, requestContent, WWH::HttpMethod::Post());
282+
283+
// if the caller cancels this operation, make sure to cancel the http request as well
284+
cancellationToken.callback([sendRequestOperation] {
285+
sendRequestOperation.Cancel();
283286
});
284-
const auto jsonResult = co_await jsonResultOperation;
285-
if (jsonResult.HasKey(errorKey))
287+
288+
if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed)
286289
{
287-
const auto errorObject = jsonResult.GetNamedObject(errorKey);
288-
message = errorObject.GetNamedString(messageKey);
289-
errorType = ErrorTypes::FromProvider;
290+
// Parse out the suggestion from the response
291+
const auto jsonResult = sendRequestOperation.GetResults();
292+
if (jsonResult.HasKey(errorKey))
293+
{
294+
const auto errorObject = jsonResult.GetNamedObject(errorKey);
295+
message = errorObject.GetNamedString(messageKey);
296+
errorType = ErrorTypes::FromProvider;
297+
}
298+
else
299+
{
300+
const auto choices = jsonResult.GetNamedArray(L"ayy");
301+
const auto firstChoice = choices.GetAt(0).GetObject();
302+
const auto messageObject = firstChoice.GetNamedObject(messageKey);
303+
message = messageObject.GetNamedString(contentKey);
304+
}
290305
}
291306
else
292307
{
293-
const auto choices = jsonResult.GetNamedArray(choicesKey);
294-
const auto firstChoice = choices.GetAt(0).GetObject();
295-
const auto messageObject = firstChoice.GetNamedObject(messageKey);
296-
message = messageObject.GetNamedString(contentKey);
308+
// if the http request takes too long, cancel the http request and return an error
309+
sendRequestOperation.Cancel();
310+
message = RS_(L"UnknownErrorMessage");
311+
errorType = ErrorTypes::Unknown;
297312
}
298313
break;
299314
}
@@ -310,8 +325,23 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
310325
break;
311326
}
312327

313-
co_await _refreshAuthTokens();
314-
refreshAttempted = true;
328+
const auto refreshTokensAction = _refreshAuthTokens();
329+
cancellationToken.callback([refreshTokensAction] {
330+
refreshTokensAction.Cancel();
331+
});
332+
// allow up to 10 seconds for reauthentication
333+
if (refreshTokensAction.wait_for(std::chrono::seconds(10)) == AsyncStatus::Completed)
334+
{
335+
refreshAttempted = true;
336+
}
337+
else
338+
{
339+
// if the refresh action takes too long, cancel it and return an error
340+
refreshTokensAction.Cancel();
341+
message = RS_(L"UnknownErrorMessage");
342+
errorType = ErrorTypes::Unknown;
343+
break;
344+
}
315345
}
316346

317347
// Also make a new entry in our jsonMessages list, so the AI knows the full conversation so far
@@ -339,7 +369,12 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
339369

340370
try
341371
{
342-
const auto jsonResult = co_await _SendRequestReturningJson(accessTokenEndpoint, requestContent, WWH::HttpMethod::Post());
372+
const auto reAuthOperation = _SendRequestReturningJson(accessTokenEndpoint, requestContent, WWH::HttpMethod::Post());
373+
auto cancellationToken{ co_await winrt::get_cancellation_token() };
374+
cancellationToken.callback([reAuthOperation] {
375+
reAuthOperation.Cancel();
376+
});
377+
const auto jsonResult{ co_await reAuthOperation };
343378

344379
_authToken = jsonResult.GetNamedString(accessTokenKey);
345380
_refreshToken = jsonResult.GetNamedString(refreshTokenKey);
@@ -371,7 +406,6 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
371406
sendRequestOperation.Cancel();
372407
});
373408
const auto response{ co_await sendRequestOperation };
374-
_lastRequest = sendRequestOperation;
375409
const auto string{ co_await response.Content().ReadAsStringAsync() };
376410
_lastResponse = string;
377411
const auto jsonResult{ WDJ::JsonObject::Parse(string) };

src/cascadia/QueryExtension/GithubCopilotLLMProvider.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
5151
winrt::Windows::Web::Http::HttpClient _httpClient{ nullptr };
5252
IBrandingData _brandingData{ winrt::make<GithubCopilotBranding>() };
5353
winrt::hstring _lastResponse;
54-
winrt::Windows::Foundation::IAsyncOperationWithProgress<winrt::Windows::Web::Http::HttpResponseMessage, winrt::Windows::Web::Http::HttpProgress> _lastRequest{ nullptr };
5554

5655
Extension::IContext _context;
5756

src/cascadia/QueryExtension/OpenAILLMProvider.cpp

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,26 +101,39 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
101101
try
102102
{
103103
const auto sendRequestOperation = _httpClient.SendRequestAsync(request);
104+
105+
// if the caller cancels this operation, make sure to cancel the http request as well
104106
auto cancellationToken{ co_await winrt::get_cancellation_token() };
105107
cancellationToken.callback([sendRequestOperation] {
106108
sendRequestOperation.Cancel();
107109
});
108-
const auto response{ co_await sendRequestOperation };
109-
// Parse out the suggestion from the response
110-
const auto string{ co_await response.Content().ReadAsStringAsync() };
111-
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
112-
if (jsonResult.HasKey(L"error"))
110+
111+
if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed)
113112
{
114-
const auto errorObject = jsonResult.GetNamedObject(L"error");
115-
message = errorObject.GetNamedString(L"message");
116-
errorType = ErrorTypes::FromProvider;
113+
// Parse out the suggestion from the response
114+
const auto response = sendRequestOperation.GetResults();
115+
const auto string{ co_await response.Content().ReadAsStringAsync() };
116+
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
117+
if (jsonResult.HasKey(L"error"))
118+
{
119+
const auto errorObject = jsonResult.GetNamedObject(L"error");
120+
message = errorObject.GetNamedString(L"message");
121+
errorType = ErrorTypes::FromProvider;
122+
}
123+
else
124+
{
125+
const auto choices = jsonResult.GetNamedArray(L"choices");
126+
const auto firstChoice = choices.GetAt(0).GetObject();
127+
const auto messageObject = firstChoice.GetNamedObject(L"message");
128+
message = messageObject.GetNamedString(L"content");
129+
}
117130
}
118131
else
119132
{
120-
const auto choices = jsonResult.GetNamedArray(L"choices");
121-
const auto firstChoice = choices.GetAt(0).GetObject();
122-
const auto messageObject = firstChoice.GetNamedObject(L"message");
123-
message = messageObject.GetNamedString(L"content");
133+
// if the http request takes too long, cancel the http request and return an error
134+
sendRequestOperation.Cancel();
135+
message = RS_(L"UnknownErrorMessage");
136+
errorType = ErrorTypes::Unknown;
124137
}
125138
}
126139
catch (...)

0 commit comments

Comments
 (0)