diff --git a/cnn_class/keras_example.py b/cnn_class/keras_example.py
index d0463588..f6b77f46 100644
--- a/cnn_class/keras_example.py
+++ b/cnn_class/keras_example.py
@@ -19,14 +19,8 @@
 
 from benchmark import get_data, error_rate
 
-
-# helper
-# def y2indicator(Y):
-#   N = len(Y)
-#   K = len(set(Y))
-#   I = np.zeros((N, K))
-#   I[np.arange(N), Y] = 1
-#   return I
+# get the data
+train, test = get_data()
 
 def rearrange(X):
     # input is (32, 32, 3, N)
@@ -39,10 +33,6 @@ def rearrange(X):
     # return out / 255
     return (X.transpose(3, 0, 1, 2) / 255.).astype(np.float32)
 
-
-# get the data
-train, test = get_data()
-
 # Need to scale! don't leave as 0..255
 # Y is a N x 1 matrix with values 1..10 (MATLAB indexes by 1)
 # So flatten it and make it 0..9
@@ -55,49 +45,39 @@ def rearrange(X):
 Ytest  = test['y'].flatten() - 1
 del test
 
-
-
 # get shapes
 K = len(set(Ytrain))
 
-
-
 # make the CNN
-i = Input(shape=Xtrain.shape[1:])
-x = Conv2D(filters=20, kernel_size=(5, 5))(i)
-x = BatchNormalization()(x)
-x = Activation('relu')(x)
-x = MaxPooling2D()(x)
-
-x = Conv2D(filters=50, kernel_size=(5, 5))(x)
-x = BatchNormalization()(x)
-x = Activation('relu')(x)
-x = MaxPooling2D()(x)
-
-x = Flatten()(x)
-x = Dense(units=500)(x)
-x = Activation('relu')(x)
-x = Dropout(0.3)(x)
-x = Dense(units=K)(x)
-x = Activation('softmax')(x)
-
-model = Model(inputs=i, outputs=x)
-
+model = Sequential([
+    Input(shape=Xtrain.shape[1:]),
+    Conv2D(filters=20, kernel_size=(5, 5)),  # First Conv layer
+    BatchNormalization(), 
+    Activation('relu'), 
+    MaxPooling2D(), 
+
+    Conv2D(filters=50, kernel_size=(5, 5)),  # Second Conv layer
+    BatchNormalization(),
+    Activation('relu'),
+    MaxPooling2D(), 
+
+    Flatten(), 
+    Dense(units=500),  # Fully connected layer
+    Activation('relu'), 
+    Dropout(0.3), 
+    Dense(units=K),  # Output layer
+    Activation('softmax')
+])
 
 # list of losses: https://keras.io/losses/
 # list of optimizers: https://keras.io/optimizers/
 # list of metrics: https://keras.io/metrics/
 model.compile(
-  loss='sparse_categorical_crossentropy',
-  optimizer='adam',
-  metrics=['accuracy']
+    loss='sparse_categorical_crossentropy',
+    optimizer='adam',
+    metrics=['accuracy']
 )
 
-# note: multiple ways to choose a backend
-# either theano, tensorflow, or cntk
-# https://keras.io/backend/
-
-
 # gives us back a <keras.callbacks.History object at 0x112e61a90>
 r = model.fit(Xtrain, Ytrain, validation_data=(Xtest, Ytest), epochs=10, batch_size=32)
 print("Returned:", r)
@@ -113,9 +93,7 @@ def rearrange(X):
 plt.show()
 
 # accuracies
-plt.plot(r.history['accuracy'], label='acc')
-plt.plot(r.history['val_accuracy'], label='val_acc')
+plt.plot(r.history['accuracy'], label='accuracy')
+plt.plot(r.history['val_accuracy'], label='val_accuracy')
 plt.legend()
 plt.show()
-
-