Commit 20becf84 authored by ppwwyyxx's avatar ppwwyyxx

rollback

parent dd1ac6b0
...@@ -47,14 +47,14 @@ def get_model(inputs, is_training): ...@@ -47,14 +47,14 @@ def get_model(inputs, is_training):
image, label = inputs image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel image = tf.expand_dims(image, 3) # add a single channel
if is_training: #if is_training: # slow
# augmentations ## augmentations
image, label = tf.train.slice_input_producer( #image, label = tf.train.slice_input_producer(
[image, label], name='slice_queue') #[image, label], name='slice_queue')
image = tf.image.random_brightness(image, 0.1) #image = tf.image.random_brightness(image, 0.1)
image, label = tf.train.shuffle_batch( #image, label = tf.train.shuffle_batch(
[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE, #[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
num_threads=2, enqueue_many=False) #num_threads=2, enqueue_many=False)
conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5) conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
pool0 = MaxPooling('pool0', conv0, 2) pool0 = MaxPooling('pool0', conv0, 2)
...@@ -100,9 +100,9 @@ def get_config(): ...@@ -100,9 +100,9 @@ def get_config():
IMAGE_SIZE = 28 IMAGE_SIZE = 28
dataset_train = Mnist('train') dataset_train = BatchData(Mnist('train'), 128)
dataset_test = BatchData(Mnist('test'), 256, remainder=True) dataset_test = BatchData(Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() / BATCH_SIZE step_per_epoch = dataset_train.size()
#step_per_epoch = 20 #step_per_epoch = 20
#dataset_test = FixedSizeData(dataset_test, 20) #dataset_test = FixedSizeData(dataset_test, 20)
......
...@@ -39,10 +39,7 @@ class EnqueueThread(threading.Thread): ...@@ -39,10 +39,7 @@ class EnqueueThread(threading.Thread):
for dp in self.dataflow.get_data(): for dp in self.dataflow.get_data():
if self.coord.should_stop(): if self.coord.should_stop():
return return
feed = {} feed = dict(izip(self.input_vars, dp))
for var, data in izip(self.input_vars, dp):
data = expand_dim_if_necessary(var, data)
feed[var] = data
self.sess.run([self.op], feed_dict=feed) self.sess.run([self.op], feed_dict=feed)
except tf.errors.CancelledError as e: except tf.errors.CancelledError as e:
pass pass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment