Skip to content

Commit ca28586

Browse files
authored
Fix filtering by run metadata (#3344)
* Fix filtering artifact versions by metadata * More fixes * Json encode after extracting filter operator * More fixes and tests * Linting * Linting
1 parent 2f521bf commit ca28586

File tree

6 files changed

+109
-20
lines changed

6 files changed

+109
-20
lines changed

src/zenml/client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3821,7 +3821,7 @@ def list_pipeline_runs(
38213821
templatable: Optional[bool] = None,
38223822
tag: Optional[str] = None,
38233823
user: Optional[Union[UUID, str]] = None,
3824-
run_metadata: Optional[Dict[str, str]] = None,
3824+
run_metadata: Optional[Dict[str, Any]] = None,
38253825
pipeline: Optional[Union[UUID, str]] = None,
38263826
code_repository: Optional[Union[UUID, str]] = None,
38273827
model: Optional[Union[UUID, str]] = None,
@@ -3974,6 +3974,7 @@ def list_run_steps(
39743974
user: Optional[Union[UUID, str]] = None,
39753975
model_version_id: Optional[Union[str, UUID]] = None,
39763976
model: Optional[Union[UUID, str]] = None,
3977+
run_metadata: Optional[Dict[str, Any]] = None,
39773978
hydrate: bool = False,
39783979
) -> Page[StepRunResponse]:
39793980
"""List all pipelines.
@@ -4000,6 +4001,7 @@ def list_run_steps(
40004001
cache_key: The cache key of the step run to filter by.
40014002
code_hash: The code hash of the step run to filter by.
40024003
status: The name of the run to filter by.
4004+
run_metadata: Filter by run metadata.
40034005
hydrate: Flag deciding whether to hydrate the output model(s)
40044006
by including metadata fields in the response.
40054007
@@ -4028,6 +4030,7 @@ def list_run_steps(
40284030
user=user,
40294031
model_version_id=model_version_id,
40304032
model=model,
4033+
run_metadata=run_metadata,
40314034
)
40324035
step_run_filter_model.set_scope_workspace(self.active_workspace.id)
40334036
return self.zen_store.list_run_steps(
@@ -4254,7 +4257,7 @@ def list_artifact_versions(
42544257
user: Optional[Union[UUID, str]] = None,
42554258
model: Optional[Union[UUID, str]] = None,
42564259
pipeline_run: Optional[Union[UUID, str]] = None,
4257-
run_metadata: Optional[Dict[str, str]] = None,
4260+
run_metadata: Optional[Dict[str, Any]] = None,
42584261
tag: Optional[str] = None,
42594262
hydrate: bool = False,
42604263
) -> Page[ArtifactVersionResponse]:
@@ -4320,6 +4323,7 @@ def list_artifact_versions(
43204323
user=user,
43214324
model=model,
43224325
pipeline_run=pipeline_run,
4326+
run_metadata=run_metadata,
43234327
)
43244328
artifact_version_filter_model.set_scope_workspace(
43254329
self.active_workspace.id

src/zenml/models/v2/base/filter.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ def generate_query_conditions_from_column(self, column: Any) -> Any:
171171
class StrFilter(Filter):
172172
"""Filter for all string fields."""
173173

174+
json_encode_value: bool = False
175+
174176
ALLOWED_OPS: ClassVar[List[str]] = [
175177
GenericFilterOps.EQUALS,
176178
GenericFilterOps.NOT_EQUALS,
@@ -211,16 +213,6 @@ def generate_query_conditions_from_column(self, column: Any) -> Any:
211213
Raises:
212214
ValueError: the comparison of the column to a numeric value fails.
213215
"""
214-
if self.operation == GenericFilterOps.CONTAINS:
215-
return column.like(f"%{self.value}%")
216-
if self.operation == GenericFilterOps.STARTSWITH:
217-
return column.startswith(f"{self.value}")
218-
if self.operation == GenericFilterOps.ENDSWITH:
219-
return column.endswith(f"{self.value}")
220-
if self.operation == GenericFilterOps.NOT_EQUALS:
221-
return column != self.value
222-
if self.operation == GenericFilterOps.ONEOF:
223-
return column.in_(self.value)
224216
if self.operation in {
225217
GenericFilterOps.GT,
226218
GenericFilterOps.LT,
@@ -254,7 +246,33 @@ def generate_query_conditions_from_column(self, column: Any) -> Any:
254246
f"value '{self.value}' (must be numeric): {e}"
255247
)
256248

257-
return column == self.value
249+
if self.operation == GenericFilterOps.ONEOF:
250+
assert isinstance(self.value, list)
251+
# Convert the list of values to a list of json strings
252+
json_list = (
253+
[json.dumps(v) for v in self.value]
254+
if self.json_encode_value
255+
else self.value
256+
)
257+
return column.in_(json_list)
258+
259+
# Don't convert the value to a json string if the operation is contains
260+
# because the quotes around strings will mess with the comparison
261+
if self.operation == GenericFilterOps.CONTAINS:
262+
return column.like(f"%{self.value}%")
263+
264+
json_value = (
265+
json.dumps(self.value) if self.json_encode_value else self.value
266+
)
267+
268+
if self.operation == GenericFilterOps.STARTSWITH:
269+
return column.startswith(f"{json_value}")
270+
if self.operation == GenericFilterOps.ENDSWITH:
271+
return column.endswith(f"{json_value}")
272+
if self.operation == GenericFilterOps.NOT_EQUALS:
273+
return column != json_value
274+
275+
return column == json_value
258276

259277

260278
class UUIDFilter(StrFilter):
@@ -733,13 +751,15 @@ def generate_custom_query_conditions_for_column(
733751
value: Any,
734752
table: Type[SQLModel],
735753
column: str,
754+
json_encode_value: bool = False,
736755
) -> "ColumnElement[bool]":
737756
"""Generate custom filter conditions for a column of a table.
738757
739758
Args:
740759
value: The filter value.
741760
table: The table which contains the column.
742761
column: The column name.
762+
json_encode_value: Whether to json encode the value.
743763
744764
Returns:
745765
The query conditions.
@@ -748,6 +768,9 @@ def generate_custom_query_conditions_for_column(
748768
filter_ = FilterGenerator(table).define_filter(
749769
column=column, value=value, operator=operator
750770
)
771+
if isinstance(filter_, StrFilter):
772+
filter_.json_encode_value = json_encode_value
773+
751774
return filter_.generate_query_conditions(table=table)
752775

753776
@property

src/zenml/models/v2/core/artifact_version.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ class ArtifactVersionFilter(WorkspaceScopedFilter, TaggableFilter):
559559
description="Name/ID of a pipeline run that is associated with this "
560560
"artifact version.",
561561
)
562-
run_metadata: Optional[Dict[str, str]] = Field(
562+
run_metadata: Optional[Dict[str, Any]] = Field(
563563
default=None,
564564
description="The run_metadata to filter the artifact versions by.",
565565
)
@@ -683,13 +683,19 @@ def get_custom_filters(
683683
RunMetadataResourceSchema.resource_id
684684
== ArtifactVersionSchema.id,
685685
RunMetadataResourceSchema.resource_type
686-
== MetadataResourceTypes.ARTIFACT_VERSION,
686+
== MetadataResourceTypes.ARTIFACT_VERSION.value,
687687
RunMetadataResourceSchema.run_metadata_id
688688
== RunMetadataSchema.id,
689+
self.generate_custom_query_conditions_for_column(
690+
value=key,
691+
table=RunMetadataSchema,
692+
column="key",
693+
),
689694
self.generate_custom_query_conditions_for_column(
690695
value=value,
691696
table=RunMetadataSchema,
692697
column="value",
698+
json_encode_value=True,
693699
),
694700
)
695701
custom_filters.append(additional_filter)

src/zenml/models/v2/core/pipeline_run.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ class PipelineRunFilter(WorkspaceScopedFilter, TaggableFilter):
688688
union_mode="left_to_right",
689689
)
690690
unlisted: Optional[bool] = None
691-
run_metadata: Optional[Dict[str, str]] = Field(
691+
run_metadata: Optional[Dict[str, Any]] = Field(
692692
default=None,
693693
description="The run_metadata to filter the pipeline runs by.",
694694
)
@@ -915,13 +915,19 @@ def get_custom_filters(
915915
RunMetadataResourceSchema.resource_id
916916
== PipelineRunSchema.id,
917917
RunMetadataResourceSchema.resource_type
918-
== MetadataResourceTypes.PIPELINE_RUN,
918+
== MetadataResourceTypes.PIPELINE_RUN.value,
919919
RunMetadataResourceSchema.run_metadata_id
920920
== RunMetadataSchema.id,
921+
self.generate_custom_query_conditions_for_column(
922+
value=key,
923+
table=RunMetadataSchema,
924+
column="key",
925+
),
921926
self.generate_custom_query_conditions_for_column(
922927
value=value,
923928
table=RunMetadataSchema,
924929
column="value",
930+
json_encode_value=True,
925931
),
926932
)
927933
custom_filters.append(additional_filter)

src/zenml/models/v2/core/step_run.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from datetime import datetime
1717
from typing import (
1818
TYPE_CHECKING,
19+
Any,
1920
ClassVar,
2021
Dict,
2122
List,
@@ -574,7 +575,7 @@ class StepRunFilter(WorkspaceScopedFilter):
574575
default=None,
575576
description="Name/ID of the model associated with the step run.",
576577
)
577-
run_metadata: Optional[Dict[str, str]] = Field(
578+
run_metadata: Optional[Dict[str, Any]] = Field(
578579
default=None,
579580
description="The run_metadata to filter the step runs by.",
580581
)
@@ -619,13 +620,19 @@ def get_custom_filters(
619620
additional_filter = and_(
620621
RunMetadataResourceSchema.resource_id == StepRunSchema.id,
621622
RunMetadataResourceSchema.resource_type
622-
== MetadataResourceTypes.STEP_RUN,
623+
== MetadataResourceTypes.STEP_RUN.value,
623624
RunMetadataResourceSchema.run_metadata_id
624625
== RunMetadataSchema.id,
626+
self.generate_custom_query_conditions_for_column(
627+
value=key,
628+
table=RunMetadataSchema,
629+
column="key",
630+
),
625631
self.generate_custom_query_conditions_for_column(
626632
value=value,
627633
table=RunMetadataSchema,
628634
column="value",
635+
json_encode_value=True,
629636
),
630637
)
631638
custom_filters.append(additional_filter)

tests/integration/functional/zen_stores/test_zen_store.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2924,7 +2924,7 @@ def pipeline_to_log_metadata(metadata):
29242924
step_to_log_metadata(metadata)
29252925

29262926

2927-
def test_pipeline_run_filters_with_oneof_and_run_metadata(clean_client):
2927+
def test_pipeline_run_filters_with_oneof(clean_client):
29282928
store = clean_client.zen_store
29292929

29302930
metadata_values = [3, 25, 100, "random_string", True]
@@ -2961,6 +2961,49 @@ def test_pipeline_run_filters_with_oneof_and_run_metadata(clean_client):
29612961
PipelineRunFilter(name="oneof:random_value")
29622962

29632963

2964+
def test_run_metadata_filtering(clean_client):
2965+
store = clean_client.zen_store
2966+
2967+
metadata_values = [3, 25, 100, "random_string", True]
2968+
2969+
for v in metadata_values:
2970+
pipeline_to_log_metadata(v)
2971+
2972+
# Test run metadata filtering with string value
2973+
runs_filter = StepRunFilter(run_metadata={"blupus": "random_string"})
2974+
runs = store.list_run_steps(step_run_filter_model=runs_filter)
2975+
assert len(runs.items) == 1
2976+
2977+
# Test run metadata filtering with boolean value
2978+
runs_filter = StepRunFilter(run_metadata={"blupus": True})
2979+
runs = store.list_run_steps(step_run_filter_model=runs_filter)
2980+
assert len(runs.items) == 1
2981+
2982+
# Test run metadata filtering with int value
2983+
runs_filter = StepRunFilter(run_metadata={"blupus": 3})
2984+
runs = store.list_run_steps(step_run_filter_model=runs_filter)
2985+
assert len(runs.items) == 1
2986+
2987+
# Test run metadata filtering for a non-existent key
2988+
runs_filter = StepRunFilter(run_metadata={"non-existent": 3})
2989+
runs = store.list_run_steps(step_run_filter_model=runs_filter)
2990+
assert len(runs) == 0
2991+
2992+
# Test run metadata filtering with non-existent value
2993+
runs_filter = StepRunFilter(run_metadata={"blupus": "non-existent"})
2994+
runs = store.list_run_steps(step_run_filter_model=runs_filter)
2995+
assert len(runs.items) == 0
2996+
2997+
# Test run metadata filtering with operator value
2998+
runs_filter = StepRunFilter(run_metadata={"blupus": "lt:30"})
2999+
runs = store.list_run_steps(step_run_filter_model=runs_filter)
3000+
assert len(runs) == 2 # The run with 3 and 25
3001+
3002+
for r in runs:
3003+
assert isinstance(r.run_metadata["blupus"], int)
3004+
assert r.run_metadata["blupus"] < 30
3005+
3006+
29643007
# .--------------------.
29653008
# | Pipeline run steps |
29663009
# '--------------------'

0 commit comments

Comments
 (0)