Skip to content

Commit c06bde6

Browse files
committed
update
1 parent c9449a6 commit c06bde6

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

ann_class2/keras_functional.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# https://deeplearningcourses.com/c/data-science-deep-learning-in-theano-tensorflow
2+
# https://www.udemy.com/data-science-deep-learning-in-theano-tensorflow
3+
from __future__ import print_function, division
4+
from builtins import range
5+
# Note: you may need to update your version of future
6+
# sudo pip install -U future
7+
8+
from keras.models import Model
9+
from keras.layers import Dense, Input
10+
from util import get_normalized_data, y2indicator
11+
12+
import matplotlib.pyplot as plt
13+
14+
# NOTE: do NOT name your file keras.py because it will conflict
15+
# with importing keras
16+
17+
# installation is easy! just the usual "sudo pip(3) install keras"
18+
19+
20+
# get the data, same as Theano + Tensorflow examples
21+
# no need to split now, the fit() function will do it
22+
Xtrain, Xtest, Ytrain, Ytest = get_normalized_data()
23+
24+
# get shapes
25+
N, D = Xtrain.shape
26+
K = len(set(Ytrain))
27+
28+
# by default Keras wants one-hot encoded labels
29+
# there's another cost function we can use
30+
# where we can just pass in the integer labels directly
31+
# just like Tensorflow / Theano
32+
Ytrain = y2indicator(Ytrain)
33+
Ytest = y2indicator(Ytest)
34+
35+
36+
# ANN with layers [784] -> [500] -> [300] -> [10]
37+
i = Input(shape=(D,))
38+
x = Dense(500, activation='relu')(i)
39+
x = Dense(300, activation='relu')(x)
40+
x = Dense(K, activation='softmax')(x)
41+
model = Model(inputs=i, outputs=x)
42+
43+
44+
# list of losses: https://keras.io/losses/
45+
# list of optimizers: https://keras.io/optimizers/
46+
# list of metrics: https://keras.io/metrics/
47+
model.compile(
48+
loss='categorical_crossentropy',
49+
optimizer='adam',
50+
metrics=['accuracy']
51+
)
52+
53+
# note: multiple ways to choose a backend
54+
# either theano, tensorflow, or cntk
55+
# https://keras.io/backend/
56+
57+
58+
# gives us back a <keras.callbacks.History object at 0x112e61a90>
59+
r = model.fit(Xtrain, Ytrain, validation_data=(Xtest, Ytest), epochs=15, batch_size=32)
60+
print("Returned:", r)
61+
62+
# print the available keys
63+
# should see: dict_keys(['val_loss', 'acc', 'loss', 'val_acc'])
64+
print(r.history.keys())
65+
66+
# plot some data
67+
plt.plot(r.history['loss'], label='loss')
68+
plt.plot(r.history['val_loss'], label='val_loss')
69+
plt.legend()
70+
plt.show()
71+
72+
# accuracies
73+
plt.plot(r.history['acc'], label='acc')
74+
plt.plot(r.history['val_acc'], label='val_acc')
75+
plt.legend()
76+
plt.show()
77+
78+

0 commit comments

Comments
 (0)