Skip to content

Commit 039145c

Browse files
feat: cfg_meta 支持 tuple 类型注释
1 parent ddafbab commit 039145c

File tree

1 file changed

+45
-21
lines changed

1 file changed

+45
-21
lines changed

tooldelta/utils/cfg_meta.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
from types import GenericAlias, UnionType
32
from typing import Generic, TypeVar, Any, get_args
43
from .cfg import (
@@ -11,7 +10,7 @@
1110
VERSION,
1211
)
1312

14-
__all__ = ["JsonSchema", "field", "get_plugin_config_and_version"]
13+
__all__ = ["JsonSchema", "field", "get_plugin_config_and_version", "load_by_schema"]
1514

1615
T = TypeVar("T")
1716
JsonSchemaT = TypeVar("JsonSchemaT", bound="JsonSchema")
@@ -48,23 +47,25 @@ def __str__(self):
4847

4948

5049
class _Field(Generic[T]):
51-
def __init__(self, field_name: str, default_value: type[T] | type[_missing]):
50+
def __init__(self, field_name: str, default_value: type[T] | type[_missing], optional=False):
5251
self.field_name = field_name
5352
self.default_value = default_value
53+
self.optional = optional
5454
self._annotation = None
5555

56-
def __call__(self, annotation):
56+
def __call__(self, annotation: Any):
5757
self._annotation = annotation
5858
return self
5959

6060

61-
def field(field_name: str, default: T | type[_missing] = _missing) -> T:
61+
def field(field_name: str, default: T | type[_missing] = _missing, optional=False) -> T:
6262
"""
6363
为 `JsonSchema` 模版类标注字段。
6464
6565
Args:
6666
field_name (str): 模版字段对应的配置文件键名
6767
default: 该字段的默认值 (注意, 如果不填写的话, 生成配置文件时就不会生成关于它的默认配置)
68+
optional: 该字段是否为可选字段, 若所取配置中无此字段则自动补全 default
6869
6970
>>> class MyConfig(JsonSchema):
7071
... cfg_a: str = field("配置A")
@@ -76,7 +77,7 @@ def field(field_name: str, default: T | type[_missing] = _missing) -> T:
7677
>>> cfg.cfg_b
7778
350
7879
"""
79-
return _Field(field_name, default) # type: ignore
80+
return _Field(field_name, default, optional) # type: ignore
8081

8182

