Skip to content

Commit ac34c63

Browse files
authored
Update pipeline import to create a new version instead of patching (#291)
* Update pipeline import to create a new version instead of patching * Take draft status into account when importing * Change logging and refactor * Update version_response handling / status check * Update mock in import-with-overwrite unit test to account for implementation change * Fix linting issues
1 parent 27a6034 commit ac34c63

File tree

2 files changed

+99
-29
lines changed

2 files changed

+99
-29
lines changed

deepset_cloud_sdk/_service/pipeline_service.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from io import StringIO
88
from typing import Any, List, Optional, Protocol, runtime_checkable
99

10+
import httpx
1011
import structlog
1112
from httpx import Response
1213
from pydantic import BaseModel
@@ -376,28 +377,47 @@ async def _create_index(self, name: str, pipeline_yaml: str) -> Response:
376377
async def _overwrite_pipeline(self, name: str, pipeline_yaml: str) -> Response:
377378
"""Overwrite a pipeline in deepset AI Platform.
378379
379-
:param name: Name of the pipeline.
380-
:param pipeline_yaml: Generated pipeline YAML string.
380+
Behavior:
381+
- First try to fetch the latest version.
382+
- If the pipeline doesn't exist (404), create it instead.
383+
- If the latest version is a draft (is_draft == True), PATCH that version.
384+
- Otherwise, create a new version via POST /pipelines/{name}/versions.
381385
"""
382-
# First get the (last) version id if available
383-
version_response = await self._api.get(
384-
workspace_name=self._workspace_name, endpoint=f"pipelines/{name}/versions"
385-
)
386-
387-
# If pipeline doesn't exist (404), create it instead
388-
if version_response.status_code == HTTPStatus.NOT_FOUND:
389-
logger.debug(f"Pipeline {name} not found, creating new pipeline.")
390-
response = await self._create_pipeline(name=name, pipeline_yaml=pipeline_yaml)
391-
else:
392-
version_body = version_response.json()
393-
version_id = version_body["data"][0]["version_id"]
394-
response = await self._api.patch(
386+
# Fetch versions
387+
try:
388+
version_response = await self._api.get(
389+
workspace_name=self._workspace_name,
390+
endpoint=f"pipelines/{name}/versions",
391+
)
392+
version_response.raise_for_status()
393+
except httpx.HTTPStatusError as e:
394+
if e.response.status_code != HTTPStatus.NOT_FOUND:
395+
raise
396+
# the pipeline does not exist, let's create it.
397+
logger.debug(f"Pipeline '{name}' not found, creating new pipeline.")
398+
return await self._create_pipeline(name=name, pipeline_yaml=pipeline_yaml)
399+
400+
version_body = version_response.json()
401+
latest_version = version_body["data"][0]
402+
version_id = latest_version["version_id"]
403+
is_draft = latest_version.get("is_draft", False)
404+
405+
if is_draft:
406+
# Patch existing draft version
407+
logger.debug(f"Patching existing draft version '{version_id}' of pipeline '{name}'.")
408+
return await self._api.patch(
395409
workspace_name=self._workspace_name,
396410
endpoint=f"pipelines/{name}/versions/{version_id}",
397411
json={"config_yaml": pipeline_yaml},
398412
)
399413

400-
return response
414+
# Create a new version
415+
logger.debug(f"Latest version '{version_id}' of pipeline '{name}' is not a draft, creating new version.")
416+
return await self._api.post(
417+
workspace_name=self._workspace_name,
418+
endpoint=f"pipelines/{name}/versions",
419+
json={"config_yaml": pipeline_yaml},
420+
)
401421

402422
async def _create_pipeline(self, name: str, pipeline_yaml: str) -> Response:
403423
"""Create a pipeline in deepset AI Platform.

tests/unit/service/test_pipeline_service.py

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any
77
from unittest.mock import AsyncMock, Mock
88

9+
import httpx
910
import pytest
1011
from haystack import AsyncPipeline, Pipeline
1112
from haystack.components.converters import CSVToDocument, TextFileToDocument
@@ -14,10 +15,7 @@
1415
from httpx import Response
1516
from structlog.testing import capture_logs
1617

17-
from deepset_cloud_sdk._service.pipeline_service import (
18-
DeepsetValidationError,
19-
PipelineService,
20-
)
18+
from deepset_cloud_sdk._service.pipeline_service import DeepsetValidationError, PipelineService
2119
from deepset_cloud_sdk.models import (
2220
IndexConfig,
2321
IndexInputs,
@@ -587,7 +585,7 @@ async def test_import_index_with_overwrite_fallback_to_create(
587585
async def test_import_pipeline_with_overwrite_true(
588586
self, pipeline_service: PipelineService, index_pipeline: Pipeline, mock_api: AsyncMock
589587
) -> None:
590-
"""Test importing a pipeline with overwrite=True uses PUT endpoint."""
588+
"""Test importing a pipeline with overwrite=True patches latest draft version."""
591589
config = PipelineConfig(
592590
name="test_pipeline_overwrite",
593591
inputs=PipelineInputs(query=["retriever.query"]),
@@ -600,24 +598,25 @@ async def test_import_pipeline_with_overwrite_true(
600598
validation_response = Mock(spec=Response)
601599
validation_response.status_code = HTTPStatus.NO_CONTENT.value
602600

603-
# Mock successful versions response
601+
# Mock successful versions response, latest version is a draft
604602
versions_response = Mock(status_code=HTTPStatus.OK.value)
605603
versions_response.json.return_value = {
606-
"data": [{"version_id": "42abcd"}],
604+
"data": [{"version_id": "42abcd", "is_draft": True}],
607605
}
608606

609-
# Mock successful overwrite response
607+
# Mock successful overwrite (PATCH) response
610608
overwrite_response = Mock(spec=Response)
611609
overwrite_response.status_code = HTTPStatus.OK.value
612610

613611
mock_api.post.return_value = validation_response
614612
mock_api.get.return_value = versions_response
615-
mock_api.put.return_value = overwrite_response
613+
mock_api.patch.return_value = overwrite_response
616614

617615
await pipeline_service.import_async(index_pipeline, config)
618616

619-
# Should call validation endpoint first, then overwrite endpoint
617+
# validation + GET versions + PATCH draft version
620618
assert mock_api.post.call_count == 1
619+
assert mock_api.get.call_count == 1
621620
assert mock_api.patch.call_count == 1
622621

623622
# Check validation call
@@ -627,11 +626,60 @@ async def test_import_pipeline_with_overwrite_true(
627626
# When overwrite=True, name should be excluded from validation payload
628627
assert "name" not in validation_call.kwargs["json"]
629628

630-
# Check overwrite call
629+
# Check PATCH call
631630
overwrite_call = mock_api.patch.call_args_list[0]
632631
assert overwrite_call.kwargs["endpoint"] == "pipelines/test_pipeline_overwrite/versions/42abcd"
633632
assert "config_yaml" in overwrite_call.kwargs["json"]
634633

634+
@pytest.mark.asyncio
635+
async def test_import_pipeline_with_overwrite_true_creates_new_version_when_not_draft(
636+
self, pipeline_service: PipelineService, index_pipeline: Pipeline, mock_api: AsyncMock
637+
) -> None:
638+
"""Test importing a pipeline with overwrite=True creates a new version when latest version is not draft."""
639+
config = PipelineConfig(
640+
name="test_pipeline_overwrite",
641+
inputs=PipelineInputs(query=["retriever.query"]),
642+
outputs=PipelineOutputs(documents="meta_ranker.documents"),
643+
strict_validation=False,
644+
overwrite=True,
645+
)
646+
647+
# Mock successful validation response
648+
validation_response = Mock(spec=Response)
649+
validation_response.status_code = HTTPStatus.NO_CONTENT.value
650+
651+
# Mock versions response, latest version is NOT a draft
652+
versions_response = Mock(status_code=HTTPStatus.OK.value)
653+
versions_response.json.return_value = {
654+
"data": [{"version_id": "42abcd", "is_draft": False}],
655+
}
656+
657+
# Mock successful "create new version" response
658+
new_version_response = Mock(spec=Response)
659+
new_version_response.status_code = HTTPStatus.CREATED.value
660+
661+
# First POST is validation, second POST is "create new version"
662+
mock_api.post.side_effect = [validation_response, new_version_response]
663+
mock_api.get.return_value = versions_response
664+
665+
await pipeline_service.import_async(index_pipeline, config)
666+
667+
# validation + GET versions + POST versions (new version)
668+
assert mock_api.post.call_count == 2
669+
assert mock_api.get.call_count == 1
670+
assert mock_api.patch.call_count == 0
671+
672+
# Check validation call
673+
validation_call = mock_api.post.call_args_list[0]
674+
assert validation_call.kwargs["endpoint"] == "pipeline_validations"
675+
assert "query_yaml" in validation_call.kwargs["json"]
676+
assert "name" not in validation_call.kwargs["json"]
677+
678+
# Check create-version POST call
679+
create_version_call = mock_api.post.call_args_list[1]
680+
assert create_version_call.kwargs["endpoint"] == "pipelines/test_pipeline_overwrite/versions"
681+
assert "config_yaml" in create_version_call.kwargs["json"]
682+
635683
@pytest.mark.asyncio
636684
async def test_import_pipeline_with_overwrite_fallback_to_create(
637685
self, pipeline_service: PipelineService, index_pipeline: Pipeline, mock_api: AsyncMock
@@ -653,6 +701,9 @@ async def test_import_pipeline_with_overwrite_fallback_to_create(
653701
# Mock 404 response for GET (resource not found)
654702
not_found_response = Mock(spec=Response)
655703
not_found_response.status_code = HTTPStatus.NOT_FOUND.value
704+
not_found_response.raise_for_status.side_effect = httpx.HTTPStatusError(
705+
"Not Found", request=Mock(), response=not_found_response
706+
)
656707

657708
# Mock successful creation response
658709
create_response = Mock(spec=Response)
@@ -663,15 +714,14 @@ async def test_import_pipeline_with_overwrite_fallback_to_create(
663714

664715
await pipeline_service.import_async(index_pipeline, config)
665716

666-
# Should call validation endpoint, then GET (which returns 404), then POST to create
717+
# validation + GET (404) + POST create
667718
assert mock_api.post.call_count == 2
668719
assert mock_api.get.call_count == 1
669720

670721
# Check validation call
671722
validation_call = mock_api.post.call_args_list[0]
672723
assert validation_call.kwargs["endpoint"] == "pipeline_validations"
673724
assert "query_yaml" in validation_call.kwargs["json"]
674-
# When overwrite=True, name should be excluded from validation payload
675725
assert "name" not in validation_call.kwargs["json"]
676726

677727
# Check GET versions attempt

0 commit comments

Comments
 (0)