diff --git a/torchx/specs/api.py b/torchx/specs/api.py index 181546bea..58dd3e9e1 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -10,6 +10,7 @@ import copy import json import re +import typing from dataclasses import asdict, dataclass, field from datetime import datetime from enum import Enum @@ -189,8 +190,25 @@ def apply(self, role: "Role") -> "Role": role = copy.deepcopy(role) role.args = [self.substitute(arg) for arg in role.args] role.env = {key: self.substitute(arg) for key, arg in role.env.items()} + role.metadata = self._apply_nested(role.metadata) + return role + def _apply_nested(self, d: typing.Dict[str, Any]) -> typing.Dict[str, Any]: + stack = [d] + while stack: + current_dict = stack.pop() + for k, v in current_dict.items(): + if isinstance(v, dict): + stack.append(v) + elif isinstance(v, str): + current_dict[k] = self.substitute(v) + elif isinstance(v, list): + for i in range(len(v)): + if isinstance(v[i], str): + v[i] = self.substitute(v[i]) + return d + def substitute(self, arg: str) -> str: """ substitute applies the values to the template arg.