Commit d04661e3 authored by Yuxin Wu's avatar Yuxin Wu

fix prefetch bug

parent b81c2263
...@@ -16,17 +16,12 @@ from tensorpack.tfutils import * ...@@ -16,17 +16,12 @@ from tensorpack.tfutils import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug
""" """
A small cifar10 convnet model. A small cifar10 convnet model.
90% validation accuracy after 40k step. 90% validation accuracy after 40k step.
""" """
BATCH_SIZE = 128
MIN_AFTER_DEQUEUE = int(50000 * 0.4)
CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
class Model(ModelDesc): class Model(ModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
return [InputVar(tf.float32, [None, 30, 30, 3], 'input'), return [InputVar(tf.float32, [None, 30, 30, 3], 'input'),
...@@ -134,7 +129,7 @@ def get_config(): ...@@ -134,7 +129,7 @@ def get_config():
session_config=sess_config, session_config=sess_config,
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=200, max_epoch=3,
) )
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -16,7 +16,6 @@ from tensorpack.tfutils import * ...@@ -16,7 +16,6 @@ from tensorpack.tfutils import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug
""" """
CIFAR10-resnet example. CIFAR10-resnet example.
...@@ -45,7 +44,7 @@ class Model(ModelDesc): ...@@ -45,7 +44,7 @@ class Model(ModelDesc):
def _get_cost(self, input_vars, is_training): def _get_cost(self, input_vars, is_training):
image, label = input_vars image, label = input_vars
image = image / 255.0 image = image / 128.0 - 1
def conv(name, l, channel, stride): def conv(name, l, channel, stride):
return Conv2D(name, l, channel, 3, stride=stride, return Conv2D(name, l, channel, 3, stride=stride,
...@@ -117,10 +116,10 @@ class Model(ModelDesc): ...@@ -117,10 +116,10 @@ class Model(ModelDesc):
# weight decay on all W of fc layers # weight decay on all W of fc layers
wd_w = tf.train.exponential_decay(0.0002, get_global_step_var(), wd_w = tf.train.exponential_decay(0.0002, get_global_step_var(),
480000, 0.2, True) 480000, 0.2, True)
wd_cost = wd_w * regularize_cost('.*/W', tf.nn.l2_loss) wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary([('.*/W', ['histogram', 'sparsity'])]) # monitor W add_param_summary([('.*/W', ['histogram'])]) # monitor W
return tf.add_n([cost, wd_cost], name='cost') return tf.add_n([cost, wd_cost], name='cost')
def get_data(train_or_test): def get_data(train_or_test):
...@@ -146,8 +145,6 @@ def get_data(train_or_test): ...@@ -146,8 +145,6 @@ def get_data(train_or_test):
ds = PrefetchData(ds, 3, 2) ds = PrefetchData(ds, 3, 2)
return ds return ds
def get_config(): def get_config():
# prepare dataset # prepare dataset
dataset_train = get_data('train') dataset_train = get_data('train')
...@@ -170,7 +167,7 @@ def get_config(): ...@@ -170,7 +167,7 @@ def get_config():
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)]) [(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
]), ]),
session_config=sess_config, session_config=sess_config,
model=Model(n=18), model=Model(n=30),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=500, max_epoch=500,
) )
......
...@@ -9,9 +9,6 @@ from ..utils.concurrency import ensure_procs_terminate ...@@ -9,9 +9,6 @@ from ..utils.concurrency import ensure_procs_terminate
__all__ = ['PrefetchData'] __all__ = ['PrefetchData']
class Sentinel:
pass
class PrefetchProcess(multiprocessing.Process): class PrefetchProcess(multiprocessing.Process):
def __init__(self, ds, queue): def __init__(self, ds, queue):
""" """
...@@ -24,11 +21,9 @@ class PrefetchProcess(multiprocessing.Process): ...@@ -24,11 +21,9 @@ class PrefetchProcess(multiprocessing.Process):
def run(self): def run(self):
self.ds.reset_state() self.ds.reset_state()
try: while True:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
self.queue.put(dp) self.queue.put(dp)
finally:
self.queue.put(Sentinel())
class PrefetchData(ProxyDataFlow): class PrefetchData(ProxyDataFlow):
""" """
...@@ -52,17 +47,11 @@ class PrefetchData(ProxyDataFlow): ...@@ -52,17 +47,11 @@ class PrefetchData(ProxyDataFlow):
x.start() x.start()
def get_data(self): def get_data(self):
end_cnt = 0
tot_cnt = 0 tot_cnt = 0
while True: while True:
dp = self.queue.get() dp = self.queue.get()
if isinstance(dp, Sentinel):
end_cnt += 1
if end_cnt == self.nr_proc:
break
continue
tot_cnt += 1
yield dp yield dp
tot_cnt += 1
if tot_cnt == self._size: if tot_cnt == self._size:
break break
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment