Skip to content

Commit 050bf22

Browse files
committed
review and refactor export, fix tests again
1 parent 931406e commit 050bf22

6 files changed

Lines changed: 244 additions & 123 deletions

File tree

src/datachain/client/local.py

Lines changed: 37 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -22,53 +22,6 @@ class FileClient(Client):
2222
PREFIX = "file://"
2323
protocol = "file"
2424

25-
@classmethod
26-
def is_path_in(
27-
cls,
28-
output: str | os.PathLike[str],
29-
dst: str,
30-
) -> bool:
31-
"""Return whether `dst` is safe to write under local `output`.
32-
33-
Accepts both plain OS paths and `file://` urlpaths. For other schemes,
34-
returns False.
35-
"""
36-
37-
from fsspec.utils import stringify_path
38-
39-
output_str = stringify_path(output)
40-
41-
# Only handle local paths + file:// urlpaths.
42-
if "://" in dst and not dst.startswith(cls.PREFIX):
43-
return False
44-
if "://" in output_str and not output_str.startswith(cls.PREFIX):
45-
return False
46-
47-
output_os = (
48-
LocalFileSystem._strip_protocol(output_str)
49-
if output_str.startswith(cls.PREFIX)
50-
else output_str
51-
)
52-
dst_os = (
53-
LocalFileSystem._strip_protocol(dst) if dst.startswith(cls.PREFIX) else dst
54-
)
55-
56-
# Use abspath (makes relative paths absolute based on CWD, collapses
57-
# .. and .) followed by normcase (lowercases on Windows) for deterministic,
58-
# filesystem-independent comparison. resolve(strict=False) is avoided
59-
# because its symlink / junction behaviour varies across Python versions
60-
# and can differ depending on which path components already exist.
61-
output_normed = os.path.normcase(os.path.abspath(output_os))
62-
dst_normed = os.path.normcase(os.path.abspath(dst_os))
63-
64-
# Destination must be a file path under output, not the output dir itself.
65-
if dst_normed == output_normed:
66-
return False
67-
68-
# Ensure dst starts with the output prefix followed by a separator,
69-
# so that output="/foo/bar" does not match dst="/foo/bar2".
70-
return dst_normed.startswith(output_normed + os.sep)
71-
7225
def __init__(
7326
self,
7427
name: str,
@@ -171,7 +124,10 @@ async def get_current_etag(self, file: "File") -> str:
171124
@classmethod
172125
def rel_path_for_file(cls, file: "File") -> str:
173126
path = file.path
174-
cls.validate_local_relpath(file.source, path)
127+
try:
128+
cls.validate_file_relpath(path)
129+
except ValueError as exc:
130+
raise FileError(str(exc), file.source, path) from None
175131

176132
normpath = os.path.normpath(path)
177133
normpath = Path(normpath).as_posix()
@@ -184,43 +140,57 @@ def rel_path_for_file(cls, file: "File") -> str:
184140

185141
return normpath
186142

187-
@classmethod
188-
def validate_local_relpath(cls, source: str, path: str) -> None:
189-
"""Validate a local relative path string.
143+
@staticmethod
144+
def validate_file_relpath(path: str) -> None:
145+
"""Ensure a file's relative path is safe to use on the local filesystem.
190146
191-
Used both for local `File.path` values and for local export destination
192-
suffixes. Reject inputs that would require normalization (e.g. empty
193-
segments) or that could be unsafe on the local filesystem.
147+
Rejects absolute, traversal, and malformed paths that could escape the
148+
intended directory or require implicit normalization. On Windows,
149+
backslashes and drive-letter prefixes are handled as separators/absolute.
194150
"""
195151

196152
if not path:
197-
raise FileError("path must not be empty", source, path)
153+
raise ValueError(f"unsafe file path {path!r}: must not be empty")
198154

199-
if path.endswith("/"):
200-
raise FileError("path must not be a directory", source, path)
155+
# On Windows, backslash is a path separator — normalize so the
156+
# checks below work uniformly. On Linux/macOS, backslash is a legal
157+
# filename character and must not be reinterpreted as a separator.
158+
if os.name == "nt":
159+
canonical = path.replace("\\", "/")
160+
else:
161+
canonical = path
201162

202-
raw_posix = path.replace("\\", "/")
163+
if canonical.endswith("/"):
164+
raise ValueError(f"unsafe file path {path!r}: must not be a directory")
203165

204166
# Disallow absolute paths; local file paths are interpreted relative to
205167
# the source/output prefix.
206-
if raw_posix.startswith("/"):
207-
raise FileError("path must not be absolute", source, path)
168+
if canonical.startswith("/"):
169+
raise ValueError(f"unsafe file path {path!r}: must not be absolute")
208170

209171
# On Windows, a drive-letter prefix like "C:/" is absolute even
210-
# without a leading "/".
211-
if len(raw_posix) >= 2 and raw_posix[0].isalpha() and raw_posix[1] == ":":
212-
raise FileError("path must not be absolute", source, path)
172+
# without a leading "/". On Unix, colons are legal in filenames,
173+
# so only enforce this on Windows.
174+
if (
175+
os.name == "nt"
176+
and len(canonical) >= 2
177+
and canonical[0].isalpha()
178+
and canonical[1] == ":"
179+
):
180+
raise ValueError(f"unsafe file path {path!r}: must not be absolute")
213181

214182
# Disallow empty segments (e.g. 'dir//file.txt') to avoid implicit
215183
# normalization.
216-
if "//" in raw_posix:
217-
raise FileError("path must not contain empty segments", source, path)
184+
if "//" in canonical:
185+
raise ValueError(
186+
f"unsafe file path {path!r}: must not contain empty segments"
187+
)
218188

219189
# Disallow dot segments like '.' or '..' (even if they could be
220190
# normalized away) because they can make I/O and exports unsafe.
221-
raw_parts = raw_posix.split("/")
191+
raw_parts = canonical.split("/")
222192
if any(part in (".", "..") for part in raw_parts):
223-
raise FileError("path must not contain '.' or '..'", source, path)
193+
raise ValueError(f"unsafe file path {path!r}: must not contain '.' or '..'")
224194

225195
@staticmethod
226196
def _has_drive_letter(source: str) -> bool:

src/datachain/fs/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import TYPE_CHECKING
23

34
from fsspec.implementations.local import LocalFileSystem
@@ -6,6 +7,23 @@
67
from fsspec import AbstractFileSystem
78

89

10+
def is_subpath(parent: str, child: str) -> bool:
11+
"""True iff *child* is strictly inside *parent* (path-traversal guard).
12+
13+
Both paths must be absolute OS paths. Comparison is case-insensitive on
14+
Windows.
15+
"""
16+
assert os.path.isabs(parent), f"parent must be absolute: {parent!r}"
17+
assert os.path.isabs(child), f"child must be absolute: {child!r}"
18+
19+
parent_normed = os.path.normcase(parent)
20+
child_normed = os.path.normcase(child)
21+
22+
if child_normed == parent_normed:
23+
return False
24+
return child_normed.startswith(parent_normed + os.sep)
25+
26+
927
def _isdir(fs: "AbstractFileSystem", path: str) -> bool:
1028
info = fs.info(path)
1129
return info["type"] == "directory" or (

src/datachain/lib/file.py

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -633,45 +633,84 @@ def export(
633633
link_type: Literal["copy", "symlink"] = "copy",
634634
client_config: dict | None = None,
635635
) -> None:
636-
"""Export file to new location."""
636+
"""Copy or link this file into an output directory.
637+
638+
Args:
639+
output: Destination directory (local path or cloud prefix).
640+
placement: How to build the path under *output*:
641+
642+
- ``"fullpath"`` (default) — ``output/bucket/dir/file.txt``
643+
- ``"filepath"`` — ``output/dir/file.txt``
644+
- ``"filename"`` — ``output/file.txt``
645+
- ``"etag"`` — ``output/<etag>.txt``
646+
use_cache: If True, download to local cache first. Also
647+
required for symlinking remote files.
648+
link_type: ``"copy"`` (default) or ``"symlink"``.
649+
Symlink falls back to copy for virtual files and for
650+
remote files when *use_cache* is False.
651+
client_config: Extra kwargs forwarded to the storage client.
652+
653+
Example:
654+
```py
655+
# flat export by filename
656+
f.export("./export", placement="filename")
657+
658+
# export to a cloud prefix
659+
f.export("s3://output-bucket/results", placement="filepath")
660+
661+
# pass storage credentials via client_config
662+
f.export("s3://bucket/out", client_config={"aws_access_key_id": "…"})
663+
664+
# symlink from local cache (avoids re-downloading)
665+
f.export("./local_out", use_cache=True, link_type="symlink")
666+
```
667+
"""
637668
if self._catalog is None:
638669
raise RuntimeError("Cannot export file: catalog is not set")
639670

640671
self._caching_enabled = use_cache
641672

642673
suffix = self._get_destination_suffix(placement)
643-
# Normalize backslashes to forward slashes so posixpath.join produces
644-
# consistent separator paths on Windows.
645-
output_fspath = os.fspath(output).replace("\\", "/")
646-
dst = posixpath.join(output_fspath, suffix)
647-
dst_dir = posixpath.dirname(dst)
648-
client: Client = self._catalog.get_client(dst_dir, **(client_config or {}))
649-
650-
# If we're exporting to a local directory, ensure the resolved destination
651-
# stays within the chosen output directory. This protects against traversal
652-
# and absolute-path suffixes even when the source key is cloud/opaque.
674+
output_fspath = os.fspath(output)
675+
# On Windows, normalize backslash separators to forward slashes so
676+
# posixpath.join produces consistent paths. On Linux/macOS backslash
677+
# is a legal filename character and must not be replaced.
678+
if os.name == "nt":
679+
output_fspath = output_fspath.replace("\\", "/")
680+
681+
# Resolve output to absolute immediately so all derived paths are
682+
# absolute and don't need further normalization.
683+
output_abs = os.path.abspath(output_fspath)
684+
client: Client = self._catalog.get_client(output_abs, **(client_config or {}))
685+
dst_abs = posixpath.join(output_abs, suffix)
686+
687+
# Traversal safety: for local exports, validate the suffix
653688
if client.PREFIX == "file://":
654689
from datachain.client.local import FileClient
690+
from datachain.fs.utils import is_subpath
655691

656-
FileClient.validate_local_relpath(os.fspath(output), suffix)
692+
try:
693+
FileClient.validate_file_relpath(suffix)
694+
except ValueError as exc:
695+
raise FileError(str(exc), stringify_path(output), suffix) from None
657696

658-
if not FileClient.is_path_in(output, dst):
697+
if not is_subpath(output_abs, dst_abs):
659698
raise FileError(
660699
"destination is not within output directory",
661700
stringify_path(output),
662-
dst,
701+
dst_abs,
663702
)
664703

665-
client.fs.makedirs(dst_dir, exist_ok=True)
704+
client.fs.makedirs(posixpath.dirname(dst_abs), exist_ok=True)
666705

667706
if link_type == "symlink":
668707
try:
669-
return self._symlink_to(dst)
708+
return self._symlink_to(dst_abs)
670709
except OSError as exc:
671710
if exc.errno not in (errno.ENOTSUP, errno.EXDEV, errno.ENOSYS):
672711
raise
673712

674-
self.save(dst, client_config=client_config)
713+
self.save(dst_abs, client_config=client_config)
675714

676715
def _set_stream(
677716
self,

tests/unit/lib/test_file.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,10 @@ def test_export_placements_build_expected_destination_for_each_source(
152152
file.export(output, placement="filepath", use_cache=False)
153153
file.export(output, placement="fullpath", use_cache=False)
154154

155-
# export() normalises backslashes to forward slashes via
156-
# os.fspath(output).replace("\\", "/"), so the expected prefix must match.
157-
expected_prefix = os.fspath(output).replace("\\", "/")
155+
# export() resolves output to absolute (os.path.abspath) and normalizes
156+
# backslashes to forward slashes on Windows, so build the expected prefix
157+
# the same way.
158+
expected_prefix = os.path.abspath(os.fspath(output)).replace("\\", "/")
158159

159160
expected_fullpath_suffix = (
160161
f"{expected_fullpath_prefix}/dir1/dir2/test.txt"
@@ -220,19 +221,19 @@ def test_export_local_output_allows_literal_percent_encoded_traversal(
220221
@pytest.mark.parametrize(
221222
"source,path,is_error,expected",
222223
[
223-
("", "", True, r"path must not be empty"),
224-
("", ".", True, r"path must not contain"),
225-
("", "..", True, r"path must not contain"),
226-
("", "/abs/file.txt", True, r"path must not be absolute"),
224+
("", "", True, r"must not be empty"),
225+
("", ".", True, r"must not contain"),
226+
("", "..", True, r"must not contain"),
227+
("", "/abs/file.txt", True, r"must not be absolute"),
227228
("", "file#hash.txt", False, "file#hash.txt"),
228229
("", "./dir/../file#hash.txt", True, r"must not contain"),
229230
("", "../escape.txt", True, r"must not contain"),
230231
("", "dir//file.txt", True, r"must not contain empty segments"),
231232
("", "dir/", True, r"must not be a directory"),
232-
("file:///bucket", "", True, r"path must not be empty"),
233-
("file:///bucket", ".", True, r"path must not contain"),
234-
("file:///bucket", "..", True, r"path must not contain"),
235-
("file:///bucket", "/abs/file.txt", True, r"path must not be absolute"),
233+
("file:///bucket", "", True, r"must not be empty"),
234+
("file:///bucket", ".", True, r"must not contain"),
235+
("file:///bucket", "..", True, r"must not contain"),
236+
("file:///bucket", "/abs/file.txt", True, r"must not be absolute"),
236237
("file:///bucket", "file#hash.txt", False, "file:///bucket/file#hash.txt"),
237238
(
238239
"file:///bucket",
@@ -306,17 +307,17 @@ def test_get_uri_contract(
306307
@pytest.mark.parametrize(
307308
"source,path,is_error,expected",
308309
[
309-
("", "", True, r"path must not be empty"),
310-
("", ".", True, r"path must not contain"),
311-
("", "..", True, r"path must not contain"),
312-
("", "/abs/file.txt", True, r"path must not be absolute"),
310+
("", "", True, r"must not be empty"),
311+
("", ".", True, r"must not contain"),
312+
("", "..", True, r"must not contain"),
313+
("", "/abs/file.txt", True, r"must not be absolute"),
313314
("", "file.txt", False, "file.txt"),
314315
("", "../escape.txt", True, r"must not contain"),
315316
("", "dir/", True, r"must not be a directory"),
316-
("file:///bucket", "", True, r"path must not be empty"),
317-
("file:///bucket", ".", True, r"path must not contain"),
318-
("file:///bucket", "..", True, r"path must not contain"),
319-
("file:///bucket", "/abs/file.txt", True, r"path must not be absolute"),
317+
("file:///bucket", "", True, r"must not be empty"),
318+
("file:///bucket", ".", True, r"must not contain"),
319+
("file:///bucket", "..", True, r"must not contain"),
320+
("file:///bucket", "/abs/file.txt", True, r"must not be absolute"),
320321
pytest.param(
321322
"file:///bucket",
322323
"file#hash.txt",

0 commit comments

Comments
 (0)