Skip to content

Commit 023fae6

Browse files
committed
Change all approach
1 parent 2fe7cc3 commit 023fae6

9 files changed

Lines changed: 303 additions & 130 deletions

File tree

src/datachain/cli/__init__.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from multiprocessing import freeze_support
66
from typing import TYPE_CHECKING, Optional
77

8-
from datachain.cli.commands.storages import cp_storage
98
from datachain.cli.utils import get_logging_level
109

1110
from .commands import (
@@ -26,6 +25,8 @@
2625
logger = logging.getLogger("datachain")
2726

2827
if TYPE_CHECKING:
28+
from argparse import Namespace
29+
2930
from datachain.catalog import Catalog
3031

3132

@@ -82,7 +83,6 @@ def main(argv: Optional[list[str]] = None) -> int:
8283

8384
def handle_command(args, catalog, client_config) -> int:
8485
"""Handle the different CLI commands."""
85-
from datachain.cli.commands.storages import mv_storage, rm_storage
8686
from datachain.studio import process_auth_cli_args, process_jobs_args
8787

8888
command_handlers = {
@@ -101,8 +101,8 @@ def handle_command(args, catalog, client_config) -> int:
101101
"gc": lambda: garbage_collect(catalog),
102102
"auth": lambda: process_auth_cli_args(args),
103103
"job": lambda: process_jobs_args(args),
104-
"mv": lambda: mv_storage(args),
105-
"rm": lambda: rm_storage(args),
104+
"mv": lambda: handle_mv_command(args, catalog),
105+
"rm": lambda: handle_rm_command(args, catalog),
106106
}
107107

108108
handler = command_handlers.get(args.command)
@@ -115,23 +115,36 @@ def handle_command(args, catalog, client_config) -> int:
115115
return 1
116116

117117

118-
def handle_cp_command(args, catalog: "Catalog"):
118+
def _get_storage_implementation(args: "Namespace", catalog: "Catalog"):
119+
from datachain.cli.commands.storage import (
120+
LocalStorageImplementation,
121+
StudioStorageImplementation,
122+
)
119123
from datachain.config import Config
120124

121125
config = Config().read().get("studio", {})
122126
token = config.get("token")
123-
local = True if not token else args.local
124-
if local:
125-
return catalog.cp(
126-
[args.source_path],
127-
args.destination_path,
128-
force=bool(args.force),
129-
update=bool(args.update),
130-
recursive=bool(args.recursive),
131-
no_glob=args.no_glob,
132-
)
127+
studio = False if not token else args.studio_cloud_auth
128+
return (
129+
StudioStorageImplementation(args, catalog)
130+
if studio
131+
else LocalStorageImplementation(args, catalog)
132+
)
133+
134+
135+
def handle_cp_command(args, catalog):
136+
storage_implementation = _get_storage_implementation(args, catalog)
137+
return storage_implementation.cp()
138+
139+
140+
def handle_mv_command(args, catalog):
141+
storage_implementation = _get_storage_implementation(args, catalog)
142+
return storage_implementation.mv()
143+
133144

134-
return cp_storage(args)
145+
def handle_rm_command(args, catalog):
146+
storage_implementation = _get_storage_implementation(args, catalog)
147+
return storage_implementation.rm()
135148

136149

137150
def handle_clone_command(args, catalog):
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .local import LocalStorageImplementation
2+
from .studio import StudioStorageImplementation
3+
from .utils import build_file_paths, validate_upload_args
4+
5+
__all__ = [
6+
"LocalStorageImplementation",
7+
"StudioStorageImplementation",
8+
"build_file_paths",
9+
"validate_upload_args",
10+
]
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from typing import TYPE_CHECKING, Optional
2+
3+
from datachain.error import DataChainError
4+
5+
if TYPE_CHECKING:
6+
from argparse import Namespace
7+
8+
from fsspec import AbstractFileSystem
9+
10+
from datachain.catalog import Catalog
11+
from datachain.client.fsspec import Client
12+
from datachain.remote.studio import StudioClient
13+
14+
15+
class StorageImplementation:
16+
def __init__(self, args: "Namespace", catalog: "Catalog"):
17+
self.args = args
18+
self.catalog = catalog
19+
20+
def rm(self):
21+
raise NotImplementedError("Remove is not implemented")
22+
23+
def mv(self):
24+
raise NotImplementedError("Move is not implemented")
25+
26+
def cp(self):
27+
from datachain.client.fsspec import Client
28+
29+
source_cls = Client.get_implementation(self.args.source_path)
30+
destination_cls = Client.get_implementation(self.args.destination_path)
31+
32+
if source_cls.protocol == "file" and destination_cls.protocol == "file":
33+
self.copy_local_to_local(source_cls)
34+
elif source_cls.protocol == "file":
35+
self.upload_to_remote(source_cls, destination_cls)
36+
elif destination_cls.protocol == "file":
37+
self.download_from_remote(destination_cls)
38+
else:
39+
self.copy_remote_to_remote(source_cls)
40+
41+
def copy_local_to_local(self, source_cls: "Client"):
42+
source_fs = source_cls.create_fs()
43+
source_fs.copy(
44+
self.args.source_path,
45+
self.args.destination_path,
46+
recursive=self.args.recursive,
47+
)
48+
print(f"Copied {self.args.source_path} to {self.args.destination_path}")
49+
50+
def upload_to_remote(self, source_cls: "Client", destination_cls: "Client"):
51+
raise NotImplementedError("Upload to remote is not implemented")
52+
53+
def download_from_remote(self, destination_cls: "Client"):
54+
raise NotImplementedError("Download from remote is not implemented")
55+
56+
def copy_remote_to_remote(self, source_cls: "Client"):
57+
raise NotImplementedError("Copy remote to remote is not implemented")
58+
59+
def save_upload_log(
60+
self,
61+
studio_client: Optional["StudioClient"],
62+
destination_path: str,
63+
file_paths: dict,
64+
local_fs: "AbstractFileSystem",
65+
):
66+
from datachain.remote.storages import get_studio_client
67+
68+
try:
69+
if studio_client is None:
70+
studio_client = get_studio_client(self.args)
71+
except DataChainError:
72+
return
73+
74+
uploads = [
75+
{
76+
"path": dst,
77+
"size": local_fs.info(src).get("size", 0),
78+
}
79+
for dst, src in file_paths.items()
80+
]
81+
studio_client.save_upload_log(destination_path, uploads)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import TYPE_CHECKING
2+
3+
from datachain.cli.commands.storage.base import StorageImplementation
4+
from datachain.cli.commands.storage.utils import build_file_paths, validate_upload_args
5+
6+
if TYPE_CHECKING:
7+
from datachain.client.fsspec import Client
8+
9+
10+
class LocalStorageImplementation(StorageImplementation):
11+
def upload_to_remote(self, source_cls: "Client", destination_cls: "Client"):
12+
from tqdm import tqdm
13+
14+
source_fs = source_cls.create_fs()
15+
destination_fs = destination_cls.create_fs()
16+
is_dir = validate_upload_args(self.args, source_fs)
17+
file_paths = build_file_paths(self.args, source_fs, is_dir)
18+
destination_bucket, _ = destination_cls.parse_url(self.args.destination_path)
19+
20+
for dest_path, source_path in file_paths.items():
21+
file_size = source_fs.info(source_path)["size"]
22+
print(f"Uploading {source_path} to {dest_path}")
23+
24+
with tqdm(total=file_size, unit="B", unit_scale=True) as pbar:
25+
with source_fs.open(source_path, "rb") as source_file:
26+
with destination_fs.open(
27+
f"{destination_bucket}/{dest_path}", "wb"
28+
) as dest_file:
29+
while True:
30+
chunk = source_file.read(8192)
31+
if not chunk:
32+
break
33+
dest_file.write(chunk)
34+
pbar.update(len(chunk))
35+
36+
self.save_upload_log(None, self.args.destination_path, file_paths, source_fs)
37+
38+
def download_from_remote(self, destination_cls: "Client"):
39+
self.catalog.cp(
40+
[self.args.source_path],
41+
self.args.destination_path,
42+
force=bool(self.args.force),
43+
update=bool(self.args.update),
44+
recursive=bool(self.args.recursive),
45+
no_glob=self.args.no_glob,
46+
)
47+
48+
def copy_remote_to_remote(self, source_cls: "Client"):
49+
source_fs = source_cls.create_fs()
50+
source_fs.copy(
51+
self.args.source_path,
52+
self.args.destination_path,
53+
recursive=self.args.recursive,
54+
)
55+
56+
def rm(self):
57+
from datachain.client.fsspec import Client
58+
59+
client_cls = Client.get_implementation(self.args.path)
60+
fs = client_cls.create_fs()
61+
fs.rm(self.args.path, recursive=self.args.recursive)
62+
# TODO: Add storage logging.
63+
64+
def mv(self):
65+
from datachain.client.fsspec import Client
66+
67+
client_cls = Client.get_implementation(self.args.path)
68+
fs = client_cls.create_fs()
69+
fs.mv(self.args.path, self.args.new_path, recursive=self.args.recursive)
70+
# TODO: Add storage logging.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from typing import TYPE_CHECKING
2+
3+
from datachain.cli.commands.storage.base import StorageImplementation
4+
from datachain.error import DataChainError
5+
6+
if TYPE_CHECKING:
7+
from datachain.client.fsspec import Client
8+
9+
10+
class StudioStorageImplementation(StorageImplementation):
11+
def upload_to_remote(self, source_cls: "Client", destination_cls: "Client"):
12+
from datachain.remote.storages import upload_to_storage
13+
14+
source_fs = source_cls.create_fs()
15+
file_paths = upload_to_storage(self.args, source_fs)
16+
self.save_upload_log(None, self.args.destination_path, file_paths, source_fs)
17+
18+
def download_from_remote(self, destination_cls: "Client"):
19+
from datachain.remote.storages import download_from_storage
20+
21+
destination_fs = destination_cls.create_fs()
22+
download_from_storage(self.args, destination_fs)
23+
24+
def copy_remote_to_remote(self, source_cls: "Client"):
25+
from datachain.remote.storages import copy_inside_storage
26+
27+
copy_inside_storage(self.args)
28+
29+
def rm(self):
30+
from datachain.remote.storages import get_studio_client
31+
32+
client = get_studio_client(self.args)
33+
34+
response = client.delete_storage_file(
35+
self.args.path,
36+
recursive=self.args.recursive,
37+
)
38+
if not response.ok:
39+
raise DataChainError(response.message)
40+
41+
print(f"Deleted {self.args.path}")
42+
43+
def mv(self):
44+
from datachain.remote.storages import get_studio_client
45+
46+
client = get_studio_client(self.args)
47+
48+
response = client.move_storage_file(
49+
self.args.path,
50+
self.args.new_path,
51+
recursive=self.args.recursive,
52+
)
53+
if not response.ok:
54+
raise DataChainError(response.message)
55+
56+
print(f"Moved {self.args.path} to {self.args.new_path}")
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import os.path
2+
from typing import TYPE_CHECKING
3+
4+
if TYPE_CHECKING:
5+
from argparse import Namespace
6+
7+
from fsspec.spec import AbstractFileSystem
8+
9+
from datachain.error import DataChainError
10+
11+
12+
def validate_upload_args(args: "Namespace", local_fs: "AbstractFileSystem"):
13+
"""Validate upload arguments and raise appropriate errors."""
14+
is_dir = local_fs.isdir(args.source_path)
15+
if is_dir and not args.recursive:
16+
raise DataChainError("Cannot copy directory without --recursive")
17+
return is_dir
18+
19+
20+
def build_file_paths(args: "Namespace", local_fs: "AbstractFileSystem", is_dir: bool):
21+
"""Build mapping of destination paths to source paths."""
22+
from datachain.client.fsspec import Client
23+
24+
client = Client.get_implementation(args.destination_path)
25+
_, subpath = client.split_url(args.destination_path)
26+
27+
if is_dir:
28+
return {
29+
os.path.join(subpath, os.path.relpath(path, args.source_path)): path
30+
for path in local_fs.find(args.source_path)
31+
}
32+
33+
destination_path = (
34+
os.path.join(subpath, os.path.basename(args.source_path))
35+
if args.destination_path.endswith(("/", "\\")) or not subpath
36+
else subpath
37+
)
38+
return {destination_path: args.source_path}

0 commit comments

Comments
 (0)