Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8868a9e4
Commit
8868a9e4
authored
Mar 21, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add gamma_init in BN to imitate torch
parent
4c82fb50
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
17 deletions
+20
-17
examples/GAN/GAN.py
examples/GAN/GAN.py
+6
-6
examples/GAN/Image2Image.py
examples/GAN/Image2Image.py
+1
-2
examples/GAN/WGAN-CelebA.py
examples/GAN/WGAN-CelebA.py
+2
-2
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+11
-7
No files found.
examples/GAN/GAN.py
View file @
8868a9e4
...
...
@@ -71,8 +71,8 @@ class GANTrainer(FeedfreeTrainerBase):
self
.
train_op
=
self
.
d_min
class
S
plit
GANTrainer
(
FeedfreeTrainerBase
):
""" A
new trainer which runs two optimization ops with a certain ratio.
"""
class
S
eparate
GANTrainer
(
FeedfreeTrainerBase
):
""" A
GAN trainer which runs two optimization ops with a certain ratio, one in each step.
"""
def
__init__
(
self
,
config
,
d_interval
=
1
):
"""
Args:
...
...
@@ -80,10 +80,10 @@ class SplitGANTrainer(FeedfreeTrainerBase):
"""
self
.
_input_method
=
QueueInput
(
config
.
dataflow
)
self
.
_d_interval
=
d_interval
super
(
S
plit
GANTrainer
,
self
)
.
__init__
(
config
)
super
(
S
eparate
GANTrainer
,
self
)
.
__init__
(
config
)
def
_setup
(
self
):
super
(
S
plit
GANTrainer
,
self
)
.
_setup
()
super
(
S
eparate
GANTrainer
,
self
)
.
_setup
()
self
.
build_train_tower
()
opt
=
self
.
model
.
get_optimizer
()
...
...
@@ -94,11 +94,11 @@ class SplitGANTrainer(FeedfreeTrainerBase):
self
.
_cnt
=
0
def
run_step
(
self
):
self
.
_cnt
+=
1
if
self
.
_cnt
%
(
self
.
_d_interval
)
==
0
:
if
self
.
_cnt
%
(
self
.
_d_interval
+
1
)
==
0
:
self
.
hooked_sess
.
run
(
self
.
d_min
)
else
:
self
.
hooked_sess
.
run
(
self
.
g_min
)
self
.
_cnt
+=
1
class
RandomZData
(
DataFlow
):
...
...
examples/GAN/Image2Image.py
View file @
8868a9e4
...
...
@@ -205,14 +205,13 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--sample'
,
action
=
'store_true'
,
help
=
'run sampling'
)
parser
.
add_argument
(
'--data'
,
help
=
'Image directory'
)
parser
.
add_argument
(
'--data'
,
help
=
'Image directory'
,
required
=
True
)
parser
.
add_argument
(
'--mode'
,
choices
=
[
'AtoB'
,
'BtoA'
],
default
=
'AtoB'
)
parser
.
add_argument
(
'-b'
,
'--batch'
,
type
=
int
,
default
=
1
)
global
args
args
=
parser
.
parse_args
()
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
assert
args
.
data
BATCH
=
args
.
batch
...
...
examples/GAN/WGAN-CelebA.py
View file @
8868a9e4
...
...
@@ -9,7 +9,7 @@ import argparse
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
import
tensorflow
as
tf
from
GAN
import
S
plit
GANTrainer
from
GAN
import
S
eparate
GANTrainer
"""
Wasserstein-GAN.
...
...
@@ -84,4 +84,4 @@ if __name__ == '__main__':
This is to be consistent with the original code, but I found just
running them 1:1 (i.e. just using the existing GANTrainer) also works well.
"""
S
plit
GANTrainer
(
config
,
d_interval
=
5
)
.
train
()
S
eparate
GANTrainer
(
config
,
d_interval
=
5
)
.
train
()
tensorpack/models/batch_norm.py
View file @
8868a9e4
...
...
@@ -96,13 +96,13 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
x
,
ema_mean
,
ema_var
,
beta
,
gamma
,
epsilon
,
'output'
)
def
get_bn_variables
(
n_out
,
use_scale
,
use_bias
):
def
get_bn_variables
(
n_out
,
use_scale
,
use_bias
,
gamma_init
):
if
use_bias
:
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
())
else
:
beta
=
tf
.
zeros
([
n_out
],
name
=
'beta'
)
if
use_scale
:
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(
1.0
)
)
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
gamma_init
)
else
:
gamma
=
tf
.
ones
([
n_out
],
name
=
'gamma'
)
# x * gamma + beta
...
...
@@ -132,7 +132,8 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
@
layer_register
(
log_shape
=
False
)
def
BatchNorm
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
,
use_scale
=
True
,
use_bias
=
True
,
data_format
=
'NHWC'
):
use_scale
=
True
,
use_bias
=
True
,
gamma_init
=
tf
.
constant_initializer
(
1.0
),
data_format
=
'NHWC'
):
"""
Batch Normalization layer, as described in the paper:
`Batch Normalization: Accelerating Deep Network Training by
...
...
@@ -145,14 +146,16 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
decay (float): decay rate of moving average.
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
gamma_init: initializer for gamma (the scale).
Returns:
tf.Tensor: a tensor named ``output`` with the same shape of x.
Variable Names:
* ``beta``: the bias term.
* ``gamma``: the scale term. Input will be transformed by ``x * gamma + beta``.
* ``beta``: the bias term. Will be zero-inited by default.
* ``gamma``: the scale term. Will be one-inited by default.
Input will be transformed by ``x * gamma + beta``.
* ``mean/EMA``: the moving average of mean.
* ``variance/EMA``: the moving average of variance.
...
...
@@ -176,7 +179,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
n_out
])
assert
n_out
is
not
None
,
"Input to BatchNorm cannot have unknown channels!"
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
n_out
,
use_scale
,
use_bias
)
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
n_out
,
use_scale
,
use_bias
,
gamma_init
)
ctx
=
get_current_tower_context
()
if
use_local_stat
is
None
:
...
...
@@ -245,7 +248,8 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
n_out
=
shape
[
-
1
]
if
len
(
shape
)
==
2
:
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
n_out
])
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
n_out
,
use_scale
,
use_bias
)
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
n_out
,
use_scale
,
use_bias
,
tf
.
constant_initializer
(
1.0
))
ctx
=
get_current_tower_context
()
use_local_stat
=
ctx
.
is_training
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment