Commit 9ce7f032 authored by Yuxin Wu's avatar Yuxin Wu

use tf.keras rather than keras

parent 627ad534
......@@ -6,9 +6,8 @@
import numpy as np
import tensorflow as tf
from keras.models import Sequential
import keras.layers as KL
from keras import regularizers
from tensorflow import keras
KL = keras.layers
from tensorpack.train import SimpleTrainer
......@@ -35,7 +34,7 @@ def get_data():
if __name__ == '__main__':
logger.auto_set_dir()
M = Sequential()
M = keras.models.Sequential()
M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same'))
M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
......@@ -43,9 +42,9 @@ if __name__ == '__main__':
M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, padding='same', activation='relu'))
M.add(KL.Flatten())
M.add(KL.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(1e-5)))
M.add(KL.Dense(512, activation='relu', kernel_regularizer=keras.regularizers.l2(1e-5)))
M.add(KL.Dropout(0.5))
M.add(KL.Dense(10, activation=None, kernel_regularizer=regularizers.l2(1e-5)))
M.add(KL.Dense(10, activation=None, kernel_regularizer=keras.regularizers.l2(1e-5)))
M.add(KL.Activation('softmax'))
dataset_train, dataset_test = get_data()
......
......@@ -4,7 +4,7 @@
import tensorflow as tf
from six.moves import zip
import keras
from tensorflow import keras
from ..graph_builder import InputDesc
from ..tfutils.tower import get_current_tower_context
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment