Commit a6b091aa authored by Yuxin Wu's avatar Yuxin Wu

use tf.keras for both keras examples

parent 9ce7f032
...@@ -10,10 +10,8 @@ import os ...@@ -10,10 +10,8 @@ import os
import sys import sys
import argparse import argparse
import keras from tensorflow import keras
import keras.layers as KL KL = keras.layers
from keras.models import Sequential
from keras import regularizers
""" """
This is an mnist example demonstrating how to use Keras symbolic function inside tensorpack. This is an mnist example demonstrating how to use Keras symbolic function inside tensorpack.
...@@ -31,7 +29,7 @@ IMAGE_SIZE = 28 ...@@ -31,7 +29,7 @@ IMAGE_SIZE = 28
@memoized # this is necessary for sonnet/Keras to work under tensorpack @memoized # this is necessary for sonnet/Keras to work under tensorpack
def get_keras_model(): def get_keras_model():
M = Sequential() M = keras.models.Sequential()
M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same')) M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same'))
M.add(KL.MaxPooling2D()) M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, activation='relu', padding='same')) M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
...@@ -39,9 +37,9 @@ def get_keras_model(): ...@@ -39,9 +37,9 @@ def get_keras_model():
M.add(KL.MaxPooling2D()) M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, padding='same', activation='relu')) M.add(KL.Conv2D(32, 3, padding='same', activation='relu'))
M.add(KL.Flatten()) M.add(KL.Flatten())
M.add(KL.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(1e-5))) M.add(KL.Dense(512, activation='relu', kernel_regularizer=keras.regularizers.l2(1e-5)))
M.add(KL.Dropout(0.5)) M.add(KL.Dropout(0.5))
M.add(KL.Dense(10, activation=None, kernel_regularizer=regularizers.l2(1e-5))) M.add(KL.Dense(10, activation=None, kernel_regularizer=keras.regularizers.l2(1e-5)))
return M return M
...@@ -67,7 +65,7 @@ class Model(ModelDesc): ...@@ -67,7 +65,7 @@ class Model(ModelDesc):
wd_cost = tf.add_n(M.losses, name='regularize_loss') # this is how Keras manage regularizers wd_cost = tf.add_n(M.losses, name='regularize_loss') # this is how Keras manage regularizers
self.cost = tf.add_n([wd_cost, cost], name='total_cost') self.cost = tf.add_n([wd_cost, cost], name='total_cost')
summary.add_moving_summary(self.cost) summary.add_moving_summary(self.cost, wd_cost)
def _get_optimizer(self): def _get_optimizer(self):
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment