Skip to content

Commit 30d0476

Browse files
committed
Take draft status into account when importing
1 parent e5a89c0 commit 30d0476

File tree

3 files changed

+149
-66
lines changed

3 files changed

+149
-66
lines changed

deepset_cloud_sdk/_service/pipeline_service.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -374,32 +374,59 @@ async def _create_index(self, name: str, pipeline_yaml: str) -> Response:
374374
)
375375

376376
async def _overwrite_pipeline(self, name: str, pipeline_yaml: str) -> Response:
377-
"""Overwrite a pipeline in deepset AI Platform by creating a new version.
378-
If creating a new version fails (e.g. pipeline doesn't exist), create the
379-
pipeline instead.
377+
"""Overwrite a pipeline in deepset AI Platform.
378+
379+
Behavior:
380+
- First try to fetch the latest version.
381+
- If the pipeline doesn't exist (404), create it instead.
382+
- If the latest version is a draft (is_draft == True), PATCH that version.
383+
- Otherwise, create a new version via POST /pipelines/{name}/versions.
380384
381385
:param name: Name of the pipeline.
382386
:param pipeline_yaml: Generated pipeline YAML string.
383387
"""
384-
# First try to create a new version of the existing pipeline
385-
version_response = await self._api.post(
388+
# First get the (last) version id if available
389+
version_response = await self._api.get(
386390
workspace_name=self._workspace_name,
387391
endpoint=f"pipelines/{name}/versions",
388-
json={"config_yaml": pipeline_yaml},
389392
)
390393

391-
if version_response.status_code == HTTPStatus.CREATED:
392-
logger.debug("Created new version for pipeline %s.", name)
393-
return version_response
394-
# If creating a version fails, assume the pipeline doesn't exist and create it
395-
logger.debug(
396-
"Failed to create new version for pipeline %s (status %s). "
397-
"Assuming pipeline does not exist and creating it instead.",
398-
name,
399-
version_response.status_code,
400-
)
394+
# If pipeline doesn't exist (404), create it instead
395+
if version_response.status_code == HTTPStatus.NOT_FOUND:
396+
logger.debug("Pipeline %s not found, creating new pipeline.", name)
397+
response = await self._create_pipeline(name=name, pipeline_yaml=pipeline_yaml)
398+
return response
399+
400+
version_body = version_response.json()
401+
latest_version = version_body["data"][0]
402+
version_id = latest_version["version_id"]
403+
is_draft = latest_version.get("is_draft", False)
404+
405+
if is_draft:
406+
# If the latest version is a draft, patch that version
407+
logger.debug(
408+
"Latest version %s of pipeline %s is a draft, patching existing version.",
409+
version_id,
410+
name,
411+
)
412+
response = await self._api.patch(
413+
workspace_name=self._workspace_name,
414+
endpoint=f"pipelines/{name}/versions/{version_id}",
415+
json={"config_yaml": pipeline_yaml},
416+
)
417+
else:
418+
# Otherwise, create a new version
419+
logger.debug(
420+
"Latest version %s of pipeline %s is not a draft, creating new version.",
421+
version_id,
422+
name,
423+
)
424+
response = await self._api.post(
425+
workspace_name=self._workspace_name,
426+
endpoint=f"pipelines/{name}/versions",
427+
json={"config_yaml": pipeline_yaml, "is_draft": True},
428+
)
401429

402-
response = await self._create_pipeline(name=name, pipeline_yaml=pipeline_yaml)
403430
return response
404431

405432
async def _create_pipeline(self, name: str, pipeline_yaml: str) -> Response:

tests/integration/workflows/test_integration_pipeline_client.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -484,17 +484,23 @@ async def test_import_pipeline_with_overwrite_fallback_to_create_async(
484484
) -> None:
485485
"""Test overwriting a pipeline that doesn't exist, falling back to creation."""
486486
# Mock validation success
487-
validation_route = respx.post("https://test-api-url.com/workspaces/test-workspace/pipeline_validations").mock(
487+
validation_route = respx.post(
488+
"https://test-api-url.com/workspaces/test-workspace/pipeline_validations"
489+
).mock(
488490
return_value=Response(status_code=HTTPStatus.NO_CONTENT)
489491
)
490492

491-
# Mock failed version creation (non-201 for POST /pipelines/{name}/versions)
492-
version_create_route = respx.post(
493+
# Mock 404 response for GET (resource not found)
494+
version_check_route = respx.get(
493495
"https://test-api-url.com/workspaces/test-workspace/pipelines/test-pipeline-fallback/versions"
494-
).mock(return_value=Response(status_code=HTTPStatus.BAD_REQUEST))
496+
).mock(
497+
return_value=Response(status_code=HTTPStatus.NOT_FOUND)
498+
)
495499

496500
# Mock successful creation
497-
create_route = respx.post("https://test-api-url.com/workspaces/test-workspace/pipelines").mock(
501+
create_route = respx.post(
502+
"https://test-api-url.com/workspaces/test-workspace/pipelines"
503+
).mock(
498504
return_value=Response(status_code=HTTPStatus.OK, json={"id": "test-pipeline-id"})
499505
)
500506

@@ -508,24 +514,20 @@ async def test_import_pipeline_with_overwrite_fallback_to_create_async(
508514

509515
await test_async_client.import_into_deepset(sample_pipeline, pipeline_config)
510516

511-
# Verify all three endpoints were called
517+
# Verify all three endpoints were called in sequence
512518
assert validation_route.called
513-
assert version_create_route.called
519+
assert version_check_route.called
514520
assert create_route.called
515521

516522
# Check validation request
517523
validation_request = validation_route.calls[0].request
518524
assert validation_request.headers["Authorization"] == "Bearer test-api-key"
519525
validation_body = json.loads(validation_request.content)
520526
assert "query_yaml" in validation_body
521-
# When overwrite=True, name should be excluded from validation payload (if your code does that)
522-
# assert "name" not in validation_body
523-
524-
# Check attempted version creation request
525-
version_create_request = version_create_route.calls[0].request
526-
assert version_create_request.headers["Authorization"] == "Bearer test-api-key"
527-
version_body = json.loads(version_create_request.content)
528-
assert "config_yaml" in version_body
527+
528+
# Check GET attempt
529+
version_check_request = version_check_route.calls[0].request
530+
assert version_check_request.headers["Authorization"] == "Bearer test-api-key"
529531

530532
# Check fallback creation
531533
create_request = create_route.calls[0].request

tests/unit/service/test_pipeline_service.py

Lines changed: 88 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ async def test_import_index_with_overwrite_fallback_to_create(
587587
async def test_import_pipeline_with_overwrite_true(
588588
self, pipeline_service: PipelineService, index_pipeline: Pipeline, mock_api: AsyncMock
589589
) -> None:
590-
"""Test importing a pipeline with overwrite=True creates a new version via POST endpoint."""
590+
"""Test importing a pipeline with overwrite=True patches latest draft version."""
591591
config = PipelineConfig(
592592
name="test_pipeline_overwrite",
593593
inputs=PipelineInputs(query=["retriever.query"]),
@@ -600,39 +600,99 @@ async def test_import_pipeline_with_overwrite_true(
600600
validation_response = Mock(spec=Response)
601601
validation_response.status_code = HTTPStatus.NO_CONTENT.value
602602

603-
# Mock successful "create new version" response
603+
# Mock successful versions response, latest version is a draft
604+
versions_response = Mock(status_code=HTTPStatus.OK.value)
605+
versions_response.json.return_value = {
606+
"data": [{"version_id": "42abcd", "is_draft": True}],
607+
}
608+
609+
# Mock successful overwrite (PATCH) response
604610
overwrite_response = Mock(spec=Response)
605-
overwrite_response.status_code = HTTPStatus.CREATED.value
611+
overwrite_response.status_code = HTTPStatus.OK.value
612+
613+
mock_api.post.return_value = validation_response
614+
mock_api.get.return_value = versions_response
615+
mock_api.patch.return_value = overwrite_response
616+
617+
await pipeline_service.import_async(index_pipeline, config)
618+
619+
# validation + GET versions + PATCH draft version
620+
assert mock_api.post.call_count == 1
621+
assert mock_api.get.call_count == 1
622+
assert mock_api.patch.call_count == 1
623+
624+
# Check validation call
625+
validation_call = mock_api.post.call_args_list[0]
626+
assert validation_call.kwargs["endpoint"] == "pipeline_validations"
627+
assert "query_yaml" in validation_call.kwargs["json"]
628+
# When overwrite=True, name should be excluded from validation payload
629+
assert "name" not in validation_call.kwargs["json"]
630+
631+
# Check PATCH call
632+
overwrite_call = mock_api.patch.call_args_list[0]
633+
assert (
634+
overwrite_call.kwargs["endpoint"]
635+
== "pipelines/test_pipeline_overwrite/versions/42abcd"
636+
)
637+
assert "config_yaml" in overwrite_call.kwargs["json"]
638+
639+
@pytest.mark.asyncio
640+
async def test_import_pipeline_with_overwrite_true_creates_new_version_when_not_draft(
641+
self, pipeline_service: PipelineService, index_pipeline: Pipeline, mock_api: AsyncMock
642+
) -> None:
643+
"""Test importing a pipeline with overwrite=True creates a new version when latest version is not draft."""
644+
config = PipelineConfig(
645+
name="test_pipeline_overwrite",
646+
inputs=PipelineInputs(query=["retriever.query"]),
647+
outputs=PipelineOutputs(documents="meta_ranker.documents"),
648+
strict_validation=False,
649+
overwrite=True,
650+
)
651+
652+
# Mock successful validation response
653+
validation_response = Mock(spec=Response)
654+
validation_response.status_code = HTTPStatus.NO_CONTENT.value
655+
656+
# Mock versions response, latest version is NOT a draft
657+
versions_response = Mock(status_code=HTTPStatus.OK.value)
658+
versions_response.json.return_value = {
659+
"data": [{"version_id": "42abcd", "is_draft": False}],
660+
}
661+
662+
# Mock successful "create new version" response
663+
new_version_response = Mock(spec=Response)
664+
new_version_response.status_code = HTTPStatus.CREATED.value
606665

607666
# First POST is validation, second POST is "create new version"
608-
mock_api.post.side_effect = [validation_response, overwrite_response]
667+
mock_api.post.side_effect = [validation_response, new_version_response]
668+
mock_api.get.return_value = versions_response
609669

610670
await pipeline_service.import_async(index_pipeline, config)
611671

612-
# Should call validation endpoint first, then create-version endpoint
672+
# validation + GET versions + POST versions (new version)
613673
assert mock_api.post.call_count == 2
614-
# No GET/PATCH/PUT calls in the overwrite path anymore
615-
assert mock_api.get.call_count == 0
674+
assert mock_api.get.call_count == 1
616675
assert mock_api.patch.call_count == 0
617-
assert mock_api.put.call_count == 0
618676

619677
# Check validation call
620678
validation_call = mock_api.post.call_args_list[0]
621679
assert validation_call.kwargs["endpoint"] == "pipeline_validations"
622680
assert "query_yaml" in validation_call.kwargs["json"]
623-
# When overwrite=True, name should be excluded from validation payload
624681
assert "name" not in validation_call.kwargs["json"]
625682

626-
# Check create-version call
627-
overwrite_call = mock_api.post.call_args_list[1]
628-
assert overwrite_call.kwargs["endpoint"] == "pipelines/test_pipeline_overwrite/versions"
629-
assert "config_yaml" in overwrite_call.kwargs["json"]
683+
# Check create-version POST call
684+
create_version_call = mock_api.post.call_args_list[1]
685+
assert (
686+
create_version_call.kwargs["endpoint"]
687+
== "pipelines/test_pipeline_overwrite/versions"
688+
)
689+
assert "config_yaml" in create_version_call.kwargs["json"]
630690

631691
@pytest.mark.asyncio
632692
async def test_import_pipeline_with_overwrite_fallback_to_create(
633693
self, pipeline_service: PipelineService, index_pipeline: Pipeline, mock_api: AsyncMock
634694
) -> None:
635-
"""Test importing a pipeline with overwrite=True that falls back to create when version creation fails."""
695+
"""Test importing a pipeline with overwrite=True that falls back to create when resource doesn't exist."""
636696

637697
config = PipelineConfig(
638698
name="test_pipeline_fallback",
@@ -646,41 +706,35 @@ async def test_import_pipeline_with_overwrite_fallback_to_create(
646706
validation_response = Mock(spec=Response)
647707
validation_response.status_code = HTTPStatus.NO_CONTENT.value
648708

649-
# Mock non-201 response for POST /pipelines/{name}/versions (version creation fails)
650-
version_fail_response = Mock(spec=Response)
651-
version_fail_response.status_code = HTTPStatus.BAD_REQUEST.value
709+
# Mock 404 response for GET (resource not found)
710+
not_found_response = Mock(spec=Response)
711+
not_found_response.status_code = HTTPStatus.NOT_FOUND.value
652712

653-
# Mock successful creation response for POST /pipelines
713+
# Mock successful creation response
654714
create_response = Mock(spec=Response)
655715
create_response.status_code = HTTPStatus.CREATED.value
656716

657-
# POST calls: validation, create-version (fails), create-pipeline (fallback)
658-
mock_api.post.side_effect = [validation_response, version_fail_response, create_response]
717+
mock_api.post.side_effect = [validation_response, create_response]
718+
mock_api.get.return_value = not_found_response
659719

660720
await pipeline_service.import_async(index_pipeline, config)
661721

662-
# Should call validation endpoint, then POST to create new version (fails),
663-
# then POST to create the pipeline
664-
assert mock_api.post.call_count == 3
665-
# No GET anymore; overwrite logic doesn't fetch versions
666-
assert mock_api.get.call_count == 0
667-
assert mock_api.patch.call_count == 0
668-
assert mock_api.put.call_count == 0
722+
# validation + GET (404) + POST create
723+
assert mock_api.post.call_count == 2
724+
assert mock_api.get.call_count == 1
669725

670726
# Check validation call
671727
validation_call = mock_api.post.call_args_list[0]
672728
assert validation_call.kwargs["endpoint"] == "pipeline_validations"
673729
assert "query_yaml" in validation_call.kwargs["json"]
674-
# When overwrite=True, name should be excluded from validation payload
675730
assert "name" not in validation_call.kwargs["json"]
676731

677-
# Check attempted version creation call
678-
version_call = mock_api.post.call_args_list[1]
679-
assert version_call.kwargs["endpoint"] == "pipelines/test_pipeline_fallback/versions"
680-
assert "config_yaml" in version_call.kwargs["json"]
732+
# Check GET versions attempt
733+
get_call = mock_api.get.call_args_list[0]
734+
assert get_call.kwargs["endpoint"] == "pipelines/test_pipeline_fallback/versions"
681735

682-
# Check fallback create-pipeline call
683-
create_call = mock_api.post.call_args_list[2]
736+
# Check fallback POST call
737+
create_call = mock_api.post.call_args_list[1]
684738
assert create_call.kwargs["endpoint"] == "pipelines"
685739
assert create_call.kwargs["json"]["name"] == "test_pipeline_fallback"
686740
assert "query_yaml" in create_call.kwargs["json"]

0 commit comments

Comments
 (0)