Commit 5854c7de authored by Yuxin Wu's avatar Yuxin Wu

misc small fixes for old TF version (#810)

parent 2a712afe
...@@ -53,7 +53,7 @@ class ILSVRCMeta(object): ...@@ -53,7 +53,7 @@ class ILSVRCMeta(object):
return dict(enumerate(lines)) return dict(enumerate(lines))
def _download_caffe_meta(self): def _download_caffe_meta(self):
fpath = download(CAFFE_ILSVRC12_URL, self.dir) fpath = download(CAFFE_ILSVRC12_URL, self.dir, expect_size=17858008)
tarfile.open(fpath, 'r:gz').extractall(self.dir) tarfile.open(fpath, 'r:gz').extractall(self.dir)
def get_image_list(self, name, dir_structure='original'): def get_image_list(self, name, dir_structure='original'):
......
...@@ -504,7 +504,7 @@ class StagingInput(FeedfreeInput): ...@@ -504,7 +504,7 @@ class StagingInput(FeedfreeInput):
def _setup_graph(self): def _setup_graph(self):
self.stage_op = self._input._get_stage_op() self.stage_op = self._input._get_stage_op()
unstage_ops = self._input._get_unstage_ops() unstage_ops = self._input._get_unstage_ops()
unstage_op = tf.group(unstage_ops, name='unstage_all') unstage_op = tf.group(*unstage_ops, name='unstage_all')
self._check_dependency_op = unstage_ops[0] self._check_dependency_op = unstage_ops[0]
self.fetches = tf.train.SessionRunArgs( self.fetches = tf.train.SessionRunArgs(
fetches=[self.stage_op, unstage_op]) fetches=[self.stage_op, unstage_op])
......
...@@ -166,7 +166,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -166,7 +166,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
center=center, scale=scale, center=center, scale=scale,
beta_initializer=beta_initializer, beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer, gamma_initializer=gamma_initializer,
fused=True, fused=(ndims == 4 and axis in [1, 3]),
_reuse=tf.get_variable_scope().reuse) _reuse=tf.get_variable_scope().reuse)
if TF_version >= 1.5: if TF_version >= 1.5:
tf_args['virtual_batch_size'] = virtual_batch_size tf_args['virtual_batch_size'] = virtual_batch_size
......
...@@ -28,7 +28,7 @@ def mkdir_p(dirname): ...@@ -28,7 +28,7 @@ def mkdir_p(dirname):
raise e raise e
def download(url, dir, filename=None): def download(url, dir, filename=None, expect_size=None):
""" """
Download URL to a directory. Download URL to a directory.
Will figure out the filename automatically from URL, if not given. Will figure out the filename automatically from URL, if not given.
...@@ -55,7 +55,12 @@ def download(url, dir, filename=None): ...@@ -55,7 +55,12 @@ def download(url, dir, filename=None):
except IOError: except IOError:
logger.error("Failed to download {}".format(url)) logger.error("Failed to download {}".format(url))
raise raise
assert size > 0, "Download an empty file!" assert size > 0, "Downloaded an empty file from {}!".format(url)
if expect_size is not None and size != expect_size:
logger.error("File downloaded from {} does not match the expected size!".format(url))
logger.error("You may have downloaded a broken file, or the upstream may have modified the file.")
# TODO human-readable size # TODO human-readable size
print('Succesfully downloaded ' + filename + ". " + str(size) + ' bytes.') print('Succesfully downloaded ' + filename + ". " + str(size) + ' bytes.')
return fpath return fpath
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment