Commit c2661527 authored by Yuxin Wu's avatar Yuxin Wu

update enhancenet to be able to use zip directly.

parent 6fb2261e
...@@ -26,6 +26,7 @@ See [Unawareness of Deep Learning Mistakes](https://medium.com/@ppwwyyxx/unaware ...@@ -26,6 +26,7 @@ See [Unawareness of Deep Learning Mistakes](https://medium.com/@ppwwyyxx/unaware
| [Spatial Transformer Networks on MNIST addition](SpatialTransformer) | reproduce paper | | [Spatial Transformer Networks on MNIST addition](SpatialTransformer) | reproduce paper |
| [Visualize CNN saliency maps](Saliency) | visually reproduce | | [Visualize CNN saliency maps](Saliency) | visually reproduce |
| [Similarity learning on MNIST](SimilarityLearning) | | | [Similarity learning on MNIST](SimilarityLearning) | |
| Single-image super-resolution using [EnhanceNet](SuperResolution) | visually reproduce |
| Learn steering filters with [Dynamic Filter Networks](DynamicFilterNetwork) | visually reproduce | | Learn steering filters with [Dynamic Filter Networks](DynamicFilterNetwork) | visually reproduce |
| Load a pre-trained [AlexNet](load-alexnet.py), [VGG16](load-vgg16.py), or [Convolutional Pose Machines](ConvolutionalPoseMachines/) | | | Load a pre-trained [AlexNet](load-alexnet.py), [VGG16](load-vgg16.py), or [Convolutional Pose Machines](ConvolutionalPoseMachines/) | |
......
...@@ -20,15 +20,17 @@ produce a 4x resolution image using different loss functions. ...@@ -20,15 +20,17 @@ produce a 4x resolution image using different loss functions.
```bash ```bash
wget http://images.cocodataset.org/zips/train2017.zip wget http://images.cocodataset.org/zips/train2017.zip
python data_sampler.py --lmdb train2017.lmdb --input train2017.zip --create
wget http://models.tensorpack.com/caffe/vgg19.npy wget http://models.tensorpack.com/caffe/vgg19.npy
``` ```
2. Train an EnhanceNet-PAT using: 2. Train an EnhanceNet-PAT using:
```bash ```bash
python enet-pat.py --vgg19 /path/to/vgg19.npy --lmdb train2017.lmdb python enet-pat.py --vgg19 /path/to/vgg19.npy --data train2017.zip
# or: convert to an lmdb first and train with lmdb:
python data_sampler.py --lmdb train2017.lmdb --input train2017.zip --create
python enet-pat.py --vgg19 /path/to/vgg19.npy --data train2017.lmdb
``` ```
Training is highly unstable and does not often give results as good as the pretrained model. Training is highly unstable and does not often give results as good as the pretrained model.
......
...@@ -3,28 +3,33 @@ import os ...@@ -3,28 +3,33 @@ import os
import argparse import argparse
import numpy as np import numpy as np
import zipfile import zipfile
import random
from tensorpack import RNGDataFlow, MapDataComponent, dftools from tensorpack import RNGDataFlow, MapDataComponent, dftools
class ImageDataFromZIPFile(RNGDataFlow): class ImageDataFromZIPFile(RNGDataFlow):
""" Produce images read from a list of zip files. """ """ Produce images read from a list of zip files. """
def __init__(self, zip_file, shuffle=False, max_files=None): def __init__(self, zip_file, shuffle=False):
""" """
Args: Args:
zip_file (list): list of zip file paths. zip_file (list): list of zip file paths.
""" """
assert os.path.isfile(zip_file) assert os.path.isfile(zip_file)
self._file = zip_file
self.shuffle = shuffle self.shuffle = shuffle
self.max = max_files self.open()
def open(self):
self.archivefiles = [] self.archivefiles = []
archive = zipfile.ZipFile(zip_file) archive = zipfile.ZipFile(self._file)
imagesInArchive = archive.namelist() imagesInArchive = archive.namelist()
for img_name in imagesInArchive: for img_name in imagesInArchive:
if img_name.endswith('.jpg'): if img_name.endswith('.jpg'):
self.archivefiles.append((archive, img_name)) self.archivefiles.append((archive, img_name))
if self.max is None:
self.max = self.size() def reset_state(self):
super(ImageDataFromZIPFile, self).reset_state()
# Seems necessary to reopen the zip file in forked processes.
self.open()
def size(self): def size(self):
return len(self.archivefiles) return len(self.archivefiles)
...@@ -32,7 +37,6 @@ class ImageDataFromZIPFile(RNGDataFlow): ...@@ -32,7 +37,6 @@ class ImageDataFromZIPFile(RNGDataFlow):
def get_data(self): def get_data(self):
if self.shuffle: if self.shuffle:
self.rng.shuffle(self.archivefiles) self.rng.shuffle(self.archivefiles)
self.archivefiles = random.sample(self.archivefiles, self.max)
for archive in self.archivefiles: for archive in self.archivefiles:
im_data = archive[0].read(archive[1]) im_data = archive[0].read(archive[1])
im_data = np.asarray(bytearray(im_data), dtype='uint8') im_data = np.asarray(bytearray(im_data), dtype='uint8')
......
...@@ -13,7 +13,9 @@ from tensorpack import * ...@@ -13,7 +13,9 @@ from tensorpack import *
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger from tensorpack.utils import logger
from data_sampler import ImageDecode from data_sampler import (
ImageDecode, ImageDataFromZIPFile,
RejectTooSmallImages, CenterSquareResize)
from GAN import SeparateGANTrainer, GANModelDesc from GAN import SeparateGANTrainer, GANModelDesc
Reduction = tf.losses.Reduction Reduction = tf.losses.Reduction
...@@ -236,14 +238,22 @@ def apply(model_path, lowres_path="", output_path='.'): ...@@ -236,14 +238,22 @@ def apply(model_path, lowres_path="", output_path='.'):
cv2.imwrite(os.path.join(output_path, "baseline.png"), baseline) cv2.imwrite(os.path.join(output_path, "baseline.png"), baseline)
def get_data(lmdb): def get_data(file_name):
ds = LMDBDataPoint(lmdb, shuffle=True) if file_name.endswith('.lmdb'):
ds = ImageDecode(ds, index=0) ds = LMDBDataPoint(file_name, shuffle=True)
ds = ImageDecode(ds, index=0)
elif file_name.endswith('.zip'):
ds = ImageDataFromZIPFile(file_name, shuffle=True)
ds = ImageDecode(ds, index=0)
ds = RejectTooSmallImages(ds, index=0)
ds = CenterSquareResize(ds, index=0)
else:
raise ValueError("Unknown file format " + file_name)
augmentors = [imgaug.RandomCrop(128), augmentors = [imgaug.RandomCrop(128),
imgaug.Flip(horiz=True)] imgaug.Flip(horiz=True)]
ds = AugmentImageComponent(ds, augmentors, index=0, copy=True) ds = AugmentImageComponent(ds, augmentors, index=0, copy=True)
ds = MapData(ds, lambda x: [cv2.resize(x[0], (32, 32), interpolation=cv2.INTER_CUBIC), x[0]]) ds = MapData(ds, lambda x: [cv2.resize(x[0], (32, 32), interpolation=cv2.INTER_CUBIC), x[0]])
ds = PrefetchDataZMQ(ds, 8) ds = PrefetchDataZMQ(ds, 3)
ds = BatchData(ds, BATCH_SIZE) ds = BatchData(ds, BATCH_SIZE)
return ds return ds
...@@ -253,7 +263,8 @@ if __name__ == '__main__': ...@@ -253,7 +263,8 @@ if __name__ == '__main__':
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--apply', action='store_true') parser.add_argument('--apply', action='store_true')
parser.add_argument('--lmdb', help='path to lmdb_file') parser.add_argument('--data', help='path to the dataset. '
'Can be either a LMDB generated by `data_sampler.py` or the original COCO zip.')
parser.add_argument('--vgg19', help='load model', default="") parser.add_argument('--vgg19', help='load model', default="")
parser.add_argument('--lowres', help='low resolution image as input', default="", type=str) parser.add_argument('--lowres', help='low resolution image as input', default="", type=str)
parser.add_argument('--output', help='directory for saving predicted high-res image', default=".", type=str) parser.add_argument('--output', help='directory for saving predicted high-res image', default=".", type=str)
...@@ -276,7 +287,7 @@ if __name__ == '__main__': ...@@ -276,7 +287,7 @@ if __name__ == '__main__':
session_init = DictRestore(param_dict) session_init = DictRestore(param_dict)
nr_tower = max(get_nr_gpu(), 1) nr_tower = max(get_nr_gpu(), 1)
data = QueueInput(get_data(args.lmdb)) data = QueueInput(get_data(args.data))
model = Model() model = Model()
trainer = SeparateGANTrainer(data, model, d_period=3) trainer = SeparateGANTrainer(data, model, d_period=3)
...@@ -287,5 +298,5 @@ if __name__ == '__main__': ...@@ -287,5 +298,5 @@ if __name__ == '__main__':
], ],
session_init=session_init, session_init=session_init,
steps_per_epoch=data.size() // 4, steps_per_epoch=data.size() // 4,
max_epoch=2000 max_epoch=300
) )
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment