Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
81d5fbd8
Commit
81d5fbd8
authored
Dec 29, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[SuperResolution] closer to paper's settings (#541)
parent
6b10019e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
37 deletions
+38
-37
examples/SuperResolution/enet-pat.py
examples/SuperResolution/enet-pat.py
+38
-37
No files found.
examples/SuperResolution/enet-pat.py
View file @
81d5fbd8
...
...
@@ -14,29 +14,27 @@ from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils
import
logger
from
data_sampler
import
ImageDecode
from
GAN
import
MultiGPU
GANTrainer
,
GANModelDesc
from
GAN
import
Separate
GANTrainer
,
GANModelDesc
Reduction
=
tf
.
losses
.
Reduction
BATCH_SIZE
=
6
BATCH_SIZE
=
16
CHANNELS
=
3
SHAPE_LR
=
32
NF
=
64
VGG_MEAN
=
np
.
array
([
123.68
,
116.779
,
103.939
])
# RGB
GAN_FACTOR_PARAMETER
=
2.
def
normalize
(
v
):
assert
isinstance
(
v
,
tf
.
Tensor
)
v
.
get_shape
()
.
assert_has_rank
(
4
)
dim
=
v
.
get_shape
()
.
as_list
()
return
v
/
(
dim
[
1
]
*
dim
[
2
]
*
dim
[
3
])
return
v
/
tf
.
reduce_mean
(
v
,
axis
=
[
1
,
2
,
3
],
keep_dims
=
True
)
def
gram_matrix
(
v
):
assert
isinstance
(
v
,
tf
.
Tensor
)
v
.
get_shape
()
.
assert_has_rank
(
4
)
dim
=
v
.
get_shape
()
.
as_list
()
v
=
normalize
(
v
)
v
=
tf
.
reshape
(
v
,
[
-
1
,
dim
[
1
]
*
dim
[
2
],
dim
[
3
]])
return
tf
.
matmul
(
v
,
v
,
transpose_a
=
True
)
...
...
@@ -49,14 +47,17 @@ class Model(GANModelDesc):
self
.
width
=
width
def
_get_inputs
(
self
):
# mean-subtracted images
return
[
InputDesc
(
tf
.
float32
,
(
None
,
self
.
height
*
1
,
self
.
width
*
1
,
CHANNELS
),
'Ilr'
),
InputDesc
(
tf
.
float32
,
(
None
,
self
.
height
*
4
,
self
.
width
*
4
,
CHANNELS
),
'Ihr'
)]
def
_build_graph
(
self
,
inputs
):
ctx
=
get_current_tower_context
()
Ilr
,
Ihr
=
inputs
[
0
]
/
255.0
,
inputs
[
1
]
/
255.0
Ibicubic
=
tf
.
image
.
resize_bicubic
(
Ilr
,
[
4
*
self
.
height
,
4
*
self
.
width
])
Ibicubic
=
tf
.
image
.
resize_bicubic
(
Ilr
,
[
4
*
self
.
height
,
4
*
self
.
width
],
align_corners
=
True
,
name
=
'bicubic_baseline'
)
# (0,1)
VGG_MEAN_TENSOR
=
tf
.
constant
(
VGG_MEAN
,
dtype
=
tf
.
float32
)
def
resnet_block
(
x
,
name
):
with
tf
.
variable_scope
(
name
):
...
...
@@ -66,10 +67,11 @@ class Model(GANModelDesc):
def
upsample
(
x
,
factor
=
2
):
_
,
h
,
w
,
_
=
x
.
get_shape
()
.
as_list
()
x
=
tf
.
image
.
resize_nearest_neighbor
(
x
,
[
factor
*
h
,
factor
*
w
])
x
=
tf
.
image
.
resize_nearest_neighbor
(
x
,
[
factor
*
h
,
factor
*
w
]
,
align_corners
=
True
)
return
x
def
generator
(
x
,
Ibicubic
):
x
=
x
-
VGG_MEAN_TENSOR
/
255.0
with
argscope
(
Conv2D
,
kernel_shape
=
3
,
stride
=
1
,
nl
=
tf
.
nn
.
relu
):
x
=
Conv2D
(
'conv1'
,
x
,
NF
)
for
i
in
range
(
10
):
...
...
@@ -81,10 +83,11 @@ class Model(GANModelDesc):
x
=
Conv2D
(
'conv_post_3'
,
x
,
NF
)
Ires
=
Conv2D
(
'conv_post_4'
,
x
,
3
,
nl
=
tf
.
identity
)
Iest
=
tf
.
add
(
Ibicubic
,
Ires
,
name
=
'Iest'
)
return
Iest
return
Iest
# [0,1]
@
auto_reuse_variable_scope
def
discriminator
(
x
):
x
=
x
-
VGG_MEAN_TENSOR
/
255.0
with
argscope
(
Conv2D
,
kernel_shape
=
3
,
stride
=
1
,
nl
=
tf
.
nn
.
leaky_relu
):
x
=
Conv2D
(
'conv0'
,
x
,
32
)
x
=
Conv2D
(
'conv0b'
,
x
,
32
,
stride
=
2
)
...
...
@@ -104,7 +107,8 @@ class Model(GANModelDesc):
def
additional_losses
(
a
,
b
):
with
tf
.
variable_scope
(
'VGG19'
):
x
=
tf
.
concat
([
a
,
b
],
axis
=
0
)
x
=
tf
.
reshape
(
x
,
[
2
*
BATCH_SIZE
,
128
,
128
,
3
])
x
=
tf
.
reshape
(
x
,
[
2
*
BATCH_SIZE
,
SHAPE_LR
*
4
,
SHAPE_LR
*
4
,
3
])
*
255.0
x
=
x
-
VGG_MEAN_TENSOR
# VGG 19
with
varreplace
.
freeze_variables
():
with
argscope
(
Conv2D
,
kernel_shape
=
3
,
nl
=
tf
.
nn
.
relu
):
...
...
@@ -132,6 +136,8 @@ class Model(GANModelDesc):
# perceptual loss
with
tf
.
name_scope
(
'perceptual_loss'
):
pool2
=
normalize
(
pool2
)
pool5
=
normalize
(
pool5
)
phi_a_1
,
phi_b_1
=
tf
.
split
(
pool2
,
2
,
axis
=
0
)
phi_a_2
,
phi_b_2
=
tf
.
split
(
pool5
,
2
,
axis
=
0
)
...
...
@@ -143,23 +149,23 @@ class Model(GANModelDesc):
# texture loss
with
tf
.
name_scope
(
'texture_loss'
):
def
texture_loss
(
x
,
p
=
16
):
x
=
normalize
(
x
)
_
,
h
,
w
,
c
=
x
.
get_shape
()
.
as_list
()
x
=
normalize
(
x
)
assert
h
%
p
==
0
and
w
%
p
==
0
logger
.
info
(
'Create texture loss for layer {} with shape {}'
.
format
(
x
.
name
,
x
.
get_shape
()))
x
=
tf
.
space_to_batch_nd
(
x
,
[
p
,
p
],
[[
0
,
0
],
[
0
,
0
]])
x
=
tf
.
reshape
(
x
,
[
p
,
p
,
-
1
,
h
//
p
,
w
//
p
,
c
])
x
=
tf
.
transpose
(
x
,
[
2
,
3
,
4
,
0
,
1
,
5
])
patches_a
,
patches_b
=
tf
.
split
(
x
,
2
)
# each is b,h/p,w/p,p,p,c
x
=
tf
.
space_to_batch_nd
(
x
,
[
p
,
p
],
[[
0
,
0
],
[
0
,
0
]])
# [b * ?, h/p, w/p, c]
x
=
tf
.
reshape
(
x
,
[
p
,
p
,
-
1
,
h
//
p
,
w
//
p
,
c
])
# [p, p, b, h/p, w/p, c]
x
=
tf
.
transpose
(
x
,
[
2
,
3
,
4
,
0
,
1
,
5
])
# [b * ?, p, p, c]
patches_a
,
patches_b
=
tf
.
split
(
x
,
2
,
axis
=
0
)
# each is b,h/p,w/p,p,p,c
patches_a
=
tf
.
reshape
(
patches_a
,
[
-
1
,
p
,
p
,
c
])
patches_b
=
tf
.
reshape
(
patches_b
,
[
-
1
,
p
,
p
,
c
])
patches_a
=
tf
.
reshape
(
patches_a
,
[
-
1
,
p
,
p
,
c
])
# [b * ?, p, p, c]
patches_b
=
tf
.
reshape
(
patches_b
,
[
-
1
,
p
,
p
,
c
])
# [b * ?, p, p, c]
return
tf
.
losses
.
mean_squared_error
(
gram_matrix
(
patches_a
),
gram_matrix
(
patches_b
),
reduction
=
Reduction
.
SUM
)
*
(
1.0
/
BATCH_SIZE
)
reduction
=
Reduction
.
MEAN
)
texture_loss_conv1_1
=
tf
.
identity
(
texture_loss
(
conv1_1
),
name
=
'normalized_conv1_1'
)
texture_loss_conv2_1
=
tf
.
identity
(
texture_loss
(
conv2_1
),
name
=
'normalized_conv2_1'
)
...
...
@@ -171,8 +177,7 @@ class Model(GANModelDesc):
fake_hr
=
generator
(
Ilr
,
Ibicubic
)
real_hr
=
Ihr
VGG_MEAN_TENSOR
=
tf
.
constant
(
VGG_MEAN
,
dtype
=
tf
.
float32
)
tf
.
add
(
fake_hr
,
VGG_MEAN_TENSOR
/
255.0
,
name
=
'prediction'
)
tf
.
multiply
(
fake_hr
,
255.0
,
name
=
'prediction'
)
if
ctx
.
is_training
:
with
tf
.
variable_scope
(
'discrim'
):
...
...
@@ -185,18 +190,19 @@ class Model(GANModelDesc):
with
tf
.
name_scope
(
'additional_losses'
):
# see table 2 from appendix
loss
=
[]
loss
.
append
(
tf
.
multiply
(
1.
,
self
.
g_loss
,
name
=
"loss_LA"
))
loss
.
append
(
tf
.
multiply
(
GAN_FACTOR_PARAMETER
,
self
.
g_loss
,
name
=
"loss_LA"
))
loss
.
append
(
tf
.
multiply
(
2e-1
,
additional_losses
[
0
],
name
=
"loss_LP1"
))
loss
.
append
(
tf
.
multiply
(
2e-2
,
additional_losses
[
1
],
name
=
"loss_LP2"
))
loss
.
append
(
tf
.
multiply
(
3e-7
,
additional_losses
[
2
],
name
=
"loss_LT1"
))
loss
.
append
(
tf
.
multiply
(
1e-6
,
additional_losses
[
3
],
name
=
"loss_LT2"
))
loss
.
append
(
tf
.
multiply
(
1e-6
,
additional_losses
[
4
],
name
=
"loss_LT3"
))
self
.
g_loss
=
self
.
g_loss
+
tf
.
add_n
(
loss
,
name
=
'total_g_loss'
)
add_moving_summary
(
self
.
g_loss
,
*
loss
)
self
.
g_loss
=
tf
.
add_n
(
loss
,
name
=
'total_g_loss'
)
self
.
d_loss
=
tf
.
multiply
(
self
.
d_loss
,
GAN_FACTOR_PARAMETER
,
name
=
'd_loss'
)
add_moving_summary
(
self
.
g_loss
,
self
.
d_loss
,
*
loss
)
# visualization
viz
=
(
tf
.
concat
([
Ibicubic
,
fake_hr
,
real_hr
],
2
))
*
255.
+
VGG_MEAN_TENSOR
viz
=
(
tf
.
concat
([
Ibicubic
,
fake_hr
,
real_hr
],
2
))
*
255.
viz
=
tf
.
cast
(
tf
.
clip_by_value
(
viz
,
0
,
255
),
tf
.
uint8
,
name
=
'viz'
)
tf
.
summary
.
image
(
'input,fake,real'
,
viz
,
max_outputs
=
max
(
30
,
BATCH_SIZE
))
...
...
@@ -207,8 +213,7 @@ class Model(GANModelDesc):
lr
=
tf
.
get_variable
(
'learning_rate'
,
initializer
=
1e-4
,
trainable
=
False
)
opt
=
tf
.
train
.
AdamOptimizer
(
lr
)
gradprocs
=
[
gradproc
.
ScaleGradient
([(
'discrim/*'
,
0.3
)])]
return
optimizer
.
apply_grad_processors
(
opt
,
gradprocs
)
return
opt
def
apply
(
model_path
,
lowres_path
=
""
,
output_path
=
'.'
):
...
...
@@ -217,7 +222,6 @@ def apply(model_path, lowres_path="", output_path='.'):
lr
=
cv2
.
imread
(
lowres_path
)
.
astype
(
np
.
float32
)
baseline
=
cv2
.
resize
(
lr
,
(
0
,
0
),
fx
=
4
,
fy
=
4
,
interpolation
=
cv2
.
INTER_CUBIC
)
LR_SIZE_H
,
LR_SIZE_W
=
lr
.
shape
[:
2
]
lr
-=
VGG_MEAN
predict_func
=
OfflinePredictor
(
PredictConfig
(
model
=
Model
(
LR_SIZE_H
,
LR_SIZE_W
),
...
...
@@ -226,7 +230,7 @@ def apply(model_path, lowres_path="", output_path='.'):
output_names
=
[
'prediction'
]))
pred
=
predict_func
(
lr
[
None
,
...
])
p
=
np
.
clip
(
pred
[
0
][
0
,
...
]
*
255
,
0
,
255
)
p
=
np
.
clip
(
pred
[
0
][
0
,
...
],
0
,
255
)
cv2
.
imwrite
(
os
.
path
.
join
(
output_path
,
"predition.png"
),
p
)
cv2
.
imwrite
(
os
.
path
.
join
(
output_path
,
"baseline.png"
),
baseline
)
...
...
@@ -238,8 +242,7 @@ def get_data(lmdb):
augmentors
=
[
imgaug
.
RandomCrop
(
128
),
imgaug
.
Flip
(
horiz
=
True
)]
ds
=
AugmentImageComponent
(
ds
,
augmentors
,
index
=
0
,
copy
=
True
)
ds
=
MapData
(
ds
,
lambda
x
:
x
-
VGG_MEAN
)
ds
=
MapData
(
ds
,
lambda
x
:
[
cv2
.
resize
(
x
[
0
],
(
32
,
32
),
interpolation
=
cv2
.
INTER_AREA
),
x
[
0
]])
ds
=
MapData
(
ds
,
lambda
x
:
[
cv2
.
resize
(
x
[
0
],
(
32
,
32
),
interpolation
=
cv2
.
INTER_CUBIC
),
x
[
0
]])
ds
=
PrefetchDataZMQ
(
ds
,
8
)
ds
=
BatchData
(
ds
,
BATCH_SIZE
)
return
ds
...
...
@@ -275,16 +278,14 @@ if __name__ == '__main__':
nr_tower
=
max
(
get_nr_gpu
(),
1
)
data
=
QueueInput
(
get_data
(
args
.
lmdb
))
model
=
Model
()
if
nr_tower
==
1
:
trainer
=
GANTrainer
(
data
,
model
)
else
:
trainer
=
MultiGPUGANTrainer
(
nr_tower
,
data
,
model
)
trainer
=
SeparateGANTrainer
(
data
,
model
,
d_period
=
3
)
trainer
.
train_with_defaults
(
callbacks
=
[
ModelSaver
(
keep_checkpoint_every_n_hours
=
2
)
],
session_init
=
session_init
,
steps_per_epoch
=
data
.
size
(),
steps_per_epoch
=
data
.
size
()
//
4
,
max_epoch
=
2000
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment