From d0e27049d2621990dd09d3654ba8f44ccc0260ea Mon Sep 17 00:00:00 2001 From: Isha Devyani Chirimar Date: Tue, 16 Jul 2024 18:53:39 -0700 Subject: [PATCH] Add functionality to load scheduler cfg from a json string Summary: add a function `cfg_from_json_repr()` to create scheduler runopts from a json string representation. currently, there is only `cfg_from_str()` which takes a string in the format "k=v,k1=v1" (designed for cmd line). however, the runcfg is stored in scuba logs in a json string representation. so in order to read it back when cloning jobs we need to read from json string. Differential Revision: D59832084 --- torchx/specs/api.py | 17 +++++++++++++++++ torchx/specs/test/api_test.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/torchx/specs/api.py b/torchx/specs/api.py index ba084ffe7..8b3b947a2 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -913,6 +913,23 @@ def _cast_to_type(value: str, opt_type: Type[CfgVal]) -> CfgVal: cfg[key] = _cast_to_type(val, runopt_.opt_type) return cfg + def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]: + """ + Converts the given dict to a valid cfg for this ``runopts`` object. + """ + cfg: Dict[str, CfgVal] = {} + cfg_dict = json.loads(json_repr) + for key, val in cfg_dict.items(): + runopt_ = self.get(key) + if runopt_: + if runopt_.opt_type == List[str]: + cfg[key] = [str(v) for v in val] + elif runopt_.opt_type == Dict[str, str]: + cfg[key] = {str(k): str(v) for k, v in val.items()} + else: + cfg[key] = val + return cfg + def add( self, cfg_key: str, diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index af8e34b49..3ff5caea1 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -485,6 +485,41 @@ def test_resolve_from_str(self) -> None: ), ), + def test_config_from_json_repr(self) -> None: + opts = runopts() + opts.add("foo", type_=str, default="", help="") + opts.add("test_key", type_=str, default="", help="") + opts.add("default_time", type_=int, default=0, help="") + opts.add("enable", type_=bool, default=True, help="") + opts.add("disable", type_=bool, default=True, help="") + opts.add("complex_list", type_=List[str], default=[], help="") + opts.add("complex_dict", type_=Dict[str, str], default={}, help="") + + self.assertDictEqual( + { + "foo": "bar", + "test_key": "test_value", + "default_time": 42, + "enable": True, + "disable": False, + "complex_list": ["v1", "v2", "v3"], + "complex_dict": {"k1": "v1", "k2": "v2"}, + }, + opts.resolve( + opts.cfg_from_json_repr( + """{ + "foo": "bar", + "test_key": "test_value", + "default_time": 42, + "enable": true, + "disable": false, + "complex_list": ["v1", "v2", "v3"], + "complex_dict": {"k1": "v1", "k2": "v2"} + }""" + ) + ), + ) + def test_runopts_is_type(self) -> None: # primitive types self.assertTrue(runopts.is_type(3, int))