Commit a812979a authored by Yuxin Wu's avatar Yuxin Wu

Delay stagingarea initialize to the first iteration. (fix #461)

parent ccef4d4f
...@@ -414,13 +414,18 @@ class StagingInput(FeedfreeInput): ...@@ -414,13 +414,18 @@ class StagingInput(FeedfreeInput):
self.stage_op = stage_op self.stage_op = stage_op
self.fetches = tf.train.SessionRunArgs( self.fetches = tf.train.SessionRunArgs(
fetches=[stage_op, unstage_op]) fetches=[stage_op, unstage_op])
self._initialized = False
def _before_train(self): def _prefill(self):
logger.info("Pre-filling staging area ...") logger.info("Pre-filling staging area ...")
for k in range(self.nr_stage): for k in range(self.nr_stage):
self.stage_op.run() self.stage_op.run()
def _before_run(self, ctx): def _before_run(self, ctx):
# This has to happen once, right before the first iteration.
if not self._initialized:
self._initialized = True
self._prefill()
return self.fetches return self.fetches
def __init__(self, input, towers, nr_stage=5): def __init__(self, input, towers, nr_stage=5):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment