Skip to content

Commit 7df8ea5

Browse files
authored
Add functionality to load scheduler cfg from a json string
Differential Revision: D59832084 Pull Request resolved: #933
1 parent 512e40d commit 7df8ea5

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

torchx/specs/api.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,23 @@ def _cast_to_type(value: str, opt_type: Type[CfgVal]) -> CfgVal:
913913
cfg[key] = _cast_to_type(val, runopt_.opt_type)
914914
return cfg
915915

916+
def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
917+
"""
918+
Converts the given dict to a valid cfg for this ``runopts`` object.
919+
"""
920+
cfg: Dict[str, CfgVal] = {}
921+
cfg_dict = json.loads(json_repr)
922+
for key, val in cfg_dict.items():
923+
runopt_ = self.get(key)
924+
if runopt_:
925+
if runopt_.opt_type == List[str]:
926+
cfg[key] = [str(v) for v in val]
927+
elif runopt_.opt_type == Dict[str, str]:
928+
cfg[key] = {str(k): str(v) for k, v in val.items()}
929+
else:
930+
cfg[key] = val
931+
return cfg
932+
916933
def add(
917934
self,
918935
cfg_key: str,

torchx/specs/test/api_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,41 @@ def test_resolve_from_str(self) -> None:
485485
),
486486
),
487487

488+
def test_config_from_json_repr(self) -> None:
489+
opts = runopts()
490+
opts.add("foo", type_=str, default="", help="")
491+
opts.add("test_key", type_=str, default="", help="")
492+
opts.add("default_time", type_=int, default=0, help="")
493+
opts.add("enable", type_=bool, default=True, help="")
494+
opts.add("disable", type_=bool, default=True, help="")
495+
opts.add("complex_list", type_=List[str], default=[], help="")
496+
opts.add("complex_dict", type_=Dict[str, str], default={}, help="")
497+
498+
self.assertDictEqual(
499+
{
500+
"foo": "bar",
501+
"test_key": "test_value",
502+
"default_time": 42,
503+
"enable": True,
504+
"disable": False,
505+
"complex_list": ["v1", "v2", "v3"],
506+
"complex_dict": {"k1": "v1", "k2": "v2"},
507+
},
508+
opts.resolve(
509+
opts.cfg_from_json_repr(
510+
"""{
511+
"foo": "bar",
512+
"test_key": "test_value",
513+
"default_time": 42,
514+
"enable": true,
515+
"disable": false,
516+
"complex_list": ["v1", "v2", "v3"],
517+
"complex_dict": {"k1": "v1", "k2": "v2"}
518+
}"""
519+
)
520+
),
521+
)
522+
488523
def test_runopts_is_type(self) -> None:
489524
# primitive types
490525
self.assertTrue(runopts.is_type(3, int))

0 commit comments

Comments
 (0)