Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c2661527
Commit
c2661527
authored
Jan 05, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update enhancenet to be able to use zip directly.
parent
6fb2261e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
18 deletions
+36
-18
examples/README.md
examples/README.md
+1
-0
examples/SuperResolution/README.md
examples/SuperResolution/README.md
+5
-3
examples/SuperResolution/data_sampler.py
examples/SuperResolution/data_sampler.py
+11
-7
examples/SuperResolution/enet-pat.py
examples/SuperResolution/enet-pat.py
+19
-8
No files found.
examples/README.md
View file @
c2661527
...
@@ -26,6 +26,7 @@ See [Unawareness of Deep Learning Mistakes](https://medium.com/@ppwwyyxx/unaware
...
@@ -26,6 +26,7 @@ See [Unawareness of Deep Learning Mistakes](https://medium.com/@ppwwyyxx/unaware
|
[
Spatial Transformer Networks on MNIST addition
](
SpatialTransformer
)
| reproduce paper |
|
[
Spatial Transformer Networks on MNIST addition
](
SpatialTransformer
)
| reproduce paper |
|
[
Visualize CNN saliency maps
](
Saliency
)
| visually reproduce |
|
[
Visualize CNN saliency maps
](
Saliency
)
| visually reproduce |
|
[
Similarity learning on MNIST
](
SimilarityLearning
)
| |
|
[
Similarity learning on MNIST
](
SimilarityLearning
)
| |
| Single-image super-resolution using
[
EnhanceNet
](
SuperResolution
)
| visually reproduce |
| Learn steering filters with
[
Dynamic Filter Networks
](
DynamicFilterNetwork
)
| visually reproduce |
| Learn steering filters with
[
Dynamic Filter Networks
](
DynamicFilterNetwork
)
| visually reproduce |
| Load a pre-trained
[
AlexNet
](
load-alexnet.py
)
,
[
VGG16
](
load-vgg16.py
)
, or
[
Convolutional Pose Machines
](
ConvolutionalPoseMachines/
)
| |
| Load a pre-trained
[
AlexNet
](
load-alexnet.py
)
,
[
VGG16
](
load-vgg16.py
)
, or
[
Convolutional Pose Machines
](
ConvolutionalPoseMachines/
)
| |
...
...
examples/SuperResolution/README.md
View file @
c2661527
...
@@ -20,15 +20,17 @@ produce a 4x resolution image using different loss functions.
...
@@ -20,15 +20,17 @@ produce a 4x resolution image using different loss functions.
```
bash
```
bash
wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/train2017.zip
python data_sampler.py
--lmdb
train2017.lmdb
--input
train2017.zip
--create
wget http://models.tensorpack.com/caffe/vgg19.npy
wget http://models.tensorpack.com/caffe/vgg19.npy
```
```
2.
Train an EnhanceNet-PAT using:
2.
Train an EnhanceNet-PAT using:
```
bash
```
bash
python enet-pat.py
--vgg19
/path/to/vgg19.npy
--lmdb
train2017.lmdb
python enet-pat.py
--vgg19
/path/to/vgg19.npy
--data
train2017.zip
# or: convert to an lmdb first and train with lmdb:
python data_sampler.py
--lmdb
train2017.lmdb
--input
train2017.zip
--create
python enet-pat.py
--vgg19
/path/to/vgg19.npy
--data
train2017.lmdb
```
```
Training is highly unstable and does not often give results as good as the pretrained model.
Training is highly unstable and does not often give results as good as the pretrained model.
...
...
examples/SuperResolution/data_sampler.py
View file @
c2661527
...
@@ -3,28 +3,33 @@ import os
...
@@ -3,28 +3,33 @@ import os
import
argparse
import
argparse
import
numpy
as
np
import
numpy
as
np
import
zipfile
import
zipfile
import
random
from
tensorpack
import
RNGDataFlow
,
MapDataComponent
,
dftools
from
tensorpack
import
RNGDataFlow
,
MapDataComponent
,
dftools
class
ImageDataFromZIPFile
(
RNGDataFlow
):
class
ImageDataFromZIPFile
(
RNGDataFlow
):
""" Produce images read from a list of zip files. """
""" Produce images read from a list of zip files. """
def
__init__
(
self
,
zip_file
,
shuffle
=
False
,
max_files
=
None
):
def
__init__
(
self
,
zip_file
,
shuffle
=
False
):
"""
"""
Args:
Args:
zip_file (list): list of zip file paths.
zip_file (list): list of zip file paths.
"""
"""
assert
os
.
path
.
isfile
(
zip_file
)
assert
os
.
path
.
isfile
(
zip_file
)
self
.
_file
=
zip_file
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
self
.
max
=
max_files
self
.
open
()
def
open
(
self
):
self
.
archivefiles
=
[]
self
.
archivefiles
=
[]
archive
=
zipfile
.
ZipFile
(
zip
_file
)
archive
=
zipfile
.
ZipFile
(
self
.
_file
)
imagesInArchive
=
archive
.
namelist
()
imagesInArchive
=
archive
.
namelist
()
for
img_name
in
imagesInArchive
:
for
img_name
in
imagesInArchive
:
if
img_name
.
endswith
(
'.jpg'
):
if
img_name
.
endswith
(
'.jpg'
):
self
.
archivefiles
.
append
((
archive
,
img_name
))
self
.
archivefiles
.
append
((
archive
,
img_name
))
if
self
.
max
is
None
:
self
.
max
=
self
.
size
()
def
reset_state
(
self
):
super
(
ImageDataFromZIPFile
,
self
)
.
reset_state
()
# Seems necessary to reopen the zip file in forked processes.
self
.
open
()
def
size
(
self
):
def
size
(
self
):
return
len
(
self
.
archivefiles
)
return
len
(
self
.
archivefiles
)
...
@@ -32,7 +37,6 @@ class ImageDataFromZIPFile(RNGDataFlow):
...
@@ -32,7 +37,6 @@ class ImageDataFromZIPFile(RNGDataFlow):
def
get_data
(
self
):
def
get_data
(
self
):
if
self
.
shuffle
:
if
self
.
shuffle
:
self
.
rng
.
shuffle
(
self
.
archivefiles
)
self
.
rng
.
shuffle
(
self
.
archivefiles
)
self
.
archivefiles
=
random
.
sample
(
self
.
archivefiles
,
self
.
max
)
for
archive
in
self
.
archivefiles
:
for
archive
in
self
.
archivefiles
:
im_data
=
archive
[
0
]
.
read
(
archive
[
1
])
im_data
=
archive
[
0
]
.
read
(
archive
[
1
])
im_data
=
np
.
asarray
(
bytearray
(
im_data
),
dtype
=
'uint8'
)
im_data
=
np
.
asarray
(
bytearray
(
im_data
),
dtype
=
'uint8'
)
...
...
examples/SuperResolution/enet-pat.py
View file @
c2661527
...
@@ -13,7 +13,9 @@ from tensorpack import *
...
@@ -13,7 +13,9 @@ from tensorpack import *
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils
import
logger
from
tensorpack.utils
import
logger
from
data_sampler
import
ImageDecode
from
data_sampler
import
(
ImageDecode
,
ImageDataFromZIPFile
,
RejectTooSmallImages
,
CenterSquareResize
)
from
GAN
import
SeparateGANTrainer
,
GANModelDesc
from
GAN
import
SeparateGANTrainer
,
GANModelDesc
Reduction
=
tf
.
losses
.
Reduction
Reduction
=
tf
.
losses
.
Reduction
...
@@ -236,14 +238,22 @@ def apply(model_path, lowres_path="", output_path='.'):
...
@@ -236,14 +238,22 @@ def apply(model_path, lowres_path="", output_path='.'):
cv2
.
imwrite
(
os
.
path
.
join
(
output_path
,
"baseline.png"
),
baseline
)
cv2
.
imwrite
(
os
.
path
.
join
(
output_path
,
"baseline.png"
),
baseline
)
def
get_data
(
lmdb
):
def
get_data
(
file_name
):
ds
=
LMDBDataPoint
(
lmdb
,
shuffle
=
True
)
if
file_name
.
endswith
(
'.lmdb'
):
ds
=
ImageDecode
(
ds
,
index
=
0
)
ds
=
LMDBDataPoint
(
file_name
,
shuffle
=
True
)
ds
=
ImageDecode
(
ds
,
index
=
0
)
elif
file_name
.
endswith
(
'.zip'
):
ds
=
ImageDataFromZIPFile
(
file_name
,
shuffle
=
True
)
ds
=
ImageDecode
(
ds
,
index
=
0
)
ds
=
RejectTooSmallImages
(
ds
,
index
=
0
)
ds
=
CenterSquareResize
(
ds
,
index
=
0
)
else
:
raise
ValueError
(
"Unknown file format "
+
file_name
)
augmentors
=
[
imgaug
.
RandomCrop
(
128
),
augmentors
=
[
imgaug
.
RandomCrop
(
128
),
imgaug
.
Flip
(
horiz
=
True
)]
imgaug
.
Flip
(
horiz
=
True
)]
ds
=
AugmentImageComponent
(
ds
,
augmentors
,
index
=
0
,
copy
=
True
)
ds
=
AugmentImageComponent
(
ds
,
augmentors
,
index
=
0
,
copy
=
True
)
ds
=
MapData
(
ds
,
lambda
x
:
[
cv2
.
resize
(
x
[
0
],
(
32
,
32
),
interpolation
=
cv2
.
INTER_CUBIC
),
x
[
0
]])
ds
=
MapData
(
ds
,
lambda
x
:
[
cv2
.
resize
(
x
[
0
],
(
32
,
32
),
interpolation
=
cv2
.
INTER_CUBIC
),
x
[
0
]])
ds
=
PrefetchDataZMQ
(
ds
,
8
)
ds
=
PrefetchDataZMQ
(
ds
,
3
)
ds
=
BatchData
(
ds
,
BATCH_SIZE
)
ds
=
BatchData
(
ds
,
BATCH_SIZE
)
return
ds
return
ds
...
@@ -253,7 +263,8 @@ if __name__ == '__main__':
...
@@ -253,7 +263,8 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--apply'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--apply'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--lmdb'
,
help
=
'path to lmdb_file'
)
parser
.
add_argument
(
'--data'
,
help
=
'path to the dataset. '
'Can be either a LMDB generated by `data_sampler.py` or the original COCO zip.'
)
parser
.
add_argument
(
'--vgg19'
,
help
=
'load model'
,
default
=
""
)
parser
.
add_argument
(
'--vgg19'
,
help
=
'load model'
,
default
=
""
)
parser
.
add_argument
(
'--lowres'
,
help
=
'low resolution image as input'
,
default
=
""
,
type
=
str
)
parser
.
add_argument
(
'--lowres'
,
help
=
'low resolution image as input'
,
default
=
""
,
type
=
str
)
parser
.
add_argument
(
'--output'
,
help
=
'directory for saving predicted high-res image'
,
default
=
"."
,
type
=
str
)
parser
.
add_argument
(
'--output'
,
help
=
'directory for saving predicted high-res image'
,
default
=
"."
,
type
=
str
)
...
@@ -276,7 +287,7 @@ if __name__ == '__main__':
...
@@ -276,7 +287,7 @@ if __name__ == '__main__':
session_init
=
DictRestore
(
param_dict
)
session_init
=
DictRestore
(
param_dict
)
nr_tower
=
max
(
get_nr_gpu
(),
1
)
nr_tower
=
max
(
get_nr_gpu
(),
1
)
data
=
QueueInput
(
get_data
(
args
.
lmdb
))
data
=
QueueInput
(
get_data
(
args
.
data
))
model
=
Model
()
model
=
Model
()
trainer
=
SeparateGANTrainer
(
data
,
model
,
d_period
=
3
)
trainer
=
SeparateGANTrainer
(
data
,
model
,
d_period
=
3
)
...
@@ -287,5 +298,5 @@ if __name__ == '__main__':
...
@@ -287,5 +298,5 @@ if __name__ == '__main__':
],
],
session_init
=
session_init
,
session_init
=
session_init
,
steps_per_epoch
=
data
.
size
()
//
4
,
steps_per_epoch
=
data
.
size
()
//
4
,
max_epoch
=
20
00
max_epoch
=
3
00
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment