Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
47f76e94
Commit
47f76e94
authored
Jun 08, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix everything (hopefully)
parent
e072d909
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
179 additions
and
171 deletions
+179
-171
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+0
-4
examples/Inception/inception-bn.py
examples/Inception/inception-bn.py
+6
-7
examples/ResNet/cifar10-resnet.py
examples/ResNet/cifar10-resnet.py
+4
-4
examples/ResNet/svhn-resnet.py
examples/ResNet/svhn-resnet.py
+11
-20
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+1
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+14
-7
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+3
-1
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+1
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+131
-0
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+7
-125
No files found.
examples/Atari2600/DQN.py
View file @
47f76e94
...
@@ -14,12 +14,8 @@ import multiprocessing, threading
...
@@ -14,12 +14,8 @@ import multiprocessing, threading
from
collections
import
deque
from
collections
import
deque
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.models
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils.concurrency
import
*
from
tensorpack.utils.concurrency
import
*
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.callbacks
import
*
from
tensorpack.RL
import
*
from
tensorpack.RL
import
*
import
common
import
common
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
...
...
examples/Inception/inception-bn.py
View file @
47f76e94
...
@@ -196,10 +196,9 @@ if __name__ == '__main__':
...
@@ -196,10 +196,9 @@ if __name__ == '__main__':
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
with
tf
.
Graph
()
.
as_default
():
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
gpu
:
if
args
.
gpu
:
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
SyncMultiGPUTrainer
(
config
)
.
train
()
QueueInputTrainer
(
config
)
.
train
()
examples/ResNet/cifar10-resnet.py
View file @
47f76e94
...
@@ -141,6 +141,10 @@ def get_data(train_or_test):
...
@@ -141,6 +141,10 @@ def get_data(train_or_test):
return
ds
return
ds
def
get_config
():
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
# prepare dataset
# prepare dataset
dataset_train
=
get_data
(
'train'
)
dataset_train
=
get_data
(
'train'
)
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
dataset_train
.
size
()
...
@@ -174,10 +178,6 @@ if __name__ == '__main__':
...
@@ -174,10 +178,6 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
...
...
examples/ResNet/svhn-resnet.py
View file @
47f76e94
...
@@ -8,16 +8,9 @@ import argparse
...
@@ -8,16 +8,9 @@ import argparse
import
numpy
as
np
import
numpy
as
np
import
os
import
os
from
tensorpack.train
import
TrainConfig
,
QueueInputTrainer
from
tensorpack
import
*
from
tensorpack.models
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.utils
import
*
from
tensorpack.tfutils
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.dataflow
import
*
from
tensorpack.dataflow
import
imgaug
"""
"""
ResNet-110 for SVHN Digit Classification.
ResNet-110 for SVHN Digit Classification.
...
@@ -151,6 +144,10 @@ def get_data(train_or_test):
...
@@ -151,6 +144,10 @@ def get_data(train_or_test):
return
ds
return
ds
def
get_config
():
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
# prepare dataset
# prepare dataset
dataset_train
=
get_data
(
'train'
)
dataset_train
=
get_data
(
'train'
)
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
dataset_train
.
size
()
...
@@ -184,18 +181,12 @@ if __name__ == '__main__':
...
@@ -184,18 +181,12 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
with
tf
.
Graph
()
.
as_default
():
config
=
get_config
()
with
tf
.
device
(
'/cpu:0'
):
if
args
.
load
:
config
=
get_config
()
config
.
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
:
if
args
.
gpu
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
if
args
.
gpu
:
SyncMultiGPUTrainer
(
config
)
.
train
()
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
QueueInputTrainer
(
config
)
.
train
()
tensorpack/callbacks/common.py
View file @
47f76e94
...
@@ -36,7 +36,7 @@ class ModelSaver(Callback):
...
@@ -36,7 +36,7 @@ class ModelSaver(Callback):
var_dict
=
{}
var_dict
=
{}
for
v
in
vars
:
for
v
in
vars
:
name
=
v
.
op
.
name
name
=
v
.
op
.
name
if
re
.
match
(
'tower[1-9]'
,
name
):
if
re
.
match
(
'tower[
p
1-9]'
,
name
):
#logger.info("Skip {} when saving model.".format(name))
#logger.info("Skip {} when saving model.".format(name))
continue
continue
if
'tower0/'
in
name
:
if
'tower0/'
in
name
:
...
...
tensorpack/models/batch_norm.py
View file @
47f76e94
...
@@ -56,6 +56,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
...
@@ -56,6 +56,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
else
:
else
:
# use training-statistics in prediction
assert
not
use_local_stat
assert
not
use_local_stat
# have to do this again to get actual name. see issue:
# have to do this again to get actual name. see issue:
# https://github.com/tensorflow/tensorflow/issues/2740
# https://github.com/tensorflow/tensorflow/issues/2740
...
@@ -63,15 +64,21 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
...
@@ -63,15 +64,21 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
mean_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
ema_mean
.
name
)
var_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
ema_var
.
name
)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
G
=
tf
.
get_default_graph
()
G
=
tf
.
get_default_graph
()
ema_mean
=
G
.
get_tensor_by_name
(
mean_name
)
try
:
ema_var
=
G
.
get_tensor_by_name
(
var_name
)
mean_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
ema_mean
.
name
)
var_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
ema_var
.
name
)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
ema_mean
=
G
.
get_tensor_by_name
(
mean_name
)
ema_var
=
G
.
get_tensor_by_name
(
var_name
)
except
KeyError
:
mean_name
=
re
.
sub
(
'towerp[0-9]+/'
,
'tower0/'
,
ema_mean
.
name
)
var_name
=
re
.
sub
(
'towerp[0-9]+/'
,
'tower0/'
,
ema_var
.
name
)
ema_mean
=
G
.
get_tensor_by_name
(
mean_name
)
ema_var
=
G
.
get_tensor_by_name
(
var_name
)
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
if
use_local_stat
:
if
use_local_stat
:
with
tf
.
control_dependencies
([
ema_apply_op
]):
with
tf
.
control_dependencies
([
ema_apply_op
]):
...
...
tensorpack/tfutils/sessinit.py
View file @
47f76e94
...
@@ -106,8 +106,10 @@ class SaverRestore(SessionInit):
...
@@ -106,8 +106,10 @@ class SaverRestore(SessionInit):
var_dict
=
defaultdict
(
list
)
var_dict
=
defaultdict
(
list
)
for
v
in
vars_to_restore
:
for
v
in
vars_to_restore
:
name
=
v
.
op
.
name
name
=
v
.
op
.
name
if
'towerp'
in
name
:
logger
.
warn
(
"Anything from prediction tower shouldn't be saved."
)
if
'tower'
in
name
:
if
'tower'
in
name
:
new_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
name
)
new_name
=
re
.
sub
(
'tower[
p
0-9]+/'
,
''
,
name
)
name
=
new_name
name
=
new_name
if
name
in
vars_available
:
if
name
in
vars_available
:
var_dict
[
name
]
.
append
(
v
)
var_dict
[
name
]
.
append
(
v
)
...
...
tensorpack/tfutils/summary.py
View file @
47f76e94
...
@@ -90,7 +90,7 @@ def summary_moving_average():
...
@@ -90,7 +90,7 @@ def summary_moving_average():
vars_to_summary
=
tf
.
get_collection
(
MOVING_SUMMARY_VARS_KEY
)
vars_to_summary
=
tf
.
get_collection
(
MOVING_SUMMARY_VARS_KEY
)
avg_maintain_op
=
averager
.
apply
(
vars_to_summary
)
avg_maintain_op
=
averager
.
apply
(
vars_to_summary
)
for
idx
,
c
in
enumerate
(
vars_to_summary
):
for
idx
,
c
in
enumerate
(
vars_to_summary
):
name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
c
.
op
.
name
)
name
=
re
.
sub
(
'tower[
p
0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
scalar_summary
(
name
,
averager
.
average
(
c
))
tf
.
scalar_summary
(
name
,
averager
.
average
(
c
))
return
avg_maintain_op
return
avg_maintain_op
tensorpack/train/base.py
View file @
47f76e94
...
@@ -88,7 +88,7 @@ class Trainer(object):
...
@@ -88,7 +88,7 @@ class Trainer(object):
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
for
val
in
summary
.
value
:
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
val
.
tag
)
# TODO move to subclasses
val
.
tag
=
re
.
sub
(
'tower[
p
0-9]+/'
,
''
,
val
.
tag
)
# TODO move to subclasses
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
summary_writer
.
add_summary
(
summary
,
self
.
global_step
)
self
.
summary_writer
.
add_summary
(
summary
,
self
.
global_step
)
...
...
tensorpack/train/multigpu.py
0 → 100644
View file @
47f76e94
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: multigpu.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
six.moves
import
zip
,
range
from
..utils
import
*
from
..utils.concurrency
import
LoopThread
from
..tfutils.summary
import
summary_moving_average
from
..tfutils.modelutils
import
describe_model
from
..tfutils
import
*
from
.trainer
import
QueueInputTrainer
__all__
=
[
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
class
MultiGPUTrainer
(
QueueInputTrainer
):
""" Base class for multi-gpu training"""
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
super
(
MultiGPUTrainer
,
self
)
.
__init__
(
config
,
input_queue
,
predict_tower
)
self
.
dequed_inputs
=
[]
@
staticmethod
def
_average_grads
(
tower_grads
):
ret
=
[]
for
grad_and_vars
in
zip
(
*
tower_grads
):
v
=
grad_and_vars
[
0
][
1
]
try
:
grad
=
tf
.
add_n
([
x
[
0
]
for
x
in
grad_and_vars
])
/
float
(
len
(
tower_grads
))
except
:
logger
.
error
(
"Error while processing gradients of {}"
.
format
(
v
.
name
))
raise
ret
.
append
((
grad
,
v
))
return
ret
def
_multi_tower_grads
(
self
):
logger
.
info
(
"Training a model of {} tower"
.
format
(
self
.
config
.
nr_tower
))
grad_list
=
[]
for
i
in
range
(
self
.
config
.
nr_tower
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
i
)),
\
tf
.
name_scope
(
'tower{}'
.
format
(
i
))
as
scope
:
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
i
))
model_inputs
=
self
.
_get_model_inputs
()
# each tower dequeue from input queue
self
.
dequed_inputs
.
append
(
model_inputs
)
self
.
model
.
build_graph
(
model_inputs
,
True
)
cost_var
=
self
.
model
.
get_cost
()
# build tower
# TODO gate_gradienst=0 seems to be faster?
grad_list
.
append
(
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
,
gate_gradients
=
0
))
if
i
==
0
:
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
tf
.
get_variable_scope
()
.
reuse_variables
()
# avoid repeated summary from each device
backup
=
backup_collection
(
self
.
SUMMARY_BACKUP_KEYS
)
restore_collection
(
backup
)
return
grad_list
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
):
def
train
(
self
):
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
grad_list
=
self
.
_multi_tower_grads
()
grads
=
MultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
self
.
process_grads
(
grads
)
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
())
describe_model
()
with
freeze_collection
(
self
.
SUMMARY_BACKUP_KEYS
):
self
.
_build_predict_tower
()
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self
.
main_loop
()
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
):
def
train
(
self
):
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
grad_list
=
self
.
_multi_tower_grads
()
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def
scale
(
grads
):
return
[(
grad
/
self
.
config
.
nr_tower
,
var
)
for
grad
,
var
in
grads
]
grad_list
=
map
(
scale
,
grad_list
)
grad_list
=
[
self
.
process_grads
(
g
)
for
g
in
grad_list
]
# use grad from the first tower for iteration in main thread
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
0
],
get_global_step_var
()),
summary_moving_average
())
describe_model
()
# prepare train_op for the rest of the towers
self
.
training_threads
=
[]
for
k
in
range
(
1
,
self
.
config
.
nr_tower
):
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
f
=
lambda
op
=
train_op
:
self
.
sess
.
run
([
op
])
# avoid late-binding
th
=
LoopThread
(
f
)
th
.
pause
()
th
.
start
()
self
.
training_threads
.
append
(
th
)
self
.
async_running
=
False
with
freeze_collection
(
self
.
SUMMARY_BACKUP_KEYS
):
self
.
_build_predict_tower
()
self
.
main_loop
()
def
run_step
(
self
):
if
not
self
.
async_running
:
self
.
async_running
=
True
for
th
in
self
.
training_threads
:
# resume all threads
th
.
resume
()
super
(
AsyncMultiGPUTrainer
,
self
)
.
run_step
()
def
_trigger_epoch
(
self
):
self
.
async_running
=
False
for
th
in
self
.
training_threads
:
th
.
pause
()
super
(
AsyncMultiGPUTrainer
,
self
)
.
_trigger_epoch
()
tensorpack/train/trainer.py
View file @
47f76e94
...
@@ -5,20 +5,16 @@
...
@@ -5,20 +5,16 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
import
threading
import
threading
import
time
import
time
import
re
import
functools
from
six.moves
import
zip
from
six.moves
import
zip
from
.base
import
Trainer
from
.base
import
Trainer
from
..dataflow.common
import
RepeatedData
from
..dataflow.common
import
RepeatedData
from
..utils
import
*
from
..utils
import
*
from
..utils.concurrency
import
LoopThread
from
..tfutils.summary
import
summary_moving_average
from
..tfutils.summary
import
summary_moving_average
from
..tfutils.modelutils
import
describe_model
from
..tfutils.modelutils
import
describe_model
from
..tfutils
import
*
from
..tfutils
import
*
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
,
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
class
SimpleTrainer
(
Trainer
):
class
SimpleTrainer
(
Trainer
):
def
run_step
(
self
):
def
run_step
(
self
):
...
@@ -110,6 +106,7 @@ class QueueInputTrainer(Trainer):
...
@@ -110,6 +106,7 @@ class QueueInputTrainer(Trainer):
:param config: a `TrainConfig` instance
:param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 100.
Defaults to a FIFO queue of size 100.
:param predict_tower: list of gpu idx to run prediction. default to be [0].
"""
"""
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
self
.
input_vars
=
self
.
model
.
get_input_vars
()
self
.
input_vars
=
self
.
model
.
get_input_vars
()
...
@@ -119,7 +116,7 @@ class QueueInputTrainer(Trainer):
...
@@ -119,7 +116,7 @@ class QueueInputTrainer(Trainer):
else
:
else
:
self
.
input_queue
=
input_queue
self
.
input_queue
=
input_queue
if
predict_tower
is
None
:
if
predict_tower
is
None
:
# by default, use
first training tower
for prediction
# by default, use
the first training gpu
for prediction
predict_tower
=
[
0
]
predict_tower
=
[
0
]
self
.
predict_tower
=
predict_tower
self
.
predict_tower
=
predict_tower
self
.
dequed_inputs
=
None
self
.
dequed_inputs
=
None
...
@@ -144,7 +141,7 @@ class QueueInputTrainer(Trainer):
...
@@ -144,7 +141,7 @@ class QueueInputTrainer(Trainer):
self
.
model
.
build_graph
(
inputs
,
False
)
self
.
model
.
build_graph
(
inputs
,
False
)
def
_single_tower_grad
(
self
):
def
_single_tower_grad
(
self
):
""" Get grad and cost for single-tower
case
"""
""" Get grad and cost for single-tower"""
self
.
dequed_inputs
=
model_inputs
=
self
.
_get_model_inputs
()
self
.
dequed_inputs
=
model_inputs
=
self
.
_get_model_inputs
()
self
.
model
.
build_graph
(
model_inputs
,
True
)
self
.
model
.
build_graph
(
model_inputs
,
True
)
cost_var
=
self
.
model
.
get_cost
()
cost_var
=
self
.
model
.
get_cost
()
...
@@ -153,13 +150,14 @@ class QueueInputTrainer(Trainer):
...
@@ -153,13 +150,14 @@ class QueueInputTrainer(Trainer):
return
grads
return
grads
def
_build_enque_thread
(
self
):
def
_build_enque_thread
(
self
):
# create a thread that keeps filling the queue
""" create a thread that keeps filling the queue """
enqueue_op
=
self
.
input_queue
.
enqueue
(
self
.
input_vars
)
enqueue_op
=
self
.
input_queue
.
enqueue
(
self
.
input_vars
)
self
.
input_th
=
EnqueueThread
(
self
,
self
.
input_queue
,
enqueue_op
,
self
.
input_vars
)
self
.
input_th
=
EnqueueThread
(
self
,
self
.
input_queue
,
enqueue_op
,
self
.
input_vars
)
self
.
extra_threads_procs
.
append
(
self
.
input_th
)
self
.
extra_threads_procs
.
append
(
self
.
input_th
)
def
train
(
self
):
def
train
(
self
):
assert
self
.
config
.
nr_tower
==
1
,
"QueueInputTrainer only supports 1 tower!"
assert
self
.
config
.
nr_tower
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
self
.
init_session_and_coord
()
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
self
.
_build_enque_thread
()
...
@@ -207,119 +205,3 @@ class QueueInputTrainer(Trainer):
...
@@ -207,119 +205,3 @@ class QueueInputTrainer(Trainer):
return
[
self
.
get_predict_func
(
input_names
,
output_names
,
k
)
return
[
self
.
get_predict_func
(
input_names
,
output_names
,
k
)
for
k
in
range
(
n
)]
for
k
in
range
(
n
)]
class
MultiGPUTrainer
(
QueueInputTrainer
):
""" Base class for multi-gpu training"""
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
super
(
MultiGPUTrainer
,
self
)
.
__init__
(
config
,
input_queue
,
predict_tower
)
assert
config
.
nr_tower
>
1
self
.
dequed_inputs
=
[]
@
staticmethod
def
_average_grads
(
tower_grads
):
ret
=
[]
for
grad_and_vars
in
zip
(
*
tower_grads
):
v
=
grad_and_vars
[
0
][
1
]
try
:
grad
=
tf
.
add_n
([
x
[
0
]
for
x
in
grad_and_vars
])
/
float
(
len
(
tower_grads
))
except
AssertionError
:
logger
.
error
(
"Error while processing gradients of {}"
.
format
(
v
.
name
))
raise
ret
.
append
((
grad
,
v
))
return
ret
def
_multi_tower_grads
(
self
):
logger
.
info
(
"Training a model of {} tower"
.
format
(
self
.
config
.
nr_tower
))
grad_list
=
[]
for
i
in
range
(
self
.
config
.
nr_tower
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
i
)),
\
tf
.
name_scope
(
'tower{}'
.
format
(
i
))
as
scope
:
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
i
))
model_inputs
=
self
.
_get_model_inputs
()
# each tower dequeue from input queue
self
.
dequed_inputs
.
append
(
model_inputs
)
self
.
model
.
build_graph
(
model_inputs
,
True
)
cost_var
=
self
.
model
.
get_cost
()
# build tower
# gate_gradienst=0 seems to be faster?
grad_list
.
append
(
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
,
gate_gradients
=
0
))
if
i
==
0
:
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
tf
.
get_variable_scope
()
.
reuse_variables
()
# avoid repeated summary from each device
backup
=
backup_collection
(
self
.
SUMMARY_BACKUP_KEYS
)
restore_collection
(
backup
)
return
grad_list
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
):
def
train
(
self
):
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
grad_list
=
self
.
_multi_tower_grads
()
grads
=
MultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
self
.
process_grads
(
grads
)
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
())
describe_model
()
self
.
_build_predict_tower
()
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self
.
main_loop
()
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
):
def
train
(
self
):
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
grad_list
=
self
.
_multi_tower_grads
()
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def
scale
(
grads
):
return
[(
grad
/
self
.
config
.
nr_tower
,
var
)
for
grad
,
var
in
grads
]
grad_list
=
map
(
scale
,
grad_list
)
grad_list
=
[
self
.
process_grads
(
g
)
for
g
in
grad_list
]
grads
=
grad_list
[
0
]
# use grad from the first tower for the main iteration
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
())
describe_model
()
# prepare train_op for the rest of the towers
self
.
threads
=
[]
for
k
in
range
(
1
,
self
.
config
.
nr_tower
):
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
f
=
lambda
op
=
train_op
:
self
.
sess
.
run
([
op
])
# avoid late-binding
th
=
LoopThread
(
f
)
th
.
pause
()
th
.
start
()
self
.
threads
.
append
(
th
)
self
.
async_running
=
False
self
.
_build_predict_tower
()
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self
.
main_loop
()
def
run_step
(
self
):
if
not
self
.
async_running
:
self
.
async_running
=
True
for
th
in
self
.
threads
:
# resume all threads
th
.
resume
()
self
.
sess
.
run
([
self
.
train_op
])
# faster since train_op return None
def
_trigger_epoch
(
self
):
self
.
async_running
=
False
for
th
in
self
.
threads
:
th
.
pause
()
if
self
.
summary_op
is
not
None
:
summary_str
=
self
.
summary_op
.
eval
()
self
.
_process_summary
(
summary_str
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment