Skip to content

Commit 4dc12f4

Browse files
amitksingh1490ForgeCodeautofix-ci[bot]
authored
fix: auto-refresh OAuth token for models API calls (#2018)
Co-authored-by: ForgeCode <[email protected]> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent a38ed15 commit 4dc12f4

File tree

6 files changed

+108
-84
lines changed

6 files changed

+108
-84
lines changed

crates/forge_api/src/forge_api.rs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::path::{Path, PathBuf};
22
use std::sync::Arc;
33
use std::time::Duration;
44

5-
use anyhow::{Context, Result};
5+
use anyhow::Result;
66
use forge_app::dto::ToolsOverview;
77
use forge_app::{
88
AgentProviderResolver, AgentRegistry, AppConfigService, AuthService, CommandInfra,
@@ -67,14 +67,7 @@ impl<A: Services, F: CommandInfra + EnvironmentInfra + SkillRepository + AppConf
6767
}
6868

6969
async fn get_models(&self) -> Result<Vec<Model>> {
70-
Ok(self
71-
.services
72-
.models(
73-
self.get_default_provider()
74-
.await
75-
.context("Failed to fetch models")?,
76-
)
77-
.await?)
70+
self.app().get_models().await
7871
}
7972
async fn get_agents(&self) -> Result<Vec<Agent>> {
8073
self.services.get_agents().await

crates/forge_app/src/agent_provider_resolver.rs

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -39,36 +39,6 @@ where
3939
self.0.get_default_provider().await?
4040
};
4141

42-
// Check if credential needs refresh (5 minute buffer before expiry)
43-
if let Some(credential) = &provider.credential {
44-
let buffer = chrono::Duration::minutes(5);
45-
46-
if credential.needs_refresh(buffer) {
47-
for auth_method in &provider.auth_methods {
48-
match auth_method {
49-
forge_domain::AuthMethod::OAuthDevice(_)
50-
| forge_domain::AuthMethod::OAuthCode(_) => {
51-
match self
52-
.0
53-
.refresh_provider_credential(&provider, auth_method.clone())
54-
.await
55-
{
56-
Ok(refreshed_credential) => {
57-
let mut updated_provider = provider.clone();
58-
updated_provider.credential = Some(refreshed_credential);
59-
return Ok(updated_provider);
60-
}
61-
Err(_) => {
62-
return Ok(provider);
63-
}
64-
}
65-
}
66-
forge_domain::AuthMethod::ApiKey => {}
67-
}
68-
}
69-
}
70-
}
71-
7242
Ok(provider)
7343
}
7444

crates/forge_app/src/app.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ use crate::changed_files::ChangedFiles;
1212
use crate::dto::ToolsOverview;
1313
use crate::init_conversation_metrics::InitConversationMetrics;
1414
use crate::orch::Orchestrator;
15-
use crate::services::{AgentRegistry, CustomInstructionsService, TemplateService};
15+
use crate::services::{
16+
AgentRegistry, CustomInstructionsService, ProviderAuthService, TemplateService,
17+
};
1618
use crate::set_conversation_id::SetConversationId;
1719
use crate::system_prompt::SystemPrompt;
1820
use crate::tool_registry::ToolRegistry;
@@ -104,6 +106,12 @@ impl<S: Services> ForgeApp<S> {
104106
let agent_provider = agent_provider_resolver
105107
.get_provider(Some(agent.id.clone()))
106108
.await?;
109+
let agent_provider = self
110+
.services
111+
.provider_auth_service()
112+
.refresh_provider_credential(agent_provider)
113+
.await?;
114+
107115
let models = services.models(agent_provider).await?;
108116

109117
// Get system and mcp tool definitions and resolve them for the agent
@@ -265,6 +273,20 @@ impl<S: Services> ForgeApp<S> {
265273
pub async fn list_tools(&self) -> Result<ToolsOverview> {
266274
self.tool_registry.tools_overview().await
267275
}
276+
277+
/// Gets available models for the default provider with automatic credential
278+
/// refresh.
279+
pub async fn get_models(&self) -> Result<Vec<Model>> {
280+
let agent_provider_resolver = AgentProviderResolver::new(self.services.clone());
281+
let provider = agent_provider_resolver.get_provider(None).await?;
282+
let provider = self
283+
.services
284+
.provider_auth_service()
285+
.refresh_provider_credential(provider)
286+
.await?;
287+
288+
self.services.models(provider).await
289+
}
268290
pub async fn login(&self, init_auth: &InitAuth) -> Result<()> {
269291
self.authenticator.login(init_auth).await
270292
}

crates/forge_app/src/git_app.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ where
250250
agent_provider_resolver.get_provider(agent_id.clone()),
251251
agent_provider_resolver.get_model(agent_id)
252252
)?;
253-
253+
let provider = self.services.refresh_provider_credential(provider).await?;
254254
// Build git diff content with optional truncation notice
255255
// Build user message using Element
256256
let mut user_message = Element::new("user_message")

crates/forge_app/src/services.rs

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@ use std::time::Duration;
44
use bytes::Bytes;
55
use derive_setters::Setters;
66
use forge_domain::{
7-
AgentId, AnyProvider, Attachment, AuthContextRequest, AuthContextResponse, AuthCredential,
8-
AuthMethod, ChatCompletionMessage, CommandOutput, Context, Conversation, ConversationId,
9-
Environment, File, Image, InitAuth, LoginInfo, McpConfig, McpServers, Model, ModelId,
10-
PatchOperation, Provider, ProviderId, ResultStream, Scope, Template, ToolCallFull, ToolOutput,
11-
Workflow,
7+
AgentId, AnyProvider, Attachment, AuthContextRequest, AuthContextResponse, AuthMethod,
8+
ChatCompletionMessage, CommandOutput, Context, Conversation, ConversationId, Environment, File,
9+
Image, InitAuth, LoginInfo, McpConfig, McpServers, Model, ModelId, PatchOperation, Provider,
10+
ProviderId, ResultStream, Scope, Template, ToolCallFull, ToolOutput, Workflow,
1211
};
1312
use merge::Merge;
1413
use reqwest::Response;
@@ -474,11 +473,16 @@ pub trait ProviderAuthService: Send + Sync {
474473
context: AuthContextResponse,
475474
timeout: Duration,
476475
) -> anyhow::Result<()>;
476+
477+
/// Refreshes provider credentials if they're about to expire.
478+
/// Checks if credential needs refresh (5 minute buffer before expiry),
479+
/// iterates through provider's auth methods, and attempts to refresh.
480+
/// Returns the provider with updated credentials, or original if refresh
481+
/// fails or isn't needed.
477482
async fn refresh_provider_credential(
478483
&self,
479-
provider: &Provider<Url>,
480-
method: AuthMethod,
481-
) -> anyhow::Result<AuthCredential>;
484+
provider: Provider<Url>,
485+
) -> anyhow::Result<Provider<Url>>;
482486
}
483487

484488
/// Core app trait providing access to services and repositories.
@@ -986,11 +990,10 @@ impl<I: Services> ProviderAuthService for I {
986990
}
987991
async fn refresh_provider_credential(
988992
&self,
989-
provider: &Provider<Url>,
990-
method: AuthMethod,
991-
) -> anyhow::Result<AuthCredential> {
993+
provider: Provider<Url>,
994+
) -> anyhow::Result<Provider<Url>> {
992995
self.provider_auth_service()
993-
.refresh_provider_credential(provider, method)
996+
.refresh_provider_credential(provider)
994997
.await
995998
}
996999
}

crates/forge_services/src/provider_auth.rs

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ use std::time::Duration;
33

44
use forge_app::{AuthStrategy, ProviderAuthService, StrategyFactory};
55
use forge_domain::{
6-
AuthContextRequest, AuthContextResponse, AuthCredential, AuthMethod, Provider, ProviderId,
7-
ProviderRepository,
6+
AuthContextRequest, AuthContextResponse, AuthMethod, Provider, ProviderId, ProviderRepository,
87
};
98

109
/// Forge Provider Authentication Service
@@ -103,37 +102,74 @@ where
103102
self.infra.upsert_credential(credential).await
104103
}
105104

106-
/// Refresh provider credential
105+
/// Refreshes provider credentials if they're about to expire.
106+
/// Checks if credential needs refresh (5 minute buffer before expiry),
107+
/// iterates through provider's auth methods, and attempts to refresh.
108+
/// Returns the provider with updated credentials, or original if refresh
109+
/// fails or isn't needed.
107110
async fn refresh_provider_credential(
108111
&self,
109-
provider: &Provider<url::Url>,
110-
auth_method: AuthMethod,
111-
) -> anyhow::Result<AuthCredential> {
112-
// Get existing credential
113-
let credential = self
114-
.infra
115-
.get_credential(&provider.id)
116-
.await?
117-
.ok_or_else(|| forge_domain::Error::ProviderNotAvailable {
118-
provider: provider.id.clone(),
119-
})?;
120-
121-
// Get required params (only used for API key, but needed for factory)
122-
let required_params = if matches!(auth_method, AuthMethod::ApiKey) {
123-
provider.url_params.clone()
124-
} else {
125-
vec![]
126-
};
127-
128-
// Create strategy and refresh credential
129-
let strategy =
130-
self.infra
131-
.create_auth_strategy(provider.id.clone(), auth_method, required_params)?;
132-
let refreshed = strategy.refresh(&credential).await?;
133-
134-
// Store refreshed credential
135-
self.infra.upsert_credential(refreshed.clone()).await?;
112+
mut provider: Provider<url::Url>,
113+
) -> anyhow::Result<Provider<url::Url>> {
114+
// Check if credential needs refresh (5 minute buffer before expiry)
115+
if let Some(credential) = &provider.credential {
116+
let buffer = chrono::Duration::minutes(5);
117+
118+
if credential.needs_refresh(buffer) {
119+
// Iterate through auth methods and try to refresh
120+
for auth_method in &provider.auth_methods {
121+
match auth_method {
122+
AuthMethod::OAuthDevice(_) | AuthMethod::OAuthCode(_) => {
123+
// Get existing credential
124+
let existing_credential =
125+
self.infra.get_credential(&provider.id).await?.ok_or_else(
126+
|| forge_domain::Error::ProviderNotAvailable {
127+
provider: provider.id.clone(),
128+
},
129+
)?;
130+
131+
// Get required params (only used for API key, but needed for factory)
132+
let required_params = if matches!(auth_method, AuthMethod::ApiKey) {
133+
provider.url_params.clone()
134+
} else {
135+
vec![]
136+
};
137+
138+
// Create strategy and refresh credential
139+
if let Ok(strategy) = self.infra.create_auth_strategy(
140+
provider.id.clone(),
141+
auth_method.clone(),
142+
required_params,
143+
) {
144+
match strategy.refresh(&existing_credential).await {
145+
Ok(refreshed) => {
146+
// Store refreshed credential
147+
if self
148+
.infra
149+
.upsert_credential(refreshed.clone())
150+
.await
151+
.is_err()
152+
{
153+
continue;
154+
}
155+
156+
// Update provider with refreshed credential
157+
provider.credential = Some(refreshed);
158+
break; // Success, stop trying other methods
159+
}
160+
Err(_) => {
161+
// If refresh fails, continue with
162+
// existing credentials
163+
}
164+
}
165+
}
166+
}
167+
_ => {}
168+
}
169+
}
170+
}
171+
}
136172

137-
Ok(refreshed)
173+
Ok(provider)
138174
}
139175
}

0 commit comments

Comments
 (0)