Commit bbaf8d12 authored by Yuxin Wu's avatar Yuxin Wu

use Variable.load to avoid assign ops

parent d97c0081
# Efficient DataFlow # Efficient DataFlow
This tutorial gives an overview of how to build an efficient DataFlow, using ImageNet This tutorial gives an overview of how to build an efficient DataFlow, using ImageNet
dataset as an example. dataset as an example.
Our goal in the end is to have Our goal in the end is to have
a generator which yields ImageNet datapoints (after proper preprocessing) as fast as possible. a generator which yields ImageNet datapoints (after proper preprocessing) as fast as possible.
Since it is simply a generator interface, you can use the DataFlow in other frameworks (e.g. Keras)
or your own code as well.
We use ILSVRC12 training set, which contains 1.28 million images. We use ILSVRC12 training set, which contains 1.28 million images.
Following the [ResNet example](../examples/ResNet), our pre-processing need images in their original resolution, Following the [ResNet example](../examples/ResNet), our pre-processing need images in their original resolution,
...@@ -120,7 +121,7 @@ It will generate a database file of 140G. We build a DataFlow to read the LMDB f ...@@ -120,7 +121,7 @@ It will generate a database file of 140G. We build a DataFlow to read the LMDB f
``` ```
from tensorpack import * from tensorpack import *
ds = LMDBData('/path/to/ILSVRC-train.lmdb', shuffle=False) ds = LMDBData('/path/to/ILSVRC-train.lmdb', shuffle=False)
ds = BatchData(ds, 256, allow_list=True) ds = BatchData(ds, 256, use_list=True)
TestDataSpeed(ds).start_test() TestDataSpeed(ds).start_test()
``` ```
Depending on whether the OS has cached the file for you (and how large the RAM is), the above script Depending on whether the OS has cached the file for you (and how large the RAM is), the above script
...@@ -134,7 +135,7 @@ As a reference, on Samsung SSD 850, the uncached speed is about 16it/s. ...@@ -134,7 +135,7 @@ As a reference, on Samsung SSD 850, the uncached speed is about 16it/s.
ds = LMDBData('/path/to/ILSVRC-train.lmdb', shuffle=False) ds = LMDBData('/path/to/ILSVRC-train.lmdb', shuffle=False)
ds = LocallyShuffleData(ds, 50000) ds = LocallyShuffleData(ds, 50000)
ds = BatchData(ds, 256, allow_list=True) ds = BatchData(ds, 256, use_list=True)
``` ```
Instead of shuffling all the training data in every epoch (which would require random read), Instead of shuffling all the training data in every epoch (which would require random read),
the added line above maintains a buffer of datapoints and shuffle them once a while. the added line above maintains a buffer of datapoints and shuffle them once a while.
...@@ -153,7 +154,7 @@ Then we add necessary transformations: ...@@ -153,7 +154,7 @@ Then we add necessary transformations:
ds = AugmentImageComponent(ds, lots_of_augmentors) ds = AugmentImageComponent(ds, lots_of_augmentors)
ds = BatchData(ds, 256) ds = BatchData(ds, 256)
``` ```
1. `LMDBData` deserialized the datapoints (from string to [jpeg_string, label]) 1. `LMDBData` deserialize the datapoints (from string to [jpeg_string, label])
2. Use opencv to decode the first component into ndarray 2. Use opencv to decode the first component into ndarray
3. Apply augmentations to the ndarray 3. Apply augmentations to the ndarray
...@@ -172,7 +173,7 @@ Both imdecode and the augmentors can be quite slow. We can parallelize them like ...@@ -172,7 +173,7 @@ Both imdecode and the augmentors can be quite slow. We can parallelize them like
ds = BatchData(ds, 256) ds = BatchData(ds, 256)
``` ```
Since we are reading the database sequentially, have multiple identical instances of the Since we are reading the database sequentially, having multiple identical instances of the
underlying DataFlow will result in biased data distribution. Therefore we use `PrefetchData` to underlying DataFlow will result in biased data distribution. Therefore we use `PrefetchData` to
launch the underlying DataFlow in one independent process, and only parallelize the transformations. launch the underlying DataFlow in one independent process, and only parallelize the transformations.
(`PrefetchDataZMQ` is faster but not fork-safe, so the first prefetch has to be `PrefetchData`. This is [issue#138]) (`PrefetchDataZMQ` is faster but not fork-safe, so the first prefetch has to be `PrefetchData`. This is [issue#138])
......
...@@ -20,9 +20,9 @@ from GAN import GANTrainer, RandomZData, GANModelDesc ...@@ -20,9 +20,9 @@ from GAN import GANTrainer, RandomZData, GANModelDesc
""" """
DCGAN on CelebA dataset. DCGAN on CelebA dataset.
The original code (dcgan.torch) uses kernel_shape=4, but I found the difference not significant.
1. Download the 'aligned&cropped' version of CelebA dataset. 1. Download the 'aligned&cropped' version of CelebA dataset
(or just use any directory of jpg files).
2. Start training: 2. Start training:
./DCGAN-CelebA.py --data /path/to/image_align_celeba/ ./DCGAN-CelebA.py --data /path/to/image_align_celeba/
3. Visualize samples of a trained model: 3. Visualize samples of a trained model:
......
...@@ -132,7 +132,6 @@ def get_data(train_or_test): ...@@ -132,7 +132,6 @@ def get_data(train_or_test):
crop 8%~100% of the original image crop 8%~100% of the original image
See `Going Deeper with Convolutions` by Google. See `Going Deeper with Convolutions` by Google.
""" """
def _augment(self, img, _): def _augment(self, img, _):
h, w = img.shape[:2] h, w = img.shape[:2]
area = h * w area = h * w
......
...@@ -74,13 +74,9 @@ class GraphVarParam(HyperParam): ...@@ -74,13 +74,9 @@ class GraphVarParam(HyperParam):
else: else:
raise ValueError("{} is not a VARIABLE in the graph!".format(self.var_name)) raise ValueError("{} is not a VARIABLE in the graph!".format(self.var_name))
self.val_holder = tf.placeholder(tf.float32, shape=self.shape,
name=self._readable_name + '_feed')
self.assign_op = self.var.assign(self.val_holder)
def set_value(self, v): def set_value(self, v):
""" Assign the variable a new value. """ """ Assign the variable a new value. """
self.assign_op.eval(feed_dict={self.val_holder: v}) self.var.load(v)
def get_value(self): def get_value(self):
""" Evaluate the variable. """ """ Evaluate the variable. """
......
...@@ -146,7 +146,8 @@ class ILSVRC12(RNGDataFlow): ...@@ -146,7 +146,8 @@ class ILSVRC12(RNGDataFlow):
mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train
find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}' find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}'
""" """
assert name in ['train', 'test', 'val'] assert name in ['train', 'test', 'val'], name
assert os.path.isdir(dir), dir
self.full_dir = os.path.join(dir, name) self.full_dir = os.path.join(dir, name)
self.name = name self.name = name
assert os.path.isdir(self.full_dir), self.full_dir assert os.path.isdir(self.full_dir), self.full_dir
......
...@@ -54,14 +54,10 @@ class SessionUpdate(object): ...@@ -54,14 +54,10 @@ class SessionUpdate(object):
vars_to_update: a collection of variables to update vars_to_update: a collection of variables to update
""" """
self.sess = sess self.sess = sess
self.assign_ops = defaultdict(list) self.name_map = defaultdict(list)
for v in vars_to_update: for v in vars_to_update:
# p = tf.placeholder(v.dtype, shape=v.get_shape()) savename = get_savename_from_varname(v.name)
with tf.device('/cpu:0'): self.name_map[savename].append(v)
p = tf.placeholder(v.dtype)
savename = get_savename_from_varname(v.name)
# multiple vars might share one savename
self.assign_ops[savename].append((p, v, v.assign(p)))
def update(self, prms): def update(self, prms):
""" """
...@@ -70,8 +66,8 @@ class SessionUpdate(object): ...@@ -70,8 +66,8 @@ class SessionUpdate(object):
Any name in prms must be in the graph and in vars_to_update. Any name in prms must be in the graph and in vars_to_update.
""" """
for name, value in six.iteritems(prms): for name, value in six.iteritems(prms):
assert name in self.assign_ops assert name in self.name_map
for p, v, op in self.assign_ops[name]: for v in self.name_map[name]:
varshape = tuple(v.get_shape().as_list()) varshape = tuple(v.get_shape().as_list())
if varshape != value.shape: if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis # TODO only allow reshape when shape different by empty axis
...@@ -79,7 +75,7 @@ class SessionUpdate(object): ...@@ -79,7 +75,7 @@ class SessionUpdate(object):
"{}: {}!={}".format(name, varshape, value.shape) "{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during assigning".format(name)) logger.warn("Param {} is reshaped during assigning".format(name))
value = value.reshape(varshape) value = value.reshape(varshape)
self.sess.run(op, feed_dict={p: value}) v.load(value, session=self.sess)
def dump_session_params(path): def dump_session_params(path):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment