Commit 5683e4d9 authored by Yuxin Wu's avatar Yuxin Wu

tqdm & humanparam

parent 67dd4887
...@@ -107,7 +107,7 @@ class InferenceRunner(Callback): ...@@ -107,7 +107,7 @@ class InferenceRunner(Callback):
sess = tf.get_default_session() sess = tf.get_default_session()
self.ds.reset_state() self.ds.reset_state()
with tqdm(total=self.ds.size(), ascii=True) as pbar: with tqdm(total=self.ds.size(), **get_tqdm_kwargs()) as pbar:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
#feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping? #feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
#outputs = sess.run(self.output_tensors, feed_dict=feed) #outputs = sess.run(self.output_tensors, feed_dict=feed)
......
...@@ -122,7 +122,7 @@ class HumanHyperParamSetter(HyperParamSetter): ...@@ -122,7 +122,7 @@ class HumanHyperParamSetter(HyperParamSetter):
""" """
Set hyperparameters by loading the value from a file each time it get called. Set hyperparameters by loading the value from a file each time it get called.
""" """
def __init__(self, param, file_name): def __init__(self, param, file_name='hyper.txt'):
""" """
:param file_name: a file containing the value of the variable. :param file_name: a file containing the value of the variable.
Each line in the file is a k:v pair, where k is Each line in the file is a k:v pair, where k is
......
...@@ -144,12 +144,12 @@ class ILSVRC12(RNGDataFlow): ...@@ -144,12 +144,12 @@ class ILSVRC12(RNGDataFlow):
Produce original images or shape [h, w, 3], and label Produce original images or shape [h, w, 3], and label
""" """
idxs = np.arange(len(self.imglist)) idxs = np.arange(len(self.imglist))
isTrain = self.name == 'train' add_label_to_fname = (self.name != 'train' and self.dir_structure != 'original')
if self.shuffle: if self.shuffle:
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
fname, label = self.imglist[k] fname, label = self.imglist[k]
if not isTrain and self.dir_structure != 'original': if add_label_to_fname:
fname = os.path.join(self.full_dir, self.synset[label], fname) fname = os.path.join(self.full_dir, self.synset[label], fname)
else: else:
fname = os.path.join(self.full_dir, fname) fname = os.path.join(self.full_dir, fname)
......
...@@ -107,7 +107,6 @@ class CaffeLMDB(LMDBData): ...@@ -107,7 +107,6 @@ class CaffeLMDB(LMDBData):
:param shuffle: about 3 times slower :param shuffle: about 3 times slower
""" """
super(CaffeLMDB, self).__init__(lmdb_dir, shuffle) super(CaffeLMDB, self).__init__(lmdb_dir, shuffle)
self.cpb = get_caffe_pb() self.cpb = get_caffe_pb()
def get_data(self): def get_data(self):
......
...@@ -98,7 +98,7 @@ class Gamma(ImageAugmentor): ...@@ -98,7 +98,7 @@ class Gamma(ImageAugmentor):
return self._rand_range(*self.range) return self._rand_range(*self.range)
def _augment(self, img, gamma): def _augment(self, img, gamma):
lut = ((np.arange(256, dtype='float32') / 255) ** (1. / (1. + gamma)) * 255).astype('uint8') lut = ((np.arange(256, dtype='float32') / 255) ** (1. / (1. + gamma)) * 255).astype('uint8')
img = (img * 255.0).astype('uint8') img = img.astype('uint8')
img = cv2.LUT(img, lut).astype('float32') / 255.0 img = cv2.LUT(img, lut).astype('float32')
return img return img
...@@ -126,11 +126,7 @@ class Trainer(object): ...@@ -126,11 +126,7 @@ class Trainer(object):
epoch, self.global_step + self.config.step_per_epoch)): epoch, self.global_step + self.config.step_per_epoch)):
for step in tqdm.trange( for step in tqdm.trange(
self.config.step_per_epoch, self.config.step_per_epoch,
leave=True, mininterval=0.5, **get_tqdm_kwargs(leave=True)):
smoothing=0.5,
dynamic_ncols=True,
ascii=True):
#bar_format='{l_bar}{bar}|{n_fmt}/{total_fmt} [{elapsed}<{remaining},{rate_noinv_fmt}]'):
if self.coord.should_stop(): if self.coord.should_stop():
return return
self.run_step() self.run_step()
......
...@@ -17,7 +17,9 @@ __all__ = ['change_env', 'map_arg', ...@@ -17,7 +17,9 @@ __all__ = ['change_env', 'map_arg',
'get_rng', 'memoized', 'get_rng', 'memoized',
'get_nr_gpu', 'get_nr_gpu',
'get_gpus', 'get_gpus',
'get_dataset_dir'] 'get_dataset_dir',
'get_tqdm_kwargs'
]
#def expand_dim_if_necessary(var, dp): #def expand_dim_if_necessary(var, dp):
# """ # """
...@@ -113,3 +115,18 @@ def get_dataset_dir(*args): ...@@ -113,3 +115,18 @@ def get_dataset_dir(*args):
assert os.path.isdir(d), d assert os.path.isdir(d), d
return os.path.join(d, *args) return os.path.join(d, *args)
def get_tqdm_kwargs(**kwargs):
default = dict(
smoothing=0.5,
dynamic_ncols=True,
ascii=True,
bar_format='{l_bar}{bar}|{n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_noinv_fmt}]'
)
f = kwargs.get('file', sys.stderr)
if f.isatty():
default['mininterval'] = 0.5
else:
default['mininterval'] = 60
default.update(kwargs)
return default
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment