Commit d04661e3 authored by Yuxin Wu's avatar Yuxin Wu

fix prefetch bug

parent b81c2263
......@@ -16,17 +16,12 @@ from tensorpack.tfutils import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug
"""
A small cifar10 convnet model.
90% validation accuracy after 40k step.
"""
BATCH_SIZE = 128
MIN_AFTER_DEQUEUE = int(50000 * 0.4)
CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
class Model(ModelDesc):
def _get_input_vars(self):
return [InputVar(tf.float32, [None, 30, 30, 3], 'input'),
......@@ -134,7 +129,7 @@ def get_config():
session_config=sess_config,
model=Model(),
step_per_epoch=step_per_epoch,
max_epoch=200,
max_epoch=3,
)
if __name__ == '__main__':
......
......@@ -16,7 +16,6 @@ from tensorpack.tfutils import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug
"""
CIFAR10-resnet example.
......@@ -45,7 +44,7 @@ class Model(ModelDesc):
def _get_cost(self, input_vars, is_training):
image, label = input_vars
image = image / 255.0
image = image / 128.0 - 1
def conv(name, l, channel, stride):
return Conv2D(name, l, channel, 3, stride=stride,
......@@ -117,10 +116,10 @@ class Model(ModelDesc):
# weight decay on all W of fc layers
wd_w = tf.train.exponential_decay(0.0002, get_global_step_var(),
480000, 0.2, True)
wd_cost = wd_w * regularize_cost('.*/W', tf.nn.l2_loss)
wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary([('.*/W', ['histogram', 'sparsity'])]) # monitor W
add_param_summary([('.*/W', ['histogram'])]) # monitor W
return tf.add_n([cost, wd_cost], name='cost')
def get_data(train_or_test):
......@@ -146,8 +145,6 @@ def get_data(train_or_test):
ds = PrefetchData(ds, 3, 2)
return ds
def get_config():
# prepare dataset
dataset_train = get_data('train')
......@@ -170,7 +167,7 @@ def get_config():
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
]),
session_config=sess_config,
model=Model(n=18),
model=Model(n=30),
step_per_epoch=step_per_epoch,
max_epoch=500,
)
......
......@@ -9,9 +9,6 @@ from ..utils.concurrency import ensure_procs_terminate
__all__ = ['PrefetchData']
class Sentinel:
pass
class PrefetchProcess(multiprocessing.Process):
def __init__(self, ds, queue):
"""
......@@ -24,11 +21,9 @@ class PrefetchProcess(multiprocessing.Process):
def run(self):
self.ds.reset_state()
try:
while True:
for dp in self.ds.get_data():
self.queue.put(dp)
finally:
self.queue.put(Sentinel())
class PrefetchData(ProxyDataFlow):
"""
......@@ -52,17 +47,11 @@ class PrefetchData(ProxyDataFlow):
x.start()
def get_data(self):
end_cnt = 0
tot_cnt = 0
while True:
dp = self.queue.get()
if isinstance(dp, Sentinel):
end_cnt += 1
if end_cnt == self.nr_proc:
break
continue
tot_cnt += 1
yield dp
tot_cnt += 1
if tot_cnt == self._size:
break
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment