Commit cdd71bfe authored by Yuxin Wu's avatar Yuxin Wu

fix get_input_queue

parent 0a012166
......@@ -33,6 +33,7 @@ class Model(ModelDesc):
def _get_cost(self, input_vars, is_training):
image, label = input_vars
keep_prob = tf.constant(0.5 if is_training else 1.0)
if is_training:
image, label = tf.train.shuffle_batch(
......@@ -40,7 +41,7 @@ class Model(ModelDesc):
num_threads=6, enqueue_many=True)
tf.image_summary("train_image", image, 10)
l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3, padding='SAME')
l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3)
l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=tf.identity)
l = BatchNorm('bn1', l, is_training)
l = tf.nn.relu(l)
......@@ -56,8 +57,9 @@ class Model(ModelDesc):
l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=tf.identity)
l = BatchNorm('bn3', l, is_training)
l = tf.nn.relu(l)
l = FullyConnected('fc0', l, 512,
l = FullyConnected('fc0', l, 1024 + 512,
b_init=tf.constant_initializer(0.1))
l = tf.nn.dropout(l, keep_prob)
l = FullyConnected('fc1', l, out_dim=512,
b_init=tf.constant_initializer(0.1))
# fc will have activation summary by default. disable for the output layer
......@@ -120,13 +122,13 @@ def get_config():
lr = tf.train.exponential_decay(
learning_rate=1e-2,
global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 30,
decay_rate=0.5, staircase=True, name='learning_rate')
decay_steps=dataset_train.size() * 40,
decay_rate=0.4, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
return TrainConfig(
dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr),
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
......@@ -135,7 +137,7 @@ def get_config():
session_config=sess_config,
model=Model(),
step_per_epoch=step_per_epoch,
max_epoch=500,
max_epoch=300,
)
if __name__ == '__main__':
......
......@@ -40,7 +40,7 @@ class StatHolder(object):
def _print_stat(self):
for k, v in sorted(self.stat_now.items(), key=operator.itemgetter(0)):
if self.print_tag is None or k in self.print_tag:
logger.info('{}: {:.4f}'.format(k, v))
logger.info('{}: {:.5f}'.format(k, v))
def _write_stat(self):
tmp_filename = self.filename + '.tmp'
......
......@@ -93,7 +93,7 @@ class QueueInputTrainer(Trainer):
def train(self):
model = self.model
input_vars = model.get_input_vars()
input_queue = model.get_input_queue()
input_queue = model.get_input_queue(input_vars)
enqueue_op = input_queue.enqueue(input_vars)
def get_model_inputs():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment