Skip to content

Commit 7c77488

Browse files
committed
tunable: Allow empty default lists when type-hinted
1 parent eb6098d commit 7c77488

File tree

2 files changed

+80
-5
lines changed

2 files changed

+80
-5
lines changed

magicbot/magic_tunable.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def execute(self):
8080
"_ntsubtable",
8181
"_ntwritedefault",
8282
# "__doc__",
83+
"__orig_class__",
8384
"_topic_type",
8485
"_nt",
8586
)
@@ -100,13 +101,47 @@ def __init__(
100101
self._ntwritedefault = writeDefault
101102
# self.__doc__ = doc
102103

103-
self._topic_type = _get_topic_type_for_value(self._ntdefault)
104-
if self._topic_type is None:
105-
checked_type: type = type(self._ntdefault)
104+
# Defer checks for empty sequences to check type hints.
105+
# Report errors here when we can so the error points to the tunable line.
106+
if default or not isinstance(default, collections.abc.Sequence):
107+
self._topic_type = _get_topic_type_for_value(default)
108+
if self._topic_type is None:
109+
checked_type: type = type(default)
110+
raise TypeError(
111+
f"tunable is not publishable to NetworkTables, type: {checked_type.__name__}"
112+
)
113+
114+
def __set_name__(self, owner: type, name: str) -> None:
115+
type_hint: Optional[type] = None
116+
# __orig_class__ is set after __init__, check it here.
117+
orig_class = getattr(self, "__orig_class__", None)
118+
if orig_class is not None:
119+
# Accept field = tunable[Sequence[int]]([])
120+
type_hint = typing.get_args(orig_class)[0]
121+
else:
122+
type_hint = typing.get_type_hints(owner).get(name)
123+
origin = typing.get_origin(type_hint)
124+
if origin is typing.ClassVar:
125+
# Accept field: ClassVar[tunable[Sequence[int]]] = tunable([])
126+
type_hint = typing.get_args(type_hint)[0]
127+
origin = typing.get_origin(type_hint)
128+
if origin is tunable:
129+
# Accept field: tunable[Sequence[int]] = tunable([])
130+
type_hint = typing.get_args(type_hint)[0]
131+
132+
if type_hint is not None:
133+
topic_type = _get_topic_type(type_hint)
134+
else:
135+
topic_type = _get_topic_type_for_value(self._ntdefault)
136+
137+
if topic_type is None:
138+
checked_type: type = type_hint or type(self._ntdefault)
106139
raise TypeError(
107140
f"tunable is not publishable to NetworkTables, type: {checked_type.__name__}"
108141
)
109142

143+
self._topic_type = topic_type
144+
110145
@overload
111146
def __get__(self, instance: None, owner=None) -> "tunable[V]": ...
112147

@@ -218,7 +253,7 @@ class MyComponent:
218253
navx: ...
219254
220255
@feedback
221-
def get_angle(self):
256+
def get_angle(self) -> float:
222257
return self.navx.getYaw()
223258
224259
class MyRobot(magicbot.MagicRobot):
@@ -297,6 +332,8 @@ def _get_topic_type(
297332
if hasattr(inner_type, "WPIStruct"):
298333
return lambda topic: ntcore.StructArrayTopic(topic, inner_type)
299334

335+
return None
336+
300337

301338
def collect_feedbacks(component, cname: str, prefix: Optional[str] = "components"):
302339
"""

tests/test_magicbot_tunable.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import ClassVar, List, Sequence
2+
13
import ntcore
24
import pytest
35
from wpimath import geometry
@@ -53,7 +55,43 @@ class Component:
5355

5456

5557
def test_tunable_errors_with_empty_sequence():
56-
with pytest.raises(ValueError):
58+
with pytest.raises(RuntimeError):
5759

5860
class Component:
5961
empty = tunable([])
62+
63+
64+
def test_type_hinted_empty_sequences() -> None:
65+
class Component:
66+
generic_seq = tunable[Sequence[int]](())
67+
class_var_seq: ClassVar[tunable[Sequence[int]]] = tunable(())
68+
inst_seq: Sequence[int] = tunable(())
69+
70+
generic_typing_list = tunable[List[int]]([])
71+
class_var_typing_list: ClassVar[tunable[List[int]]] = tunable([])
72+
inst_typing_list: List[int] = tunable([])
73+
74+
generic_list = tunable[list[int]]([])
75+
class_var_list: ClassVar[tunable[list[int]]] = tunable([])
76+
inst_list: list[int] = tunable([])
77+
78+
component = Component()
79+
setup_tunables(component, "test_type_hinted_sequences")
80+
NetworkTables = ntcore.NetworkTableInstance.getDefault()
81+
nt = NetworkTables.getTable("/components/test_type_hinted_sequences")
82+
83+
for name in [
84+
"generic_seq",
85+
"class_var_seq",
86+
"inst_seq",
87+
"generic_typing_list",
88+
"class_var_typing_list",
89+
"inst_typing_list",
90+
"generic_list",
91+
"class_var_list",
92+
"inst_list",
93+
]:
94+
assert nt.getTopic(name).getTypeString() == "int[]"
95+
entry = nt.getEntry(name)
96+
assert entry.getIntegerArray(None) == []
97+
assert getattr(component, name) == []

0 commit comments

Comments
 (0)