1- import os
21from types import GenericAlias , UnionType
32from typing import Generic , TypeVar , Any , get_args
43from .cfg import (
1110 VERSION ,
1211)
1312
14- __all__ = ["JsonSchema" , "field" , "get_plugin_config_and_version" ]
13+ __all__ = ["JsonSchema" , "field" , "get_plugin_config_and_version" , "load_by_schema" ]
1514
1615T = TypeVar ("T" )
1716JsonSchemaT = TypeVar ("JsonSchemaT" , bound = "JsonSchema" )
@@ -48,23 +47,25 @@ def __str__(self):
4847
4948
5049class _Field (Generic [T ]):
51- def __init__ (self , field_name : str , default_value : type [T ] | type [_missing ]):
50+ def __init__ (self , field_name : str , default_value : type [T ] | type [_missing ], optional = False ):
5251 self .field_name = field_name
5352 self .default_value = default_value
53+ self .optional = optional
5454 self ._annotation = None
5555
56- def __call__ (self , annotation ):
56+ def __call__ (self , annotation : Any ):
5757 self ._annotation = annotation
5858 return self
5959
6060
61- def field (field_name : str , default : T | type [_missing ] = _missing ) -> T :
61+ def field (field_name : str , default : T | type [_missing ] = _missing , optional = False ) -> T :
6262 """
6363 为 `JsonSchema` 模版类标注字段。
6464
6565 Args:
6666 field_name (str): 模版字段对应的配置文件键名
6767 default: 该字段的默认值 (注意, 如果不填写的话, 生成配置文件时就不会生成关于它的默认配置)
68+ optional: 该字段是否为可选字段, 若所取配置中无此字段则自动补全 default
6869
6970 >>> class MyConfig(JsonSchema):
7071 ... cfg_a: str = field("配置A")
@@ -76,7 +77,7 @@ def field(field_name: str, default: T | type[_missing] = _missing) -> T:
7677 >>> cfg.cfg_b
7778 350
7879 """
79- return _Field (field_name , default ) # type: ignore
80+ return _Field (field_name , default , optional ) # type: ignore
8081
8182
8283class JsonSchema :
@@ -88,21 +89,21 @@ class JsonSchema:
8889 ... cfg_c: str | int = field("配置C", "Hello dream")
8990
9091 基本类型标注仅接受 `str`, `int`, `float`, `bool` 基本类型。
91- 你也可以使用 `str | int`, `list[float]` 这样的复合类型和 `JsonSchema` 嵌套。
92+ 你也可以使用 `str | int`, `list[float]`, `dict[str, int]`, `tuple[int, str]` 这样的复合类型和 `JsonSchema` 嵌套。
9293 """
9394
94- def __init__ (self , ** kwargs ):
95- for k , v in kwargs .items ():
95+ def __init__ (self , ** obj ):
96+ for k , v in obj .items ():
9697 if k not in self ._fields :
9798 raise ValueError (f"设置默认配置时遇到未知字段 { k } " )
9899 try :
99100 setattr (
100- self , k , load_param_and_type_check (v , self ._fields [k ]._annotation )
101+ self , k , load_by_schema (v , self ._fields [k ]._annotation )
101102 )
102103 except ConfigError as e :
103104 raise ValueError ("设置默认配置传参出错: " + e .msg )
104105 for k , v in self ._fields .items ():
105- if k not in kwargs :
106+ if k not in obj :
106107 if v .default_value is _missing :
107108 raise ValueError (f'字段 "{ k } " 缺失默认值' )
108109 setattr (self , k , v .default_value )
@@ -151,18 +152,22 @@ def _annotation_type_check(typ):
151152 if typ in checkable_types :
152153 return
153154 elif isinstance (typ , GenericAlias ):
154- # list[...] or dict[str, ...]
155+ # list[...] or dict[str, ...] or tuple[...]
155156 orig = typ .__origin__
156157 args = get_args (typ )
157158 if typ .__origin__ is list :
158159 if len (args ) != 1 :
159160 raise ValueError ("不支持的泛型类型个数, 最多只能为 1 个" )
161+ _annotation_type_check (args [0 ])
160162 elif typ .__origin__ is dict :
161163 if len (args ) != 2 :
162164 raise ValueError ("不支持的泛型类型个数, 最多只能为 2 个" )
163165 if args [0 ] is not str :
164166 raise ValueError ("dict 泛型首项参数只能为 str" )
165167 _annotation_type_check (args [1 ])
168+ elif typ .__origin__ is tuple :
169+ for arg in args :
170+ _annotation_type_check (arg )
166171 else :
167172 raise ValueError (f"不支持的泛型类型: { orig } " )
168173 _annotation_type_check (args [0 ])
@@ -181,7 +186,7 @@ def _annotation_type_check(typ):
181186 raise TypeError (f"不支持的类型注释 { typ } " )
182187
183188
184- def load_param_and_type_check (obj , typ : type [T ] | None , field_name : str = "" ) -> T :
189+ def load_by_schema (obj , typ : type [T ] | None , field_name : str = "" ) -> T :
185190 if typ in checkable_types :
186191 if isinstance (obj , int ) and typ is float :
187192 return obj # type: ignore
@@ -194,7 +199,7 @@ def load_param_and_type_check(obj, typ: type[T] | None, field_name: str = "") ->
194199 elif isinstance (typ , UnionType ):
195200 for t in get_args (typ ):
196201 try :
197- return load_param_and_type_check (obj , t )
202+ return load_by_schema (obj , t )
198203 except ConfigError :
199204 pass
200205 raise ConfigError (
@@ -214,7 +219,7 @@ def load_param_and_type_check(obj, typ: type[T] | None, field_name: str = "") ->
214219 lst = []
215220 for i , v in enumerate (obj ):
216221 try :
217- lst .append (load_param_and_type_check (v , sub_type ))
222+ lst .append (load_by_schema (v , sub_type ))
218223 except ConfigError as e :
219224 raise ConfigError (current_key_or_index = i , fromerr = e )
220225 return lst # type: ignore
@@ -228,12 +233,31 @@ def load_param_and_type_check(obj, typ: type[T] | None, field_name: str = "") ->
228233 dic = {}
229234 for k , v in obj .items ():
230235 try :
231- dic [k ] = load_param_and_type_check (v , sub_type )
236+ dic [k ] = load_by_schema (v , sub_type )
232237 except ConfigError as e :
233238 raise ConfigError (current_key_or_index = k , fromerr = e )
234239 return dic # type: ignore
240+ elif orig is tuple :
241+ if not isinstance (obj , list ):
242+ raise ConfigError (
243+ f"值 { obj } 类型错误, 需为列表, 得到 { _get_cfg_type_name (type (obj ))} " ,
244+ field_name ,
245+ )
246+ lst = obj .copy ()
247+ sub_types = get_args (typ )
248+ if len (obj ) != len (sub_types ):
249+ raise ConfigError (
250+ f"值 { obj } 类型错误, 需为长度为 { len (sub_types )} 的列表, 实际上为 { len (obj )} " ,
251+ field_name ,
252+ )
253+ for i , (schema , obj_i ) in enumerate (zip (sub_types , obj )):
254+ try :
255+ obj [i ] = load_by_schema (obj_i , schema )
256+ except ConfigError as e :
257+ raise ConfigError (current_key_or_index = i , fromerr = e )
258+ return lst # type: ignore
235259 else :
236- raise ValueError (f"未知泛型类型 { typ } " )
260+ raise RuntimeError (f"未知泛型类型 { typ } " )
237261 elif type (typ ) is type and issubclass (typ , JsonSchema ):
238262 if not isinstance (obj , dict ):
239263 raise ConfigError (
@@ -250,14 +274,14 @@ def load_param_and_type_check(obj, typ: type[T] | None, field_name: str = "") ->
250274 setattr (
251275 instance ,
252276 k ,
253- load_param_and_type_check (
277+ load_by_schema (
254278 obj [v .field_name ], annotation , field_name
255279 ),
256280 )
257281 except ConfigError as e :
258282 raise ConfigError (current_key_or_index = field_name , fromerr = e )
259283 else :
260- if v .default_value is _missing :
284+ if v .default_value is _missing or not v . optional :
261285 raise ConfigError (f"{ v .field_name } 缺少必填字段" )
262286 setattr (instance , k , v .default_value )
263287 return instance
@@ -310,7 +334,7 @@ def get_plugin_config_and_version(
310334 """
311335 p = TOOLDELTA_PLUGIN_CFG_DIR / plugin_name
312336 if not _jsonfile_exists (p ):
313- s = load_param_and_type_check ({}, schema )
337+ s = load_by_schema ({}, schema )
314338 defaultCfg = PLUGINCFG_DEFAULT .copy ()
315339 defaultCfg ["配置项" ] = dump_param (s )
316340 defaultCfg ["配置版本" ] = "." .join ([str (n ) for n in default_vers ])
@@ -320,4 +344,4 @@ def get_plugin_config_and_version(
320344 VERSION_LENGTH = 3 # 版本长度
321345 if len (cfg_vers ) != VERSION_LENGTH :
322346 raise ValueError ("配置文件出错:版本出错" )
323- return load_param_and_type_check (cfg_get ["配置项" ], schema ), cfg_vers
347+ return load_by_schema (cfg_get ["配置项" ], schema ), cfg_vers
0 commit comments