Commit f1fc7337 authored by Yuxin Wu's avatar Yuxin Wu

fix debugging bug in expreplay

parent e5a48033
...@@ -72,9 +72,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -72,9 +72,7 @@ class ExpReplay(DataFlow, Callback):
with tqdm(total=self.init_memory_size) as pbar: with tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size: while len(self.mem) < self.init_memory_size:
from copy import deepcopy self._populate_exp()
self.mem.append(deepcopy(self.mem[0]))
#self._populate_exp()
pbar.update() pbar.update()
self._init_memory_flag.set() self._init_memory_flag.set()
......
...@@ -64,7 +64,6 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -64,7 +64,6 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
G = tf.get_default_graph() G = tf.get_default_graph()
try: try:
mean_name = re.sub('towerp[0-9]+/', '', ema_mean.name) mean_name = re.sub('towerp[0-9]+/', '', ema_mean.name)
......
...@@ -90,9 +90,9 @@ class EnqueueThread(threading.Thread): ...@@ -90,9 +90,9 @@ class EnqueueThread(threading.Thread):
pass pass
except Exception: except Exception:
logger.exception("Exception in EnqueueThread:") logger.exception("Exception in EnqueueThread:")
finally:
self.sess.run(self.close_op) self.sess.run(self.close_op)
self.coord.request_stop() self.coord.request_stop()
finally:
logger.info("Enqueue Thread Exited.") logger.info("Enqueue Thread Exited.")
class QueueInputTrainer(Trainer): class QueueInputTrainer(Trainer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment