Commit 8efd12b1 authored by Yuxin Wu's avatar Yuxin Wu

add progress bar

parent 26edfabe
...@@ -140,7 +140,7 @@ def get_config(): ...@@ -140,7 +140,7 @@ def get_config():
get_model_func=get_model, get_model_func=get_model,
batched_model_input=False, batched_model_input=False,
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=100, max_epoch=500,
) )
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -2,3 +2,4 @@ pip @ https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.6.0-cp27- ...@@ -2,3 +2,4 @@ pip @ https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.6.0-cp27-
termcolor termcolor
numpy numpy
protobuf~=3.0.0a1 protobuf~=3.0.0a1
tqdm
...@@ -7,6 +7,7 @@ import tensorflow as tf ...@@ -7,6 +7,7 @@ import tensorflow as tf
from itertools import count from itertools import count
import argparse import argparse
import tqdm
from utils import * from utils import *
from utils.concurrency import EnqueueThread,coordinator_guard from utils.concurrency import EnqueueThread,coordinator_guard
from utils.callback import Callbacks from utils.callback import Callbacks
...@@ -134,7 +135,8 @@ def start_train(config): ...@@ -134,7 +135,8 @@ def start_train(config):
callbacks.before_train() callbacks.before_train()
for epoch in xrange(1, config.max_epoch): for epoch in xrange(1, config.max_epoch):
with timed_operation('epoch {}'.format(epoch)): with timed_operation('epoch {}'.format(epoch)):
for step in xrange(config.step_per_epoch): for step in tqdm.trange(
config.step_per_epoch, leave=True, mininterval=0.2):
if coord.should_stop(): if coord.should_stop():
return return
# TODO if no one uses trigger_step, train_op can be # TODO if no one uses trigger_step, train_op can be
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from tqdm import tqdm
from .stat import * from .stat import *
from .callback import PeriodicCallback, Callback from .callback import PeriodicCallback, Callback
from .naming import * from .naming import *
...@@ -42,17 +44,19 @@ class ValidationError(PeriodicCallback): ...@@ -42,17 +44,19 @@ class ValidationError(PeriodicCallback):
cnt = 0 cnt = 0
err_stat = Accuracy() err_stat = Accuracy()
cost_sum = 0 cost_sum = 0
for dp in self.ds.get_data(): with tqdm(total=self.ds.size()) as pbar:
feed = dict(zip(self.input_vars, dp)) for dp in self.ds.get_data():
feed = dict(zip(self.input_vars, dp))
batch_size = dp[0].shape[0] # assume batched input batch_size = dp[0].shape[0] # assume batched input
cnt += batch_size cnt += batch_size
wrong, cost = self.sess.run( wrong, cost = self.sess.run(
[self.wrong_var, self.cost_var], feed_dict=feed) [self.wrong_var, self.cost_var], feed_dict=feed)
err_stat.feed(wrong, batch_size) err_stat.feed(wrong, batch_size)
# each batch might not have the same size in validation # each batch might not have the same size in validation
cost_sum += cost * batch_size cost_sum += cost * batch_size
pbar.update()
cost_avg = cost_sum / cnt cost_avg = cost_sum / cnt
self.writer.add_summary( self.writer.add_summary(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment