Commit c5de2ef9 authored by Yuxin Wu's avatar Yuxin Wu

download with tqdm

parent a349c558
...@@ -123,3 +123,6 @@ class ProgressBar(Callback): ...@@ -123,3 +123,6 @@ class ProgressBar(Callback):
self._bar.update() self._bar.update()
if self.trainer.local_step == self._total - 1: if self.trainer.local_step == self._total - 1:
self._bar.close() self._bar.close()
def _after_train(self):
self._bar.close()
...@@ -101,8 +101,12 @@ class EnqueueThread(ShareSessionThread): ...@@ -101,8 +101,12 @@ class EnqueueThread(ShareSessionThread):
feed = dict(zip(self.placehdrs, dp)) feed = dict(zip(self.placehdrs, dp))
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1] # print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed) self.op.run(feed_dict=feed)
except tf.errors.CancelledError: except (tf.errors.CancelledError, tf.errors.OutOfRangeError):
pass try:
self.close_op.run()
except Exception:
pass
return
except Exception: except Exception:
logger.exception("Exception in EnqueueThread:") logger.exception("Exception in EnqueueThread:")
finally: finally:
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os import os
import sys
from six.moves import urllib from six.moves import urllib
import errno import errno
import tqdm
from . import logger from . import logger
from .utils import execute_only_once from .utils import execute_only_once
...@@ -31,29 +31,32 @@ def mkdir_p(dirname): ...@@ -31,29 +31,32 @@ def mkdir_p(dirname):
def download(url, dir, filename=None): def download(url, dir, filename=None):
""" """
Download URL to a directory. Will figure out the filename automatically Download URL to a directory.
from URL. Will figure out the filename automatically from URL, if not given.
""" """
mkdir_p(dir) mkdir_p(dir)
if filename is None: if filename is None:
filename = url.split('/')[-1] filename = url.split('/')[-1]
fpath = os.path.join(dir, filename) fpath = os.path.join(dir, filename)
def _progress(count, block_size, total_size): def hook(t):
sys.stdout.write('\r>> Downloading %s %.1f%%' % last_b = [0]
(filename,
min(float(count * block_size) / total_size, def inner(b, bsize, tsize=None):
1.0) * 100.0)) if tsize is not None:
sys.stdout.flush() t.total = tsize
t.update((b - last_b[0]) * bsize)
last_b[0] = b
return inner
try: try:
fpath, _ = urllib.request.urlretrieve(url, fpath, reporthook=_progress) with tqdm.tqdm(unit='B', unit_scale=True, miniters=1, desc=filename) as t:
fpath, _ = urllib.request.urlretrieve(url, fpath, reporthook=hook(t))
statinfo = os.stat(fpath) statinfo = os.stat(fpath)
size = statinfo.st_size size = statinfo.st_size
except: except:
logger.error("Failed to download {}".format(url)) logger.error("Failed to download {}".format(url))
raise raise
assert size > 0, "Download an empty file!" assert size > 0, "Download an empty file!"
sys.stdout.write('\n')
# TODO human-readable size # TODO human-readable size
print('Succesfully downloaded ' + filename + " " + str(size) + ' bytes.') print('Succesfully downloaded ' + filename + " " + str(size) + ' bytes.')
return fpath return fpath
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment