1010from src .models .registry import MODELS
1111from src .trainers .registry import TRAINERS
1212
13- # Configuration constants (simplified)
1413DERIVED_CFG_KEYS = {'input_shape' , 'num_classes' }
1514REQUIRED_KEYS = {
1615 'dataset' ,
2019 'base_lr' ,
2120 'batch_size' ,
2221 'epochs' ,
23- 'seed' , # a default seed (can be overridden by sweep dimension `seed` list)
22+ 'seed' ,
2423 'num_workers' ,
2524}
2625KEY_ORDER = [
4039# Expansion constants
4140LOGGING_LIST_KEYS = {'batch_metrics' } # list-valued keys not treated as sweep dimensions
4241HASH_IGNORE_KEYS = DERIVED_CFG_KEYS | LOGGING_LIST_KEYS
43- SEED_KEY = 'seed' # may be scalar or list (list => sweep over seeds)
4442
4543
4644@dataclass
@@ -110,7 +108,7 @@ def cfg_hash(cfg: dict) -> str:
110108 configuration, with mapping keys sorted and without ignored keys.
111109 """
112110 norm = normalize_resolved_cfg (cfg )
113- data = json .dumps (norm , sort_keys = True , separators = ("," , ":" ), ensure_ascii = False , allow_nan = False )
111+ data = json .dumps (norm , sort_keys = True , separators = (',' , ':' ), ensure_ascii = False , allow_nan = False )
114112 return hashlib .sha256 (data .encode ('utf-8' )).hexdigest ()
115113
116114
@@ -120,7 +118,8 @@ def _extract_path_info(path: Path) -> tuple[str, str, str, str]:
120118 Expected pattern (after first 'cfg' segment):
121119 cfg/<mode>/<dataset>/<model>/<trainer>/.../cfg.yaml
122120
123- <mode> must currently be one of {'sweep', 'experiment'}.
121+ The `<mode>` segment is a free-form folder namespace (e.g., 'sweep',
122+ 'experiment', 'sweep_1'). Expansion semantics do not depend on this value.
124123 """
125124 parts = list (path .parts )
126125 try :
@@ -131,17 +130,9 @@ def _extract_path_info(path: Path) -> tuple[str, str, str, str]:
131130 raise ValueError (f'Invalid cfg path (no cfg segment): { path } ' ) from e
132131
133132 tail = parts [cfg_idx + 1 :]
134- # Legacy directory prefixes ('tune', 'experiments') are no longer supported and not skipped.
135- # 'sweep' and 'experiment' are treated as real dataset names now, so we do not strip anything.
136- if tail and tail [0 ] in {'tune' , 'experiments' }:
137- raise ValueError (
138- f"Legacy path prefix no longer supported (remove 'tune/' or 'experiments/' from cfg path): { path } "
139- )
140133 if len (tail ) < 4 :
141134 raise ValueError (f'Insufficient path depth for mode/dataset/model/trainer: { path } ' )
142135 mode , dataset , model , trainer = tail [:4 ]
143- if mode not in {'sweep' , 'experiment' }:
144- raise ValueError (f"Unsupported mode '{ mode } ' in cfg path (expected 'sweep' or 'experiment'): { path } " )
145136 return mode , dataset , model , Path (trainer ).stem
146137
147138
@@ -181,27 +172,17 @@ def _dimension_items(cfg: dict) -> list[tuple[str, list]]:
181172 return dims
182173
183174
184- def expand_cfg_from_dict (cfg : dict , mode : str ) -> list [TrialSpec ]:
175+ def expand_cfg_from_dict (cfg : dict ) -> list [TrialSpec ]:
185176 """Expand a (possibly merged) cfg dict into trial specs.
186177
187- For 'sweep' mode: all list-valued keys (except structural allowlist) become
188- dimensions. 'seed' must be scalar.
189- For 'experiment' mode: only 'seed' may be list-valued (creating per-seed trials).
178+ Unified behavior (mode-agnostic):
179+ - Every list-valued key at the top level (except those in LOGGING_LIST_KEYS)
180+ is treated as a sweep dimension.
181+ - 'seed' may be scalar or list; if list, it is simply another dimension.
182+ - If no list-valued keys are present, produce a single TrialSpec.
190183 """
191- dims : list [tuple [str , list ]] = []
192- for k , v in cfg .items ():
193- if isinstance (v , list ) and k not in LOGGING_LIST_KEYS :
194- if mode == 'sweep' :
195- if k == SEED_KEY :
196- raise ValueError ("sweep cfg.yaml must not contain list-valued 'seed'; provide a single scalar seed" )
197- dims .append ((k , v ))
198- elif mode == 'experiment' :
199- if k == SEED_KEY :
200- dims .append ((k , v ))
201- else : # pragma: no cover - defensive, normally enforced earlier
202- raise ValueError ('experiment cfg.yaml may only vary seed; offending list key: ' + k )
203- else : # pragma: no cover - future mode types
204- raise ValueError (f'Unsupported mode for expansion: { mode } ' )
184+ # Identify dimensions using the shared helper for a single source of truth
185+ dims = _dimension_items (cfg )
205186
206187 if not dims :
207188 return [TrialSpec (idx = 1 , assignments = {})]
@@ -215,9 +196,9 @@ def expand_cfg_from_dict(cfg: dict, mode: str) -> list[TrialSpec]:
215196def create_trials (base_cfg_path : str | Path ) -> list [Path ]:
216197 """Write resolved trial cfgs beneath out/log/<mode>/<dataset>/<model>/<trainer>/trial_###/cfg.yaml.
217198
218- Mode-specific rules:
219- - sweep: if 'seed' provided it must be scalar (not list); other list hyperparams allowed.
220- - experiment: only 'seed' may be list-valued (aside from structural allowlist) .
199+ Expansion uses unified rules (see `expand_cfg_from_dict`): all list-valued
200+ top-level keys except logging lists become sweep dimensions. Resolved cfgs
201+ must not contain any remaining list-valued hyperparameters .
221202 """
222203 path = Path (base_cfg_path )
223204 if path .name != 'cfg.yaml' :
@@ -234,7 +215,7 @@ def create_trials(base_cfg_path: str | Path) -> list[Path]:
234215 # properly swept. Previously only the leaf cfg file was inspected which
235216 # caused inherited lists (e.g. batch_size) to remain unresolved in the
236217 # written trial cfgs.
237- trials = expand_cfg_from_dict (merged_base , mode )
218+ trials = expand_cfg_from_dict (merged_base )
238219 # Root output directory for this cfg tree
239220 out_root = Path ('out' ) / 'log' / mode / dataset / model / trainer
240221 out_root .mkdir (parents = True , exist_ok = True )
@@ -269,21 +250,19 @@ def _parse_idx(p: Path) -> int:
269250 resolved = merged_base .copy ()
270251 resolved .update (spec .assignments ) # apply concrete assignments
271252 # Ensure that any list-valued dimension keys have been resolved to scalars
272- for k , v in list (resolved .items ()):
273- if isinstance (v , list ) and k not in LOGGING_LIST_KEYS :
274- # If still a list here, it was not selected as a dimension (e.g. leftover due to mode rules)
275- # For experiment mode only 'seed' may remain list; others are invalid.
276- if mode == 'experiment' and k == SEED_KEY :
277- continue
278- raise ValueError (
279- f'Unresolved list-valued hyperparameter { k !r} remained in resolved cfg. Check expansion rules.'
280- )
253+ unresolved_dims = _dimension_items (resolved )
254+ if unresolved_dims :
255+ # With unified expansion, no list-valued hyperparameters should remain.
256+ names = [name for name , _ in unresolved_dims ]
257+ raise ValueError (
258+ f'Unresolved list-valued hyperparameters remained in resolved cfg: { names } . Check expansion rules.'
259+ )
281260 # Compute identity hash for de-duplication
282261 new_hash = cfg_hash (resolved )
283262
284263 # If this configuration already exists anywhere under out_root, reuse that path
285264 if new_hash in hash_to_dir :
286- existing_path = ( hash_to_dir [new_hash ] / 'cfg.yaml' )
265+ existing_path = hash_to_dir [new_hash ] / 'cfg.yaml'
287266 written .append (existing_path )
288267 continue
289268
@@ -356,23 +335,14 @@ def get_output_path(cfg_path: str | Path) -> Path:
356335 raise ValueError ('Only resolved trial cfg paths are supported.' )
357336
358337
359- def _validate_expansion_rules (_ : dict ) -> None : # pragma: no cover - retained for potential future rules
360- return
361-
362-
363338def has_cfg_been_run (cfg_path : str | Path ) -> tuple [bool , str ]:
364339 """A trial is considered run if its cfg exists and batch_log.csv is present."""
365340 try :
366341 cfg_path = Path (cfg_path )
367342 if not cfg_path .exists ():
368343 return False , 'Resolved cfg missing'
369- with cfg_path .open () as f :
370- resolved = yaml .safe_load (f ) or {}
371344 if not (cfg_path .parent / 'batch_log.csv' ).exists ():
372345 return False , 'No batch_log.csv'
373- # Lightweight validation of registries
374- if not cfgs_equal (resolved , resolved ): # always true, placeholder for future diff
375- return False , 'Internal mismatch'
376346 return True , 'Outputs present'
377- except Exception as e : # noqa: BLE001
347+ except Exception as e :
378348 return False , f'Error checking cfg: { e } '
0 commit comments