diff --git a/magicbot/magic_tunable.py b/magicbot/magic_tunable.py index d8e3570..091e2d8 100644 --- a/magicbot/magic_tunable.py +++ b/magicbot/magic_tunable.py @@ -80,6 +80,7 @@ def execute(self): "_ntsubtable", "_ntwritedefault", # "__doc__", + "__orig_class__", "_topic_type", "_nt", ) @@ -100,13 +101,48 @@ def __init__( self._ntwritedefault = writeDefault # self.__doc__ = doc - self._topic_type = _get_topic_type_for_value(self._ntdefault) - if self._topic_type is None: - checked_type: type = type(self._ntdefault) + # Defer checks for empty sequences to check type hints. + # Report errors here when we can so the error points to the tunable line. + if default or not isinstance(default, collections.abc.Sequence): + topic_type = _get_topic_type_for_value(default) + if topic_type is None: + checked_type: type = type(default) + raise TypeError( + f"tunable is not publishable to NetworkTables, type: {checked_type.__name__}" + ) + self._topic_type = topic_type + + def __set_name__(self, owner: type, name: str) -> None: + type_hint: Optional[type] = None + # __orig_class__ is set after __init__, check it here. + orig_class = getattr(self, "__orig_class__", None) + if orig_class is not None: + # Accept field = tunable[Sequence[int]]([]) + type_hint = typing.get_args(orig_class)[0] + else: + type_hint = typing.get_type_hints(owner).get(name) + origin = typing.get_origin(type_hint) + if origin is typing.ClassVar: + # Accept field: ClassVar[tunable[Sequence[int]]] = tunable([]) + type_hint = typing.get_args(type_hint)[0] + origin = typing.get_origin(type_hint) + if origin is tunable: + # Accept field: tunable[Sequence[int]] = tunable([]) + type_hint = typing.get_args(type_hint)[0] + + if type_hint is not None: + topic_type = _get_topic_type(type_hint) + else: + topic_type = _get_topic_type_for_value(self._ntdefault) + + if topic_type is None: + checked_type: type = type_hint or type(self._ntdefault) raise TypeError( f"tunable is not publishable to NetworkTables, type: {checked_type.__name__}" ) + self._topic_type = topic_type + @overload def __get__(self, instance: None, owner=None) -> "tunable[V]": ... @@ -218,7 +254,7 @@ class MyComponent: navx: ... @feedback - def get_angle(self): + def get_angle(self) -> float: return self.navx.getYaw() class MyRobot(magicbot.MagicRobot): @@ -297,6 +333,8 @@ def _get_topic_type( if hasattr(inner_type, "WPIStruct"): return lambda topic: ntcore.StructArrayTopic(topic, inner_type) + return None + def collect_feedbacks(component, cname: str, prefix: Optional[str] = "components"): """ diff --git a/tests/test_magicbot_tunable.py b/tests/test_magicbot_tunable.py index fad999b..a90a13b 100644 --- a/tests/test_magicbot_tunable.py +++ b/tests/test_magicbot_tunable.py @@ -1,3 +1,5 @@ +from typing import ClassVar, List, Sequence + import ntcore import pytest from wpimath import geometry @@ -25,6 +27,7 @@ class Component: topic = nt.getTopic(name) assert topic.getTypeString() == type_str assert topic.genericSubscribe().get().value() == value + assert getattr(component, name) == value for name, value in [ ("rotation", geometry.Rotation2d()), @@ -33,6 +36,7 @@ class Component: assert nt.getTopic(name).getTypeString() == f"struct:{struct_type.__name__}" topic = nt.getStructTopic(name, struct_type) assert topic.subscribe(None).get() == value + assert getattr(component, name) == value for name, struct_type, value in [ ("rotations", geometry.Rotation2d, [geometry.Rotation2d()]), @@ -40,6 +44,7 @@ class Component: assert nt.getTopic(name).getTypeString() == f"struct:{struct_type.__name__}[]" topic = nt.getStructArrayTopic(name, struct_type) assert topic.subscribe([]).get() == value + assert getattr(component, name) == value def test_tunable_errors(): @@ -50,7 +55,44 @@ class Component: def test_tunable_errors_with_empty_sequence(): - with pytest.raises(ValueError): + with pytest.raises((RuntimeError, ValueError)): class Component: empty = tunable([]) + + +def test_type_hinted_empty_sequences() -> None: + class Component: + generic_seq = tunable[Sequence[int]](()) + class_var_seq: ClassVar[tunable[Sequence[int]]] = tunable(()) + inst_seq: Sequence[int] = tunable(()) + + generic_typing_list = tunable[List[int]]([]) + class_var_typing_list: ClassVar[tunable[List[int]]] = tunable([]) + inst_typing_list: List[int] = tunable([]) + + # TODO(davo): re-enable after py3.8 is dropped + # generic_list = tunable[list[int]]([]) + # class_var_list: ClassVar[tunable[list[int]]] = tunable([]) + # inst_list: list[int] = tunable([]) + + component = Component() + setup_tunables(component, "test_type_hinted_sequences") + NetworkTables = ntcore.NetworkTableInstance.getDefault() + nt = NetworkTables.getTable("/components/test_type_hinted_sequences") + + for name in [ + "generic_seq", + "class_var_seq", + "inst_seq", + "generic_typing_list", + "class_var_typing_list", + "inst_typing_list", + # "generic_list", + # "class_var_list", + # "inst_list", + ]: + assert nt.getTopic(name).getTypeString() == "int[]" + entry = nt.getEntry(name) + assert entry.getIntegerArray(None) == [] + assert getattr(component, name) == []