Commit c5de2ef9 authored by Yuxin Wu's avatar Yuxin Wu

download with tqdm

parent a349c558
......@@ -123,3 +123,6 @@ class ProgressBar(Callback):
self._bar.update()
if self.trainer.local_step == self._total - 1:
self._bar.close()
def _after_train(self):
self._bar.close()
......@@ -101,8 +101,12 @@ class EnqueueThread(ShareSessionThread):
feed = dict(zip(self.placehdrs, dp))
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed)
except tf.errors.CancelledError:
except (tf.errors.CancelledError, tf.errors.OutOfRangeError):
try:
self.close_op.run()
except Exception:
pass
return
except Exception:
logger.exception("Exception in EnqueueThread:")
finally:
......
......@@ -4,9 +4,9 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import sys
from six.moves import urllib
import errno
import tqdm
from . import logger
from .utils import execute_only_once
......@@ -31,29 +31,32 @@ def mkdir_p(dirname):
def download(url, dir, filename=None):
"""
Download URL to a directory. Will figure out the filename automatically
from URL.
Download URL to a directory.
Will figure out the filename automatically from URL, if not given.
"""
mkdir_p(dir)
if filename is None:
filename = url.split('/')[-1]
fpath = os.path.join(dir, filename)
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' %
(filename,
min(float(count * block_size) / total_size,
1.0) * 100.0))
sys.stdout.flush()
def hook(t):
last_b = [0]
def inner(b, bsize, tsize=None):
if tsize is not None:
t.total = tsize
t.update((b - last_b[0]) * bsize)
last_b[0] = b
return inner
try:
fpath, _ = urllib.request.urlretrieve(url, fpath, reporthook=_progress)
with tqdm.tqdm(unit='B', unit_scale=True, miniters=1, desc=filename) as t:
fpath, _ = urllib.request.urlretrieve(url, fpath, reporthook=hook(t))
statinfo = os.stat(fpath)
size = statinfo.st_size
except:
logger.error("Failed to download {}".format(url))
raise
assert size > 0, "Download an empty file!"
sys.stdout.write('\n')
# TODO human-readable size
print('Succesfully downloaded ' + filename + " " + str(size) + ' bytes.')
return fpath
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment