diff --git a/torchx/specs/finder.py b/torchx/specs/finder.py index 4258ec7a7..66900e7a4 100644 --- a/torchx/specs/finder.py +++ b/torchx/specs/finder.py @@ -274,12 +274,23 @@ def _get_validation_errors( linter_errors = validate(path, function_name, validators) return [linter_error.description for linter_error in linter_errors] + def _get_path_to_function_decl( + self, function: Callable[..., Any] # pyre-ignore[2] + ) -> str: + """ + Attempts to return the path to the file where the function is implemented. + This can be different from the path where the function is looked up, for example if we have: + my_component defined in some_file.py, imported in other_file.py + and the component is invoked as other_file.py:my_component + """ + path_to_function_decl = inspect.getabsfile(function) + if path_to_function_decl is None or not os.path.isfile(path_to_function_decl): + return self._filepath + return path_to_function_decl + def find( self, validators: Optional[List[TorchxFunctionValidator]] ) -> List[_Component]: - validation_errors = self._get_validation_errors( - self._filepath, self._function_name, validators - ) file_source = read_conf_file(self._filepath) namespace = copy.copy(globals()) @@ -292,6 +303,12 @@ def find( ) app_fn = namespace[self._function_name] fn_desc, _ = get_fn_docstring(app_fn) + + func_path = self._get_path_to_function_decl(app_fn) + validation_errors = self._get_validation_errors( + func_path, self._function_name, validators + ) + return [ _Component( name=f"{self._filepath}:{self._function_name}", diff --git a/torchx/specs/test/finder_test.py b/torchx/specs/test/finder_test.py index 69dfecf01..18f01b4c5 100644 --- a/torchx/specs/test/finder_test.py +++ b/torchx/specs/test/finder_test.py @@ -29,6 +29,7 @@ get_components, ModuleComponentsFinder, ) +from torchx.specs.test.components.a import comp_a from torchx.util.test.entrypoints_test import EntryPoint_from_text from torchx.util.types import none_throws @@ -238,6 +239,10 @@ def test_get_component_invalid(self) -> None: with self.assertRaises(ComponentValidationException): get_component(f"{current_file_path()}:invalid_component") + def test_get_component_imported_from_other_file(self) -> None: + component = get_component(f"{current_file_path()}:comp_a") + self.assertListEqual([], component.validation_errors) + class GetBuiltinSourceTest(unittest.TestCase): def setUp(self) -> None: