Commit cdd71bfe authored by Yuxin Wu's avatar Yuxin Wu

fix get_input_queue

parent 0a012166
...@@ -33,6 +33,7 @@ class Model(ModelDesc): ...@@ -33,6 +33,7 @@ class Model(ModelDesc):
def _get_cost(self, input_vars, is_training): def _get_cost(self, input_vars, is_training):
image, label = input_vars image, label = input_vars
keep_prob = tf.constant(0.5 if is_training else 1.0)
if is_training: if is_training:
image, label = tf.train.shuffle_batch( image, label = tf.train.shuffle_batch(
...@@ -40,7 +41,7 @@ class Model(ModelDesc): ...@@ -40,7 +41,7 @@ class Model(ModelDesc):
num_threads=6, enqueue_many=True) num_threads=6, enqueue_many=True)
tf.image_summary("train_image", image, 10) tf.image_summary("train_image", image, 10)
l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3, padding='SAME') l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3)
l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=tf.identity) l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=tf.identity)
l = BatchNorm('bn1', l, is_training) l = BatchNorm('bn1', l, is_training)
l = tf.nn.relu(l) l = tf.nn.relu(l)
...@@ -56,8 +57,9 @@ class Model(ModelDesc): ...@@ -56,8 +57,9 @@ class Model(ModelDesc):
l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=tf.identity) l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=tf.identity)
l = BatchNorm('bn3', l, is_training) l = BatchNorm('bn3', l, is_training)
l = tf.nn.relu(l) l = tf.nn.relu(l)
l = FullyConnected('fc0', l, 512, l = FullyConnected('fc0', l, 1024 + 512,
b_init=tf.constant_initializer(0.1)) b_init=tf.constant_initializer(0.1))
l = tf.nn.dropout(l, keep_prob)
l = FullyConnected('fc1', l, out_dim=512, l = FullyConnected('fc1', l, out_dim=512,
b_init=tf.constant_initializer(0.1)) b_init=tf.constant_initializer(0.1))
# fc will have activation summary by default. disable for the output layer # fc will have activation summary by default. disable for the output layer
...@@ -120,13 +122,13 @@ def get_config(): ...@@ -120,13 +122,13 @@ def get_config():
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
learning_rate=1e-2, learning_rate=1e-2,
global_step=get_global_step_var(), global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 30, decay_steps=dataset_train.size() * 40,
decay_rate=0.5, staircase=True, name='learning_rate') decay_rate=0.4, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
PeriodicSaver(), PeriodicSaver(),
...@@ -135,7 +137,7 @@ def get_config(): ...@@ -135,7 +137,7 @@ def get_config():
session_config=sess_config, session_config=sess_config,
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=500, max_epoch=300,
) )
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -40,7 +40,7 @@ class StatHolder(object): ...@@ -40,7 +40,7 @@ class StatHolder(object):
def _print_stat(self): def _print_stat(self):
for k, v in sorted(self.stat_now.items(), key=operator.itemgetter(0)): for k, v in sorted(self.stat_now.items(), key=operator.itemgetter(0)):
if self.print_tag is None or k in self.print_tag: if self.print_tag is None or k in self.print_tag:
logger.info('{}: {:.4f}'.format(k, v)) logger.info('{}: {:.5f}'.format(k, v))
def _write_stat(self): def _write_stat(self):
tmp_filename = self.filename + '.tmp' tmp_filename = self.filename + '.tmp'
......
...@@ -93,7 +93,7 @@ class QueueInputTrainer(Trainer): ...@@ -93,7 +93,7 @@ class QueueInputTrainer(Trainer):
def train(self): def train(self):
model = self.model model = self.model
input_vars = model.get_input_vars() input_vars = model.get_input_vars()
input_queue = model.get_input_queue() input_queue = model.get_input_queue(input_vars)
enqueue_op = input_queue.enqueue(input_vars) enqueue_op = input_queue.enqueue(input_vars)
def get_model_inputs(): def get_model_inputs():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment