Skip to content

Commit 24b139f

Browse files
committed
refactoring table naming methods
1 parent 471caaa commit 24b139f

9 files changed

Lines changed: 88 additions & 56 deletions

File tree

src/datachain/catalog/catalog.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tqdm.auto import tqdm
2626

2727
from datachain.cache import Cache
28+
from datachain.checkpoint import Checkpoint
2829
from datachain.client import Client
2930
from datachain.dataset import (
3031
DATASET_PREFIX,
@@ -2096,22 +2097,15 @@ def index(
20962097
):
20972098
pass
20982099

2099-
def _cleanup_udf_tables(self, prefix: str, suffix: str) -> None:
2100-
"""Remove all UDF tables matching a prefix and optional suffix."""
2101-
tables = self.warehouse.db.list_tables(prefix=prefix)
2102-
tables = [t for t in tables if t.endswith(suffix)]
2103-
if tables:
2104-
logger.info("Removing %d UDF tables: %s", len(tables), tables)
2105-
self.warehouse.cleanup_tables(tables)
2106-
21072100
def cleanup_checkpoints(self, ttl_seconds: int | None = None) -> int:
21082101
"""Clean up outdated checkpoints and their associated UDF tables.
21092102
21102103
Algorithm:
21112104
1. Find inactive jobs (all checkpoints older than TTL) — single query
2112-
2. For each inactive job: remove output/partial tables and checkpoints
2105+
2. Collect output/partial table names from checkpoints and remove them
21132106
3. For run groups where all member jobs are inactive: also remove
2114-
shared input tables
2107+
shared input tables (prefix-based, since input hash is not on checkpoint)
2108+
4. Soft-delete checkpoint metadata
21152109
"""
21162110
if ttl_seconds is None:
21172111
ttl_seconds = CHECKPOINTS_TTL
@@ -2127,24 +2121,37 @@ def cleanup_checkpoints(self, ttl_seconds: int | None = None) -> int:
21272121
inactive_job_ids = [job.id for job in inactive_jobs]
21282122
run_group_ids = {job.run_group_id for job in inactive_jobs if job.run_group_id}
21292123

2124+
checkpoints = list(self.metastore.list_checkpoints(job_id=inactive_job_ids))
2125+
21302126
logger.info(
2131-
"Cleaning %d inactive jobs across %d run groups",
2127+
"Cleaning %d inactive jobs across %d run groups (%d checkpoints)",
21322128
len(inactive_jobs),
21332129
len(run_group_ids),
2130+
len(checkpoints),
21342131
)
21352132

2136-
for job in inactive_jobs:
2137-
self._cleanup_udf_tables(f"udf_{job.id}_", "_output")
2138-
self._cleanup_udf_tables(f"udf_{job.id}_", "_output_partial")
2133+
# Remove output/partial tables using exact names from checkpoints
2134+
tables = [ch.table_name for ch in checkpoints]
2135+
if tables:
2136+
logger.info("Removing %d UDF output tables: %s", len(tables), tables)
2137+
self.warehouse.cleanup_tables(tables)
21392138

21402139
# Shared input tables — only when entire run group is inactive
21412140
for group_id in run_group_ids:
21422141
if not self.metastore.has_active_checkpoints_in_run_group(
21432142
group_id, ttl_threshold
21442143
):
2145-
self._cleanup_udf_tables(f"udf_{group_id}_", "_input")
2144+
input_tables = self.warehouse.db.list_tables(
2145+
pattern=Checkpoint.input_table_pattern(group_id)
2146+
)
2147+
if input_tables:
2148+
logger.info(
2149+
"Removing %d shared input tables: %s",
2150+
len(input_tables),
2151+
input_tables,
2152+
)
2153+
self.warehouse.cleanup_tables(input_tables)
21462154

2147-
checkpoints = list(self.metastore.list_checkpoints(job_id=inactive_job_ids))
21482155
self.metastore.remove_checkpoints([ch.id for ch in checkpoints])
21492156

21502157
logger.info(

src/datachain/checkpoint.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,33 @@ class Checkpoint:
3232
created_at: datetime
3333
status: int = CheckpointStatus.ACTIVE
3434

35+
@staticmethod
36+
def output_table_name(job_id: str, _hash: str) -> str:
37+
"""Final UDF output table. Job-specific, created when UDF completes."""
38+
return f"udf_{job_id}_{_hash}_output"
39+
40+
@staticmethod
41+
def partial_output_table_name(job_id: str, _hash: str) -> str:
42+
"""Partial UDF output table. Temporary, renamed to final on completion."""
43+
return f"udf_{job_id}_{_hash}_output_partial"
44+
45+
@staticmethod
46+
def input_table_name(group_id: str, _hash: str) -> str:
47+
"""Shared UDF input table. Scoped to run group, reused across jobs."""
48+
return f"udf_{group_id}_{_hash}_input"
49+
50+
@staticmethod
51+
def input_table_pattern(group_id: str) -> str:
52+
"""LIKE pattern for finding all input tables in a run group."""
53+
return f"udf_{group_id}_%_input"
54+
55+
@property
56+
def table_name(self) -> str:
57+
"""UDF output table name associated with this checkpoint."""
58+
if self.partial:
59+
return self.partial_output_table_name(self.job_id, self.hash)
60+
return self.output_table_name(self.job_id, self.hash)
61+
3562
@classmethod
3663
def parse(
3764
cls,

src/datachain/data_storage/db_engine.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,15 @@ def has_table(self, name: str) -> bool:
123123
return sa.inspect(self.engine).has_table(name)
124124

125125
@abstractmethod
126-
def list_tables(self, prefix: str = "") -> list[str]:
126+
def list_tables(self, pattern: str = "") -> list[str]:
127127
"""
128-
List all table names, optionally filtered by prefix.
128+
List all table names, optionally filtered by a SQL LIKE pattern.
129129
130130
Args:
131-
prefix: Optional prefix to filter table names
131+
pattern: SQL LIKE pattern to filter table names (e.g. 'udf_%')
132132
133133
Returns:
134-
List of table names matching the prefix
134+
List of table names matching the pattern
135135
"""
136136

137137
@abstractmethod

src/datachain/data_storage/sqlite.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,8 @@ def execute_str(self, sql: str, parameters=None) -> sqlite3.Cursor:
285285
return self.db.execute(sql)
286286
return self.db.execute(sql, parameters)
287287

288-
def list_tables(self, prefix: str = "") -> list[str]:
289-
"""List all table names, optionally filtered by prefix."""
288+
def list_tables(self, pattern: str = "") -> list[str]:
289+
"""List all table names, optionally filtered by a SQL LIKE pattern."""
290290
sqlite_master = sqlalchemy.table(
291291
"sqlite_master",
292292
sqlalchemy.column("type"),
@@ -295,8 +295,8 @@ def list_tables(self, prefix: str = "") -> list[str]:
295295
query = sqlalchemy.select(sqlite_master.c.name).where(
296296
sqlite_master.c.type == "table"
297297
)
298-
if prefix:
299-
query = query.where(sqlite_master.c.name.like(f"{prefix}%"))
298+
if pattern:
299+
query = query.where(sqlite_master.c.name.like(pattern))
300300
result = self.execute(query)
301301
return [row[0] for row in result.fetchall()]
302302

src/datachain/query/dataset.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,6 @@
101101
T = TypeVar("T", bound="DatasetQuery")
102102

103103

104-
def udf_input_table_name(run_group_id: str, _hash: str) -> str:
105-
return f"udf_{run_group_id}_{_hash}_input"
106-
107-
108-
def udf_output_table_name(job_id: str, _hash: str) -> str:
109-
return f"udf_{job_id}_{_hash}_output"
110-
111-
112-
def udf_partial_output_table_name(job_id: str, _hash: str) -> str:
113-
return f"udf_{job_id}_{_hash}_output_partial"
114-
115-
116104
def detach(
117105
method: "Callable[Concatenate[T, P], T]",
118106
) -> "Callable[Concatenate[T, P], T]":
@@ -970,7 +958,7 @@ def get_or_create_input_table(self, query: Select, _hash: str) -> "Table":
970958
Returns the input table.
971959
"""
972960
group_id = self.job.run_group_id or self.job.id
973-
input_table_name = udf_input_table_name(group_id, _hash)
961+
input_table_name = Checkpoint.input_table_name(group_id, _hash)
974962

975963
# Check if input table already exists (created by ancestor job)
976964
if self.warehouse.db.has_table(input_table_name):
@@ -1063,10 +1051,10 @@ def _skip_udf(
10631051
checkpoint.job_id,
10641052
)
10651053
existing_output_table = self.warehouse.get_table(
1066-
udf_output_table_name(checkpoint.job_id, checkpoint.hash)
1054+
Checkpoint.output_table_name(checkpoint.job_id, checkpoint.hash)
10671055
)
10681056
output_table = self.warehouse.create_table_from_query(
1069-
udf_output_table_name(self.job.id, checkpoint.hash),
1057+
Checkpoint.output_table_name(self.job.id, checkpoint.hash),
10701058
sa.select(existing_output_table),
10711059
create_fn=self.create_output_table,
10721060
)
@@ -1140,7 +1128,7 @@ def _run_from_scratch(
11401128
input_table = self.get_or_create_input_table(query, hash_input)
11411129

11421130
partial_output_table = self.create_output_table(
1143-
udf_partial_output_table_name(self.job.id, partial_hash),
1131+
Checkpoint.partial_output_table_name(self.job.id, partial_hash),
11441132
)
11451133

11461134
if self.partition_by is not None:
@@ -1151,7 +1139,7 @@ def _run_from_scratch(
11511139
self.populate_udf_output_table(partial_output_table, input_query)
11521140

11531141
output_table = self.warehouse.rename_table(
1154-
partial_output_table, udf_output_table_name(self.job.id, hash_output)
1142+
partial_output_table, Checkpoint.output_table_name(self.job.id, hash_output)
11551143
)
11561144

11571145
if partial_checkpoint:
@@ -1218,7 +1206,7 @@ def _continue_udf(
12181206

12191207
try:
12201208
parent_partial_table = self.warehouse.get_table(
1221-
udf_partial_output_table_name(
1209+
Checkpoint.partial_output_table_name(
12221210
self.job.rerun_from_job_id, checkpoint.hash
12231211
)
12241212
)
@@ -1246,7 +1234,9 @@ def _continue_udf(
12461234
len(incomplete_input_ids),
12471235
)
12481236

1249-
partial_table_name = udf_partial_output_table_name(self.job.id, checkpoint.hash)
1237+
partial_table_name = Checkpoint.partial_output_table_name(
1238+
self.job.id, checkpoint.hash
1239+
)
12501240
if incomplete_input_ids:
12511241
# Filter out incomplete inputs - they will be re-processed
12521242
filtered_query = sa.select(parent_partial_table).where(
@@ -1288,7 +1278,7 @@ def _continue_udf(
12881278
)
12891279

12901280
output_table = self.warehouse.rename_table(
1291-
partial_table, udf_output_table_name(self.job.id, hash_output)
1281+
partial_table, Checkpoint.output_table_name(self.job.id, hash_output)
12921282
)
12931283

12941284
self.metastore.remove_checkpoints([partial_checkpoint.id])

tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,9 @@ def cleanup_udf_tables(warehouse):
220220
UDF tables are shared across jobs and persist after chain finishes,
221221
so we need to clean them up after each test to prevent interference.
222222
"""
223-
for table_name in warehouse.db.list_tables(prefix=warehouse.UDF_TABLE_NAME_PREFIX):
223+
for table_name in warehouse.db.list_tables(
224+
pattern=f"{warehouse.UDF_TABLE_NAME_PREFIX}%"
225+
):
224226
table = warehouse.db.get_table(table_name)
225227
warehouse.db.drop_table(table, if_exists=True)
226228

tests/func/checkpoints/test_checkpoint_cleanup.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_cleanup_checkpoints_with_ttl(test_session, nums_dataset):
2424
job_id = test_session.get_or_create_job().id
2525

2626
assert len(list(metastore.list_checkpoints(job_id))) == 4
27-
assert len(warehouse.db.list_tables(prefix="udf_")) > 0
27+
assert len(warehouse.db.list_tables(pattern="udf_%")) > 0
2828

2929
# Make all checkpoints older than the default TTL (4h)
3030
old_time = datetime.now(timezone.utc) - timedelta(hours=5)
@@ -33,7 +33,7 @@ def test_cleanup_checkpoints_with_ttl(test_session, nums_dataset):
3333
catalog.cleanup_checkpoints()
3434

3535
assert len(list(metastore.list_checkpoints(job_id))) == 0
36-
assert len(warehouse.db.list_tables(prefix=f"udf_{job_id}_")) == 0
36+
assert len(warehouse.db.list_tables(pattern=f"udf_{job_id}_%")) == 0
3737

3838

3939
def test_cleanup_checkpoints_with_custom_ttl(test_session, nums_dataset):
@@ -101,7 +101,7 @@ def test_cleanup_does_not_remove_unrelated_tables(test_session, nums_dataset):
101101
catalog.cleanup_checkpoints()
102102

103103
assert len(list(metastore.list_checkpoints(job_id))) == 0
104-
assert len(warehouse.db.list_tables(prefix=f"udf_{job_id}_")) == 0
104+
assert len(warehouse.db.list_tables(pattern=f"udf_{job_id}_%")) == 0
105105
assert warehouse.db.has_table(fake_table_name)
106106

107107
warehouse.cleanup_tables([fake_table_name])
@@ -194,9 +194,6 @@ def test_cleanup_preserves_input_tables_when_run_group_active(
194194
warehouse.create_udf_table(name=input_table)
195195
assert warehouse.db.has_table(input_table)
196196

197-
output_table = f"udf_{job1_id}_somehash_output"
198-
warehouse.create_udf_table(name=output_table)
199-
200197
# Create a second job in the same run group with active checkpoints
201198
reset_session_job_state()
202199
job2_id = metastore.create_job(
@@ -208,6 +205,14 @@ def test_cleanup_preserves_input_tables_when_run_group_active(
208205
monkeypatch.setenv("DATACHAIN_JOB_ID", job2_id)
209206
chain.map(tripled=lambda num: num * 3, output=int).save("nums_tripled")
210207

208+
# Record job1's UDF output tables before cleanup
209+
job1_output_tables = [
210+
t
211+
for t in warehouse.db.list_tables(pattern=f"udf_{job1_id}_%")
212+
if t.endswith("_output")
213+
]
214+
assert len(job1_output_tables) > 0
215+
211216
# Make only job 1's checkpoints outdated
212217
old_time = datetime.now(timezone.utc) - timedelta(hours=5)
213218
for ch in metastore.list_checkpoints(job1_id):
@@ -219,7 +224,8 @@ def test_cleanup_preserves_input_tables_when_run_group_active(
219224

220225
catalog.cleanup_checkpoints()
221226

222-
assert not warehouse.db.has_table(output_table)
227+
# Job 1's output tables removed
228+
assert all(not warehouse.db.has_table(t) for t in job1_output_tables)
223229
assert len(list(metastore.list_checkpoints(job1_id))) == 0
224230
assert len(list(metastore.list_checkpoints(job2_id))) > 0
225231
# Shared input table preserved (run group still active)

tests/func/checkpoints/test_checkpoint_recovery.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def mapper(num) -> int:
586586
catalog = test_session.catalog
587587

588588
# Drop all UDF output tables from first run
589-
for table_name in catalog.warehouse.db.list_tables(prefix="udf_"):
589+
for table_name in catalog.warehouse.db.list_tables(pattern="udf_%"):
590590
if "_output" in table_name and "_partial" not in table_name:
591591
table = catalog.warehouse.db.get_table(table_name)
592592
catalog.warehouse.db.drop_table(table, if_exists=True)
@@ -622,7 +622,7 @@ def mapper(num) -> int:
622622
test_session.get_or_create_job()
623623

624624
# Drop all partial output tables from first run
625-
for table_name in catalog.warehouse.db.list_tables(prefix="udf_"):
625+
for table_name in catalog.warehouse.db.list_tables(pattern="udf_%"):
626626
if "_partial" in table_name:
627627
table = catalog.warehouse.db.get_table(table_name)
628628
catalog.warehouse.db.drop_table(table, if_exists=True)

tests/unit/test_data_storage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def test_list_tables(catalog):
9595
tables_after = db.list_tables()
9696
assert "test_list_tables_abc" in tables_after
9797

98-
# Test with prefix filter
99-
filtered = db.list_tables(prefix="test_list_tables")
98+
# Test with pattern filter
99+
filtered = db.list_tables(pattern="test_list_tables%")
100100
assert "test_list_tables_abc" in filtered
101101
finally:
102102
db.drop_table(table)

0 commit comments

Comments
 (0)