Commit 8efd12b1 authored by Yuxin Wu's avatar Yuxin Wu

add progress bar

parent 26edfabe
......@@ -140,7 +140,7 @@ def get_config():
get_model_func=get_model,
batched_model_input=False,
step_per_epoch=step_per_epoch,
max_epoch=100,
max_epoch=500,
)
if __name__ == '__main__':
......
......@@ -2,3 +2,4 @@ pip @ https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.6.0-cp27-
termcolor
numpy
protobuf~=3.0.0a1
tqdm
......@@ -7,6 +7,7 @@ import tensorflow as tf
from itertools import count
import argparse
import tqdm
from utils import *
from utils.concurrency import EnqueueThread,coordinator_guard
from utils.callback import Callbacks
......@@ -134,7 +135,8 @@ def start_train(config):
callbacks.before_train()
for epoch in xrange(1, config.max_epoch):
with timed_operation('epoch {}'.format(epoch)):
for step in xrange(config.step_per_epoch):
for step in tqdm.trange(
config.step_per_epoch, leave=True, mininterval=0.2):
if coord.should_stop():
return
# TODO if no one uses trigger_step, train_op can be
......
......@@ -4,6 +4,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from tqdm import tqdm
from .stat import *
from .callback import PeriodicCallback, Callback
from .naming import *
......@@ -42,6 +44,7 @@ class ValidationError(PeriodicCallback):
cnt = 0
err_stat = Accuracy()
cost_sum = 0
with tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data():
feed = dict(zip(self.input_vars, dp))
......@@ -53,6 +56,7 @@ class ValidationError(PeriodicCallback):
err_stat.feed(wrong, batch_size)
# each batch might not have the same size in validation
cost_sum += cost * batch_size
pbar.update()
cost_avg = cost_sum / cnt
self.writer.add_summary(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment