Skip to content

Commit 8e2fb2d

Browse files
authored
feat(backend): Speed up graph create/update (#10025)
- Resolves #10024 Caching the repeated DB calls by the graph lifecycle hooks significantly speeds up graph update/create calls with many authenticated blocks (~300ms saved per authenticated block) ### Changes 🏗️ - Add and use `IntegrationCredentialsManager.cached_getter(user_id)` in lifecycle hooks - Split `refresh_if_needed(..)` method out of `IntegrationCredentialsManager.get(..)` - Simplify interface of lifecycle hooks: change `get_credentials` parameter to `user_id` ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Save a graph with nodes with credentials
1 parent 767d2f2 commit 8e2fb2d

File tree

4 files changed

+62
-76
lines changed

4 files changed

+62
-76
lines changed

autogpt_platform/backend/backend/integrations/creds_manager.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import logging
22
from contextlib import contextmanager
33
from datetime import datetime
4-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Callable
55

66
from autogpt_libs.utils.synchronize import RedisKeyedMutex
77
from redis.lock import Lock as RedisLock
88

99
from backend.data import redis
10-
from backend.data.model import Credentials
10+
from backend.data.model import Credentials, OAuth2Credentials
1111
from backend.integrations.credentials_store import IntegrationCredentialsStore
1212
from backend.integrations.oauth import HANDLERS_BY_NAME
1313
from backend.integrations.providers import ProviderName
@@ -78,25 +78,7 @@ def get(
7878
f"{datetime.fromtimestamp(credentials.access_token_expires_at)}; "
7979
f"current time is {datetime.now()}"
8080
)
81-
82-
with self._locked(user_id, credentials_id, "refresh"):
83-
oauth_handler = _get_provider_oauth_handler(credentials.provider)
84-
if oauth_handler.needs_refresh(credentials):
85-
logger.debug(
86-
f"Refreshing '{credentials.provider}' "
87-
f"credentials #{credentials.id}"
88-
)
89-
_lock = None
90-
if lock:
91-
# Wait until the credentials are no longer in use anywhere
92-
_lock = self._acquire_lock(user_id, credentials_id)
93-
94-
fresh_credentials = oauth_handler.refresh_tokens(credentials)
95-
self.store.update_creds(user_id, fresh_credentials)
96-
if _lock and _lock.locked() and _lock.owned():
97-
_lock.release()
98-
99-
credentials = fresh_credentials
81+
credentials = self.refresh_if_needed(user_id, credentials, lock)
10082
else:
10183
logger.debug(f"Credentials #{credentials.id} never expire")
10284

@@ -121,6 +103,50 @@ def acquire(
121103
)
122104
return credentials, lock
123105

106+
def cached_getter(self, user_id: str) -> Callable[[str], "Credentials | None"]:
107+
all_credentials = None
108+
109+
def get_credentials(creds_id: str) -> "Credentials | None":
110+
nonlocal all_credentials
111+
if not all_credentials:
112+
# Fetch credentials on first necessity
113+
all_credentials = self.store.get_all_creds(user_id)
114+
115+
credential = next((c for c in all_credentials if c.id == creds_id), None)
116+
if not credential:
117+
return None
118+
if credential.type != "oauth2" or not credential.access_token_expires_at:
119+
# Credential doesn't expire
120+
return credential
121+
122+
# Credential is OAuth2 credential and has expiration timestamp
123+
return self.refresh_if_needed(user_id, credential)
124+
125+
return get_credentials
126+
127+
def refresh_if_needed(
128+
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
129+
) -> OAuth2Credentials:
130+
with self._locked(user_id, credentials.id, "refresh"):
131+
oauth_handler = _get_provider_oauth_handler(credentials.provider)
132+
if oauth_handler.needs_refresh(credentials):
133+
logger.debug(
134+
f"Refreshing '{credentials.provider}' "
135+
f"credentials #{credentials.id}"
136+
)
137+
_lock = None
138+
if lock:
139+
# Wait until the credentials are no longer in use anywhere
140+
_lock = self._acquire_lock(user_id, credentials.id)
141+
142+
fresh_credentials = oauth_handler.refresh_tokens(credentials)
143+
self.store.update_creds(user_id, fresh_credentials)
144+
if _lock and _lock.locked() and _lock.owned():
145+
_lock.release()
146+
147+
credentials = fresh_credentials
148+
return credentials
149+
124150
def update(self, user_id: str, updated: Credentials) -> None:
125151
with self._locked(user_id, updated.id):
126152
self.store.update_creds(user_id, updated)

autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import logging
2-
from typing import TYPE_CHECKING, Callable, Optional, cast
2+
from typing import TYPE_CHECKING, Optional, cast
33

44
from backend.data.block import BlockSchema, BlockWebhookConfig
55
from backend.data.graph import set_node_webhook
6+
from backend.integrations.creds_manager import IntegrationCredentialsManager
67
from backend.integrations.webhooks import get_webhook_manager, supports_webhooks
78

89
if TYPE_CHECKING:
@@ -12,21 +13,17 @@
1213
from ._base import BaseWebhooksManager
1314

1415
logger = logging.getLogger(__name__)
16+
credentials_manager = IntegrationCredentialsManager()
1517

1618

17-
async def on_graph_activate(
18-
graph: "GraphModel", get_credentials: Callable[[str], "Credentials | None"]
19-
):
19+
async def on_graph_activate(graph: "GraphModel", user_id: str):
2020
"""
2121
Hook to be called when a graph is activated/created.
2222
2323
⚠️ Assuming node entities are not re-used between graph versions, ⚠️
2424
this hook calls `on_node_activate` on all nodes in this graph.
25-
26-
Params:
27-
get_credentials: `credentials_id` -> Credentials
2825
"""
29-
# Compare nodes in new_graph_version with previous_graph_version
26+
get_credentials = credentials_manager.cached_getter(user_id)
3027
updated_nodes = []
3128
for new_node in graph.nodes:
3229
block_input_schema = cast(BlockSchema, new_node.block.input_schema)
@@ -56,18 +53,14 @@ async def on_graph_activate(
5653
return graph
5754

5855

59-
async def on_graph_deactivate(
60-
graph: "GraphModel", get_credentials: Callable[[str], "Credentials | None"]
61-
):
56+
async def on_graph_deactivate(graph: "GraphModel", user_id: str):
6257
"""
6358
Hook to be called when a graph is deactivated/deleted.
6459
6560
⚠️ Assuming node entities are not re-used between graph versions, ⚠️
6661
this hook calls `on_node_deactivate` on all nodes in `graph`.
67-
68-
Params:
69-
get_credentials: `credentials_id` -> Credentials
7062
"""
63+
get_credentials = credentials_manager.cached_getter(user_id)
7164
updated_nodes = []
7265
for node in graph.nodes:
7366
block_input_schema = cast(BlockSchema, node.block.input_schema)

autogpt_platform/backend/backend/server/routers/v1.py

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
from collections import defaultdict
44
from datetime import datetime
5-
from typing import TYPE_CHECKING, Annotated, Any, Sequence
5+
from typing import Annotated, Any, Sequence
66

77
import pydantic
88
import stripe
@@ -60,7 +60,6 @@
6060
from backend.executor import scheduler
6161
from backend.executor import utils as execution_utils
6262
from backend.executor.utils import create_execution_queue_config
63-
from backend.integrations.creds_manager import IntegrationCredentialsManager
6463
from backend.integrations.webhooks.graph_lifecycle_hooks import (
6564
on_graph_activate,
6665
on_graph_deactivate,
@@ -78,9 +77,6 @@
7877
from backend.util.service import get_service_client
7978
from backend.util.settings import Settings
8079

81-
if TYPE_CHECKING:
82-
from backend.data.model import Credentials
83-
8480

8581
@thread_cached
8682
def execution_scheduler_client() -> scheduler.SchedulerClient:
@@ -101,7 +97,6 @@ def execution_event_bus() -> AsyncRedisExecutionEventBus:
10197

10298
settings = Settings()
10399
logger = logging.getLogger(__name__)
104-
integration_creds_manager = IntegrationCredentialsManager()
105100

106101
_user_credit_model = get_user_credit_model()
107102

@@ -466,10 +461,7 @@ async def create_new_graph(
466461
library_db.add_generated_agent_image(graph, library_agent.id)
467462
)
468463

469-
graph = await on_graph_activate(
470-
graph,
471-
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
472-
)
464+
graph = await on_graph_activate(graph, user_id=user_id)
473465
return graph
474466

475467

@@ -480,11 +472,7 @@ async def delete_graph(
480472
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
481473
) -> DeleteGraphResponse:
482474
if active_version := await graph_db.get_graph(graph_id, user_id=user_id):
483-
484-
def get_credentials(credentials_id: str) -> "Credentials | None":
485-
return integration_creds_manager.get(user_id, credentials_id)
486-
487-
await on_graph_deactivate(active_version, get_credentials)
475+
await on_graph_deactivate(active_version, user_id=user_id)
488476

489477
return {"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)}
490478

@@ -521,24 +509,15 @@ async def update_graph(
521509
user_id, graph.id, graph.version
522510
)
523511

524-
def get_credentials(credentials_id: str) -> "Credentials | None":
525-
return integration_creds_manager.get(user_id, credentials_id)
526-
527512
# Handle activation of the new graph first to ensure continuity
528-
new_graph_version = await on_graph_activate(
529-
new_graph_version,
530-
get_credentials=get_credentials,
531-
)
513+
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
532514
# Ensure new version is the only active version
533515
await graph_db.set_graph_active_version(
534516
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
535517
)
536518
if current_active_version:
537519
# Handle deactivation of the previously active version
538-
await on_graph_deactivate(
539-
current_active_version,
540-
get_credentials=get_credentials,
541-
)
520+
await on_graph_deactivate(current_active_version, user_id=user_id)
542521

543522
return new_graph_version
544523

@@ -562,14 +541,8 @@ async def set_graph_active_version(
562541

563542
current_active_graph = await graph_db.get_graph(graph_id, user_id=user_id)
564543

565-
def get_credentials(credentials_id: str) -> "Credentials | None":
566-
return integration_creds_manager.get(user_id, credentials_id)
567-
568544
# Handle activation of the new graph first to ensure continuity
569-
await on_graph_activate(
570-
new_active_graph,
571-
get_credentials=get_credentials,
572-
)
545+
await on_graph_activate(new_active_graph, user_id=user_id)
573546
# Ensure new version is the only active version
574547
await graph_db.set_graph_active_version(
575548
graph_id=graph_id,
@@ -584,10 +557,7 @@ def get_credentials(credentials_id: str) -> "Credentials | None":
584557

585558
if current_active_graph and current_active_graph.version != new_active_version:
586559
# Handle deactivation of the previously active version
587-
await on_graph_deactivate(
588-
current_active_graph,
589-
get_credentials=get_credentials,
590-
)
560+
await on_graph_deactivate(current_active_graph, user_id=user_id)
591561

592562

593563
@v1_router.post(

autogpt_platform/backend/backend/server/v2/library/db.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -736,10 +736,7 @@ async def fork_library_agent(library_agent_id: str, user_id: str):
736736
new_graph = await graph_db.fork_graph(
737737
original_agent.graph_id, original_agent.graph_version, user_id
738738
)
739-
new_graph = await on_graph_activate(
740-
new_graph,
741-
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
742-
)
739+
new_graph = await on_graph_activate(new_graph, user_id=user_id)
743740

744741
# Create a library agent for the new graph
745742
return await create_library_agent(new_graph, user_id)

0 commit comments

Comments
 (0)