Skip to content

Commit d209ebe

Browse files
committed
fix parser onnx
1 parent 33869af commit d209ebe

File tree

1 file changed

+60
-69
lines changed

1 file changed

+60
-69
lines changed

app/Converters/parser_onnx.py

Lines changed: 60 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -11,89 +11,79 @@ def onnx_to_json(model_path, output_json_path):
1111
# Проверка валидности модели
1212
onnx.checker.check_model(model)
1313

14-
# Создаем словарь для хранения всей информации
15-
model_info = {
16-
"model_metadata": {
17-
"ir_version": model.ir_version,
18-
"opset_version": model.opset_import[0].version,
19-
"producer_name": model.producer_name,
20-
"producer_version": model.producer_version
21-
},
22-
"graph": {
23-
"name": model.graph.name,
24-
"inputs": [],
25-
"outputs": [],
26-
"nodes": [],
27-
"initializers": []
14+
# Создаем словарь для быстрого доступа к инициализаторам по их именам
15+
initializers_dict = {
16+
init.name: {
17+
"data_type": init.data_type,
18+
"dims": list(init.dims),
19+
"values": numpy_helper.to_array(init).tolist()
2820
}
21+
for init in model.graph.initializer
2922
}
3023

31-
# Обработка входных тензоров
32-
for input in model.graph.input:
33-
tensor_type = input.type.tensor_type
34-
model_info["graph"]["inputs"].append({
35-
"name": input.name,
36-
"elem_type": tensor_type.elem_type,
37-
"shape": [dim.dim_value if dim.HasField("dim_value") else dim.dim_param
38-
for dim in tensor_type.shape.dim]
39-
})
40-
41-
# Обработка выходных тензоров
42-
for output in model.graph.output:
43-
tensor_type = output.type.tensor_type
44-
model_info["graph"]["outputs"].append({
45-
"name": output.name,
46-
"elem_type": tensor_type.elem_type,
47-
"shape": [dim.dim_value if dim.HasField("dim_value") else dim.dim_param
48-
for dim in tensor_type.shape.dim]
49-
})
50-
51-
# Обработка узлов (операций)
24+
# Создаем список слоев в формате Keras
25+
layer_info = []
26+
27+
# Обрабатываем входные данные как первый слой
28+
input_layer = {
29+
"index": 0,
30+
"name": "input_1",
31+
"type": "InputLayer",
32+
"weights": [],
33+
"attributes": {}
34+
}
35+
layer_info.append(input_layer)
36+
37+
# Обработка узлов (операций) как слоев
5238
for node in model.graph.node:
53-
node_info = {
54-
"name": node.name,
55-
"op_type": node.op_type,
56-
"inputs": list(node.input), # Convert to list
57-
"outputs": list(node.output), # Convert to list
58-
"attributes": []
39+
# Создаем запись слоя
40+
layer_data = {
41+
"index": len(layer_info),
42+
"name": node.name.replace('/', '_'),
43+
"type": node.op_type,
44+
"weights": [],
45+
"attributes": {} # Сохраняем все атрибуты здесь
5946
}
6047

48+
# Обрабатываем все атрибуты узла
6149
for attr in node.attribute:
6250
attr_value = helper.get_attribute_value(attr)
63-
# Handle different attribute types
51+
52+
# Преобразуем разные типы атрибутов
6453
if isinstance(attr_value, bytes):
6554
attr_value = attr_value.decode('utf-8', errors='ignore')
6655
elif hasattr(attr_value, 'tolist'):
6756
attr_value = attr_value.tolist()
6857
elif str(type(attr_value)).endswith("RepeatedScalarContainer'>"):
6958
attr_value = list(attr_value)
7059

71-
node_info["attributes"].append({
72-
"name": attr.name,
73-
"value": attr_value
74-
})
75-
76-
model_info["graph"]["nodes"].append(node_info)
77-
78-
# Обработка инициализаторов (весов)
79-
for initializer in model.graph.initializer:
80-
# Получаем значения весов в виде списка
81-
weights = numpy_helper.to_array(initializer).tolist()
82-
83-
model_info["graph"]["initializers"].append({
84-
"name": initializer.name,
85-
"data_type": initializer.data_type,
86-
"dims": list(initializer.dims),
87-
"values": weights # Внимание: для больших моделей это может занять много памяти!
88-
})
89-
90-
# Обработка метаданных
91-
if model.metadata_props:
92-
model_info["metadata"] = {}
93-
for prop in model.metadata_props:
94-
model_info["metadata"][prop.key] = prop.value
95-
96-
# Custom JSON encoder to handle remaining non-serializable objects
60+
# Сохраняем атрибут
61+
layer_data["attributes"][attr.name] = attr_value
62+
63+
# Специальная обработка для удобства (можно использовать или игнорировать)
64+
if attr.name == "pads":
65+
layer_data["padding"] = "same" if any(p > 0 for p in attr_value) else "valid"
66+
elif attr.name == "kernel_shape":
67+
layer_data["kernel_size"] = attr_value
68+
elif attr.name == "strides":
69+
layer_data["strides"] = attr_value
70+
71+
# Добавляем веса в формате Keras (один список с ядрами и bias)
72+
layer_weights = []
73+
for input_name in node.input:
74+
if input_name in initializers_dict:
75+
init = initializers_dict[input_name]
76+
if len(init["dims"]) > 1: # Ядра свертки/матрицы весов
77+
layer_weights.extend(init["values"])
78+
else: # Bias
79+
layer_weights.append(init["values"])
80+
81+
if layer_weights:
82+
layer_data["weights"] = layer_weights
83+
84+
layer_info.append(layer_data)
85+
86+
# Custom JSON encoder
9787
class CustomEncoder(json.JSONEncoder):
9888
def default(self, obj):
9989
if hasattr(obj, 'tolist'):
@@ -104,10 +94,11 @@ def default(self, obj):
10494

10595
# Сохранение в JSON файл
10696
with open(output_json_path, 'w') as f:
107-
json.dump(model_info, f, indent=2, cls=CustomEncoder)
97+
json.dump(layer_info, f, indent=2, cls=CustomEncoder)
10898

10999
print(f"Модель успешно сохранена в {output_json_path}")
110100

101+
111102
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
112103
MODEL_PATH = os.path.join(BASE_DIR, 'docs\\models', 'GoogLeNet.onnx')
113104
MODEL_DATA_PATH = os.path.join(BASE_DIR, 'docs\\jsons', 'googlenet_onnx_model.json')

0 commit comments

Comments
 (0)