8283
class JsonSchema:
@@ -88,21 +89,21 @@ class JsonSchema:
8889
... cfg_c: str | int = field("配置C", "Hello dream")
8990
9091
基本类型标注仅接受 `str`, `int`, `float`, `bool` 基本类型。
91-
你也可以使用 `str | int`, `list[float]` 这样的复合类型和 `JsonSchema` 嵌套。
92+
你也可以使用 `str | int`, `list[float]`, `dict[str, int]`, `tuple[int, str]` 这样的复合类型和 `JsonSchema` 嵌套。
9293
"""
9394

94-
def __init__(self, **kwargs):
95-
for k, v in kwargs.items():
95+
def __init__(self, **obj):
96+
for k, v in obj.items():
9697
if k not in self._fields:
9798
raise ValueError(f"设置默认配置时遇到未知字段 {k}")
9899
try:
99100
setattr(
100-
self, k, load_param_and_type_check(v, self._fields[k]._annotation)
101+
self, k, load_by_schema(v, self._fields[k]._annotation)
101102
)
102103
except ConfigError as e:
103104
raise ValueError("设置默认配置传参出错: " + e.msg)
104105
for k, v in self._fields.items():
105-
if k not in kwargs:
106+
if k not in obj:
106107
if v.default_value is _missing:
107108
raise ValueError(f'字段 "{k}" 缺失默认值')
108109
setattr(self, k, v.default_value)
@@ -151,18 +152,22 @@ def _annotation_type_check(typ):
151152
if typ in checkable_types:
152153
return
153154
elif isinstance(typ, GenericAlias):
154-
# list[...] or dict[str, ...]
155+
# list[...] or dict[str, ...] or tuple[...]
155156
orig = typ.__origin__
156157
args = get_args(typ)
157158
if typ.__origin__ is list:
158159
if len(args) != 1:
159160
raise ValueError("不支持的泛型类型个数, 最多只能为 1 个")
161+
_annotation_type_check(args[0])
160162
elif typ.__origin__ is dict:
161163
if len(args) != 2:
162164
raise ValueError("不支持的泛型类型个数, 最多只能为 2 个")
163165
if args[0] is not str:
164166
raise ValueError("dict 泛型首项参数只能为 str")
165167
_annotation_type_check(args[1])
168+
elif typ.__origin__ is tuple:
169+
for arg in args:
170+
_annotation_type_check(arg)
166171
else:
167172
raise ValueError(f"不支持的泛型类型: {orig}")
168173
_annotation_type_check(args[0])
@@ -181,7 +186,7 @@ def _annotation_type_check(typ):
181186
raise TypeError(f"不支持的类型注释 {typ}")
182187

183188

184-
def load_param_and_type_check(obj, typ: type[T] | None, field_name: str = "") -> T:
189+
def load_by_schema(obj, typ: type[T] | None, field_name: str = "") -> T:
185190
if typ in checkable_types:
186191
if isinstance(obj, int) and typ is float:
187192
return obj # type: ignore
@@ -194,7 +199,7 @@ def load_param_and_type_check(obj, typ: type[T] | None, field_name: str = "") ->
194199
elif isinstance(typ, UnionType):
195200
for t in get_args(typ):
196201
try:
197-
return load_param_and_type_check(obj, t)
202+
return load_by_schema(obj, t)
198203
except ConfigError:
199204
pass
200205
raise ConfigError(
@@ -214,7 +219,7 @@ def load_param_and_type_check(obj, typ: type[T] | None, field_name: str = "") ->
214219
lst = []
215220
for i, v in enumerate(obj):
216221
try:
217-
lst.append(load_param_and_type_check(v, sub_type))
222+
lst.append(load_by_schema(v, sub_type))
218223
except ConfigError as e:
219224
raise ConfigError(current_key_or_index=i, fromerr=e)
220225
return lst # type: ignore
@@ -228,12 +233,31 @@ def load_param_and_type_check(obj, typ: type[T] | None, field_name: str = "") ->
228233
dic = {}
229234
for k, v in obj.items():
230235
try:
231-
dic[k] = load_param_and_type_check(v, sub_type)
236+
dic[k] = load_by_schema(v, sub_type)
232237
except ConfigError as e:
233238
raise ConfigError(current_key_or_index=k, fromerr=e)
234239
return dic # type: ignore
240+
elif orig is tuple:
241+
if not isinstance(obj, list):
242+
raise ConfigError(
243+
f"值 {obj} 类型错误, 需为列表, 得到 {_get_cfg_type_name(type(obj))}",
244+
field_name,
245+
)
246+
lst = obj.copy()
247+
sub_types = get_args(typ)
248+
if len(obj) != len(sub_types):
249+
raise ConfigError(
250+
f"值 {obj} 类型错误, 需为长度为 {len(sub_types)} 的列表, 实际上为 {len(obj)}",
251+
field_name,
252+
)
253+
for i, (schema, obj_i) in enumerate(zip(sub_types, obj)):
254+
try:
255+
obj[i] = load_by_schema(obj_i, schema)
256+
except ConfigError as e:
257+
raise ConfigError(current_key_or_index=i, fromerr=e)
258+
return lst # type: ignore
235259
else:
236-
raise ValueError(f"未知泛型类型 {typ}")
260+
raise RuntimeError(f"未知泛型类型 {typ}")
237261
elif type(typ) is type and issubclass(typ, JsonSchema):
238262
if not isinstance(obj, dict):
239263
raise ConfigError(
@@ -250,14 +274,14 @@ def load_param_and_type_check(obj, typ: type[T] | None, field_name: str = "") ->
250274
setattr(
251275
instance,
252276
k,
253-
load_param_and_type_check(
277+
load_by_schema(
254278
obj[v.field_name], annotation, field_name
255279
),
256280
)
257281
except ConfigError as e:
258282
raise ConfigError(current_key_or_index=field_name, fromerr=e)
259283
else:
260-
if v.default_value is _missing:
284+
if v.default_value is _missing or not v.optional:
261285
raise ConfigError(f"{v.field_name} 缺少必填字段")
262286
setattr(instance, k, v.default_value)
263287
return instance
@@ -310,7 +334,7 @@ def get_plugin_config_and_version(
310334
"""
311335
p = TOOLDELTA_PLUGIN_CFG_DIR / plugin_name
312336
if not _jsonfile_exists(p):
313-
s = load_param_and_type_check({}, schema)
337+
s = load_by_schema({}, schema)
314338
defaultCfg = PLUGINCFG_DEFAULT.copy()
315339
defaultCfg["配置项"] = dump_param(s)
316340
defaultCfg["配置版本"] = ".".join([str(n) for n in default_vers])
@@ -320,4 +344,4 @@ def get_plugin_config_and_version(
320344
VERSION_LENGTH = 3 # 版本长度
321345
if len(cfg_vers) != VERSION_LENGTH:
322346
raise ValueError("配置文件出错:版本出错")
323-
return load_param_and_type_check(cfg_get["配置项"], schema), cfg_vers
347+
return load_by_schema(cfg_get["配置项"], schema), cfg_vers

0 commit comments

Comments
 (0)