Commit 5683e4d9 authored by Yuxin Wu's avatar Yuxin Wu

tqdm & humanparam

parent 67dd4887
......@@ -107,7 +107,7 @@ class InferenceRunner(Callback):
sess = tf.get_default_session()
self.ds.reset_state()
with tqdm(total=self.ds.size(), ascii=True) as pbar:
with tqdm(total=self.ds.size(), **get_tqdm_kwargs()) as pbar:
for dp in self.ds.get_data():
#feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
#outputs = sess.run(self.output_tensors, feed_dict=feed)
......
......@@ -122,7 +122,7 @@ class HumanHyperParamSetter(HyperParamSetter):
"""
Set hyperparameters by loading the value from a file each time it get called.
"""
def __init__(self, param, file_name):
def __init__(self, param, file_name='hyper.txt'):
"""
:param file_name: a file containing the value of the variable.
Each line in the file is a k:v pair, where k is
......
......@@ -144,12 +144,12 @@ class ILSVRC12(RNGDataFlow):
Produce original images or shape [h, w, 3], and label
"""
idxs = np.arange(len(self.imglist))
isTrain = self.name == 'train'
add_label_to_fname = (self.name != 'train' and self.dir_structure != 'original')
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
fname, label = self.imglist[k]
if not isTrain and self.dir_structure != 'original':
if add_label_to_fname:
fname = os.path.join(self.full_dir, self.synset[label], fname)
else:
fname = os.path.join(self.full_dir, fname)
......
......@@ -107,7 +107,6 @@ class CaffeLMDB(LMDBData):
:param shuffle: about 3 times slower
"""
super(CaffeLMDB, self).__init__(lmdb_dir, shuffle)
self.cpb = get_caffe_pb()
def get_data(self):
......
......@@ -98,7 +98,7 @@ class Gamma(ImageAugmentor):
return self._rand_range(*self.range)
def _augment(self, img, gamma):
lut = ((np.arange(256, dtype='float32') / 255) ** (1. / (1. + gamma)) * 255).astype('uint8')
img = (img * 255.0).astype('uint8')
img = cv2.LUT(img, lut).astype('float32') / 255.0
img = img.astype('uint8')
img = cv2.LUT(img, lut).astype('float32')
return img
......@@ -126,11 +126,7 @@ class Trainer(object):
epoch, self.global_step + self.config.step_per_epoch)):
for step in tqdm.trange(
self.config.step_per_epoch,
leave=True, mininterval=0.5,
smoothing=0.5,
dynamic_ncols=True,
ascii=True):
#bar_format='{l_bar}{bar}|{n_fmt}/{total_fmt} [{elapsed}<{remaining},{rate_noinv_fmt}]'):
**get_tqdm_kwargs(leave=True)):
if self.coord.should_stop():
return
self.run_step()
......
......@@ -17,7 +17,9 @@ __all__ = ['change_env', 'map_arg',
'get_rng', 'memoized',
'get_nr_gpu',
'get_gpus',
'get_dataset_dir']
'get_dataset_dir',
'get_tqdm_kwargs'
]
#def expand_dim_if_necessary(var, dp):
# """
......@@ -113,3 +115,18 @@ def get_dataset_dir(*args):
assert os.path.isdir(d), d
return os.path.join(d, *args)
def get_tqdm_kwargs(**kwargs):
default = dict(
smoothing=0.5,
dynamic_ncols=True,
ascii=True,
bar_format='{l_bar}{bar}|{n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_noinv_fmt}]'
)
f = kwargs.get('file', sys.stderr)
if f.isatty():
default['mininterval'] = 0.5
else:
default['mininterval'] = 60
default.update(kwargs)
return default
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment