Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
fbb73a8a
Commit
fbb73a8a
authored
Apr 13, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
argscope
parent
23f2ccd6
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
70 additions
and
23 deletions
+70
-23
examples/cifar10_convnet.py
examples/cifar10_convnet.py
+9
-6
tensorpack/models/_common.py
tensorpack/models/_common.py
+12
-9
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+1
-1
tensorpack/models/nonlin.py
tensorpack/models/nonlin.py
+6
-7
tensorpack/tfutils/__init__.py
tensorpack/tfutils/__init__.py
+1
-0
tensorpack/tfutils/argscope.py
tensorpack/tfutils/argscope.py
+41
-0
No files found.
examples/cifar10_convnet.py
View file @
fbb73a8a
...
@@ -43,15 +43,18 @@ class Model(ModelDesc):
...
@@ -43,15 +43,18 @@ class Model(ModelDesc):
tf
.
image_summary
(
"train_image"
,
image
,
10
)
tf
.
image_summary
(
"train_image"
,
image
,
10
)
image
=
image
/
4.0
# just to make range smaller
image
=
image
/
4.0
# just to make range smaller
l
=
Conv2D
(
'conv1.1'
,
image
,
out_channel
=
64
,
kernel_shape
=
3
)
l
=
Conv2D
(
'conv1.1'
,
image
,
out_channel
=
64
,
kernel_shape
=
3
,
nl
=
BNReLU
(
is_training
),
use_bias
=
False
)
l
=
Conv2D
(
'conv1.2'
,
l
,
out_channel
=
64
,
kernel_shape
=
3
,
nl
=
BNReLU
(
is_training
),
use_bias
=
False
)
l
=
Conv2D
(
'conv1.2'
,
l
,
out_channel
=
64
,
kernel_shape
=
3
,
nl
=
BNReLU
(
is_training
),
use_bias
=
False
)
l
=
MaxPooling
(
'pool1'
,
l
,
3
,
stride
=
2
,
padding
=
'SAME'
)
l
=
MaxPooling
(
'pool1'
,
l
,
3
,
stride
=
2
,
padding
=
'SAME'
)
l
=
Conv2D
(
'conv2.1'
,
l
,
out_channel
=
128
,
kernel_shape
=
3
)
l
=
Conv2D
(
'conv2.1'
,
l
,
out_channel
=
128
,
kernel_shape
=
3
,
nl
=
BNReLU
(
is_training
),
use_bias
=
False
)
l
=
Conv2D
(
'conv2.2'
,
l
,
out_channel
=
128
,
kernel_shape
=
3
,
nl
=
BNReLU
(
is_training
),
use_bias
=
False
)
l
=
Conv2D
(
'conv2.2'
,
l
,
out_channel
=
128
,
kernel_shape
=
3
,
nl
=
BNReLU
(
is_training
),
use_bias
=
False
)
l
=
MaxPooling
(
'pool2'
,
l
,
3
,
stride
=
2
,
padding
=
'SAME'
)
l
=
MaxPooling
(
'pool2'
,
l
,
3
,
stride
=
2
,
padding
=
'SAME'
)
l
=
Conv2D
(
'conv3.1'
,
l
,
out_channel
=
128
,
kernel_shape
=
3
,
padding
=
'VALID'
)
l
=
Conv2D
(
'conv3.1'
,
l
,
out_channel
=
128
,
kernel_shape
=
3
,
padding
=
'VALID'
,
nl
=
BNReLU
(
is_training
),
use_bias
=
False
)
l
=
Conv2D
(
'conv3.2'
,
l
,
out_channel
=
128
,
kernel_shape
=
3
,
padding
=
'VALID'
,
nl
=
BNReLU
(
is_training
),
use_bias
=
False
)
l
=
Conv2D
(
'conv3.2'
,
l
,
out_channel
=
128
,
kernel_shape
=
3
,
padding
=
'VALID'
,
nl
=
BNReLU
(
is_training
),
use_bias
=
False
)
l
=
FullyConnected
(
'fc0'
,
l
,
1024
+
512
,
l
=
FullyConnected
(
'fc0'
,
l
,
1024
+
512
,
b_init
=
tf
.
constant_initializer
(
0.1
))
b_init
=
tf
.
constant_initializer
(
0.1
))
...
@@ -80,7 +83,7 @@ class Model(ModelDesc):
...
@@ -80,7 +83,7 @@ class Model(ModelDesc):
name
=
'regularize_loss'
)
name
=
'regularize_loss'
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
wd_cost
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
wd_cost
)
add_param_summary
([(
'.*/W'
,
[
'histogram'
,
'sparsity'
])])
# monitor W
add_param_summary
([(
'.*/W'
,
[
'histogram'
])])
# monitor W
return
tf
.
add_n
([
cost
,
wd_cost
],
name
=
'cost'
)
return
tf
.
add_n
([
cost
,
wd_cost
],
name
=
'cost'
)
def
get_data
(
train_or_test
):
def
get_data
(
train_or_test
):
...
@@ -123,7 +126,7 @@ def get_config():
...
@@ -123,7 +126,7 @@ def get_config():
lr
=
tf
.
train
.
exponential_decay
(
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-2
,
learning_rate
=
1e-2
,
global_step
=
get_global_step_var
(),
global_step
=
get_global_step_var
(),
decay_steps
=
dataset_train
.
size
()
*
30
if
nr_gpu
==
1
else
20
,
decay_steps
=
step_per_epoch
*
30
if
nr_gpu
==
1
else
20
,
decay_rate
=
0.5
,
staircase
=
True
,
name
=
'learning_rate'
)
decay_rate
=
0.5
,
staircase
=
True
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
...
@@ -138,7 +141,7 @@ def get_config():
...
@@ -138,7 +141,7 @@ def get_config():
session_config
=
sess_config
,
session_config
=
sess_config
,
model
=
Model
(),
model
=
Model
(),
step_per_epoch
=
step_per_epoch
,
step_per_epoch
=
step_per_epoch
,
max_epoch
=
5
00
,
max_epoch
=
2
00
,
)
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tensorpack/models/_common.py
View file @
fbb73a8a
...
@@ -6,6 +6,7 @@ import tensorflow as tf
...
@@ -6,6 +6,7 @@ import tensorflow as tf
from
functools
import
wraps
from
functools
import
wraps
import
six
import
six
from
..tfutils
import
*
from
..tfutils.modelutils
import
*
from
..tfutils.modelutils
import
*
from
..tfutils.summary
import
*
from
..tfutils.summary
import
*
from
..utils
import
logger
from
..utils
import
logger
...
@@ -13,14 +14,13 @@ from ..utils import logger
...
@@ -13,14 +14,13 @@ from ..utils import logger
# make sure each layer is only logged once
# make sure each layer is only logged once
_layer_logged
=
set
()
_layer_logged
=
set
()
def
layer_register
(
summary_activation
=
False
):
def
layer_register
(
summary_activation
=
False
,
log_shape
=
True
):
"""
"""
Register a layer.
Register a layer.
Args:
:param summary_activation: Define the default behavior of whether to
summary_activation:
summary the output(activation) of this layer.
Define the default behavior of whether to
Can be overriden when creating the layer.
summary the output(activation) of this layer.
:param log_shape: log input/output shape of this layer
Can be overriden when creating the layer.
"""
"""
def
wrapper
(
func
):
def
wrapper
(
func
):
@
wraps
(
func
)
@
wraps
(
func
)
...
@@ -29,13 +29,16 @@ def layer_register(summary_activation=False):
...
@@ -29,13 +29,16 @@ def layer_register(summary_activation=False):
assert
isinstance
(
name
,
six
.
string_types
),
\
assert
isinstance
(
name
,
six
.
string_types
),
\
'name must be the first argument. Args: {}'
.
format
(
str
(
args
))
'name must be the first argument. Args: {}'
.
format
(
str
(
args
))
args
=
args
[
1
:]
args
=
args
[
1
:]
do_summary
=
kwargs
.
pop
(
do_summary
=
kwargs
.
pop
(
'summary_activation'
,
summary_activation
)
'summary_activation'
,
summary_activation
)
inputs
=
args
[
0
]
inputs
=
args
[
0
]
actual_args
=
get_arg_scope
()[
func
.
__name__
]
actual_args
.
update
(
kwargs
)
with
tf
.
variable_scope
(
name
)
as
scope
:
with
tf
.
variable_scope
(
name
)
as
scope
:
outputs
=
func
(
*
args
,
**
kw
args
)
outputs
=
func
(
*
args
,
**
actual_
args
)
if
scope
.
name
not
in
_layer_logged
:
if
log_shape
and
scope
.
name
not
in
_layer_logged
:
# log shape info and add activation
# log shape info and add activation
logger
.
info
(
"{} input: {}"
.
format
(
logger
.
info
(
"{} input: {}"
.
format
(
scope
.
name
,
get_shape_str
(
inputs
)))
scope
.
name
,
get_shape_str
(
inputs
)))
...
...
tensorpack/models/batch_norm.py
View file @
fbb73a8a
...
@@ -14,7 +14,7 @@ __all__ = ['BatchNorm']
...
@@ -14,7 +14,7 @@ __all__ = ['BatchNorm']
# TF batch_norm only works for 4D tensor right now: #804
# TF batch_norm only works for 4D tensor right now: #804
# decay: being too close to 1 leads to slow start-up, but ends up better
# decay: being too close to 1 leads to slow start-up, but ends up better
# eps: torch: 1e-5. Lasagne: 1e-4
# eps: torch: 1e-5. Lasagne: 1e-4
@
layer_register
()
@
layer_register
(
log_shape
=
False
)
def
BatchNorm
(
x
,
use_local_stat
=
True
,
decay
=
0.999
,
epsilon
=
1e-5
):
def
BatchNorm
(
x
,
use_local_stat
=
True
,
decay
=
0.999
,
epsilon
=
1e-5
):
"""
"""
Batch normalization layer as described in:
Batch normalization layer as described in:
...
...
tensorpack/models/nonlin.py
View file @
fbb73a8a
...
@@ -11,7 +11,7 @@ from .batch_norm import BatchNorm
...
@@ -11,7 +11,7 @@ from .batch_norm import BatchNorm
__all__
=
[
'Maxout'
,
'PReLU'
,
'LeakyReLU'
,
'BNReLU'
]
__all__
=
[
'Maxout'
,
'PReLU'
,
'LeakyReLU'
,
'BNReLU'
]
@
layer_register
()
@
layer_register
(
log_shape
=
False
)
def
Maxout
(
x
,
num_unit
):
def
Maxout
(
x
,
num_unit
):
"""
"""
Maxout networks as in `Maxout Networks <http://arxiv.org/abs/1302.4389>`_.
Maxout networks as in `Maxout Networks <http://arxiv.org/abs/1302.4389>`_.
...
@@ -27,7 +27,7 @@ def Maxout(x, num_unit):
...
@@ -27,7 +27,7 @@ def Maxout(x, num_unit):
x
=
tf
.
reshape
(
x
,
[
-
1
,
input_shape
[
1
],
input_shape
[
2
],
ch
/
3
,
3
])
x
=
tf
.
reshape
(
x
,
[
-
1
,
input_shape
[
1
],
input_shape
[
2
],
ch
/
3
,
3
])
return
tf
.
reduce_max
(
x
,
4
,
name
=
'output'
)
return
tf
.
reduce_max
(
x
,
4
,
name
=
'output'
)
@
layer_register
()
@
layer_register
(
log_shape
=
False
)
def
PReLU
(
x
,
init
=
tf
.
constant_initializer
(
0.001
),
name
=
None
):
def
PReLU
(
x
,
init
=
tf
.
constant_initializer
(
0.001
),
name
=
None
):
"""
"""
Parameterized relu as in `Delving Deep into Rectifiers: Surpassing
Parameterized relu as in `Delving Deep into Rectifiers: Surpassing
...
@@ -44,7 +44,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None):
...
@@ -44,7 +44,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None):
else
:
else
:
return
tf
.
mul
(
x
,
0.5
,
name
=
name
)
return
tf
.
mul
(
x
,
0.5
,
name
=
name
)
@
layer_register
()
@
layer_register
(
log_shape
=
False
)
def
LeakyReLU
(
x
,
alpha
,
name
=
None
):
def
LeakyReLU
(
x
,
alpha
,
name
=
None
):
"""
"""
Leaky relu as in `Rectifier Nonlinearities Improve Neural Network Acoustic
Leaky relu as in `Rectifier Nonlinearities Improve Neural Network Acoustic
...
@@ -66,9 +66,8 @@ def BNReLU(is_training):
...
@@ -66,9 +66,8 @@ def BNReLU(is_training):
"""
"""
:returns: a activation function that performs BN + ReLU (a too common combination)
:returns: a activation function that performs BN + ReLU (a too common combination)
"""
"""
def
f
(
x
,
name
=
None
):
def
BNReLU
(
x
,
name
=
None
):
with
tf
.
variable_scope
(
'bn'
):
x
=
BatchNorm
(
'bn'
,
x
,
is_training
)
x
=
BatchNorm
.
f
(
x
,
is_training
)
x
=
tf
.
nn
.
relu
(
x
,
name
=
name
)
x
=
tf
.
nn
.
relu
(
x
,
name
=
name
)
return
x
return
x
return
f
return
BNReLU
tensorpack/tfutils/__init__.py
View file @
fbb73a8a
...
@@ -14,4 +14,5 @@ def _global_import(name):
...
@@ -14,4 +14,5 @@ def _global_import(name):
_global_import
(
'sessinit'
)
_global_import
(
'sessinit'
)
_global_import
(
'common'
)
_global_import
(
'common'
)
_global_import
(
'gradproc'
)
_global_import
(
'gradproc'
)
_global_import
(
'argscope'
)
tensorpack/tfutils/argscope.py
0 → 100644
View file @
fbb73a8a
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: argscope.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
contextlib
import
contextmanager
from
collections
import
defaultdict
import
inspect
import
copy
import
six
__all__
=
[
'argscope'
,
'get_arg_scope'
]
_ArgScopeStack
=
[]
@
contextmanager
def
argscope
(
layers
,
**
kwargs
):
param
=
kwargs
if
not
isinstance
(
layers
,
list
):
layers
=
[
layers
]
def
_check_args_exist
(
l
):
args
=
inspect
.
getargspec
(
l
)
.
args
for
k
,
v
in
six
.
iteritems
(
param
):
assert
k
in
args
,
"No argument {} in {}"
.
format
(
k
,
l
.
__name__
)
for
l
in
layers
:
assert
hasattr
(
l
,
'f'
),
"{} is not a registered layer"
.
format
(
l
.
__name__
)
_check_args_exist
(
l
.
f
)
new_scope
=
copy
.
copy
(
get_arg_scope
())
for
l
in
layers
:
new_scope
[
l
.
__name__
]
.
update
(
param
)
_ArgScopeStack
.
append
(
new_scope
)
yield
del
_ArgScopeStack
[
-
1
]
def
get_arg_scope
():
if
len
(
_ArgScopeStack
)
>
0
:
return
_ArgScopeStack
[
-
1
]
else
:
return
defaultdict
(
dict
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment