Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
9952c6c6
Commit
9952c6c6
authored
Feb 18, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add trainer with shared mainloop
parent
de37446f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
177 additions
and
105 deletions
+177
-105
tensorpack/train.py
tensorpack/train.py
+177
-105
No files found.
tensorpack/train.py
View file @
9952c6c6
...
@@ -9,6 +9,7 @@ import copy
...
@@ -9,6 +9,7 @@ import copy
import
argparse
import
argparse
import
re
import
re
import
tqdm
import
tqdm
from
abc
import
ABCMeta
from
.models
import
ModelDesc
from
.models
import
ModelDesc
from
.dataflow.common
import
RepeatedData
from
.dataflow.common
import
RepeatedData
...
@@ -63,14 +64,6 @@ class TrainConfig(object):
...
@@ -63,14 +64,6 @@ class TrainConfig(object):
self
.
nr_tower
=
int
(
kwargs
.
pop
(
'nr_tower'
,
1
))
self
.
nr_tower
=
int
(
kwargs
.
pop
(
'nr_tower'
,
1
))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
average_grads
(
tower_grads
):
ret
=
[]
for
grad_and_vars
in
zip
(
*
tower_grads
):
grad
=
tf
.
add_n
([
x
[
0
]
for
x
in
grad_and_vars
])
/
float
(
len
(
tower_grads
))
v
=
grad_and_vars
[
0
][
1
]
ret
.
append
((
grad
,
v
))
return
ret
def
summary_grads
(
grads
):
def
summary_grads
(
grads
):
for
grad
,
var
in
grads
:
for
grad
,
var
in
grads
:
if
grad
:
if
grad
:
...
@@ -95,17 +88,117 @@ def scale_grads(grads, multiplier):
...
@@ -95,17 +88,117 @@ def scale_grads(grads, multiplier):
ret
.
append
((
grad
,
var
))
ret
.
append
((
grad
,
var
))
return
ret
return
ret
def
start_train
(
config
):
class
Trainer
(
object
):
__metaclass__
=
ABCMeta
def
__init__
(
self
,
config
):
"""
"""
Start training with a config
Config: a `TrainConfig` instance
Args:
config: a TrainConfig instance
"""
"""
model
=
config
.
model
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
self
.
config
=
config
tf
.
add_to_collection
(
MODEL_KEY
,
config
.
model
)
@
abstractmethod
def
train
(
self
):
pass
@
abstractmethod
def
run_step
(
self
):
pass
def
main_loop
(
self
):
callbacks
=
self
.
config
.
callbacks
with
self
.
sess
.
as_default
():
try
:
logger
.
info
(
"Start training with global_step={}"
.
format
(
get_global_step
()))
callbacks
.
before_train
()
tf
.
get_default_graph
()
.
finalize
()
for
epoch
in
xrange
(
1
,
self
.
config
.
max_epoch
):
with
timed_operation
(
'Epoch {}, global_step={}'
.
format
(
epoch
,
get_global_step
()
+
self
.
config
.
step_per_epoch
)):
for
step
in
tqdm
.
trange
(
self
.
config
.
step_per_epoch
,
leave
=
True
,
mininterval
=
0.5
,
dynamic_ncols
=
True
,
ascii
=
True
):
if
self
.
coord
.
should_stop
():
return
self
.
run_step
()
callbacks
.
trigger_step
()
# note that summary_op will take a data from the queue
callbacks
.
trigger_epoch
()
except
(
KeyboardInterrupt
,
Exception
):
raise
finally
:
self
.
coord
.
request_stop
()
# Do I need to run queue.close?
callbacks
.
after_train
()
self
.
sess
.
close
()
def
init_session_and_coord
(
self
):
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
config
.
session_init
.
init
(
self
.
sess
)
# start training:
self
.
coord
=
tf
.
train
.
Coordinator
()
tf
.
train
.
start_queue_runners
(
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
class
SimpleTrainer
(
Trainer
):
def
run_step
(
self
):
try
:
data
=
next
(
self
.
data_producer
)
except
StopIteration
:
self
.
data_producer
=
self
.
config
.
dataset
.
get_data
()
data
=
next
(
self
.
data_producer
)
feed
=
dict
(
zip
(
self
.
input_vars
,
data
))
self
.
sess
.
run
([
self
.
train_op
],
feed_dict
=
feed
)
# faster since train_op return None
def
train
(
self
):
model
=
self
.
config
.
model
input_vars
=
model
.
get_input_vars
()
self
.
input_vars
=
input_vars
cost_var
=
model
.
get_cost
(
input_vars
,
is_training
=
True
)
avg_maintain_op
=
summary_moving_average
(
cost_var
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
check_grads
(
grads
)
grads
=
scale_grads
(
grads
,
model
.
get_lr_multiplier
())
summary_grads
(
grads
)
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
avg_maintain_op
)
describe_model
()
self
.
init_session_and_coord
()
self
.
data_producer
=
self
.
config
.
dataset
.
get_data
()
self
.
main_loop
()
class
QueueInputTrainer
(
Trainer
):
"""
Trainer which builds a queue for input.
Support multi GPU.
"""
@
staticmethod
def
_average_grads
(
tower_grads
):
ret
=
[]
for
grad_and_vars
in
zip
(
*
tower_grads
):
grad
=
tf
.
add_n
([
x
[
0
]
for
x
in
grad_and_vars
])
/
float
(
len
(
tower_grads
))
v
=
grad_and_vars
[
0
][
1
]
ret
.
append
((
grad
,
v
))
return
ret
def
train
(
self
):
model
=
self
.
config
.
model
input_vars
=
model
.
get_input_vars
()
input_vars
=
model
.
get_input_vars
()
input_queue
=
model
.
get_input_queue
()
input_queue
=
model
.
get_input_queue
()
callbacks
=
config
.
callbacks
tf
.
add_to_collection
(
MODEL_KEY
,
model
)
enqueue_op
=
input_queue
.
enqueue
(
input_vars
)
enqueue_op
=
input_queue
.
enqueue
(
input_vars
)
def
get_model_inputs
():
def
get_model_inputs
():
...
@@ -117,19 +210,19 @@ def start_train(config):
...
@@ -117,19 +210,19 @@ def start_train(config):
return
model_inputs
return
model_inputs
# get gradients to update:
# get gradients to update:
if
config
.
nr_tower
>
1
:
if
self
.
config
.
nr_tower
>
1
:
logger
.
info
(
"Training a model of {} tower"
.
format
(
config
.
nr_tower
))
logger
.
info
(
"Training a model of {} tower"
.
format
(
self
.
config
.
nr_tower
))
# to avoid repeated summary from each device
# to avoid repeated summary from each device
coll_keys
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_VARS_KEY
]
coll_keys
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_VARS_KEY
]
kept_summaries
=
{}
kept_summaries
=
{}
grad_list
=
[]
grad_list
=
[]
for
i
in
range
(
config
.
nr_tower
):
for
i
in
range
(
self
.
config
.
nr_tower
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
i
)),
\
with
tf
.
device
(
'/gpu:{}'
.
format
(
i
)),
\
tf
.
name_scope
(
'tower{}'
.
format
(
i
))
as
scope
:
tf
.
name_scope
(
'tower{}'
.
format
(
i
))
as
scope
:
model_inputs
=
get_model_inputs
()
model_inputs
=
get_model_inputs
()
cost_var
=
model
.
get_cost
(
model_inputs
,
is_training
=
True
)
cost_var
=
model
.
get_cost
(
model_inputs
,
is_training
=
True
)
grad_list
.
append
(
grad_list
.
append
(
config
.
optimizer
.
compute_gradients
(
cost_var
))
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
))
if
i
==
0
:
if
i
==
0
:
tf
.
get_variable_scope
()
.
reuse_variables
()
tf
.
get_variable_scope
()
.
reuse_variables
()
...
@@ -138,60 +231,39 @@ def start_train(config):
...
@@ -138,60 +231,39 @@ def start_train(config):
for
k
in
coll_keys
:
for
k
in
coll_keys
:
del
tf
.
get_collection
(
k
)[:]
del
tf
.
get_collection
(
k
)[:]
tf
.
get_collection
(
k
)
.
extend
(
kept_summaries
[
k
])
tf
.
get_collection
(
k
)
.
extend
(
kept_summaries
[
k
])
grads
=
average_grads
(
grad_list
)
grads
=
QueueInputTrainer
.
_
average_grads
(
grad_list
)
else
:
else
:
model_inputs
=
get_model_inputs
()
model_inputs
=
get_model_inputs
()
cost_var
=
model
.
get_cost
(
model_inputs
,
is_training
=
True
)
cost_var
=
model
.
get_cost
(
model_inputs
,
is_training
=
True
)
grads
=
config
.
optimizer
.
compute_gradients
(
cost_var
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
avg_maintain_op
=
summary_moving_average
(
cost_var
)
# TODO(multigpu) average the cost from each device?
avg_maintain_op
=
summary_moving_average
(
cost_var
)
# TODO(multigpu) average the cost from each device?
check_grads
(
grads
)
check_grads
(
grads
)
grads
=
scale_grads
(
grads
,
model
.
get_lr_multiplier
())
grads
=
scale_grads
(
grads
,
model
.
get_lr_multiplier
())
summary_grads
(
grads
)
summary_grads
(
grads
)
train_op
=
tf
.
group
(
self
.
train_op
=
tf
.
group
(
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
avg_maintain_op
)
avg_maintain_op
)
describe_model
()
describe_model
()
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
self
.
init_session_and_coord
()
config
.
session_init
.
init
(
sess
)
# start training:
coord
=
tf
.
train
.
Coordinator
()
tf
.
train
.
start_queue_runners
(
sess
=
sess
,
coord
=
coord
,
daemon
=
True
,
start
=
True
)
# create a thread that keeps filling the queue
# create a thread that keeps filling the queue
input_th
=
EnqueueThread
(
sess
,
coord
,
enqueue_op
,
config
.
dataset
,
input_queue
)
input_th
=
EnqueueThread
(
self
.
sess
,
self
.
coord
,
enqueue_op
,
self
.
config
.
dataset
,
input_queue
)
input_th
.
start
()
input_th
.
start
()
self
.
main_loop
()
with
sess
.
as_default
():
def
run_step
(
self
):
try
:
self
.
sess
.
run
([
self
.
train_op
])
# faster since train_op return None
logger
.
info
(
"Start training with global_step={}"
.
format
(
get_global_step
()))
callbacks
.
before_train
()
tf
.
get_default_graph
()
.
finalize
()
for
epoch
in
xrange
(
1
,
config
.
max_epoch
):
with
timed_operation
(
'Epoch {}, global_step={}'
.
format
(
epoch
,
get_global_step
()
+
config
.
step_per_epoch
)):
for
step
in
tqdm
.
trange
(
config
.
step_per_epoch
,
leave
=
True
,
mininterval
=
0.5
,
dynamic_ncols
=
True
,
ascii
=
True
):
if
coord
.
should_stop
():
return
sess
.
run
([
train_op
])
# faster since train_op return None
callbacks
.
trigger_step
()
# note that summary_op will take a data from the queue
callbacks
.
trigger_epoch
()
except
(
KeyboardInterrupt
,
Exception
):
raise
finally
:
coord
.
request_stop
()
# Do I need to run queue.close?
callbacks
.
after_train
()
sess
.
close
()
def
start_train
(
config
):
#if config.model.get_input_queue() is not None:
## XXX get_input_queue is called twice
#tr = QueueInputTrainer()
#else:
#tr = SimpleTrainer()
#tr = SimpleTrainer(config)
tr
=
QueueInputTrainer
(
config
)
tr
.
train
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment