Skip to content

Commit 84aefb8

Browse files
committed
Use ParamSpec for more clarity in the typing of BaseStep
1 parent d4e3cdd commit 84aefb8

File tree

10 files changed

+78
-44
lines changed

10 files changed

+78
-44
lines changed

src/zenml/integrations/huggingface/steps/accelerate_runner.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,17 @@
1717
"""Step function to run any ZenML step using Accelerate."""
1818

1919
import functools
20-
from typing import Any, Callable, Dict, Optional, TypeVar, Union, cast
20+
from typing import (
21+
Any,
22+
Callable,
23+
Dict,
24+
Optional,
25+
ParamSpec,
26+
TypeVar,
27+
Union,
28+
cast,
29+
overload,
30+
)
2131

2232
import cloudpickle as pickle
2333
from accelerate.commands.launch import (
@@ -32,13 +42,27 @@
3242

3343
logger = get_logger(__name__)
3444
F = TypeVar("F", bound=Callable[..., Any])
35-
T = TypeVar("T", bound=BaseStep[Any])
45+
P = ParamSpec("P")
46+
R = TypeVar("R")
47+
48+
49+
@overload
50+
def run_with_accelerate(
51+
**accelerate_launch_kwargs: Any,
52+
) -> Callable[[BaseStep[P, R]], BaseStep[P, R]]: ...
53+
54+
55+
@overload
56+
def run_with_accelerate(
57+
step_function_top_level: BaseStep[P, R],
58+
/,
59+
) -> BaseStep[P, R]: ...
3660

3761

3862
def run_with_accelerate(
39-
step_function_top_level: Optional[T] = None,
63+
step_function_top_level: Optional[BaseStep[P, R]] = None,
4064
**accelerate_launch_kwargs: Any,
41-
) -> Union[Callable[[T], T], T]:
65+
) -> Union[Callable[[BaseStep[P, R]], BaseStep[P, R]], BaseStep[P, R]]:
4266
"""Run a function with accelerate.
4367
4468
Accelerate package: https://huggingface.co/docs/accelerate/en/index
@@ -71,9 +95,10 @@ def training_pipeline(some_param: int, ...):
7195
The accelerate-enabled version of the step.
7296
"""
7397

74-
def _decorator(step_function: T) -> T:
98+
def _decorator(step_function: BaseStep[P, R]) -> BaseStep[P, R]:
7599
def _wrapper(
76-
entrypoint: F, accelerate_launch_kwargs: Dict[str, Any]
100+
entrypoint: F,
101+
accelerate_launch_kwargs: Dict[str, Any],
77102
) -> F:
78103
@functools.wraps(entrypoint)
79104
def inner(*args: Any, **kwargs: Any) -> Any:

src/zenml/integrations/whylogs/steps/whylogs_profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def get_whylogs_profiler_step(
5858
dataset_timestamp: Optional[datetime.datetime] = None,
5959
dataset_id: Optional[str] = None,
6060
enable_whylabs: bool = True,
61-
) -> BaseStep[Any]:
61+
) -> BaseStep[..., Any]:
6262
"""Shortcut function to create a new instance of the WhylogsProfilerStep step.
6363
6464
The returned WhylogsProfilerStep can be used in a pipeline to generate a

src/zenml/orchestrators/step_run_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _get_docstring_and_source_code_from_step_instance(
185185
"""
186186
from zenml.steps.base_step import BaseStep
187187

188-
step_instance = BaseStep[Any].load_from_source(step.spec.source)
188+
step_instance = BaseStep[..., Any].load_from_source(step.spec.source)
189189

190190
docstring = step_instance.docstring
191191
if docstring and len(docstring) > TEXT_FIELD_MAX_LENGTH:

