Commit c2661527 authored by Yuxin Wu's avatar Yuxin Wu

update enhancenet to be able to use zip directly.

parent 6fb2261e
......@@ -26,6 +26,7 @@ See [Unawareness of Deep Learning Mistakes](https://medium.com/@ppwwyyxx/unaware
| [Spatial Transformer Networks on MNIST addition](SpatialTransformer) | reproduce paper |
| [Visualize CNN saliency maps](Saliency) | visually reproduce |
| [Similarity learning on MNIST](SimilarityLearning) | |
| Single-image super-resolution using [EnhanceNet](SuperResolution) | visually reproduce |
| Learn steering filters with [Dynamic Filter Networks](DynamicFilterNetwork) | visually reproduce |
| Load a pre-trained [AlexNet](load-alexnet.py), [VGG16](load-vgg16.py), or [Convolutional Pose Machines](ConvolutionalPoseMachines/) | |
......
......@@ -20,15 +20,17 @@ produce a 4x resolution image using different loss functions.
```bash
wget http://images.cocodataset.org/zips/train2017.zip
python data_sampler.py --lmdb train2017.lmdb --input train2017.zip --create
wget http://models.tensorpack.com/caffe/vgg19.npy
```
2. Train an EnhanceNet-PAT using:
```bash
python enet-pat.py --vgg19 /path/to/vgg19.npy --lmdb train2017.lmdb
python enet-pat.py --vgg19 /path/to/vgg19.npy --data train2017.zip
# or: convert to an lmdb first and train with lmdb:
python data_sampler.py --lmdb train2017.lmdb --input train2017.zip --create
python enet-pat.py --vgg19 /path/to/vgg19.npy --data train2017.lmdb
```
Training is highly unstable and does not often give results as good as the pretrained model.
......
......@@ -3,28 +3,33 @@ import os
import argparse
import numpy as np
import zipfile
import random
from tensorpack import RNGDataFlow, MapDataComponent, dftools
class ImageDataFromZIPFile(RNGDataFlow):
""" Produce images read from a list of zip files. """
def __init__(self, zip_file, shuffle=False, max_files=None):
def __init__(self, zip_file, shuffle=False):
"""
Args:
zip_file (list): list of zip file paths.
"""
assert os.path.isfile(zip_file)
self._file = zip_file
self.shuffle = shuffle
self.max = max_files
self.open()
def open(self):
self.archivefiles = []
archive = zipfile.ZipFile(zip_file)
archive = zipfile.ZipFile(self._file)
imagesInArchive = archive.namelist()
for img_name in imagesInArchive:
if img_name.endswith('.jpg'):
self.archivefiles.append((archive, img_name))
if self.max is None:
self.max = self.size()
def reset_state(self):
super(ImageDataFromZIPFile, self).reset_state()
# Seems necessary to reopen the zip file in forked processes.
self.open()
def size(self):
return len(self.archivefiles)
......@@ -32,7 +37,6 @@ class ImageDataFromZIPFile(RNGDataFlow):
def get_data(self):
if self.shuffle:
self.rng.shuffle(self.archivefiles)
self.archivefiles = random.sample(self.archivefiles, self.max)
for archive in self.archivefiles:
im_data = archive[0].read(archive[1])
im_data = np.asarray(bytearray(im_data), dtype='uint8')
......
......@@ -13,7 +13,9 @@ from tensorpack import *
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger
from data_sampler import ImageDecode
from data_sampler import (
ImageDecode, ImageDataFromZIPFile,
RejectTooSmallImages, CenterSquareResize)
from GAN import SeparateGANTrainer, GANModelDesc
Reduction = tf.losses.Reduction
......@@ -236,14 +238,22 @@ def apply(model_path, lowres_path="", output_path='.'):
cv2.imwrite(os.path.join(output_path, "baseline.png"), baseline)
def get_data(lmdb):
ds = LMDBDataPoint(lmdb, shuffle=True)
ds = ImageDecode(ds, index=0)
def get_data(file_name):
if file_name.endswith('.lmdb'):
ds = LMDBDataPoint(file_name, shuffle=True)
ds = ImageDecode(ds, index=0)
elif file_name.endswith('.zip'):
ds = ImageDataFromZIPFile(file_name, shuffle=True)
ds = ImageDecode(ds, index=0)
ds = RejectTooSmallImages(ds, index=0)
ds = CenterSquareResize(ds, index=0)
else:
raise ValueError("Unknown file format " + file_name)
augmentors = [imgaug.RandomCrop(128),
imgaug.Flip(horiz=True)]
ds = AugmentImageComponent(ds, augmentors, index=0, copy=True)
ds = MapData(ds, lambda x: [cv2.resize(x[0], (32, 32), interpolation=cv2.INTER_CUBIC), x[0]])
ds = PrefetchDataZMQ(ds, 8)
ds = PrefetchDataZMQ(ds, 3)
ds = BatchData(ds, BATCH_SIZE)
return ds
......@@ -253,7 +263,8 @@ if __name__ == '__main__':
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('--apply', action='store_true')
parser.add_argument('--lmdb', help='path to lmdb_file')
parser.add_argument('--data', help='path to the dataset. '
'Can be either a LMDB generated by `data_sampler.py` or the original COCO zip.')
parser.add_argument('--vgg19', help='load model', default="")
parser.add_argument('--lowres', help='low resolution image as input', default="", type=str)
parser.add_argument('--output', help='directory for saving predicted high-res image', default=".", type=str)
......@@ -276,7 +287,7 @@ if __name__ == '__main__':
session_init = DictRestore(param_dict)
nr_tower = max(get_nr_gpu(), 1)
data = QueueInput(get_data(args.lmdb))
data = QueueInput(get_data(args.data))
model = Model()
trainer = SeparateGANTrainer(data, model, d_period=3)
......@@ -287,5 +298,5 @@ if __name__ == '__main__':
],
session_init=session_init,
steps_per_epoch=data.size() // 4,
max_epoch=2000
max_epoch=300
)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment