Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
44a2c531
Commit
44a2c531
authored
Feb 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add remap_get_variable
parent
c59987a3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
43 additions
and
34 deletions
+43
-34
README.md
README.md
+2
-1
examples/DoReFa-Net/alexnet-dorefa.py
examples/DoReFa-Net/alexnet-dorefa.py
+5
-5
examples/DoReFa-Net/resnet-dorefa.py
examples/DoReFa-Net/resnet-dorefa.py
+5
-5
examples/DoReFa-Net/svhn-digit-dorefa.py
examples/DoReFa-Net/svhn-digit-dorefa.py
+5
-5
examples/GAN/DCGAN-CelebA.py
examples/GAN/DCGAN-CelebA.py
+9
-10
tensorpack/tfutils/varreplace.py
tensorpack/tfutils/varreplace.py
+17
-8
No files found.
README.md
View file @
44a2c531
...
@@ -19,7 +19,8 @@ Docs & tutorials should be ready within a month. See some [examples](examples) t
...
@@ -19,7 +19,8 @@ Docs & tutorials should be ready within a month. See some [examples](examples) t
+
[
Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym
](
examples/A3C-Gym
)
+
[
Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym
](
examples/A3C-Gym
)
### Unsupervised Learning:
### Unsupervised Learning:
+
[
Generative Adversarial Network(GAN) variants, including DCGAN, Image2Image, InfoGAN
](
examples/GAN
)
+
[
Generative Adversarial Network(GAN) variants
](
examples/GAN
)
, including DCGAN, InfoGAN, Conditional GAN, Image to Image.
### Speech / NLP:
### Speech / NLP:
+
[
LSTM-CTC for speech recognition
](
examples/CTC-TIMIT
)
+
[
LSTM-CTC for speech recognition
](
examples/CTC-TIMIT
)
...
...
examples/DoReFa-Net/alexnet-dorefa.py
View file @
44a2c531
...
@@ -15,7 +15,7 @@ import sys
...
@@ -15,7 +15,7 @@ import sys
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.tfutils.varreplace
import
re
place
_get_variable
from
tensorpack.tfutils.varreplace
import
re
map
_get_variable
from
dorefa
import
get_dorefa
from
dorefa
import
get_dorefa
"""
"""
...
@@ -87,10 +87,10 @@ class Model(ModelDesc):
...
@@ -87,10 +87,10 @@ class Model(ModelDesc):
old_get_variable
=
tf
.
get_variable
old_get_variable
=
tf
.
get_variable
# monkey-patch tf.get_variable to apply fw
# monkey-patch tf.get_variable to apply fw
def
new_get_variable
(
name
,
shape
=
None
,
**
kwargs
):
def
new_get_variable
(
v
):
v
=
old_get_variable
(
name
,
shape
,
**
kwargs
)
name
=
v
.
op
.
name
# don't binarize first and last layer
# don't binarize first and last layer
if
n
ame
!=
'W'
or
'conv0'
in
v
.
op
.
name
or
'fct'
in
v
.
op
.
name
:
if
n
ot
name
.
endswith
(
'W'
)
or
'conv0'
in
name
or
'fct'
in
name
:
return
v
return
v
else
:
else
:
logger
.
info
(
"Binarizing weight {}"
.
format
(
v
.
op
.
name
))
logger
.
info
(
"Binarizing weight {}"
.
format
(
v
.
op
.
name
))
...
@@ -104,7 +104,7 @@ class Model(ModelDesc):
...
@@ -104,7 +104,7 @@ class Model(ModelDesc):
def
activate
(
x
):
def
activate
(
x
):
return
fa
(
nonlin
(
x
))
return
fa
(
nonlin
(
x
))
with
re
place
_get_variable
(
new_get_variable
),
\
with
re
map
_get_variable
(
new_get_variable
),
\
argscope
(
BatchNorm
,
decay
=
0.9
,
epsilon
=
1e-4
),
\
argscope
(
BatchNorm
,
decay
=
0.9
,
epsilon
=
1e-4
),
\
argscope
([
Conv2D
,
FullyConnected
],
use_bias
=
False
,
nl
=
tf
.
identity
):
argscope
([
Conv2D
,
FullyConnected
],
use_bias
=
False
,
nl
=
tf
.
identity
):
logits
=
(
LinearWrap
(
image
)
logits
=
(
LinearWrap
(
image
)
...
...
examples/DoReFa-Net/resnet-dorefa.py
View file @
44a2c531
...
@@ -13,7 +13,7 @@ from tensorpack import *
...
@@ -13,7 +13,7 @@ from tensorpack import *
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.utils.stats
import
RatioCounter
from
tensorpack.utils.stats
import
RatioCounter
from
tensorpack.tfutils.varreplace
import
re
place
_get_variable
from
tensorpack.tfutils.varreplace
import
re
map
_get_variable
from
dorefa
import
get_dorefa
from
dorefa
import
get_dorefa
"""
"""
...
@@ -44,10 +44,10 @@ class Model(ModelDesc):
...
@@ -44,10 +44,10 @@ class Model(ModelDesc):
fw
,
fa
,
fg
=
get_dorefa
(
BITW
,
BITA
,
BITG
)
fw
,
fa
,
fg
=
get_dorefa
(
BITW
,
BITA
,
BITG
)
old_get_variable
=
tf
.
get_variable
old_get_variable
=
tf
.
get_variable
def
new_get_variable
(
name
,
shape
=
None
,
**
kwargs
):
def
new_get_variable
(
v
):
v
=
old_get_variable
(
name
,
shape
,
**
kwargs
)
name
=
v
.
op
.
name
# don't binarize first and last layer
# don't binarize first and last layer
if
n
ame
!=
'W'
or
'conv1'
in
v
.
op
.
name
or
'fct'
in
v
.
op
.
name
:
if
n
ot
name
.
endswith
(
'W'
)
or
'conv1'
in
name
or
'fct'
in
name
:
return
v
return
v
else
:
else
:
logger
.
info
(
"Binarizing weight {}"
.
format
(
v
.
op
.
name
))
logger
.
info
(
"Binarizing weight {}"
.
format
(
v
.
op
.
name
))
...
@@ -90,7 +90,7 @@ class Model(ModelDesc):
...
@@ -90,7 +90,7 @@ class Model(ModelDesc):
x
=
resblock
(
x
,
channel
,
1
)
x
=
resblock
(
x
,
channel
,
1
)
return
x
return
x
with
re
place
_get_variable
(
new_get_variable
),
\
with
re
map
_get_variable
(
new_get_variable
),
\
argscope
(
BatchNorm
,
decay
=
0.9
,
epsilon
=
1e-4
),
\
argscope
(
BatchNorm
,
decay
=
0.9
,
epsilon
=
1e-4
),
\
argscope
(
Conv2D
,
use_bias
=
False
,
nl
=
tf
.
identity
):
argscope
(
Conv2D
,
use_bias
=
False
,
nl
=
tf
.
identity
):
logits
=
(
LinearWrap
(
image
)
logits
=
(
LinearWrap
(
image
)
...
...
examples/DoReFa-Net/svhn-digit-dorefa.py
View file @
44a2c531
...
@@ -11,7 +11,7 @@ import os
...
@@ -11,7 +11,7 @@ import os
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.tfutils.varreplace
import
re
place
_get_variable
from
tensorpack.tfutils.varreplace
import
re
map
_get_variable
from
dorefa
import
get_dorefa
from
dorefa
import
get_dorefa
"""
"""
...
@@ -56,10 +56,10 @@ class Model(ModelDesc):
...
@@ -56,10 +56,10 @@ class Model(ModelDesc):
old_get_variable
=
tf
.
get_variable
old_get_variable
=
tf
.
get_variable
# monkey-patch tf.get_variable to apply fw
# monkey-patch tf.get_variable to apply fw
def
new_get_variable
(
name
,
shape
=
None
,
**
kwargs
):
def
new_get_variable
(
v
):
v
=
old_get_variable
(
name
,
shape
,
**
kwargs
)
name
=
v
.
op
.
name
# don't binarize first and last layer
# don't binarize first and last layer
if
n
ame
!=
'W'
or
'conv0'
in
v
.
op
.
name
or
'fc'
in
v
.
op
.
name
:
if
n
ot
name
.
endswith
(
'W'
)
or
'conv0'
in
name
or
'fc'
in
name
:
return
v
return
v
else
:
else
:
logger
.
info
(
"Binarizing weight {}"
.
format
(
v
.
op
.
name
))
logger
.
info
(
"Binarizing weight {}"
.
format
(
v
.
op
.
name
))
...
@@ -73,7 +73,7 @@ class Model(ModelDesc):
...
@@ -73,7 +73,7 @@ class Model(ModelDesc):
image
=
image
/
256.0
image
=
image
/
256.0
with
re
place
_get_variable
(
new_get_variable
),
\
with
re
map
_get_variable
(
new_get_variable
),
\
argscope
(
BatchNorm
,
decay
=
0.9
,
epsilon
=
1e-4
),
\
argscope
(
BatchNorm
,
decay
=
0.9
,
epsilon
=
1e-4
),
\
argscope
(
Conv2D
,
use_bias
=
False
,
nl
=
tf
.
identity
):
argscope
(
Conv2D
,
use_bias
=
False
,
nl
=
tf
.
identity
):
logits
=
(
LinearWrap
(
image
)
logits
=
(
LinearWrap
(
image
)
...
...
examples/GAN/DCGAN-CelebA.py
View file @
44a2c531
...
@@ -15,7 +15,6 @@ import cv2
...
@@ -15,7 +15,6 @@ import cv2
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.utils.viz
import
*
from
tensorpack.utils.viz
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils.globvars
import
globalns
as
CFG
,
use_global_argument
import
tensorpack.tfutils.symbolic_functions
as
symbf
import
tensorpack.tfutils.symbolic_functions
as
symbf
from
GAN
import
GANTrainer
,
RandomZData
,
GANModelDesc
from
GAN
import
GANTrainer
,
RandomZData
,
GANModelDesc
...
@@ -30,14 +29,14 @@ The original code (dcgan.torch) uses kernel_shape=4, but I found the difference
...
@@ -30,14 +29,14 @@ The original code (dcgan.torch) uses kernel_shape=4, but I found the difference
./DCGAN-CelebA.py --load path/to/model --sample
./DCGAN-CelebA.py --load path/to/model --sample
"""
"""
CFG
.
SHAPE
=
64
SHAPE
=
64
CFG
.
BATCH
=
128
BATCH
=
128
CFG
.
Z_DIM
=
100
Z_DIM
=
100
class
Model
(
GANModelDesc
):
class
Model
(
GANModelDesc
):
def
_get_inputs
(
self
):
def
_get_inputs
(
self
):
return
[
InputVar
(
tf
.
float32
,
(
None
,
CFG
.
SHAPE
,
CFG
.
SHAPE
,
3
),
'input'
)]
return
[
InputVar
(
tf
.
float32
,
(
None
,
SHAPE
,
SHAPE
,
3
),
'input'
)]
def
generator
(
self
,
z
):
def
generator
(
self
,
z
):
""" return a image generated from z"""
""" return a image generated from z"""
...
@@ -73,8 +72,8 @@ class Model(GANModelDesc):
...
@@ -73,8 +72,8 @@ class Model(GANModelDesc):
image_pos
=
inputs
[
0
]
image_pos
=
inputs
[
0
]
image_pos
=
image_pos
/
128.0
-
1
image_pos
=
image_pos
/
128.0
-
1
z
=
tf
.
random_uniform
([
CFG
.
BATCH
,
CFG
.
Z_DIM
],
-
1
,
1
,
name
=
'z_train'
)
z
=
tf
.
random_uniform
([
BATCH
,
Z_DIM
],
-
1
,
1
,
name
=
'z_train'
)
z
=
tf
.
placeholder_with_default
(
z
,
[
None
,
CFG
.
Z_DIM
],
name
=
'z'
)
z
=
tf
.
placeholder_with_default
(
z
,
[
None
,
Z_DIM
],
name
=
'z'
)
with
argscope
([
Conv2D
,
Deconv2D
,
FullyConnected
],
with
argscope
([
Conv2D
,
Deconv2D
,
FullyConnected
],
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
0.02
)):
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
0.02
)):
...
@@ -91,12 +90,13 @@ class Model(GANModelDesc):
...
@@ -91,12 +90,13 @@ class Model(GANModelDesc):
def
get_data
():
def
get_data
():
datadir
=
CFG
.
data
global
args
datadir
=
args
.
data
imgs
=
glob
.
glob
(
datadir
+
'/*.jpg'
)
imgs
=
glob
.
glob
(
datadir
+
'/*.jpg'
)
ds
=
ImageFromFile
(
imgs
,
channel
=
3
,
shuffle
=
True
)
ds
=
ImageFromFile
(
imgs
,
channel
=
3
,
shuffle
=
True
)
augs
=
[
imgaug
.
CenterCrop
(
140
),
imgaug
.
Resize
(
64
)]
augs
=
[
imgaug
.
CenterCrop
(
140
),
imgaug
.
Resize
(
64
)]
ds
=
AugmentImageComponent
(
ds
,
augs
)
ds
=
AugmentImageComponent
(
ds
,
augs
)
ds
=
BatchData
(
ds
,
CFG
.
BATCH
)
ds
=
BatchData
(
ds
,
BATCH
)
ds
=
PrefetchDataZMQ
(
ds
,
1
)
ds
=
PrefetchDataZMQ
(
ds
,
1
)
return
ds
return
ds
...
@@ -137,7 +137,6 @@ if __name__ == '__main__':
...
@@ -137,7 +137,6 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--sample'
,
action
=
'store_true'
,
help
=
'run sampling'
)
parser
.
add_argument
(
'--sample'
,
action
=
'store_true'
,
help
=
'run sampling'
)
parser
.
add_argument
(
'--data'
,
help
=
'`image_align_celeba` directory of the celebA dataset'
)
parser
.
add_argument
(
'--data'
,
help
=
'`image_align_celeba` directory of the celebA dataset'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
use_global_argument
(
args
)
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
if
args
.
sample
:
if
args
.
sample
:
...
...
tensorpack/tfutils/varreplace.py
View file @
44a2c531
...
@@ -7,7 +7,7 @@ import tensorflow as tf
...
@@ -7,7 +7,7 @@ import tensorflow as tf
from
tensorflow.python.ops
import
variable_scope
from
tensorflow.python.ops
import
variable_scope
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
__all__
=
[
'replace_get_variable'
,
'freeze_get_variable'
]
__all__
=
[
'replace_get_variable'
,
'freeze_get_variable'
,
'remap_get_variable'
]
_ORIG_GET_VARIABLE
=
tf
.
get_variable
_ORIG_GET_VARIABLE
=
tf
.
get_variable
...
@@ -16,7 +16,7 @@ _ORIG_GET_VARIABLE = tf.get_variable
...
@@ -16,7 +16,7 @@ _ORIG_GET_VARIABLE = tf.get_variable
def
replace_get_variable
(
fn
):
def
replace_get_variable
(
fn
):
"""
"""
Args:
Args:
fn: a function
taking the same arguments as
``tf.get_variable``.
fn: a function
compatible with
``tf.get_variable``.
Returns:
Returns:
a context where ``tf.get_variable`` and
a context where ``tf.get_variable`` and
``variable_scope.get_variable`` are replaced with ``fn``.
``variable_scope.get_variable`` are replaced with ``fn``.
...
@@ -36,6 +36,19 @@ def replace_get_variable(fn):
...
@@ -36,6 +36,19 @@ def replace_get_variable(fn):
variable_scope
.
get_variable
=
old_vars_getv
variable_scope
.
get_variable
=
old_vars_getv
def
remap_get_variable
(
fn
):
""" Similar to :func:`replace_get_variable`, but the function `fn`
takes the variable returned by the original `tf.get_variable` call
and return a tensor.
"""
old_getv
=
tf
.
get_variable
def
new_get_variable
(
name
,
shape
=
None
,
**
kwargs
):
v
=
old_getv
(
name
,
shape
,
**
kwargs
)
return
fn
(
v
)
return
replace_get_variable
(
new_get_variable
)
def
freeze_get_variable
():
def
freeze_get_variable
():
"""
"""
Return a context, where all variables (reused or not) returned by
Return a context, where all variables (reused or not) returned by
...
@@ -49,9 +62,5 @@ def freeze_get_variable():
...
@@ -49,9 +62,5 @@ def freeze_get_variable():
with varreplace.freeze_get_variable():
with varreplace.freeze_get_variable():
x = FullyConnected('fc', x, 1000) # fc/* will not be trained
x = FullyConnected('fc', x, 1000) # fc/* will not be trained
"""
"""
old_get_variable
=
tf
.
get_variable
return
remap_get_variable
(
lambda
v
:
tf
.
stop_gradient
(
v
))
def
fn
(
name
,
shape
=
None
,
**
kwargs
):
v
=
old_get_variable
(
name
,
shape
,
**
kwargs
)
return
tf
.
stop_gradient
(
v
)
return
replace_get_variable
(
fn
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment