diff --git a/dev-requirements.txt b/dev-requirements.txt index 474edd3bb..f6dec1675 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -30,6 +30,7 @@ torchmetrics==1.6.3 torchserve>=0.10.0 torchtext==0.18.0 torchvision==0.22.0 +typing-extensions ts==0.5.1 ray[default] wheel diff --git a/requirements.txt b/requirements.txt index 9c2d29854..51ecf3fb1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,4 @@ -pyre-extensions docstring-parser>=0.8.1 -importlib-metadata pyyaml docker filelock diff --git a/torchx/components/structured_arg.py b/torchx/components/structured_arg.py index 29a26ea1b..638a7ff9d 100644 --- a/torchx/components/structured_arg.py +++ b/torchx/components/structured_arg.py @@ -30,8 +30,6 @@ from pathlib import Path from typing import Optional -from pyre_extensions import none_throws - from torchx import specs @@ -148,7 +146,8 @@ def parse_from( if m: # use the last module name run_name = m.rpartition(".")[2] else: # use script name w/ no extension - run_name = Path(none_throws(script)).stem + assert script, "`script` can't be `None` here due checks above" + run_name = Path(script).stem return StructuredNameArgument( experiment_name or default_experiment_name, run_name ) diff --git a/torchx/runner/api.py b/torchx/runner/api.py index 1731c4657..32f9a5314 100644 --- a/torchx/runner/api.py +++ b/torchx/runner/api.py @@ -14,7 +14,18 @@ import warnings from datetime import datetime from types import TracebackType -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, TypeVar +from typing import ( + Any, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Type, + TYPE_CHECKING, + TypeVar, +) from torchx.runner.events import log_event from torchx.schedulers import get_scheduler_factories, SchedulerFactory @@ -43,7 +54,9 @@ from torchx.util.types import none_throws from torchx.workspace.api import PkgInfo, WorkspaceBuilder, WorkspaceMixin -from typing_extensions import Self + +if TYPE_CHECKING: + from typing_extensions import Self from .config import get_config, get_configs @@ -121,7 +134,7 @@ def _get_scheduler_params_from_env(self) -> Dict[str, str]: scheduler_params[lower_case_key.strip("torchx_")] = value return scheduler_params - def __enter__(self) -> Self: + def __enter__(self) -> "Self": return self def __exit__( diff --git a/torchx/schedulers/__init__.py b/torchx/schedulers/__init__.py index 23af81d4e..c48cebaeb 100644 --- a/torchx/schedulers/__init__.py +++ b/torchx/schedulers/__init__.py @@ -8,11 +8,10 @@ # pyre-strict import importlib -from typing import Dict, Mapping +from typing import Mapping, Protocol from torchx.schedulers.api import Scheduler from torchx.util.entrypoints import load_group -from typing_extensions import Protocol DEFAULT_SCHEDULER_MODULES: Mapping[str, str] = { "local_docker": "torchx.schedulers.docker_scheduler", @@ -44,7 +43,7 @@ def run(*args: object, **kwargs: object) -> Scheduler: def get_scheduler_factories( group: str = "torchx.schedulers", skip_defaults: bool = False -) -> Dict[str, SchedulerFactory]: +) -> dict[str, SchedulerFactory]: """ get_scheduler_factories returns all the available schedulers names under `group` and the method to instantiate them. @@ -52,7 +51,7 @@ def get_scheduler_factories( The first scheduler in the dictionary is used as the default scheduler. """ - default_schedulers: Dict[str, SchedulerFactory] = {} + default_schedulers: dict[str, SchedulerFactory] = {} for scheduler, path in DEFAULT_SCHEDULER_MODULES.items(): default_schedulers[scheduler] = _defer_load_scheduler(path) diff --git a/torchx/schedulers/api.py b/torchx/schedulers/api.py index 359390a87..48ca64849 100644 --- a/torchx/schedulers/api.py +++ b/torchx/schedulers/api.py @@ -16,7 +16,6 @@ from torchx.specs import ( AppDef, - AppDryRunInfo, AppState, NONE, NULL_RESOURCE, diff --git a/torchx/schedulers/aws_batch_scheduler.py b/torchx/schedulers/aws_batch_scheduler.py index ecd5ce2c1..76e285539 100644 --- a/torchx/schedulers/aws_batch_scheduler.py +++ b/torchx/schedulers/aws_batch_scheduler.py @@ -53,13 +53,13 @@ Optional, Tuple, TYPE_CHECKING, + TypedDict, TypeVar, ) import torchx import yaml from torchx.schedulers.api import ( - AppDryRunInfo, DescribeAppResponse, filter_regex, ListAppResponse, @@ -71,6 +71,7 @@ from torchx.schedulers.ids import make_unique from torchx.specs.api import ( AppDef, + AppDryRunInfo, AppState, BindMount, CfgVal, @@ -86,7 +87,6 @@ from torchx.specs.named_resources_aws import instance_type_from_resource from torchx.util.types import none_throws from torchx.workspace.docker_workspace import DockerWorkspaceMixin -from typing_extensions import TypedDict ENV_TORCHX_ROLE_IDX = "TORCHX_ROLE_IDX" diff --git a/torchx/schedulers/aws_sagemaker_scheduler.py b/torchx/schedulers/aws_sagemaker_scheduler.py index a67509520..083ea0f7c 100644 --- a/torchx/schedulers/aws_sagemaker_scheduler.py +++ b/torchx/schedulers/aws_sagemaker_scheduler.py @@ -25,6 +25,7 @@ OrderedDict, Tuple, TYPE_CHECKING, + TypedDict, TypeVar, ) @@ -34,16 +35,14 @@ from sagemaker.pytorch import PyTorch from torchx.components.structured_arg import StructuredNameArgument from torchx.schedulers.api import ( - AppDryRunInfo, DescribeAppResponse, ListAppResponse, Scheduler, Stream, ) from torchx.schedulers.ids import make_unique -from torchx.specs.api import AppDef, AppState, CfgVal, runopts +from torchx.specs.api import AppDef, AppDryRunInfo, AppState, CfgVal, runopts from torchx.workspace.docker_workspace import DockerWorkspaceMixin -from typing_extensions import TypedDict if TYPE_CHECKING: diff --git a/torchx/schedulers/docker_scheduler.py b/torchx/schedulers/docker_scheduler.py index 454f43f92..bed591c72 100644 --- a/torchx/schedulers/docker_scheduler.py +++ b/torchx/schedulers/docker_scheduler.py @@ -13,12 +13,11 @@ import tempfile from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union +from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, TypedDict, Union import torchx import yaml from torchx.schedulers.api import ( - AppDryRunInfo, DescribeAppResponse, filter_regex, ListAppResponse, @@ -30,6 +29,7 @@ from torchx.schedulers.ids import make_unique from torchx.specs.api import ( AppDef, + AppDryRunInfo, AppState, BindMount, DeviceMount, @@ -42,7 +42,6 @@ VolumeMount, ) from torchx.workspace.docker_workspace import DockerWorkspaceMixin -from typing_extensions import TypedDict if TYPE_CHECKING: diff --git a/torchx/schedulers/gcp_batch_scheduler.py b/torchx/schedulers/gcp_batch_scheduler.py index a8fdc99f9..f5599556a 100644 --- a/torchx/schedulers/gcp_batch_scheduler.py +++ b/torchx/schedulers/gcp_batch_scheduler.py @@ -24,22 +24,28 @@ from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING +from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, TypedDict import torchx import yaml from torchx.schedulers.api import ( - AppDryRunInfo, DescribeAppResponse, ListAppResponse, Scheduler, Stream, ) from torchx.schedulers.ids import make_unique -from torchx.specs.api import AppDef, AppState, macros, Resource, Role, runopts +from torchx.specs.api import ( + AppDef, + AppDryRunInfo, + AppState, + macros, + Resource, + Role, + runopts, +) from torchx.util.strings import normalize_str -from typing_extensions import TypedDict if TYPE_CHECKING: diff --git a/torchx/schedulers/kubernetes_mcad_scheduler.py b/torchx/schedulers/kubernetes_mcad_scheduler.py index 53f1b5deb..e0ee17eb1 100644 --- a/torchx/schedulers/kubernetes_mcad_scheduler.py +++ b/torchx/schedulers/kubernetes_mcad_scheduler.py @@ -17,8 +17,8 @@ TorchX Kubernetes_MCAD scheduler depends on AppWrapper + MCAD. -Install MCAD: -See deploying Multi-Cluster-Application-Dispatcher guide +Install MCAD: +See deploying Multi-Cluster-Application-Dispatcher guide https://github.com/project-codeflare/multi-cluster-app-dispatcher/blob/main/doc/deploy/deployment.md This implementation requires MCAD v1.34.1 or higher. @@ -46,12 +46,12 @@ Optional, Tuple, TYPE_CHECKING, + TypedDict, ) import torchx import yaml from torchx.schedulers.api import ( - AppDryRunInfo, DescribeAppResponse, filter_regex, ListAppResponse, @@ -62,6 +62,7 @@ from torchx.schedulers.ids import make_unique from torchx.specs.api import ( AppDef, + AppDryRunInfo, AppState, BindMount, CfgVal, @@ -78,7 +79,6 @@ ) from torchx.workspace.docker_workspace import DockerWorkspaceMixin -from typing_extensions import TypedDict if TYPE_CHECKING: from docker import DockerClient @@ -600,7 +600,7 @@ def app_to_resource( """ Create Service: - The selector will have the key 'appwrapper.workload.codeflare.dev', and the value will be + The selector will have the key 'appwrapper.workload.codeflare.dev', and the value will be the appwrapper name """ @@ -797,7 +797,8 @@ class KubernetesMCADOpts(TypedDict, total=False): class KubernetesMCADScheduler( - DockerWorkspaceMixin, Scheduler[KubernetesMCADOpts, AppDef, AppDryRunInfo] + DockerWorkspaceMixin, + Scheduler[KubernetesMCADOpts, AppDef, AppDryRunInfo[KubernetesMCADJob]], ): """ KubernetesMCADScheduler is a TorchX scheduling interface to Kubernetes. @@ -994,7 +995,7 @@ def _submit_dryrun( if image_secret is not None and service_account is not None: msg = """Service Account and Image Secret names are both provided. Depending on the Service Account configuration, an ImagePullSecret may be defined in your Service Account. - If this is the case, check service account and image secret configurations to understand the expected behavior for + If this is the case, check service account and image secret configurations to understand the expected behavior for patched image push access.""" warnings.warn(msg) namespace = cfg.get("namespace") diff --git a/torchx/schedulers/kubernetes_scheduler.py b/torchx/schedulers/kubernetes_scheduler.py index 699e0d500..a0582755c 100644 --- a/torchx/schedulers/kubernetes_scheduler.py +++ b/torchx/schedulers/kubernetes_scheduler.py @@ -44,12 +44,12 @@ Optional, Tuple, TYPE_CHECKING, + TypedDict, ) import torchx import yaml from torchx.schedulers.api import ( - AppDryRunInfo, DescribeAppResponse, filter_regex, ListAppResponse, @@ -60,6 +60,7 @@ from torchx.schedulers.ids import make_unique from torchx.specs.api import ( AppDef, + AppDryRunInfo, AppState, BindMount, CfgVal, @@ -75,7 +76,6 @@ ) from torchx.util.strings import normalize_str from torchx.workspace.docker_workspace import DockerWorkspaceMixin -from typing_extensions import TypedDict if TYPE_CHECKING: diff --git a/torchx/schedulers/local_scheduler.py b/torchx/schedulers/local_scheduler.py index 9250ee72a..c7cf7cc76 100644 --- a/torchx/schedulers/local_scheduler.py +++ b/torchx/schedulers/local_scheduler.py @@ -40,10 +40,10 @@ Protocol, TextIO, Tuple, + TypedDict, ) from torchx.schedulers.api import ( - AppDryRunInfo, DescribeAppResponse, filter_regex, ListAppResponse, @@ -53,10 +53,10 @@ ) from torchx.schedulers.ids import make_unique from torchx.schedulers.streams import Tee +from torchx.specs import AppDryRunInfo from torchx.specs.api import AppDef, AppState, is_terminal, macros, NONE, Role, runopts from torchx.util.types import none_throws -from typing_extensions import TypedDict log: logging.Logger = logging.getLogger(__name__) diff --git a/torchx/schedulers/lsf_scheduler.py b/torchx/schedulers/lsf_scheduler.py index 0ff8b905c..b2700a316 100644 --- a/torchx/schedulers/lsf_scheduler.py +++ b/torchx/schedulers/lsf_scheduler.py @@ -29,11 +29,10 @@ import tempfile from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, TypedDict import torchx from torchx.schedulers.api import ( - AppDryRunInfo, DescribeAppResponse, filter_regex, ListAppResponse, @@ -45,6 +44,7 @@ from torchx.schedulers.local_scheduler import LogIterator from torchx.specs import ( AppDef, + AppDryRunInfo, AppState, BindMount, DeviceMount, @@ -57,7 +57,6 @@ VolumeMount, ) from torchx.util import shlex -from typing_extensions import TypedDict JOB_STATE: Dict[str, AppState] = { "DONE": AppState.SUCCEEDED, diff --git a/torchx/schedulers/ray_scheduler.py b/torchx/schedulers/ray_scheduler.py index 53f881749..ca726dc01 100644 --- a/torchx/schedulers/ray_scheduler.py +++ b/torchx/schedulers/ray_scheduler.py @@ -14,7 +14,17 @@ from dataclasses import dataclass, field from datetime import datetime from shutil import copy2, rmtree -from typing import Any, cast, Dict, Final, Iterable, List, Optional, Tuple # noqa +from typing import ( # noqa + Any, + cast, + Dict, + Final, + Iterable, + List, + Optional, + Tuple, + TypedDict, +) import urllib3 @@ -23,7 +33,6 @@ from ray.dashboard.modules.job.sdk import JobSubmissionClient from torchx.schedulers.api import ( - AppDryRunInfo, AppState, DescribeAppResponse, filter_regex, @@ -34,9 +43,17 @@ ) from torchx.schedulers.ids import make_unique from torchx.schedulers.ray.ray_common import RayActor, TORCHX_RANK0_HOST -from torchx.specs import AppDef, macros, NONE, ReplicaStatus, Role, RoleStatus, runopts +from torchx.specs import ( + AppDef, + AppDryRunInfo, + macros, + NONE, + ReplicaStatus, + Role, + RoleStatus, + runopts, +) from torchx.workspace.dir_workspace import TmpDirWorkspaceMixin -from typing_extensions import TypedDict class RayOpts(TypedDict, total=False): diff --git a/torchx/schedulers/slurm_scheduler.py b/torchx/schedulers/slurm_scheduler.py index 40c4f12bc..7d1e1833d 100644 --- a/torchx/schedulers/slurm_scheduler.py +++ b/torchx/schedulers/slurm_scheduler.py @@ -21,11 +21,10 @@ from dataclasses import dataclass from datetime import datetime from subprocess import CalledProcessError, PIPE -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict import torchx from torchx.schedulers.api import ( - AppDryRunInfo, DescribeAppResponse, filter_regex, ListAppResponse, @@ -36,6 +35,7 @@ from torchx.schedulers.local_scheduler import LogIterator from torchx.specs import ( AppDef, + AppDryRunInfo, AppState, macros, NONE, @@ -46,7 +46,6 @@ runopts, ) from torchx.workspace.dir_workspace import DirWorkspaceMixin -from typing_extensions import TypedDict SLURM_JOB_DIRS = ".torchxslurmjobdirs" diff --git a/torchx/schedulers/test/aws_sagemaker_scheduler_test.py b/torchx/schedulers/test/aws_sagemaker_scheduler_test.py index 59bee0976..f29a896bc 100644 --- a/torchx/schedulers/test/aws_sagemaker_scheduler_test.py +++ b/torchx/schedulers/test/aws_sagemaker_scheduler_test.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import threading import unittest from collections import OrderedDict @@ -7,8 +13,6 @@ from unittest import TestCase from unittest.mock import MagicMock, patch -from torchx.schedulers.api import AppDryRunInfo - from torchx.schedulers.aws_sagemaker_scheduler import ( _local_session, AWSSageMakerJob, @@ -17,7 +21,7 @@ create_scheduler, JOB_STATE, ) -from torchx.specs.api import runopts +from torchx.specs.api import AppDryRunInfo, runopts ENV_TORCHX_ROLE_NAME = "TORCHX_ROLE_NAME" MODULE = "torchx.schedulers.aws_sagemaker_scheduler" diff --git a/torchx/schedulers/test/kubernetes_mcad_scheduler_test.py b/torchx/schedulers/test/kubernetes_mcad_scheduler_test.py index 6101642b0..38328a922 100644 --- a/torchx/schedulers/test/kubernetes_mcad_scheduler_test.py +++ b/torchx/schedulers/test/kubernetes_mcad_scheduler_test.py @@ -19,7 +19,7 @@ # @manual=//torchx/schedulers:kubernetes_mcad_scheduler from torchx.schedulers import kubernetes_mcad_scheduler -from torchx.schedulers.api import AppDryRunInfo, DescribeAppResponse, ListAppResponse +from torchx.schedulers.api import DescribeAppResponse, ListAppResponse from torchx.schedulers.docker_scheduler import has_docker from torchx.schedulers.kubernetes_mcad_scheduler import ( app_to_resource, @@ -38,7 +38,7 @@ mcad_svc, role_to_pod, ) -from torchx.specs import AppState, Resource, Role +from torchx.specs import AppDryRunInfo, AppState, Resource, Role SKIP_DOCKER: bool = not has_docker() diff --git a/torchx/schedulers/test/kubernetes_scheduler_test.py b/torchx/schedulers/test/kubernetes_scheduler_test.py index 9492ca962..88983e95a 100644 --- a/torchx/schedulers/test/kubernetes_scheduler_test.py +++ b/torchx/schedulers/test/kubernetes_scheduler_test.py @@ -19,7 +19,7 @@ # @manual=//torchx/schedulers:kubernetes_scheduler from torchx.schedulers import kubernetes_scheduler -from torchx.schedulers.api import AppDryRunInfo, DescribeAppResponse, ListAppResponse +from torchx.schedulers.api import DescribeAppResponse, ListAppResponse from torchx.schedulers.docker_scheduler import has_docker from torchx.schedulers.kubernetes_scheduler import ( app_to_resource, @@ -31,7 +31,7 @@ PLACEHOLDER_FIELD_PATH, role_to_pod, ) -from torchx.specs import AppState +from torchx.specs import AppDryRunInfo, AppState SKIP_DOCKER: bool = not has_docker() diff --git a/torchx/schedulers/test/ray_scheduler_test.py b/torchx/schedulers/test/ray_scheduler_test.py index 6205c6e75..4f847025c 100644 --- a/torchx/schedulers/test/ray_scheduler_test.py +++ b/torchx/schedulers/test/ray_scheduler_test.py @@ -19,7 +19,7 @@ from ray.util.placement_group import remove_placement_group from torchx.schedulers import get_scheduler_factories -from torchx.schedulers.api import AppDryRunInfo, DescribeAppResponse, ListAppResponse +from torchx.schedulers.api import DescribeAppResponse, ListAppResponse from torchx.schedulers.ray import ray_driver from torchx.schedulers.ray.ray_common import RayActor from torchx.schedulers.ray_scheduler import ( @@ -29,7 +29,7 @@ RayScheduler, serialize, ) -from torchx.specs import AppDef, Resource, Role, runopts +from torchx.specs import AppDef, AppDryRunInfo, Resource, Role, runopts class RaySchedulerRegistryTest(TestCase): diff --git a/torchx/util/entrypoints.py b/torchx/util/entrypoints.py index 9da5626c4..f3bfcac70 100644 --- a/torchx/util/entrypoints.py +++ b/torchx/util/entrypoints.py @@ -5,14 +5,13 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +# pyre-ignore-all-errors[3, 2, 16] +from importlib import metadata +from importlib.metadata import EntryPoint from typing import Any, Dict, Optional -import importlib_metadata as metadata -from importlib_metadata import EntryPoint - -# pyre-ignore-all-errors[3, 2] def load(group: str, name: str, default=None): """ Loads the entry point specified by @@ -30,13 +29,34 @@ def load(group: str, name: str, default=None): raises an error. """ - entrypoints = metadata.entry_points().select(group=group) + # [note_on_entrypoints] + # return type of importlib.metadata.entry_points() is different between python-3.9 and python-3.10 + # https://docs.python.org/3.9/library/importlib.metadata.html#importlib.metadata.entry_points + # https://docs.python.org/3.10/library/importlib.metadata.html#importlib.metadata.entry_points + if hasattr(metadata.entry_points(), "select"): + # python>=3.10 + entrypoints = metadata.entry_points().select(group=group) - if name not in entrypoints.names and default is not None: - return default + if name not in entrypoints.names and default is not None: + return default + + ep = entrypoints[name] + return ep.load() - ep = entrypoints[name] - return ep.load() + else: + # python<3.10 (e.g. 3.9) + # metadata.entry_points() returns dict[str, tuple[EntryPoint]] (not EntryPoints) in python-3.9 + entrypoints = metadata.entry_points().get(group, ()) + + for ep in entrypoints: + if ep.name == name: + return ep.load() + + # [group].name not found + if default is not None: + return default + else: + raise KeyError(f"entrypoint {group}.{name} not found") def _defer_load_ep(ep: EntryPoint) -> object: @@ -49,7 +69,6 @@ def run(*args: object, **kwargs: object) -> object: return run -# pyre-ignore-all-errors[3, 2] def load_group( group: str, default: Optional[Dict[str, Any]] = None, skip_defaults: bool = False ): @@ -87,7 +106,13 @@ def load_group( """ - entrypoints = metadata.entry_points().select(group=group) + # see [note_on_entrypoints] above + if hasattr(metadata.entry_points(), "select"): + # python>=3.10 + entrypoints = metadata.entry_points().select(group=group) + else: + # python<3.10 (e.g. 3.9) + entrypoints = metadata.entry_points().get(group, ()) if len(entrypoints) == 0: if skip_defaults: diff --git a/torchx/util/test/entrypoints_test.py b/torchx/util/test/entrypoints_test.py index 45c456c67..e6327168c 100644 --- a/torchx/util/test/entrypoints_test.py +++ b/torchx/util/test/entrypoints_test.py @@ -8,16 +8,16 @@ import unittest from configparser import ConfigParser + +from importlib.metadata import EntryPoint from types import ModuleType -from typing import List -from unittest.mock import MagicMock, patch -from importlib_metadata import EntryPoint, EntryPoints +from unittest.mock import MagicMock, patch from torchx.util.entrypoints import load, load_group -def EntryPoint_from_config(config: ConfigParser) -> List[EntryPoint]: +def EntryPoint_from_config(config: ConfigParser) -> list[EntryPoint]: # from stdlib, Copyright (c) Python Authors return [ EntryPoint(name, value, group) @@ -26,7 +26,7 @@ def EntryPoint_from_config(config: ConfigParser) -> List[EntryPoint]: ] -def EntryPoint_from_text(text: str) -> List[EntryPoint]: +def EntryPoint_from_text(text: str) -> list[EntryPoint]: # from stdlib, Copyright (c) Python Authors config = ConfigParser(delimiters="=") config.read_string(text) @@ -66,7 +66,8 @@ def barbaz() -> str: [ep.grp.missing.mod.test] baz = torchx.util.test.entrypoints_test.missing_module """ -_ENTRY_POINTS: EntryPoints = EntryPoints( + +_EPS: list[EntryPoint] = ( EntryPoint_from_text(_EP_TXT) + EntryPoint_from_text(_EP_GRP_TXT) + EntryPoint_from_text(_EP_GRP_IGN_ATTR_TXT) @@ -74,6 +75,17 @@ def barbaz() -> str: + EntryPoint_from_text(_EP_GRP_IGN_MOD_TXT) ) +try: + from importlib.metadata import EntryPoints +except ImportError: + # python<=3.9 + _ENTRY_POINTS: dict[str, list[EntryPoint]] = {} + for ep in _EPS: + _ENTRY_POINTS.setdefault(ep.group, []).append(ep) +else: + # python>=3.10 + _ENTRY_POINTS: EntryPoints = EntryPoints(_EPS) + _METADATA_EPS: str = "torchx.util.entrypoints.metadata.entry_points" diff --git a/torchx/util/test/types_test.py b/torchx/util/test/types_test.py index 29a257c64..399b988b7 100644 --- a/torchx/util/test/types_test.py +++ b/torchx/util/test/types_test.py @@ -8,9 +8,8 @@ import inspect import unittest -from typing import cast, Dict, List, Optional, Union +from typing import cast, Optional, Union -import typing_inspect from torchx.util.types import ( decode, decode_from_string, @@ -26,25 +25,25 @@ def _test_complex_args( arg1: int, - arg2: Optional[List[str]], + arg2: Optional[list[str]], arg3: Union[float, int], ) -> int: return 42 -def _test_dict(arg1: Dict[int, float]) -> int: +def _test_dict(arg1: dict[int, float]) -> int: return 42 -def _test_nested_object(arg1: Dict[str, List[str]]) -> int: +def _test_nested_object(arg1: dict[str, list[str]]) -> int: return 42 -def _test_list(arg1: List[float]) -> int: +def _test_list(arg1: list[float]) -> int: return 42 -def _test_complex_list(arg1: List[List[float]]) -> int: +def _test_complex_list(arg1: list[list[float]]) -> int: return 42 @@ -59,24 +58,21 @@ def test_decode_optional(self) -> None: arg1_parameter = parameters["arg1"] arg1_type = decode_optional(arg1_parameter.annotation) self.assertTrue(arg1_type is int) - - arg2_parameter = parameters["arg2"] arg2_type = decode_optional(parameters["arg2"].annotation) - self.assertTrue(typing_inspect.get_origin(arg2_type) is list) - + self.assertTrue(getattr(arg2_type, "__origin__", None) is list) arg3_parameter = parameters["arg3"] arg3_type = decode_optional(arg3_parameter.annotation) - self.assertTrue(typing_inspect.get_origin(arg3_type) is Union) + self.assertTrue( + hasattr(arg3_type, "__origin__") and arg3_type.__origin__ is Union + ) def test_is_primitive(self) -> None: parameters = inspect.signature(_test_complex_args).parameters arg1_parameter = parameters["arg1"] - arg1_type = decode_optional(arg1_parameter.annotation) self.assertTrue(is_primitive(arg1_parameter.annotation)) arg2_parameter = parameters["arg2"] - arg2_type = decode_optional(parameters["arg2"].annotation) self.assertFalse(is_primitive(arg2_parameter.annotation)) def test_is_bool(self) -> None: @@ -89,7 +85,7 @@ def test_decode_from_string_dict(self) -> None: encoded_value = "1=1.0,2=42.1,3=10" value = decode_from_string(encoded_value, parameters["arg1"].annotation) - value = cast(Dict[int, float], value) + value = cast(dict[int, float], value) self.assertEqual(3, len(value)) self.assertEqual(float(1.0), value[1]) self.assertEqual(float(42.1), value[2]) @@ -101,7 +97,7 @@ def test_decode_from_string_list(self) -> None: encoded_value = "1.0,42.2,3.9" value = decode_from_string(encoded_value, parameters["arg1"].annotation) - value = cast(List[float], value) + value = cast(list[float], value) self.assertEqual(3, len(value)) self.assertEqual(float(1.0), value[0]) self.assertEqual(float(42.2), value[1]) @@ -217,8 +213,8 @@ def fake_component( f: float, s: str, b: bool, - l: List[str], - m: Dict[str, str], + l: list[str], + m: dict[str, str], o: Optional[int], ) -> None: # component has to return an AppDef diff --git a/torchx/util/types.py b/torchx/util/types.py index 599d6ad49..a057b9eb9 100644 --- a/torchx/util/types.py +++ b/torchx/util/types.py @@ -8,12 +8,10 @@ import inspect import re -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Optional, Tuple, TypeVar, Union -import typing_inspect - -def to_list(arg: str) -> List[str]: +def to_list(arg: str) -> list[str]: conf = [] if len(arg.strip()) == 0: return [] @@ -22,9 +20,9 @@ def to_list(arg: str) -> List[str]: return conf -def to_dict(arg: str) -> Dict[str, str]: +def to_dict(arg: str) -> dict[str, str]: """ - Parses the given ``arg`` string literal into a ``Dict[str, str]`` of + Parses the given ``arg`` string literal into a ``dict[str, str]`` of key-value pairs delimited by ``"="`` (equals). The values may be a list literal where the list elements are delimited by ``","`` (comma) or ``";"`` (semi-colon). The same delimiters (``","`` and ``";"``) are used @@ -85,14 +83,14 @@ def to_val(val: str) -> str: return val[1:-1] return val if val != '""' and val != "''" else "" - arg_map: Dict[str, str] = {} + arg_map: dict[str, str] = {} if not arg: return arg_map # find quoted values quoted_pattern = r'([\'"])((?:\\.|(?!\1).)*?)\1' - quoted_values: List[str] = [] + quoted_values: list[str] = [] def replace_quoted(match): quoted_values.append(match.group(0)) @@ -133,9 +131,13 @@ def replace_quoted(match): # pyre-ignore-all-errors[3, 2] def _decode_string_to_dict( - encoded_value: str, param_type: Type[Dict[Any, Any]] -) -> Dict[Any, Any]: - key_type, value_type = typing_inspect.get_args(param_type) + encoded_value: str, param_type: type[dict[Any, Any]] +) -> dict[Any, Any]: + # pyre-ignore[16] + if not hasattr(param_type, "__args__") or len(param_type.__args__) != 2: + raise ValueError(f"param_type must be a `dict` type, but was `{param_type}`") + + key_type, value_type = param_type.__args__ arg_values = {} for key, value in to_dict(encoded_value).items(): arg_values[key_type(key)] = value_type(value) @@ -143,9 +145,12 @@ def _decode_string_to_dict( def _decode_string_to_list( - encoded_value: str, param_type: Type[List[Any]] -) -> List[Any]: - value_type = typing_inspect.get_args(param_type)[0] + encoded_value: str, param_type: type[list[Any]] +) -> list[Any]: + # pyre-ignore[16] + if not hasattr(param_type, "__args__") or len(param_type.__args__) != 1: + raise ValueError(f"param_type must be a `list` type, but was `{param_type}`") + value_type = param_type.__args__[0] if not is_primitive(value_type): raise ValueError("List types support only primitives: int, str, float") arg_values = [] @@ -166,7 +171,7 @@ def decode(encoded_value: Any, annotation: Any): def decode_from_string( encoded_value: str, annotation: Any -) -> Union[Dict[Any, Any], List[Any], None]: +) -> Union[dict[Any, Any], list[Any], None]: """Decodes string representation to the underlying type(Dict or List) Given a string representation of the value, the method decodes it according @@ -191,13 +196,13 @@ def decode_from_string( if not encoded_value: return None value_type = annotation - value_origin = typing_inspect.get_origin(value_type) - if value_origin is dict: - return _decode_string_to_dict(encoded_value, value_type) - elif value_origin is list: - return _decode_string_to_list(encoded_value, value_type) - else: - raise ValueError("Unknown") + if hasattr(value_type, "__origin__"): + value_origin = value_type.__origin__ + if value_origin is dict: + return _decode_string_to_dict(encoded_value, value_type) + elif value_origin is list: + return _decode_string_to_list(encoded_value, value_type) + raise ValueError("Unknown") def is_bool(param_type: Any) -> bool: @@ -229,12 +234,13 @@ def decode_optional(param_type: Any) -> Any: If ``param_type`` is type Optional[INNER_TYPE], method returns INNER_TYPE Otherwise returns ``param_type`` """ - param_origin = typing_inspect.get_origin(param_type) - if param_origin is not Union: + if not hasattr(param_type, "__origin__"): + return param_type + if param_type.__origin__ is not Union: return param_type - key_type, value_type = typing_inspect.get_args(param_type) - if value_type is type(None): - return key_type + args = param_type.__args__ + if len(args) == 2 and args[1] is type(None): + return args[0] else: return param_type