Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
3f743301
Commit
3f743301
authored
Apr 24, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
async training.
parent
08821b55
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
17 deletions
+78
-17
tensorpack/train/base.py
tensorpack/train/base.py
+2
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+52
-15
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+24
-1
No files found.
tensorpack/train/base.py
View file @
3f743301
...
...
@@ -48,6 +48,7 @@ class Trainer(object):
@
abstractmethod
def
_trigger_epoch
(
self
):
""" This is called right after all steps in an epoch are finished"""
pass
def
_init_summary
(
self
):
...
...
@@ -94,7 +95,7 @@ class Trainer(object):
if
self
.
coord
.
should_stop
():
return
self
.
run_step
()
callbacks
.
trigger_step
()
#callbacks.trigger_step() # not useful?
self
.
global_step
+=
1
self
.
trigger_epoch
()
except
(
KeyboardInterrupt
,
Exception
):
...
...
tensorpack/train/trainer.py
View file @
3f743301
...
...
@@ -11,6 +11,7 @@ from six.moves import zip
from
.base
import
Trainer
from
..dataflow.common
import
RepeatedData
from
..utils
import
*
from
..utils.concurrency
import
LoopThread
from
..tfutils.summary
import
summary_moving_average
from
..tfutils
import
*
...
...
@@ -79,13 +80,14 @@ class EnqueueThread(threading.Thread):
finally
:
logger
.
info
(
"Enqueue Thread Exited."
)
class
QueueInputTrainer
(
Trainer
):
"""
Trainer which builds a FIFO queue for input.
Support multi GPU.
"""
def
__init__
(
self
,
config
,
input_queue
=
None
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
async
=
False
):
"""
:param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
...
...
@@ -98,6 +100,9 @@ class QueueInputTrainer(Trainer):
100
,
[
x
.
dtype
for
x
in
self
.
input_vars
],
name
=
'input_queue'
)
else
:
self
.
input_queue
=
input_queue
self
.
async
=
async
if
self
.
async
:
assert
self
.
config
.
nr_tower
>
1
@
staticmethod
def
_average_grads
(
tower_grads
):
...
...
@@ -122,14 +127,15 @@ class QueueInputTrainer(Trainer):
qv
.
set_shape
(
v
.
get_shape
())
return
ret
def
_single_tower_grad
_cost
(
self
):
def
_single_tower_grad
(
self
):
""" Get grad and cost for single-tower case"""
model_inputs
=
self
.
_get_model_inputs
()
cost_var
=
self
.
model
.
get_cost
(
model_inputs
,
is_training
=
True
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
return
(
grads
,
cost_var
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
return
grads
def
_multi_tower_grad
_cost
(
self
):
def
_multi_tower_grad
s
(
self
):
logger
.
info
(
"Training a model of {} tower"
.
format
(
self
.
config
.
nr_tower
))
# to avoid repeated summary from each device
...
...
@@ -140,6 +146,7 @@ class QueueInputTrainer(Trainer):
for
i
in
range
(
self
.
config
.
nr_tower
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
i
)),
\
tf
.
name_scope
(
'tower{}'
.
format
(
i
))
as
scope
:
logger
.
info
(
"Building graph for tower {}..."
.
format
(
i
))
model_inputs
=
self
.
_get_model_inputs
()
# each tower dequeue from input queue
cost_var
=
self
.
model
.
get_cost
(
model_inputs
,
is_training
=
True
)
# build tower
...
...
@@ -148,30 +155,49 @@ class QueueInputTrainer(Trainer):
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
,
gate_gradients
=
0
))
if
i
==
0
:
cost_var_t0
=
cost_var
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
tf
.
get_variable_scope
()
.
reuse_variables
()
for
k
in
collect_dedup
:
kept_summaries
[
k
]
=
copy
.
copy
(
tf
.
get_collection
(
k
))
logger
.
info
(
"Graph built for tower {}."
.
format
(
i
))
for
k
in
collect_dedup
:
del
tf
.
get_collection_ref
(
k
)[:]
tf
.
get_collection_ref
(
k
)
.
extend
(
kept_summaries
[
k
])
grads
=
QueueInputTrainer
.
_average_grads
(
grad_list
)
return
(
grads
,
cost_var_t0
)
return
grad_list
def
train
(
self
):
enqueue_op
=
self
.
input_queue
.
enqueue
(
self
.
input_vars
)
grads
,
cost_var
=
self
.
_single_tower_grad_cost
()
\
if
self
.
config
.
nr_tower
==
0
else
self
.
_multi_tower_grad_cost
()
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
avg_maintain_op
=
summary_moving_average
()
grads
=
self
.
process_grads
(
grads
)
if
self
.
config
.
nr_tower
>
1
:
grad_list
=
self
.
_multi_tower_grads
()
if
not
self
.
async
:
grads
=
QueueInputTrainer
.
_average_grads
(
grad_list
)
grads
=
self
.
process_grads
(
grads
)
else
:
grad_list
=
[
self
.
process_grads
(
g
)
for
g
in
grad_list
]
# pretend to average the grads, in order to make async and
# sync have consistent semantics
def
scale
(
grads
):
return
[(
grad
/
self
.
config
.
nr_tower
,
var
)
for
grad
,
var
in
grads
]
grad_list
=
map
(
scale
,
grad_list
)
grads
=
grad_list
[
0
]
# use grad from the first tower for routinely stuff
else
:
grads
=
self
.
_single_tower_grad
()
grads
=
self
.
process_grads
(
grads
)
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
avg_maintain_op
)
summary_moving_average
())
if
self
.
async
:
self
.
threads
=
[]
for
k
in
range
(
1
,
self
.
config
.
nr_tower
):
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
f
=
lambda
:
self
.
sess
.
run
([
train_op
])
th
=
LoopThread
(
f
)
th
.
pause
()
th
.
start
()
self
.
threads
.
append
(
th
)
self
.
async_running
=
False
self
.
init_session_and_coord
()
# create a thread that keeps filling the queue
...
...
@@ -183,14 +209,25 @@ class QueueInputTrainer(Trainer):
self
.
input_th
.
start
()
def
run_step
(
self
):
if
self
.
async
:
if
not
self
.
async_running
:
self
.
async_running
=
True
for
th
in
self
.
threads
:
# resume all threads
th
.
resume
()
self
.
sess
.
run
([
self
.
train_op
])
# faster since train_op return None
def
_trigger_epoch
(
self
):
# note that summary_op will take a data from the queue
if
self
.
async
:
self
.
async_running
=
False
for
th
in
self
.
threads
:
th
.
pause
()
if
self
.
summary_op
is
not
None
:
summary_str
=
self
.
summary_op
.
eval
()
self
.
_process_summary
(
summary_str
)
def
start_train
(
config
):
tr
=
QueueInputTrainer
(
config
)
tr
.
train
()
tensorpack/utils/concurrency.py
View file @
3f743301
...
...
@@ -9,7 +9,7 @@ import atexit
import
bisect
import
weakref
__all__
=
[
'StoppableThread'
,
'ensure_proc_terminate'
,
__all__
=
[
'StoppableThread'
,
'
LoopThread'
,
'
ensure_proc_terminate'
,
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
]
class
StoppableThread
(
threading
.
Thread
):
...
...
@@ -23,6 +23,29 @@ class StoppableThread(threading.Thread):
def
stopped
(
self
):
return
self
.
_stop
.
isSet
()
class
LoopThread
(
threading
.
Thread
):
""" A pausable thread that simply runs a loop"""
def
__init__
(
self
,
func
):
"""
:param func: the function to run
"""
super
(
LoopThread
,
self
)
.
__init__
()
self
.
func
=
func
self
.
lock
=
threading
.
Lock
()
self
.
daemon
=
True
def
run
(
self
):
while
True
:
self
.
lock
.
acquire
()
self
.
lock
.
release
()
self
.
func
()
def
pause
(
self
):
self
.
lock
.
acquire
()
def
resume
(
self
):
self
.
lock
.
release
()
class
DIE
(
object
):
""" A placeholder class indicating end of queue """
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment