Commit a78d02ac authored by ppwwyyxx's avatar ppwwyyxx

use is_training

parent 9bf42054
...@@ -9,6 +9,8 @@ import os ...@@ -9,6 +9,8 @@ import os
sys.path.insert(0, os.path.expanduser('~/.local/lib/python2.7/site-packages')) sys.path.insert(0, os.path.expanduser('~/.local/lib/python2.7/site-packages'))
import tensorflow as tf import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
import numpy as np import numpy as np
import os import os
...@@ -31,8 +33,9 @@ def get_model(inputs): ...@@ -31,8 +33,9 @@ def get_model(inputs):
outputs: a list of output variable outputs: a list of output variable
cost: scalar variable cost: scalar variable
""" """
# use this variable in dropout! Tensorpack will automatically set it to 1 at test time is_training = tf.get_default_graph().get_tensor_by_name(IS_TRAINING_VAR_NAME)
keep_prob = tf.get_default_graph().get_tensor_by_name(DROPOUT_PROB_VAR_NAME) keep_prob = control_flow_ops.cond(
is_training, lambda: tf.constant(0.5), lambda: tf.constant(1.0), name='dropout_prob')
image, label = inputs image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel image = tf.expand_dims(image, 3) # add a single channel
...@@ -83,6 +86,7 @@ def get_config(): ...@@ -83,6 +86,7 @@ def get_config():
sess_config = tf.ConfigProto() sess_config = tf.ConfigProto()
sess_config.device_count['GPU'] = 1 sess_config.device_count['GPU'] = 1
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
sess_config.gpu_options.allocator_type = 'BFC' sess_config.gpu_options.allocator_type = 'BFC'
sess_config.allow_soft_placement = True sess_config.allow_soft_placement = True
......
...@@ -10,8 +10,9 @@ from itertools import count ...@@ -10,8 +10,9 @@ from itertools import count
import argparse import argparse
def prepare(): def prepare():
keep_prob = tf.placeholder( is_training = tf.placeholder(tf.bool, shape=(), name=IS_TRAINING_OP_NAME)
tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME) #keep_prob = tf.placeholder(
#tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
global_step_var = tf.Variable( global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME) 0, trainable=False, name=GLOBAL_STEP_OP_NAME)
...@@ -49,19 +50,11 @@ def start_train(config): ...@@ -49,19 +50,11 @@ def start_train(config):
G.add_to_collection(INPUT_VARS_KEY, v) G.add_to_collection(INPUT_VARS_KEY, v)
for v in output_vars: for v in output_vars:
G.add_to_collection(OUTPUT_VARS_KEY, v) G.add_to_collection(OUTPUT_VARS_KEY, v)
summary_model() describe_model()
global_step_var = G.get_tensor_by_name(GLOBAL_STEP_VAR_NAME) global_step_var = G.get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
# add some summary ops to the graph avg_maintain_op = summary_moving_average(cost_var)
averager = tf.train.ExponentialMovingAverage(
0.9, num_updates=global_step_var, name='avg')
vars_to_summary = [cost_var] + \
tf.get_collection(SUMMARY_VARS_KEY) + \
tf.get_collection(COST_VARS_KEY)
avg_maintain_op = averager.apply(vars_to_summary)
for c in vars_to_summary:
tf.scalar_summary(c.op.name, averager.average(c))
# maintain average in each step # maintain average in each step
with tf.control_dependencies([avg_maintain_op]): with tf.control_dependencies([avg_maintain_op]):
...@@ -79,11 +72,11 @@ def start_train(config): ...@@ -79,11 +72,11 @@ def start_train(config):
sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_all_variables())
callbacks.before_train() callbacks.before_train()
keep_prob_var = G.get_tensor_by_name(DROPOUT_PROB_VAR_NAME) is_training = G.get_tensor_by_name(IS_TRAINING_VAR_NAME)
for epoch in xrange(1, max_epoch): for epoch in xrange(1, max_epoch):
with timed_operation('epoch {}'.format(epoch)): with timed_operation('epoch {}'.format(epoch)):
for dp in dataset_train.get_data(): for dp in dataset_train.get_data():
feed = {keep_prob_var: 0.5} feed = {is_training: True}
feed.update(dict(zip(input_vars, dp))) feed.update(dict(zip(input_vars, dp)))
results = sess.run( results = sess.run(
......
...@@ -31,8 +31,7 @@ def timed_operation(msg, log_start=False): ...@@ -31,8 +31,7 @@ def timed_operation(msg, log_start=False):
logger.info('finished {}, time={:.2f}sec.'.format( logger.info('finished {}, time={:.2f}sec.'.format(
msg, time.time() - start)) msg, time.time() - start))
def describe_model():
def summary_model():
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
msg = [""] msg = [""]
total = 0 total = 0
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
# File: naming.py # File: naming.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
DROPOUT_PROB_OP_NAME = 'dropout_prob' IS_TRAINING_OP_NAME = 'is_training'
DROPOUT_PROB_VAR_NAME = 'dropout_prob:0' IS_TRAINING_VAR_NAME = 'is_training:0'
GLOBAL_STEP_OP_NAME = 'global_step' GLOBAL_STEP_OP_NAME = 'global_step'
GLOBAL_STEP_VAR_NAME = 'global_step:0' GLOBAL_STEP_VAR_NAME = 'global_step:0'
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from .naming import *
def create_summary(name, v): def create_summary(name, v):
""" """
...@@ -42,3 +43,14 @@ def add_histogram_summary(regex): ...@@ -42,3 +43,14 @@ def add_histogram_summary(regex):
if re.search(regex, name): if re.search(regex, name):
tf.histogram_summary(name, p) tf.histogram_summary(name, p)
def summary_moving_average(cost_var):
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
averager = tf.train.ExponentialMovingAverage(
0.9, num_updates=global_step_var, name='avg')
vars_to_summary = [cost_var] + \
tf.get_collection(SUMMARY_VARS_KEY) + \
tf.get_collection(COST_VARS_KEY)
avg_maintain_op = averager.apply(vars_to_summary)
for c in vars_to_summary:
tf.scalar_summary(c.op.name, averager.average(c))
return avg_maintain_op
...@@ -33,7 +33,7 @@ class ValidationError(PeriodicCallback): ...@@ -33,7 +33,7 @@ class ValidationError(PeriodicCallback):
def _before_train(self): def _before_train(self):
self.input_vars = tf.get_collection(INPUT_VARS_KEY) self.input_vars = tf.get_collection(INPUT_VARS_KEY)
self.dropout_var = self.get_tensor(DROPOUT_PROB_VAR_NAME) self.is_training_var = self.get_tensor(IS_TRAINING_VAR_NAME)
self.wrong_var = self.get_tensor(self.wrong_var_name) self.wrong_var = self.get_tensor(self.wrong_var_name)
self.cost_var = self.get_tensor(self.cost_var_name) self.cost_var = self.get_tensor(self.cost_var_name)
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0] self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
...@@ -43,7 +43,7 @@ class ValidationError(PeriodicCallback): ...@@ -43,7 +43,7 @@ class ValidationError(PeriodicCallback):
err_stat = Accuracy() err_stat = Accuracy()
cost_sum = 0 cost_sum = 0
for dp in self.ds.get_data(): for dp in self.ds.get_data():
feed = {self.dropout_var: 1.0} feed = {self.is_training_var: False}
feed.update(dict(zip(self.input_vars, dp))) feed.update(dict(zip(self.input_vars, dp)))
batch_size = dp[0].shape[0] # assume batched input batch_size = dp[0].shape[0] # assume batched input
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment