Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
3f743301
You need to sign in or sign up before continuing.
Commit
3f743301
authored
Apr 24, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
async training.
parent
08821b55
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
17 deletions
+78
-17
tensorpack/train/base.py
tensorpack/train/base.py
+2
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+52
-15
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+24
-1
No files found.
tensorpack/train/base.py
View file @
3f743301
...
@@ -48,6 +48,7 @@ class Trainer(object):
...
@@ -48,6 +48,7 @@ class Trainer(object):
@
abstractmethod
@
abstractmethod
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
""" This is called right after all steps in an epoch are finished"""
pass
pass
def
_init_summary
(
self
):
def
_init_summary
(
self
):
...
@@ -94,7 +95,7 @@ class Trainer(object):
...
@@ -94,7 +95,7 @@ class Trainer(object):
if
self
.
coord
.
should_stop
():
if
self
.
coord
.
should_stop
():
return
return
self
.
run_step
()
self
.
run_step
()
callbacks
.
trigger_step
()
#callbacks.trigger_step() # not useful?
self
.
global_step
+=
1
self
.
global_step
+=
1
self
.
trigger_epoch
()
self
.
trigger_epoch
()
except
(
KeyboardInterrupt
,
Exception
):
except
(
KeyboardInterrupt
,
Exception
):
...
...
tensorpack/train/trainer.py
View file @
3f743301
...
@@ -11,6 +11,7 @@ from six.moves import zip
...
@@ -11,6 +11,7 @@ from six.moves import zip
from
.base
import
Trainer
from
.base
import
Trainer
from
..dataflow.common
import
RepeatedData
from
..dataflow.common
import
RepeatedData
from
..utils
import
*
from
..utils
import
*
from
..utils.concurrency
import
LoopThread
from
..tfutils.summary
import
summary_moving_average
from
..tfutils.summary
import
summary_moving_average
from
..tfutils
import
*
from
..tfutils
import
*
...
@@ -79,13 +80,14 @@ class EnqueueThread(threading.Thread):
...
@@ -79,13 +80,14 @@ class EnqueueThread(threading.Thread):
finally
:
finally
:
logger
.
info
(
"Enqueue Thread Exited."
)
logger
.
info
(
"Enqueue Thread Exited."
)
class
QueueInputTrainer
(
Trainer
):
class
QueueInputTrainer
(
Trainer
):
"""
"""
Trainer which builds a FIFO queue for input.
Trainer which builds a FIFO queue for input.
Support multi GPU.
Support multi GPU.
"""
"""
def
__init__
(
self
,
config
,
input_queue
=
None
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
async
=
False
):
"""
"""
:param config: a `TrainConfig` instance
:param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
...
@@ -98,6 +100,9 @@ class QueueInputTrainer(Trainer):
...
@@ -98,6 +100,9 @@ class QueueInputTrainer(Trainer):
100
,
[
x
.
dtype
for
x
in
self
.
input_vars
],
name
=
'input_queue'
)
100
,
[
x
.
dtype
for
x
in
self
.
input_vars
],
name
=
'input_queue'
)
else
:
else
:
self
.
input_queue
=
input_queue
self
.
input_queue
=
input_queue
self
.
async
=
async
if
self
.
async
:
assert
self
.
config
.
nr_tower
>
1
@
staticmethod
@
staticmethod
def
_average_grads
(
tower_grads
):
def
_average_grads
(
tower_grads
):
...
@@ -122,14 +127,15 @@ class QueueInputTrainer(Trainer):
...
@@ -122,14 +127,15 @@ class QueueInputTrainer(Trainer):
qv
.
set_shape
(
v
.
get_shape
())
qv
.
set_shape
(
v
.
get_shape
())
return
ret
return
ret
def
_single_tower_grad
_cost
(
self
):
def
_single_tower_grad
(
self
):
""" Get grad and cost for single-tower case"""
""" Get grad and cost for single-tower case"""
model_inputs
=
self
.
_get_model_inputs
()
model_inputs
=
self
.
_get_model_inputs
()
cost_var
=
self
.
model
.
get_cost
(
model_inputs
,
is_training
=
True
)
cost_var
=
self
.
model
.
get_cost
(
model_inputs
,
is_training
=
True
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
return
(
grads
,
cost_var
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
return
grads
def
_multi_tower_grad
_cost
(
self
):
def
_multi_tower_grad
s
(
self
):
logger
.
info
(
"Training a model of {} tower"
.
format
(
self
.
config
.
nr_tower
))
logger
.
info
(
"Training a model of {} tower"
.
format
(
self
.
config
.
nr_tower
))
# to avoid repeated summary from each device
# to avoid repeated summary from each device
...
@@ -140,6 +146,7 @@ class QueueInputTrainer(Trainer):
...
@@ -140,6 +146,7 @@ class QueueInputTrainer(Trainer):
for
i
in
range
(
self
.
config
.
nr_tower
):
for
i
in
range
(
self
.
config
.
nr_tower
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
i
)),
\
with
tf
.
device
(
'/gpu:{}'
.
format
(
i
)),
\
tf
.
name_scope
(
'tower{}'
.
format
(
i
))
as
scope
:
tf
.
name_scope
(
'tower{}'
.
format
(
i
))
as
scope
:
logger
.
info
(
"Building graph for tower {}..."
.
format
(
i
))
model_inputs
=
self
.
_get_model_inputs
()
# each tower dequeue from input queue
model_inputs
=
self
.
_get_model_inputs
()
# each tower dequeue from input queue
cost_var
=
self
.
model
.
get_cost
(
model_inputs
,
is_training
=
True
)
# build tower
cost_var
=
self
.
model
.
get_cost
(
model_inputs
,
is_training
=
True
)
# build tower
...
@@ -148,30 +155,49 @@ class QueueInputTrainer(Trainer):
...
@@ -148,30 +155,49 @@ class QueueInputTrainer(Trainer):
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
,
gate_gradients
=
0
))
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
,
gate_gradients
=
0
))
if
i
==
0
:
if
i
==
0
:
cost_var_t0
=
cost_var
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
tf
.
get_variable_scope
()
.
reuse_variables
()
tf
.
get_variable_scope
()
.
reuse_variables
()
for
k
in
collect_dedup
:
for
k
in
collect_dedup
:
kept_summaries
[
k
]
=
copy
.
copy
(
tf
.
get_collection
(
k
))
kept_summaries
[
k
]
=
copy
.
copy
(
tf
.
get_collection
(
k
))
logger
.
info
(
"Graph built for tower {}."
.
format
(
i
))
for
k
in
collect_dedup
:
for
k
in
collect_dedup
:
del
tf
.
get_collection_ref
(
k
)[:]
del
tf
.
get_collection_ref
(
k
)[:]
tf
.
get_collection_ref
(
k
)
.
extend
(
kept_summaries
[
k
])
tf
.
get_collection_ref
(
k
)
.
extend
(
kept_summaries
[
k
])
grads
=
QueueInputTrainer
.
_average_grads
(
grad_list
)
return
grad_list
return
(
grads
,
cost_var_t0
)
def
train
(
self
):
def
train
(
self
):
enqueue_op
=
self
.
input_queue
.
enqueue
(
self
.
input_vars
)
enqueue_op
=
self
.
input_queue
.
enqueue
(
self
.
input_vars
)
grads
,
cost_var
=
self
.
_single_tower_grad_cost
()
\
if
self
.
config
.
nr_tower
>
1
:
if
self
.
config
.
nr_tower
==
0
else
self
.
_multi_tower_grad_cost
()
grad_list
=
self
.
_multi_tower_grads
()
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
if
not
self
.
async
:
avg_maintain_op
=
summary_moving_average
()
grads
=
QueueInputTrainer
.
_average_grads
(
grad_list
)
grads
=
self
.
process_grads
(
grads
)
else
:
grad_list
=
[
self
.
process_grads
(
g
)
for
g
in
grad_list
]
# pretend to average the grads, in order to make async and
# sync have consistent semantics
def
scale
(
grads
):
return
[(
grad
/
self
.
config
.
nr_tower
,
var
)
for
grad
,
var
in
grads
]
grad_list
=
map
(
scale
,
grad_list
)
grads
=
grad_list
[
0
]
# use grad from the first tower for routinely stuff
else
:
grads
=
self
.
_single_tower_grad
()
grads
=
self
.
process_grads
(
grads
)
grads
=
self
.
process_grads
(
grads
)
self
.
train_op
=
tf
.
group
(
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
avg_maintain_op
)
summary_moving_average
())
if
self
.
async
:
self
.
threads
=
[]
for
k
in
range
(
1
,
self
.
config
.
nr_tower
):
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
f
=
lambda
:
self
.
sess
.
run
([
train_op
])
th
=
LoopThread
(
f
)
th
.
pause
()
th
.
start
()
self
.
threads
.
append
(
th
)
self
.
async_running
=
False
self
.
init_session_and_coord
()
self
.
init_session_and_coord
()
# create a thread that keeps filling the queue
# create a thread that keeps filling the queue
...
@@ -183,14 +209,25 @@ class QueueInputTrainer(Trainer):
...
@@ -183,14 +209,25 @@ class QueueInputTrainer(Trainer):
self
.
input_th
.
start
()
self
.
input_th
.
start
()
def
run_step
(
self
):
def
run_step
(
self
):
if
self
.
async
:
if
not
self
.
async_running
:
self
.
async_running
=
True
for
th
in
self
.
threads
:
# resume all threads
th
.
resume
()
self
.
sess
.
run
([
self
.
train_op
])
# faster since train_op return None
self
.
sess
.
run
([
self
.
train_op
])
# faster since train_op return None
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
# note that summary_op will take a data from the queue
# note that summary_op will take a data from the queue
if
self
.
async
:
self
.
async_running
=
False
for
th
in
self
.
threads
:
th
.
pause
()
if
self
.
summary_op
is
not
None
:
if
self
.
summary_op
is
not
None
:
summary_str
=
self
.
summary_op
.
eval
()
summary_str
=
self
.
summary_op
.
eval
()
self
.
_process_summary
(
summary_str
)
self
.
_process_summary
(
summary_str
)
def
start_train
(
config
):
def
start_train
(
config
):
tr
=
QueueInputTrainer
(
config
)
tr
=
QueueInputTrainer
(
config
)
tr
.
train
()
tr
.
train
()
tensorpack/utils/concurrency.py
View file @
3f743301
...
@@ -9,7 +9,7 @@ import atexit
...
@@ -9,7 +9,7 @@ import atexit
import
bisect
import
bisect
import
weakref
import
weakref
__all__
=
[
'StoppableThread'
,
'ensure_proc_terminate'
,
__all__
=
[
'StoppableThread'
,
'
LoopThread'
,
'
ensure_proc_terminate'
,
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
]
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
]
class
StoppableThread
(
threading
.
Thread
):
class
StoppableThread
(
threading
.
Thread
):
...
@@ -23,6 +23,29 @@ class StoppableThread(threading.Thread):
...
@@ -23,6 +23,29 @@ class StoppableThread(threading.Thread):
def
stopped
(
self
):
def
stopped
(
self
):
return
self
.
_stop
.
isSet
()
return
self
.
_stop
.
isSet
()
class
LoopThread
(
threading
.
Thread
):
""" A pausable thread that simply runs a loop"""
def
__init__
(
self
,
func
):
"""
:param func: the function to run
"""
super
(
LoopThread
,
self
)
.
__init__
()
self
.
func
=
func
self
.
lock
=
threading
.
Lock
()
self
.
daemon
=
True
def
run
(
self
):
while
True
:
self
.
lock
.
acquire
()
self
.
lock
.
release
()
self
.
func
()
def
pause
(
self
):
self
.
lock
.
acquire
()
def
resume
(
self
):
self
.
lock
.
release
()
class
DIE
(
object
):
class
DIE
(
object
):
""" A placeholder class indicating end of queue """
""" A placeholder class indicating end of queue """
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment