Skip to content

Commit ab14c52

Browse files
authored
Apply adjustments to run MNIST verification (#172)
1 parent 260a47c commit ab14c52

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ To start the testing process locally, you need to go to the directory
120120
You need to download [Alexnet-model.h5](https://github.com/moizahmed97/Convolutional-Neural-Net-Designer/blob/master/AlexNet-model.h5) to the folder *docs*
121121

122122
# **How do I launch the inference?**
123-
* You need to run the script *parcer.py* that is located in app/AlexNet to read weights from a model *Alexnet-model.h5* and the json file with the weights will be stored in the *docs* folder.
123+
* Make sure you install the project dependencies by running: *pip install -r requirements.txt*
124+
* You need to run the script *parser.py* that is located in app/AlexNet to read weights from a model *Alexnet-model.h5* and the json file with the weights will be stored in the *docs* folder.
124125
* Then put the test images in png format in the folder *docs/input*
125126

126127
# **Accuracy validation**

app/AlexNet/parser.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
1-
import tensorflow as tf
2-
from tensorflow.keras.models import load_model
31
import json
4-
import numpy as np
52
import os
63

4+
import tensorflow as tf
5+
from tensorflow.keras.initializers import GlorotUniform as OriginalGlorotUniform, Zeros as OriginalZeros
6+
from tensorflow.keras.models import load_model
7+
8+
class CustomGlorotUniform(OriginalGlorotUniform):
9+
def __init__(self, seed=None, **kwargs):
10+
kwargs.pop('dtype', None) # Remove the unexpected dtype keyword if present
11+
super().__init__(seed=seed, **kwargs)
12+
13+
class CustomZeros(OriginalZeros):
14+
def __init__(self, **kwargs):
15+
kwargs.pop('dtype', None) # Remove the unexpected dtype keyword if present
16+
super().__init__(**kwargs)
17+
718
# Пути к модели и JSON файлу
819
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
920
MODEL_PATH = os.path.join(BASE_DIR, 'docs', 'AlexNet-model.h5')
1021
MODEL_DATA_PATH = os.path.join(BASE_DIR, 'docs', 'model_data_alexnet_1.json')
1122

1223
# Загрузка модели
13-
model = load_model(MODEL_PATH)
24+
model = load_model(MODEL_PATH, custom_objects={'GlorotUniform': CustomGlorotUniform, 'Zeros': CustomZeros})
1425

1526
# Получение весов модели и информации о порядке слоев
1627
layer_info = []

app/AlexNet/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tensorflow==2.19.0

0 commit comments

Comments
 (0)