|
| 1 | +"""Component to download a model from HuggingFace Hub to a local directory (PVC).""" |
| 2 | + |
| 3 | +from kfp.dsl import component |
| 4 | + |
| 5 | + |
| 6 | +@component(base_image="python:3.12-slim-bullseye", packages_to_install=["huggingface-hub"]) |
| 7 | +def download_model_from_hf(model_identifier: str, local_model_dir: str = "/models", if_exists: str = "skip"): |
| 8 | + """Downloads a model from HuggingFace Hub to a local directory |
| 9 | +
|
| 10 | + (use mounted PVC path with sufficient storage for larger models). |
| 11 | +
|
| 12 | + Args: |
| 13 | + model_identifier: HuggingFace model identifier (e.g., "Qwen/Qwen3-VL-2B-Instruct") |
| 14 | + local_model_dir: Local directory to save the model files to (default: "/models") |
| 15 | + if_exists: Behavior if files already exist in local_model_dir. Options: ["skip" (default), "overwrite", "error"] |
| 16 | +
|
| 17 | + Environment Variables (Assumed mounted as secret via kfp.kubernetes.use_secret_as_env): |
| 18 | + HUGGINGFACE_TOKEN: HuggingFace token (optional) |
| 19 | + """ |
| 20 | + import os |
| 21 | + import pathlib |
| 22 | + |
| 23 | + from huggingface_hub import snapshot_download |
| 24 | + |
| 25 | + print(f"Checking for existing model files in: {local_model_dir}") |
| 26 | + |
| 27 | + # Check if model files already exist |
| 28 | + output_path = pathlib.Path(local_model_dir) |
| 29 | + existing_files = [] |
| 30 | + if output_path.exists() and output_path.is_dir(): |
| 31 | + existing_files = [f for f in output_path.rglob("*") if f.is_file()] |
| 32 | + |
| 33 | + if existing_files: |
| 34 | + print(f"Found {len(existing_files)} existing files in {local_model_dir}") |
| 35 | + |
| 36 | + if if_exists == "skip": |
| 37 | + print("Skipping download - model files already exist (if_exists='skip')") |
| 38 | + print(f"\nExisting files ({len(existing_files)}):") |
| 39 | + for f in sorted(existing_files): |
| 40 | + rel_path = f.relative_to(output_path) |
| 41 | + size_mb = f.stat().st_size / (1024 * 1024) |
| 42 | + print(f" {rel_path} ({size_mb:.2f} MB)") |
| 43 | + print(f"\nUsing existing model files from: {local_model_dir}") |
| 44 | + return |
| 45 | + elif if_exists == "error": |
| 46 | + raise RuntimeError( |
| 47 | + f"Model files already exist in {local_model_dir}. " |
| 48 | + f"Found {len(existing_files)} files. " |
| 49 | + f"Use if_exists='skip' to use existing files or if_exists='overwrite' to replace them." |
| 50 | + ) |
| 51 | + elif if_exists == "overwrite": |
| 52 | + print("Overwriting existing files (if_exists='overwrite')") |
| 53 | + else: |
| 54 | + raise ValueError( |
| 55 | + f"Invalid value for if_exists: '{if_exists}'. " f"Must be one of: 'skip', 'overwrite', 'error'" |
| 56 | + ) |
| 57 | + else: |
| 58 | + print(f"No existing files found in {local_model_dir}") |
| 59 | + |
| 60 | + print(f"Downloading model: {model_identifier}") |
| 61 | + |
| 62 | + # Download model to output directory |
| 63 | + # This will download all model files (config, tokenizer, model weights, etc.) |
| 64 | + snapshot_download( |
| 65 | + repo_id=model_identifier, |
| 66 | + local_dir=local_model_dir, |
| 67 | + local_dir_use_symlinks=False, |
| 68 | + token=os.getenv("HUGGINGFACE_TOKEN"), |
| 69 | + ) |
| 70 | + |
| 71 | + # Verify downloaded files |
| 72 | + output_path = pathlib.Path(local_model_dir) |
| 73 | + if output_path.exists(): |
| 74 | + files = list(output_path.rglob("*")) |
| 75 | + files = [f for f in files if f.is_file()] |
| 76 | + print(f"\nDownloaded {len(files)} files:") |
| 77 | + for f in sorted(files): |
| 78 | + # Show relative path from output_dir.path |
| 79 | + rel_path = f.relative_to(output_path) |
| 80 | + size_mb = f.stat().st_size / (1024 * 1024) |
| 81 | + print(f" {rel_path} ({size_mb:.2f} MB)") |
| 82 | + else: |
| 83 | + raise RuntimeError(f"Output directory {local_model_dir} was not created") |
| 84 | + |
| 85 | + print(f"\nModel download complete. Files saved to: {local_model_dir}") |
| 86 | + |
| 87 | + |
| 88 | +if __name__ == "__main__": |
| 89 | + # compile the component |
| 90 | + from kfp.compiler import Compiler |
| 91 | + |
| 92 | + compiler = Compiler() |
| 93 | + compiler.compile(download_model_from_hf, package_path="download_model_from_hf.yaml") |
0 commit comments