Skip to content

Commit efc6f17

Browse files
authored
Update keras_example.py
1 parent bf9388a commit efc6f17

File tree

1 file changed

+26
-48
lines changed

1 file changed

+26
-48
lines changed

cnn_class/keras_example.py

Lines changed: 26 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,8 @@
1919

2020
from benchmark import get_data, error_rate
2121

22-
23-
# helper
24-
# def y2indicator(Y):
25-
# N = len(Y)
26-
# K = len(set(Y))
27-
# I = np.zeros((N, K))
28-
# I[np.arange(N), Y] = 1
29-
# return I
22+
# get the data
23+
train, test = get_data()
3024

3125
def rearrange(X):
3226
# input is (32, 32, 3, N)
@@ -39,10 +33,6 @@ def rearrange(X):
3933
# return out / 255
4034
return (X.transpose(3, 0, 1, 2) / 255.).astype(np.float32)
4135

42-
43-
# get the data
44-
train, test = get_data()
45-
4636
# Need to scale! don't leave as 0..255
4737
# Y is a N x 1 matrix with values 1..10 (MATLAB indexes by 1)
4838
# So flatten it and make it 0..9
@@ -55,49 +45,39 @@ def rearrange(X):
5545
Ytest = test['y'].flatten() - 1
5646
del test
5747

58-
59-
6048
# get shapes
6149
K = len(set(Ytrain))
6250

63-
64-
6551
# make the CNN
66-
i = Input(shape=Xtrain.shape[1:])
67-
x = Conv2D(filters=20, kernel_size=(5, 5))(i)
68-
x = BatchNormalization()(x)
69-
x = Activation('relu')(x)
70-
x = MaxPooling2D()(x)
71-
72-
x = Conv2D(filters=50, kernel_size=(5, 5))(x)
73-
x = BatchNormalization()(x)
74-
x = Activation('relu')(x)
75-
x = MaxPooling2D()(x)
76-
77-
x = Flatten()(x)
78-
x = Dense(units=500)(x)
79-
x = Activation('relu')(x)
80-
x = Dropout(0.3)(x)
81-
x = Dense(units=K)(x)
82-
x = Activation('softmax')(x)
83-
84-
model = Model(inputs=i, outputs=x)
85-
52+
model = Sequential([
53+
Input(shape=Xtrain.shape[1:]),
54+
Conv2D(filters=20, kernel_size=(5, 5)), # First Conv layer
55+
BatchNormalization(),
56+
Activation('relu'),
57+
MaxPooling2D(),
58+
59+
Conv2D(filters=50, kernel_size=(5, 5)), # Second Conv layer
60+
BatchNormalization(),
61+
Activation('relu'),
62+
MaxPooling2D(),
63+
64+
Flatten(),
65+
Dense(units=500), # Fully connected layer
66+
Activation('relu'),
67+
Dropout(0.3),
68+
Dense(units=K), # Output layer
69+
Activation('softmax')
70+
])
8671

8772
# list of losses: https://keras.io/losses/
8873
# list of optimizers: https://keras.io/optimizers/
8974
# list of metrics: https://keras.io/metrics/
9075
model.compile(
91-
loss='sparse_categorical_crossentropy',
92-
optimizer='adam',
93-
metrics=['accuracy']
76+
loss='sparse_categorical_crossentropy',
77+
optimizer='adam',
78+
metrics=['accuracy']
9479
)
9580

96-
# note: multiple ways to choose a backend
97-
# either theano, tensorflow, or cntk
98-
# https://keras.io/backend/
99-
100-
10181
# gives us back a <keras.callbacks.History object at 0x112e61a90>
10282
r = model.fit(Xtrain, Ytrain, validation_data=(Xtest, Ytest), epochs=10, batch_size=32)
10383
print("Returned:", r)
@@ -113,9 +93,7 @@ def rearrange(X):
11393
plt.show()
11494

11595
# accuracies
116-
plt.plot(r.history['accuracy'], label='acc')
117-
plt.plot(r.history['val_accuracy'], label='val_acc')
96+
plt.plot(r.history['accuracy'], label='accuracy')
97+
plt.plot(r.history['val_accuracy'], label='val_accuracy')
11898
plt.legend()
11999
plt.show()
120-
121-

0 commit comments

Comments
 (0)