src/zenml/orchestrators/step_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,15 +302,17 @@ def _evaluate_artifact_names_in_collections(
302302
for d in collections:
303303
d[name] = d.pop(k)
304304

305-
def _load_step(self) -> "BaseStep[Any]":
305+
def _load_step(self) -> "BaseStep[..., Any]":
306306
"""Load the step instance.
307307
308308
Returns:
309309
The step instance.
310310
"""
311311
from zenml.steps import BaseStep
312312

313-
step_instance = BaseStep[Any].load_from_source(self._step.spec.source)
313+
step_instance = BaseStep[..., Any].load_from_source(
314+
self._step.spec.source
315+
)
314316
step_instance = copy.deepcopy(step_instance)
315317
step_instance._configuration = self._step.config
316318
return step_instance

src/zenml/pipelines/pipeline_definition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,7 +1138,7 @@ def _compute_unique_identifier(self, pipeline_spec: PipelineSpec) -> str:
11381138

11391139
def add_step_invocation(
11401140
self,
1141-
step: "BaseStep[Any]",
1141+
step: "BaseStep[..., Any]",
11421142
input_artifacts: Dict[str, StepArtifact],
11431143
external_artifacts: Dict[
11441144
str, Union["ExternalArtifact", "ArtifactVersionResponse"]
@@ -1208,7 +1208,7 @@ def add_step_invocation(
12081208

12091209
def _compute_invocation_id(
12101210
self,
1211-
step: "BaseStep[Any]",
1211+
step: "BaseStep[..., Any]",
12121212
custom_id: Optional[str] = None,
12131213
allow_suffix: bool = True,
12141214
) -> str:

src/zenml/steps/base_step.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616
import copy
1717
import hashlib
1818
import inspect
19+
from abc import abstractmethod
1920
from collections import defaultdict
2021
from typing import (
2122
TYPE_CHECKING,
2223
Any,
23-
Callable,
2424
Dict,
2525
Generic,
2626
List,
2727
Mapping,
2828
Optional,
29-
Protocol,
29+
ParamSpec,
3030
Sequence,
3131
Tuple,
3232
Type,
@@ -93,24 +93,11 @@
9393

9494
logger = get_logger(__name__)
9595

96-
T = TypeVar("T", bound="BaseStep[Any]")
97-
F = TypeVar("F", bound=Callable[..., Any])
96+
P = ParamSpec("P")
97+
R = TypeVar("R")
9898

9999

100-
class _AbstractEntrypoint(Protocol[F]):
101-
entrypoint: F
102-
"""Abstract method for core step logic.
103-
104-
Args:
105-
*args: Positional arguments passed to the step.
106-
**kwargs: Keyword arguments passed to the step.
107-
108-
Returns:
109-
The output of the step.
110-
"""
111-
112-
113-
class BaseStep(Generic[F], _AbstractEntrypoint[F]):
100+
class BaseStep(Generic[P, R]):
114101
"""Abstract base class for all ZenML steps."""
115102

116103
def __init__(
@@ -227,8 +214,20 @@ def __init__(
227214

228215
notebook_utils.try_to_save_notebook_cell_code(self.source_object)
229216

217+
@abstractmethod
218+
def entrypoint(self, *args: P.args, **kwargs: P.kwargs) -> R:
219+
"""Abstract method for core step logic.
220+
221+
Args:
222+
*args: Positional arguments passed to the step.
223+
**kwargs: Keyword arguments passed to the step.
224+
225+
Returns:
226+
The output of the step.
227+
"""
228+
230229
@classmethod
231-
def load_from_source(cls, source: Union[Source, str]) -> "BaseStep[F]":
230+
def load_from_source(cls, source: Union[Source, str]) -> "BaseStep[P, R]":
232231
"""Loads a step from source.
233232
234233
Args:
@@ -585,7 +584,7 @@ def configuration(self) -> "PartialStepConfiguration":
585584
return self._configuration
586585

587586
def configure(
588-
self: T,
587+
self: "BaseStep[P,R]",
589588
enable_cache: Optional[bool] = None,
590589
enable_artifact_metadata: Optional[bool] = None,
591590
enable_artifact_visualization: Optional[bool] = None,
@@ -604,7 +603,7 @@ def configure(
604603
merge: bool = True,
605604
retry: Optional[StepRetryConfig] = None,
606605
substitutions: Optional[Dict[str, str]] = None,
607-
) -> T:
606+
) -> "BaseStep[P,R]":
608607
"""Configures the step.
609608
610609
Configuration merging example:
@@ -737,7 +736,7 @@ def with_options(
737736
model: Optional["Model"] = None,
738737
merge: bool = True,
739738
substitutions: Optional[Dict[str, str]] = None,
740-
) -> "BaseStep[F]":
739+
) -> "BaseStep[P, R]":
741740
"""Copies the step and applies the given configurations.
742741
743742
Args:
@@ -793,7 +792,7 @@ def with_options(
793792
)
794793
return step_copy
795794

796-
def copy(self) -> "BaseStep[F]":
795+
def copy(self) -> "BaseStep[P, R]":
797796
"""Copies the step.
798797
799798
Returns:

src/zenml/steps/decorated_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from zenml.steps import BaseStep
2020

2121

22-
class _DecoratedStep(BaseStep[Any]):
22+
class _DecoratedStep(BaseStep[..., Any]):
2323
"""Internal BaseStep subclass used by the step decorator."""
2424

2525
@property

src/zenml/steps/step_decorator.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Dict,
2121
Mapping,
2222
Optional,
23+
ParamSpec,
2324
Sequence,
2425
Type,
2526
TypeVar,
@@ -47,13 +48,15 @@
4748
Mapping[str, Sequence[MaterializerClassOrSource]],
4849
]
4950
F = TypeVar("F", bound=Callable[..., Any])
51+
P = ParamSpec("P")
52+
R = TypeVar("R")
5053

5154

5255
logger = get_logger(__name__)
5356

5457

5558
@overload
56-
def step(_func: "F") -> "BaseStep[F]": ...
59+
def step(_func: Callable[P, R]) -> "BaseStep[P,R]": ...
5760

5861

5962
@overload
@@ -74,11 +77,11 @@ def step(
7477
model: Optional["Model"] = None,
7578
retry: Optional["StepRetryConfig"] = None,
7679
substitutions: Optional[Dict[str, str]] = None,
77-
) -> Callable[["F"], "BaseStep[F]"]: ...
80+
) -> Callable[[Callable[P, R]], "BaseStep[P,R]"]: ...
7881

7982

8083
def step(
81-
_func: Optional["F"] = None,
84+
_func: Optional[Callable[P, R]] = None,
8285
*,
8386
name: Optional[str] = None,
8487
enable_cache: Optional[bool] = None,
@@ -95,7 +98,7 @@ def step(
9598
model: Optional["Model"] = None,
9699
retry: Optional["StepRetryConfig"] = None,
97100
substitutions: Optional[Dict[str, str]] = None,
98-
) -> Union["BaseStep[F]", Callable[["F"], "BaseStep[F]"]]:
101+
) -> Union["BaseStep[P,R]", Callable[[Callable[P, R]], "BaseStep[P,R]"]]:
99102
"""Decorator to create a ZenML step.
100103
101104
Args:
@@ -132,10 +135,10 @@ def step(
132135
The step instance.
133136
"""
134137

135-
def inner_decorator(func: "F") -> "BaseStep[F]":
138+
def inner_decorator(func: Callable[P, R]) -> "BaseStep[P,R]":
136139
from zenml.steps.decorated_step import _DecoratedStep
137140

138-
class_: Type["BaseStep[F]"] = type(
141+
class_: Type["BaseStep[P,R]"] = type(
139142
func.__name__,
140143
(_DecoratedStep,),
141144
{

src/zenml/steps/step_invocation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class StepInvocation:
3333
def __init__(
3434
self,
3535
id: str,
36-
step: "BaseStep[Any]",
36+
step: "BaseStep[..., Any]",
3737
input_artifacts: Dict[str, "StepArtifact"],
3838
external_artifacts: Dict[
3939
str, Union["ExternalArtifact", "ArtifactVersionResponse"]

src/zenml/steps/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
Callable,
2525
Dict,
2626
Optional,
27+
ParamSpec,
2728
Tuple,
29+
TypeVar,
2830
Union,
2931
)
3032
from uuid import UUID
@@ -49,6 +51,9 @@
4951
if TYPE_CHECKING:
5052
from zenml.steps import BaseStep
5153

54+
P = ParamSpec("P")
55+
R = TypeVar("R")
56+
5257

5358
logger = get_logger(__name__)
5459

@@ -499,7 +504,7 @@ def log_step_metadata(
499504

500505

501506
def run_as_single_step_pipeline(
502-
__step: "BaseStep[Any]", *args: Any, **kwargs: Any
507+
__step: "BaseStep[P,R]", *args: Any, **kwargs: Any
503508
) -> Any:
504509
"""Runs the step as a single step pipeline.
505510

0 commit comments

Comments
 (0)