Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
087dc382
Commit
087dc382
authored
Jan 02, 2016
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[WIP] tower training
parent
95037482
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
82 additions
and
30 deletions
+82
-30
example_cifar10.py
example_cifar10.py
+5
-3
requirements.txt
requirements.txt
+0
-1
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+3
-0
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+0
-2
tensorpack/train.py
tensorpack/train.py
+73
-23
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+1
-0
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+0
-1
No files found.
example_cifar10.py
View file @
087dc382
...
...
@@ -56,7 +56,7 @@ def get_model(inputs, is_training):
y
=
one_hot
(
label
,
10
)
cost
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
logits
,
y
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
tf
.
add_to_collection
(
COST
_VARS_KEY
,
cost
)
tf
.
add_to_collection
(
SUMMARY
_VARS_KEY
,
cost
)
# compute the number of failed samples, for ValidationError to use at test time
wrong
=
tf
.
not_equal
(
...
...
@@ -71,7 +71,7 @@ def get_model(inputs, is_training):
wd_cost
=
tf
.
mul
(
1e-4
,
regularize_cost
(
'fc.*/W'
,
tf
.
nn
.
l2_loss
),
name
=
'regularize_loss'
)
tf
.
add_to_collection
(
COST
_VARS_KEY
,
wd_cost
)
tf
.
add_to_collection
(
SUMMARY
_VARS_KEY
,
wd_cost
)
add_histogram_summary
(
'.*/W'
)
# monitor histogram of all W
return
[
prob
,
nr_wrong
],
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'cost'
)
...
...
@@ -105,6 +105,7 @@ def get_config():
sess_config
=
get_default_sess_config
()
sess_config
.
gpu_options
.
per_process_gpu_memory_fraction
=
0.5
sess_config
.
device_count
[
'GPU'
]
=
2
# prepare model
input_vars
=
[
...
...
@@ -149,6 +150,7 @@ if __name__ == '__main__':
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
with
tf
.
Graph
()
.
as_default
():
with
tf
.
device
(
'/cpu:0'
):
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
...
...
requirements.txt
View file @
087dc382
pip
@ https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.6.0-cp27-none-linux_x86_64.whl
termcolor
numpy
protobuf
~=3.0.0a1
...
...
tensorpack/callbacks/common.py
View file @
087dc382
...
...
@@ -5,6 +5,7 @@
import
tensorflow
as
tf
import
os
import
re
from
.base
import
Callback
,
PeriodicCallback
from
..utils
import
*
...
...
@@ -47,6 +48,8 @@ class SummaryWriter(Callback):
summary_str
=
self
.
summary_op
.
eval
()
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
for
val
in
summary
.
value
:
#print val.tag
val
.
tag
=
re
.
sub
(
'tower[0-9]*/'
,
''
,
val
.
tag
)
if
val
.
tag
in
self
.
print_tag
:
assert
val
.
WhichOneof
(
'value'
)
==
'simple_value'
,
\
'Cannot print summary {}: not a simple_value summary!'
.
format
(
val
.
tag
)
...
...
tensorpack/callbacks/group.py
View file @
087dc382
...
...
@@ -32,8 +32,6 @@ def create_test_graph():
for
v
in
input_vars
:
Gtest
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
output_vars
,
cost
=
forward_func
(
input_vars
,
is_training
=
False
)
for
v
in
output_vars
:
Gtest
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
yield
Gtest
@
contextmanager
...
...
tensorpack/train.py
View file @
087dc382
...
...
@@ -5,6 +5,7 @@
import
tensorflow
as
tf
from
itertools
import
count
import
copy
import
argparse
import
tqdm
...
...
@@ -71,20 +72,43 @@ class TrainConfig(object):
assert
self
.
step_per_epoch
>
0
and
self
.
max_epoch
>
0
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
get_train_op
(
optimizer
,
cost_var
):
global_step_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
def
average_gradients
(
tower_grads
):
"""Calculate the average gradient for each shared variable across all towers.
avg_maintain_op
=
summary_moving_average
(
cost_var
)
# maintain average in each step
with
tf
.
control_dependencies
([
avg_maintain_op
]):
grads
=
optimizer
.
compute_gradients
(
cost_var
)
Note that this function provides a synchronization point across all towers.
for
grad
,
var
in
grads
:
if
grad
:
tf
.
histogram_summary
(
var
.
op
.
name
+
'/gradients'
,
grad
)
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over individual gradients. The inner list is over the gradient
calculation for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been averaged
across all towers.
"""
average_grads
=
[]
for
grad_and_vars
in
zip
(
*
tower_grads
):
# Note that each grad_and_vars looks like the following:
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
grads
=
[]
for
g
,
_
in
grad_and_vars
:
# Add 0 dimension to the gradients to represent the tower.
expanded_g
=
tf
.
expand_dims
(
g
,
0
)
# Append on a 'tower' dimension which we will average over below.
grads
.
append
(
expanded_g
)
# Average over the 'tower' dimension.
grad
=
tf
.
concat
(
0
,
grads
)
grad
=
tf
.
reduce_mean
(
grad
,
0
)
# Keep in mind that the Variables are redundant because they are shared
# across towers. So .. we will just return the first tower's pointer to
# the Variable.
v
=
grad_and_vars
[
0
][
1
]
grad_and_var
=
(
grad
,
v
)
average_grads
.
append
(
grad_and_var
)
return
average_grads
return
optimizer
.
apply_gradients
(
grads
,
global_step_var
)
def
start_train
(
config
):
"""
...
...
@@ -96,27 +120,53 @@ def start_train(config):
input_queue
=
config
.
input_queue
callbacks
=
config
.
callbacks
if
config
.
batched_model_input
:
enqueue_op
=
input_queue
.
enqueue
(
input_vars
)
def
get_model_inputs
():
model_inputs
=
input_queue
.
dequeue
()
for
qv
,
v
in
zip
(
model_inputs
,
input_vars
):
if
config
.
batched_model_input
:
qv
.
set_shape
(
v
.
get_shape
())
else
:
enqueue_op
=
input_queue
.
enqueue_many
(
input_vars
)
model_inputs
=
input_queue
.
dequeue
()
for
qv
,
v
in
zip
(
model_inputs
,
input_vars
):
qv
.
set_shape
(
v
.
get_shape
()
.
as_list
()[
1
:])
return
model_inputs
if
config
.
batched_model_input
:
enqueue_op
=
input_queue
.
enqueue
(
input_vars
)
else
:
enqueue_op
=
input_queue
.
enqueue_many
(
input_vars
)
keys_to_maintain
=
[
tf
.
GraphKeys
.
SUMMARIES
,
SUMMARY_VARS_KEY
]
olds
=
{}
for
k
in
keys_to_maintain
:
olds
[
k
]
=
copy
.
copy
(
tf
.
get_collection
(
k
))
all_grads
=
[]
n_tower
=
1
for
i
in
range
(
n_tower
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
i
)):
with
tf
.
name_scope
(
'tower{}'
.
format
(
i
)):
for
k
in
keys_to_maintain
:
del
tf
.
get_collection
(
k
)[:]
model_inputs
=
get_model_inputs
()
output_vars
,
cost_var
=
config
.
get_model_func
(
model_inputs
,
is_training
=
True
)
tf
.
get_variable_scope
()
.
reuse_variables
()
grads
=
config
.
optimizer
.
compute_gradients
(
cost_var
)
all_grads
.
append
(
grads
)
for
k
in
keys_to_maintain
:
tf
.
get_collection
(
k
)
.
extend
(
olds
[
k
])
grads
=
average_gradients
(
all_grads
)
for
grad
,
var
in
grads
:
if
grad
:
tf
.
histogram_summary
(
var
.
op
.
name
+
'/gradients'
,
grad
)
avg_maintain_op
=
summary_moving_average
(
cost_var
)
# build graph
tf
.
add_to_collection
(
FORWARD_FUNC_KEY
,
config
.
get_model_func
)
for
v
in
input_vars
:
tf
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
for
v
in
output_vars
:
tf
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
describe_model
()
train_op
=
get_train_op
(
config
.
optimizer
,
cost_var
)
# train_op = get_train_op(config.optimizer, cost_var)
with
tf
.
control_dependencies
([
avg_maintain_op
]):
train_op
=
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
())
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
sess
)
...
...
tensorpack/utils/concurrency.py
View file @
087dc382
...
...
@@ -44,6 +44,7 @@ class EnqueueThread(threading.Thread):
return
feed
=
dict
(
izip
(
self
.
input_vars
,
dp
))
self
.
sess
.
run
([
self
.
op
],
feed_dict
=
feed
)
#print '\nExauhsted!!!'
except
tf
.
errors
.
CancelledError
as
e
:
pass
except
Exception
:
...
...
tensorpack/utils/naming.py
View file @
087dc382
...
...
@@ -9,7 +9,6 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
SUMMARY_WRITER_COLLECTION_KEY
=
'summary_writer'
INPUT_VARS_KEY
=
'INPUT_VARIABLES'
OUTPUT_VARS_KEY
=
'OUTPUT_VARIABLES'
COST_VARS_KEY
=
'COST_VARIABLES'
# keep track of each individual cost
SUMMARY_VARS_KEY
=
'SUMMARY_VARIABLES'
# extra variables to summarize during training
FORWARD_FUNC_KEY
=
'FORWARD_FUNCTION'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment