Commit e97ce38d authored by Yuxin Wu's avatar Yuxin Wu

fix typo in last commit (fix #1202)

parent b1c82a9c
...@@ -43,6 +43,7 @@ def get_keras_model(): ...@@ -43,6 +43,7 @@ def get_keras_model():
with clear_tower0_name_scope(): with clear_tower0_name_scope():
M = keras.models.Sequential() M = keras.models.Sequential()
M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same')) M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same'))
M.add(KL.BatchNormalization())
M.add(KL.MaxPooling2D()) M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, activation='relu', padding='same')) M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.Conv2D(32, 3, activation='relu', padding='same')) M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
...@@ -68,7 +69,7 @@ class Model(ModelDesc): ...@@ -68,7 +69,7 @@ class Model(ModelDesc):
logits = M(image) logits = M(image)
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
for op in M.updates: for op in M.updates:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, op)
# build cost function by tensorflow # build cost function by tensorflow
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment