Skip to content

Commit 0d059d3

Browse files
committed
Pending Changes for Active Stack
1 parent e4ed83f commit 0d059d3

File tree

7 files changed

+129
-71
lines changed

7 files changed

+129
-71
lines changed

src/zenml/cli/pipeline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
)
3636
from zenml.pipelines.pipeline_definition import Pipeline
3737
from zenml.utils import source_utils, uuid_utils
38+
from zenml.utils.stack_utils import temporary_active_stack
3839
from zenml.utils.yaml_utils import write_yaml
3940

4041
logger = get_logger(__name__)
@@ -190,7 +191,7 @@ def build_pipeline(
190191
"your source code root."
191192
)
192193

193-
with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id):
194+
with temporary_active_stack(stack_name_or_id=stack_name_or_id):
194195
pipeline_instance = _import_pipeline(source=source)
195196

196197
pipeline_instance = pipeline_instance.with_options(
@@ -276,7 +277,7 @@ def run_pipeline(
276277
"your source code root."
277278
)
278279

279-
with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id):
280+
with temporary_active_stack(stack_name_or_id=stack_name_or_id):
280281
pipeline_instance = _import_pipeline(source=source)
281282

282283
build: Union[str, PipelineBuildBase, None] = None
@@ -353,7 +354,7 @@ def create_run_template(
353354
"init` at your source code root."
354355
)
355356

356-
with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id):
357+
with temporary_active_stack(stack_name_or_id=stack_name_or_id):
357358
pipeline_instance = _import_pipeline(source=source)
358359

359360
pipeline_instance = pipeline_instance.with_options(

src/zenml/cli/utils.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# permissions and limitations under the License.
1414
"""Utility functions for the CLI."""
1515

16-
import contextlib
1716
import json
1817
import os
1918
import platform
@@ -27,7 +26,6 @@
2726
Any,
2827
Callable,
2928
Dict,
30-
Iterator,
3129
List,
3230
NoReturn,
3331
Optional,
@@ -83,8 +81,6 @@
8381
from zenml.utils.time_utils import expires_in
8482

8583
if TYPE_CHECKING:
86-
from uuid import UUID
87-
8884
from rich.text import Text
8985

9086
from zenml.enums import ExecutionStatus
@@ -2468,33 +2464,6 @@ def wrapper(function: F) -> F:
24682464
return inner_decorator
24692465

24702466

2471-
@contextlib.contextmanager
2472-
def temporary_active_stack(
2473-
stack_name_or_id: Union["UUID", str, None] = None,
2474-
) -> Iterator["Stack"]:
2475-
"""Contextmanager to temporarily activate a stack.
2476-
2477-
Args:
2478-
stack_name_or_id: The name or ID of the stack to activate. If not given,
2479-
this contextmanager will not do anything.
2480-
2481-
Yields:
2482-
The active stack.
2483-
"""
2484-
from zenml.client import Client
2485-
2486-
try:
2487-
if stack_name_or_id:
2488-
old_stack_id = Client().active_stack_model.id
2489-
Client().activate_stack(stack_name_or_id)
2490-
else:
2491-
old_stack_id = None
2492-
yield Client().active_stack
2493-
finally:
2494-
if old_stack_id:
2495-
Client().activate_stack(old_stack_id)
2496-
2497-
24982467
def get_package_information(
24992468
package_names: Optional[List[str]] = None,
25002469
) -> Dict[str, str]:

src/zenml/config/pipeline_run_configuration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from zenml.config.strict_base_model import StrictBaseModel
2727
from zenml.model.model import Model
2828
from zenml.models import PipelineBuildBase
29+
from zenml.stack.stack import Stack
2930
from zenml.utils import pydantic_utils
3031

3132

@@ -46,6 +47,7 @@ class PipelineRunConfiguration(
4647
steps: Dict[str, StepConfigurationUpdate] = {}
4748
settings: Dict[str, SerializeAsAny[BaseSettings]] = {}
4849
tags: Optional[List[str]] = None
50+
stack: Optional[str] = None
4951
extra: Dict[str, Any] = {}
5052
model: Optional[Model] = None
5153
parameters: Optional[Dict[str, Any]] = None

src/zenml/pipelines/pipeline_definition.py

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
source_utils,
9393
yaml_utils,
9494
)
95+
from zenml.utils.stack_utils import temporary_active_stack
9596
from zenml.utils.string_utils import format_name_template
9697

9798
if TYPE_CHECKING:
@@ -565,19 +566,26 @@ def build(
565566
if settings:
566567
compile_args["settings"] = settings
567568

568-
deployment, _, _ = self._compile(**compile_args)
569-
pipeline_id = self._register().id
570-
571-
local_repo = code_repository_utils.find_active_code_repository()
572-
code_repository = build_utils.verify_local_repository_context(
573-
deployment=deployment, local_repo_context=local_repo
569+
_from_config_file = self._parse_config_file(
570+
config_path=config_path,
571+
matcher=list(PipelineRunConfiguration.model_fields.keys()),
574572
)
573+
run_config = PipelineRunConfiguration(**_from_config_file)
575574

576-
return build_utils.create_pipeline_build(
577-
deployment=deployment,
578-
pipeline_id=pipeline_id,
579-
code_repository=code_repository,
580-
)
575+
with temporary_active_stack(stack_name_or_id=run_config.stack):
576+
deployment, _, _ = self._compile(**compile_args)
577+
pipeline_id = self._register().id
578+
579+
local_repo = code_repository_utils.find_active_code_repository()
580+
code_repository = build_utils.verify_local_repository_context(
581+
deployment=deployment, local_repo_context=local_repo
582+
)
583+
584+
return build_utils.create_pipeline_build(
585+
deployment=deployment,
586+
pipeline_id=pipeline_id,
587+
code_repository=code_repository,
588+
)
581589

582590
def _create_deployment(
583591
self,
@@ -800,35 +808,42 @@ def _run(
800808
logger.info(f"Initiating a new run for the pipeline: `{self.name}`.")
801809

802810
with track_handler(AnalyticsEvent.RUN_PIPELINE) as analytics_handler:
803-
stack = Client().active_stack
804-
deployment = self._create_deployment(**self._run_args)
811+
_from_config_file = self._parse_config_file(
812+
config_path=self._run_args.get("config_path"),
813+
matcher=list(PipelineRunConfiguration.model_fields.keys()),
814+
)
815+
run_config = PipelineRunConfiguration(**_from_config_file)
805816

806-
self.log_pipeline_deployment_metadata(deployment)
807-
run = create_placeholder_run(deployment=deployment)
817+
with temporary_active_stack(stack_name_or_id=run_config.stack):
818+
stack = Client().active_stack
819+
deployment = self._create_deployment(**self._run_args)
808820

809-
analytics_handler.metadata = self._get_pipeline_analytics_metadata(
810-
deployment=deployment,
811-
stack=stack,
812-
run_id=run.id if run else None,
813-
)
821+
self.log_pipeline_deployment_metadata(deployment)
822+
run = create_placeholder_run(deployment=deployment)
814823

815-
if run:
816-
run_url = dashboard_utils.get_run_url(run)
817-
if run_url:
818-
logger.info(f"Dashboard URL for Pipeline Run: {run_url}")
819-
else:
820-
logger.info(
821-
"You can visualize your pipeline runs in the `ZenML "
822-
"Dashboard`. In order to try it locally, please run "
823-
"`zenml login --local`."
824-
)
824+
analytics_handler.metadata = self._get_pipeline_analytics_metadata(
825+
deployment=deployment,
826+
stack=stack,
827+
run_id=run.id if run else None,
828+
)
825829

826-
deploy_pipeline(
827-
deployment=deployment, stack=stack, placeholder_run=run
828-
)
829-
if run:
830-
return Client().get_pipeline_run(run.id)
831-
return None
830+
if run:
831+
run_url = dashboard_utils.get_run_url(run)
832+
if run_url:
833+
logger.info(f"Dashboard URL for Pipeline Run: {run_url}")
834+
else:
835+
logger.info(
836+
"You can visualize your pipeline runs in the `ZenML "
837+
"Dashboard`. In order to try it locally, please run "
838+
"`zenml login --local`."
839+
)
840+
841+
deploy_pipeline(
842+
deployment=deployment, stack=stack, placeholder_run=run
843+
)
844+
if run:
845+
return Client().get_pipeline_run(run.id)
846+
return None
832847

833848
@staticmethod
834849
def log_pipeline_deployment_metadata(

src/zenml/pipelines/run_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright (c) ZenML GmbH 2023. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
114
"""Utility functions for running pipelines."""
215

316
import time
@@ -248,6 +261,11 @@ def validate_run_config_is_runnable_from_server(
248261
"Can't set schedule when running pipeline via Rest API."
249262
)
250263

264+
if run_configuration.stack:
265+
raise ValueError(
266+
"Can't switch stack when running pipeline via Rest API."
267+
)
268+
251269
if run_configuration.settings.get("docker"):
252270
raise ValueError(
253271
"Can't set DockerSettings when running pipeline via Rest API."

src/zenml/utils/stack_utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) ZenML GmbH 2023. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
"""Utilities for stack."""
15+
16+
import contextlib
17+
from typing import (
18+
TYPE_CHECKING,
19+
Iterator,
20+
Union,
21+
)
22+
23+
if TYPE_CHECKING:
24+
from uuid import UUID
25+
26+
from zenml.stack import Stack
27+
28+
29+
@contextlib.contextmanager
30+
def temporary_active_stack(
31+
stack_name_or_id: Union["UUID", str, None] = None,
32+
) -> Iterator["Stack"]:
33+
"""Contextmanager to temporarily activate a stack.
34+
35+
Args:
36+
stack_name_or_id: The name or ID of the stack to activate. If not given,
37+
this contextmanager will not do anything.
38+
39+
Yields:
40+
The active stack.
41+
"""
42+
from zenml.client import Client
43+
44+
try:
45+
if stack_name_or_id:
46+
old_stack_id = Client().active_stack_model.id
47+
Client().activate_stack(stack_name_or_id)
48+
else:
49+
old_stack_id = None
50+
yield Client().active_stack
51+
finally:
52+
if old_stack_id:
53+
Client().activate_stack(old_stack_id)

tests/integration/functional/cli/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from zenml.cli import cli
1919
from zenml.cli.utils import (
2020
parse_name_and_extra_arguments,
21-
temporary_active_stack,
2221
)
2322
from zenml.client import Client
2423
from zenml.models import (
@@ -27,6 +26,7 @@
2726
UserResponse,
2827
WorkspaceResponse,
2928
)
29+
from zenml.utils.stack_utils import temporary_active_stack
3030
from zenml.utils.string_utils import random_str
3131

3232
SAMPLE_CUSTOM_ARGUMENTS = [

0 commit comments

Comments
 (0)