-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathsplit_dataset.py
More file actions
101 lines (84 loc) · 2.45 KB
/
split_dataset.py
File metadata and controls
101 lines (84 loc) · 2.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
import sys
from pathlib import Path
import torch
from configs import LDMDatasetConfig
from datasets.dataset_utils import save_dataset_charset
from datasets.loader import Loader
from utils.argparse.argparse_utils import update_config_from_args
from utils.hardware.hardware_utils import select_device
from utils.validation.project_validator import ProjectValidationError, ProjectValidator
def parse_args() -> argparse.Namespace:
"""
Parse command-line arguments for charset extraction.
"""
parser = argparse.ArgumentParser(
description="Extract train/val charset from dataset",
)
parser.add_argument(
"--target_font_path",
type=str,
help="Target font path",
)
parser.add_argument(
"--split_ratios",
type=float,
nargs=2,
help="Train/val split ratios",
)
parser.add_argument(
"--split_random_seed",
type=int,
help="Split random seed",
)
parser.add_argument(
"--device",
type=str,
help="Training device (mps, cpu, cuda)",
)
return parser.parse_args()
def split_and_extract_charset(
target_font_path: str,
dataset_config: LDMDatasetConfig,
device: torch.device,
) -> None:
"""
Extract train/val charset from dataset and save to files.
"""
loader = Loader.from_dataset_config(
dataset_config=dataset_config,
device=device,
)
save_dataset_charset(
train_loader=loader.loader.train,
val_loader=loader.loader.val,
target_font_path=target_font_path,
charset_root=dataset_config.splits_root,
)
def main() -> None:
"""
Main function to run the charset extraction process.
"""
try:
args = parse_args()
ProjectValidator.validate_font_file(
file_path=args.target_font_path,
name="Target font",
)
dataset_config = update_config_from_args(
converting_config=LDMDatasetConfig(),
args=args,
)
device = select_device(args.device)
split_and_extract_charset(
target_font_path=args.target_font_path,
dataset_config=dataset_config,
device=device,
)
print("✅ Charset splitting completed successfully")
except ProjectValidationError as e:
print(f"❌ {e}")
print("❌ Charset splitting failed")
sys.exit(1)
if __name__ == "__main__":
main()