Commit 9972b150 authored by Yuxin Wu's avatar Yuxin Wu

(fix #685)

parent 760b924a
......@@ -33,7 +33,7 @@ param.seq_len = 50
param.grad_clip = 5.
param.vocab_size = None
param.softmax_temprature = 1
param.corpus = 'input.txt'
param.corpus = None
class CharRNNData(RNGDataFlow):
......@@ -182,6 +182,7 @@ if __name__ == '__main__':
parser_sample.add_argument('-t', '--temperature', type=float,
default=1, help='softmax temperature')
parser_train = subparsers.add_parser('train', help='train')
parser_train.add_argument('--corpus', help='corpus file', default='input.txt')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
......@@ -192,6 +193,7 @@ if __name__ == '__main__':
sample(args.load, args.start, args.num)
sys.exit()
else:
param.corpus = args.corpus
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
......
......@@ -156,6 +156,7 @@ class InferenceRunner(InferenceRunnerBase):
for inf in self.infs:
inf.before_epoch()
self._input_source.reset_state()
# iterate over the data, and run the hooked session
with _inference_context(), \
tqdm.tqdm(total=self._size, **get_tqdm_kwargs()) as pbar:
......@@ -265,6 +266,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
total = self._size
nr_tower = len(self._gpus)
self._input_source.reset_state()
with _inference_context():
with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar:
while total >= nr_tower:
......
......@@ -196,8 +196,10 @@ class QueueInput(FeedfreeInput):
self._dequeue_op = self.queue.dequeue(name='dequeue_for_reset')
def _reset_state(self):
if self._started: # do not try to clear the queue if there is nothing
def refill_queue(self):
"""
Clear the queue, then call dataflow.get_data() again and fill into the queue.
"""
self.thread.pause() # pause enqueue
opt = tf.RunOptions()
......@@ -213,8 +215,6 @@ class QueueInput(FeedfreeInput):
# reset dataflow, start thread
self.thread.reinitialize_dataflow()
self.thread.resume()
else:
self._started = True
def _create_ema_callback(self):
"""
......
......@@ -137,7 +137,7 @@ class InputSource(object):
Initialize/reinitialize this InputSource.
Must be called under a default session.
For training, it will get called by the trainer in `before_train` callbacks.
For training, it will get called once by the trainer in `before_train` callbacks.
For inference, the :class:`InferenceRunner` will call this method each time it is triggered.
"""
self._reset_state()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment