Commit 9972b150 authored by Yuxin Wu's avatar Yuxin Wu

(fix #685)

parent 760b924a
...@@ -33,7 +33,7 @@ param.seq_len = 50 ...@@ -33,7 +33,7 @@ param.seq_len = 50
param.grad_clip = 5. param.grad_clip = 5.
param.vocab_size = None param.vocab_size = None
param.softmax_temprature = 1 param.softmax_temprature = 1
param.corpus = 'input.txt' param.corpus = None
class CharRNNData(RNGDataFlow): class CharRNNData(RNGDataFlow):
...@@ -182,6 +182,7 @@ if __name__ == '__main__': ...@@ -182,6 +182,7 @@ if __name__ == '__main__':
parser_sample.add_argument('-t', '--temperature', type=float, parser_sample.add_argument('-t', '--temperature', type=float,
default=1, help='softmax temperature') default=1, help='softmax temperature')
parser_train = subparsers.add_parser('train', help='train') parser_train = subparsers.add_parser('train', help='train')
parser_train.add_argument('--corpus', help='corpus file', default='input.txt')
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
...@@ -192,6 +193,7 @@ if __name__ == '__main__': ...@@ -192,6 +193,7 @@ if __name__ == '__main__':
sample(args.load, args.start, args.num) sample(args.load, args.start, args.num)
sys.exit() sys.exit()
else: else:
param.corpus = args.corpus
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
......
...@@ -156,6 +156,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -156,6 +156,7 @@ class InferenceRunner(InferenceRunnerBase):
for inf in self.infs: for inf in self.infs:
inf.before_epoch() inf.before_epoch()
self._input_source.reset_state()
# iterate over the data, and run the hooked session # iterate over the data, and run the hooked session
with _inference_context(), \ with _inference_context(), \
tqdm.tqdm(total=self._size, **get_tqdm_kwargs()) as pbar: tqdm.tqdm(total=self._size, **get_tqdm_kwargs()) as pbar:
...@@ -265,6 +266,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -265,6 +266,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
total = self._size total = self._size
nr_tower = len(self._gpus) nr_tower = len(self._gpus)
self._input_source.reset_state()
with _inference_context(): with _inference_context():
with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar: with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar:
while total >= nr_tower: while total >= nr_tower:
......
...@@ -196,25 +196,25 @@ class QueueInput(FeedfreeInput): ...@@ -196,25 +196,25 @@ class QueueInput(FeedfreeInput):
self._dequeue_op = self.queue.dequeue(name='dequeue_for_reset') self._dequeue_op = self.queue.dequeue(name='dequeue_for_reset')
def _reset_state(self): def refill_queue(self):
if self._started: # do not try to clear the queue if there is nothing """
self.thread.pause() # pause enqueue Clear the queue, then call dataflow.get_data() again and fill into the queue.
"""
opt = tf.RunOptions() self.thread.pause() # pause enqueue
opt.timeout_in_ms = 2000 # 2s
sess = tf.get_default_session() opt = tf.RunOptions()
# dequeue until empty opt.timeout_in_ms = 2000 # 2s
try: sess = tf.get_default_session()
while True: # dequeue until empty
sess.run(self._dequeue_op, options=opt) try:
except tf.errors.DeadlineExceededError: while True:
pass sess.run(self._dequeue_op, options=opt)
except tf.errors.DeadlineExceededError:
# reset dataflow, start thread pass
self.thread.reinitialize_dataflow()
self.thread.resume() # reset dataflow, start thread
else: self.thread.reinitialize_dataflow()
self._started = True self.thread.resume()
def _create_ema_callback(self): def _create_ema_callback(self):
""" """
......
...@@ -137,7 +137,7 @@ class InputSource(object): ...@@ -137,7 +137,7 @@ class InputSource(object):
Initialize/reinitialize this InputSource. Initialize/reinitialize this InputSource.
Must be called under a default session. Must be called under a default session.
For training, it will get called by the trainer in `before_train` callbacks. For training, it will get called once by the trainer in `before_train` callbacks.
For inference, the :class:`InferenceRunner` will call this method each time it is triggered. For inference, the :class:`InferenceRunner` will call this method each time it is triggered.
""" """
self._reset_state() self._reset_state()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment