Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c279dbfe
Commit
c279dbfe
authored
Nov 05, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use _setup and small other refactors
parent
82b418fd
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
77 additions
and
70 deletions
+77
-70
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+20
-2
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+2
-3
tensorpack/train/base.py
tensorpack/train/base.py
+8
-17
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+24
-24
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+23
-24
No files found.
tensorpack/tfutils/gradproc.py
View file @
c279dbfe
...
@@ -12,7 +12,23 @@ from .symbolic_functions import rms
...
@@ -12,7 +12,23 @@ from .symbolic_functions import rms
from
.summary
import
add_moving_summary
from
.summary
import
add_moving_summary
__all__
=
[
'GradientProcessor'
,
'SummaryGradient'
,
'CheckGradient'
,
__all__
=
[
'GradientProcessor'
,
'SummaryGradient'
,
'CheckGradient'
,
'ScaleGradient'
,
'MapGradient'
]
'ScaleGradient'
,
'MapGradient'
,
'apply_grad_processors'
]
def
apply_grad_processors
(
grads
,
gradprocs
):
"""
:param grads: list of (grad, var).
:param gradprocs: list of `GradientProcessor` instances.
:returns: list of (grad, var) went through the processors
"""
g
=
[]
for
grad
,
var
in
grads
:
if
grad
is
None
:
logger
.
warn
(
"No Gradient w.r.t {}"
.
format
(
var
.
op
.
name
))
else
:
g
.
append
((
grad
,
var
))
for
proc
in
gradprocs
:
g
=
proc
.
process
(
g
)
return
g
class
GradientProcessor
(
object
):
class
GradientProcessor
(
object
):
__metaclass__
=
ABCMeta
__metaclass__
=
ABCMeta
...
@@ -98,12 +114,14 @@ class CheckGradient(MapGradient):
...
@@ -98,12 +114,14 @@ class CheckGradient(MapGradient):
class
ScaleGradient
(
MapGradient
):
class
ScaleGradient
(
MapGradient
):
"""
"""
Scale gradient by a multiplier
Scale
certain
gradient by a multiplier
"""
"""
def
__init__
(
self
,
multipliers
):
def
__init__
(
self
,
multipliers
):
"""
"""
:param multipliers: list of (regex, float)
:param multipliers: list of (regex, float)
"""
"""
if
not
isinstance
(
multipliers
,
list
):
multipliers
=
[
multipliers
]
self
.
multipliers
=
multipliers
self
.
multipliers
=
multipliers
super
(
ScaleGradient
,
self
)
.
__init__
(
self
.
_mapper
)
super
(
ScaleGradient
,
self
)
.
__init__
(
self
.
_mapper
)
...
...
tensorpack/tfutils/summary.py
View file @
c279dbfe
...
@@ -98,11 +98,11 @@ def add_moving_summary(v, *args):
...
@@ -98,11 +98,11 @@ def add_moving_summary(v, *args):
v
=
[
v
]
v
=
[
v
]
v
.
extend
(
args
)
v
.
extend
(
args
)
for
x
in
v
:
for
x
in
v
:
assert
x
.
get_shape
()
.
ndims
==
0
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
x
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
x
)
def
summary_moving_average
():
def
summary_moving_average
():
""" Create a MovingAverage op and summary for all variables in
""" Create a MovingAverage op and summary for all variables in MOVING_SUMMARY_VARS_KEY.
MOVING_SUMMARY_VARS_KEY.
:returns: a op to maintain these average.
:returns: a op to maintain these average.
"""
"""
with
tf
.
name_scope
(
'EMA_summary'
):
with
tf
.
name_scope
(
'EMA_summary'
):
...
@@ -113,7 +113,6 @@ def summary_moving_average():
...
@@ -113,7 +113,6 @@ def summary_moving_average():
vars_to_summary
=
tf
.
get_collection
(
MOVING_SUMMARY_VARS_KEY
)
vars_to_summary
=
tf
.
get_collection
(
MOVING_SUMMARY_VARS_KEY
)
avg_maintain_op
=
averager
.
apply
(
vars_to_summary
)
avg_maintain_op
=
averager
.
apply
(
vars_to_summary
)
for
idx
,
c
in
enumerate
(
vars_to_summary
):
for
idx
,
c
in
enumerate
(
vars_to_summary
):
# TODO assert scalar
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
scalar_summary
(
name
,
averager
.
average
(
c
))
tf
.
scalar_summary
(
name
,
averager
.
average
(
c
))
return
avg_maintain_op
return
avg_maintain_op
...
...
tensorpack/train/base.py
View file @
c279dbfe
...
@@ -45,10 +45,10 @@ class Trainer(object):
...
@@ -45,10 +45,10 @@ class Trainer(object):
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
coord
=
tf
.
train
.
Coordinator
()
self
.
coord
=
tf
.
train
.
Coordinator
()
@
abstractmethod
def
train
(
self
):
def
train
(
self
):
""" Start training"""
""" Start training"""
pass
self
.
setup
()
self
.
main_loop
()
@
abstractmethod
@
abstractmethod
def
run_step
(
self
):
def
run_step
(
self
):
...
@@ -92,7 +92,8 @@ class Trainer(object):
...
@@ -92,7 +92,8 @@ class Trainer(object):
create_summary
(
name
,
val
),
get_global_step
())
create_summary
(
name
,
val
),
get_global_step
())
self
.
stat_holder
.
add_stat
(
name
,
val
)
self
.
stat_holder
.
add_stat
(
name
,
val
)
def
finalize
(
self
):
def
setup
(
self
):
self
.
_setup
()
# some final operations that might modify the graph
# some final operations that might modify the graph
logger
.
info
(
"Setup callbacks ..."
)
logger
.
info
(
"Setup callbacks ..."
)
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
...
@@ -112,8 +113,11 @@ class Trainer(object):
...
@@ -112,8 +113,11 @@ class Trainer(object):
tf
.
train
.
start_queue_runners
(
tf
.
train
.
start_queue_runners
(
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
@
abstractmethod
def
_setup
(
self
):
""" setup Trainer-specific stuff for training"""
def
main_loop
(
self
):
def
main_loop
(
self
):
self
.
finalize
()
callbacks
=
self
.
config
.
callbacks
callbacks
=
self
.
config
.
callbacks
with
self
.
sess
.
as_default
():
with
self
.
sess
.
as_default
():
try
:
try
:
...
@@ -139,16 +143,3 @@ class Trainer(object):
...
@@ -139,16 +143,3 @@ class Trainer(object):
self
.
coord
.
request_stop
()
self
.
coord
.
request_stop
()
self
.
summary_writer
.
close
()
self
.
summary_writer
.
close
()
self
.
sess
.
close
()
self
.
sess
.
close
()
def
process_grads
(
self
,
grads
):
g
=
[]
for
grad
,
var
in
grads
:
if
grad
is
None
:
logger
.
warn
(
"No Gradient w.r.t {}"
.
format
(
var
.
op
.
name
))
else
:
g
.
append
((
grad
,
var
))
procs
=
self
.
config
.
model
.
get_gradient_processor
()
for
proc
in
procs
:
g
=
proc
.
process
(
g
)
return
g
tensorpack/train/multigpu.py
View file @
c279dbfe
...
@@ -14,6 +14,7 @@ from ..tfutils.summary import summary_moving_average
...
@@ -14,6 +14,7 @@ from ..tfutils.summary import summary_moving_average
from
..tfutils.modelutils
import
describe_model
from
..tfutils.modelutils
import
describe_model
from
..tfutils
import
(
backup_collection
,
restore_collection
,
from
..tfutils
import
(
backup_collection
,
restore_collection
,
get_global_step_var
,
TowerContext
)
get_global_step_var
,
TowerContext
)
from
..tfutils.gradproc
import
apply_grad_processors
from
.trainer
import
QueueInputTrainer
from
.trainer
import
QueueInputTrainer
...
@@ -32,11 +33,16 @@ class MultiGPUTrainer(QueueInputTrainer):
...
@@ -32,11 +33,16 @@ class MultiGPUTrainer(QueueInputTrainer):
with
tf
.
name_scope
(
'AvgGrad'
):
with
tf
.
name_scope
(
'AvgGrad'
):
for
grad_and_vars
in
zip
(
*
tower_grads
):
for
grad_and_vars
in
zip
(
*
tower_grads
):
v
=
grad_and_vars
[
0
][
1
]
v
=
grad_and_vars
[
0
][
1
]
for
x
in
grad_and_vars
:
all_grad
=
[
k
[
0
]
for
k
in
grad_and_vars
]
assert
x
[
0
]
is
not
None
,
\
"Gradient w.r.t {} is None!"
.
format
(
v
.
name
)
nones
=
list
(
set
(
all_grad
))
if
None
in
nones
and
len
(
nones
)
!=
1
:
raise
RuntimeError
(
"Gradient w.r.t {} is None in some but not all towers!"
.
format
(
v
.
name
))
elif
nones
[
0
]
is
None
:
logger
.
warn
(
"No Gradient w.r.t {}"
.
format
(
var
.
op
.
name
))
continue
try
:
try
:
grad
=
tf
.
add_n
(
[
x
[
0
]
for
x
in
grad_and_vars
]
)
/
float
(
len
(
tower_grads
))
grad
=
tf
.
add_n
(
all_grad
)
/
float
(
len
(
tower_grads
))
except
:
except
:
logger
.
error
(
"Error while processing gradients of {}"
.
format
(
v
.
name
))
logger
.
error
(
"Error while processing gradients of {}"
.
format
(
v
.
name
))
raise
raise
...
@@ -44,8 +50,7 @@ class MultiGPUTrainer(QueueInputTrainer):
...
@@ -44,8 +50,7 @@ class MultiGPUTrainer(QueueInputTrainer):
return
ret
return
ret
def
_multi_tower_grads
(
self
):
def
_multi_tower_grads
(
self
):
logger
.
info
(
"Training a model of {} tower"
.
format
(
logger
.
info
(
"Training a model of {} tower"
.
format
(
len
(
self
.
config
.
tower
)))
len
(
self
.
config
.
tower
)))
grad_list
=
[]
grad_list
=
[]
global_scope
=
tf
.
get_variable_scope
()
global_scope
=
tf
.
get_variable_scope
()
...
@@ -60,59 +65,54 @@ class MultiGPUTrainer(QueueInputTrainer):
...
@@ -60,59 +65,54 @@ class MultiGPUTrainer(QueueInputTrainer):
self
.
model
.
build_graph
(
model_inputs
)
self
.
model
.
build_graph
(
model_inputs
)
cost_var
=
self
.
model
.
get_cost
()
# build tower
cost_var
=
self
.
model
.
get_cost
()
# build tower
# TODO gate_gradienst=0
seems to
be faster?
# TODO gate_gradienst=0
might
be faster?
grad_list
.
append
(
grad_list
.
append
(
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
,
gate_gradients
=
0
))
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
,
gate_gradients
=
0
))
if
idx
==
0
:
if
idx
==
0
:
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
add_moving_summary
(
cost_var
)
# avoid repeated summary from each device
# avoid repeated summary from each device
backup
=
backup_collection
(
SUMMARY_BACKUP_KEYS
)
backup
=
backup_collection
(
SUMMARY_BACKUP_KEYS
)
restore_collection
(
backup
)
restore_collection
(
backup
)
return
grad_list
return
grad_list
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
):
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
):
def
train
(
self
):
def
_setup
(
self
):
self
.
_build_enque_thread
()
self
.
_build_enque_thread
()
grad_list
=
self
.
_multi_tower_grads
()
grad_list
=
self
.
_multi_tower_grads
()
grads
=
MultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
MultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
self
.
process_grads
(
grads
)
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
self
.
train_op
=
tf
.
group
(
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
(),
name
=
'train_op'
)
summary_moving_average
(),
name
=
'train_op'
)
describe_model
()
describe_model
()
# [debug]: do nothing in training
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self
.
main_loop
()
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
):
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
):
def
train
(
self
):
def
_setup
(
self
):
self
.
_build_enque_thread
()
self
.
_build_enque_thread
()
grad_list
=
self
.
_multi_tower_grads
()
grad_list
=
self
.
_multi_tower_grads
()
gradprocs
=
self
.
model
.
get_gradient_processor
()
# pretend to average the grads, in order to make async and
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
# sync have consistent effective learning rate
def
scale
(
grads
):
gradprocs
.
insert
(
0
,
ScaleGradient
((
'.*'
,
1.0
/
self
.
config
.
nr_tower
)))
with
tf
.
name_scope
(
'AsyncScaleGrad'
):
grad_list
=
[
apply_grad_processors
(
g
,
gradprocs
)
for
g
in
grad_list
]
return
[(
grad
/
len
(
self
.
config
.
tower
)
if
grad
is
not
None
else
None
,
var
)
for
grad
,
var
in
grads
]
grad_list
=
map
(
scale
,
grad_list
)
grad_list
=
[
self
.
process_grads
(
g
)
for
g
in
grad_list
]
# use grad from the first tower for iteration in main thread
# use grad from the first tower for iteration in main thread
self
.
train_op
=
tf
.
group
(
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
0
],
get_global_step_var
()),
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
0
],
get_global_step_var
()),
summary_moving_average
(),
name
=
'train_op'
)
summary_moving_average
(),
name
=
'train_op'
)
describe_model
()
describe_model
()
self
.
_start_async_threads
(
grad_list
)
self
.
_start_async_threads
(
grad_list
)
self
.
main_loop
()
def
_start_async_threads
(
self
,
grad_list
):
def
_start_async_threads
(
self
,
grad_list
):
# prepare train_op for the rest of the towers
# prepare train_op for the rest of the towers
# itertools.count is atomic w.r.t. python threads
# itertools.count is atomic w.r.t. python threads
...
@@ -145,7 +145,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
...
@@ -145,7 +145,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
async_step_total_cnt
=
int
(
re
.
findall
(
async_step_total_cnt
=
int
(
re
.
findall
(
'[0-9]+'
,
self
.
async_step_counter
.
__str__
())[
0
])
'[0-9]+'
,
self
.
async_step_counter
.
__str__
())[
0
])
self
.
write_scalar_summary
(
self
.
write_scalar_summary
(
'async
_global_
step'
,
async_step_total_cnt
)
'async
-global-
step'
,
async_step_total_cnt
)
except
:
except
:
pass
logger
.
exception
(
"Cannot log async-global-step"
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
_trigger_epoch
()
super
(
AsyncMultiGPUTrainer
,
self
)
.
_trigger_epoch
()
tensorpack/train/trainer.py
View file @
c279dbfe
...
@@ -18,6 +18,7 @@ from ..tfutils.summary import summary_moving_average, add_moving_summary
...
@@ -18,6 +18,7 @@ from ..tfutils.summary import summary_moving_average, add_moving_summary
from
..tfutils.modelutils
import
describe_model
from
..tfutils.modelutils
import
describe_model
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.concurrency
import
StartProcOrThread
from
..tfutils.gradproc
import
apply_grad_processors
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
...
@@ -51,37 +52,39 @@ class PredictorFactory(object):
...
@@ -51,37 +52,39 @@ class PredictorFactory(object):
# build_predict_tower might get called anywhere, but 'towerp' should be the outermost name scope
# build_predict_tower might get called anywhere, but 'towerp' should be the outermost name scope
with
tf
.
name_scope
(
None
),
\
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
build_multi_tower_prediction_graph
(
build_multi_tower_prediction_graph
(
self
.
model
,
self
.
towers
)
self
.
model
,
self
.
towers
)
self
.
tower_built
=
True
self
.
tower_built
=
True
class
SimpleTrainer
(
Trainer
):
class
SimpleTrainer
(
Trainer
):
def
__init__
(
self
,
config
):
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
[
0
])
def
run_step
(
self
):
def
run_step
(
self
):
data
=
next
(
self
.
data_producer
)
data
=
next
(
self
.
data_producer
)
feed
=
dict
(
zip
(
self
.
input_vars
,
data
))
feed
=
dict
(
zip
(
self
.
input_vars
,
data
))
self
.
sess
.
run
([
self
.
train_op
],
feed_dict
=
feed
)
# faster since train_op return None
self
.
sess
.
run
([
self
.
train_op
],
feed_dict
=
feed
)
# faster since train_op return None
def
train
(
self
):
def
_setup
(
self
):
model
=
self
.
model
model
=
self
.
model
self
.
input_vars
=
model
.
get_input_vars
()
self
.
input_vars
=
model
.
get_input_vars
()
with
TowerContext
(
''
):
with
TowerContext
(
''
):
model
.
build_graph
(
self
.
input_vars
)
model
.
build_graph
(
self
.
input_vars
)
cost_var
=
model
.
get_cost
()
# TODO assert scalar
cost_var
=
model
.
get_cost
()
add_moving_summary
(
cost_var
)
add_moving_summary
(
cost_var
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
grads
=
self
.
process_grads
(
grads
)
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
avg_maintain_op
=
summary_moving_average
()
self
.
train_op
=
tf
.
group
(
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
avg_maintain_op
)
summary_moving_average
()
)
describe_model
()
describe_model
()
# create an infinte data producer
# create an infinte data producer
self
.
config
.
dataset
.
reset_state
()
self
.
config
.
dataset
.
reset_state
()
self
.
data_producer
=
RepeatedData
(
self
.
config
.
dataset
,
-
1
)
.
get_data
()
self
.
data_producer
=
RepeatedData
(
self
.
config
.
dataset
,
-
1
)
.
get_data
()
self
.
main_loop
()
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
if
self
.
summary_op
is
not
None
:
if
self
.
summary_op
is
not
None
:
...
@@ -91,14 +94,14 @@ class SimpleTrainer(Trainer):
...
@@ -91,14 +94,14 @@ class SimpleTrainer(Trainer):
self
.
_process_summary
(
summary_str
)
self
.
_process_summary
(
summary_str
)
def
get_predict_func
(
self
,
input_names
,
output_names
):
def
get_predict_func
(
self
,
input_names
,
output_names
):
if
not
hasattr
(
self
,
'predictor_factory'
):
return
self
.
_predictor_factory
.
get_predictor
(
input_names
,
output_names
,
0
)
self
.
predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
[
0
])
return
self
.
predictor_factory
.
get_predictor
(
input_names
,
output_names
,
0
)
class
EnqueueThread
(
threading
.
Thread
):
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
trainer
):
def
__init__
(
self
,
trainer
):
super
(
EnqueueThread
,
self
)
.
__init__
()
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
name
=
'EnqueueThread'
self
.
name
=
'EnqueueThread'
self
.
daemon
=
True
self
.
sess
=
trainer
.
sess
self
.
sess
=
trainer
.
sess
self
.
coord
=
trainer
.
coord
self
.
coord
=
trainer
.
coord
self
.
dataflow
=
RepeatedData
(
trainer
.
config
.
dataset
,
-
1
)
self
.
dataflow
=
RepeatedData
(
trainer
.
config
.
dataset
,
-
1
)
...
@@ -109,7 +112,8 @@ class EnqueueThread(threading.Thread):
...
@@ -109,7 +112,8 @@ class EnqueueThread(threading.Thread):
self
.
close_op
=
self
.
queue
.
close
(
cancel_pending_enqueues
=
True
)
self
.
close_op
=
self
.
queue
.
close
(
cancel_pending_enqueues
=
True
)
self
.
size_op
=
self
.
queue
.
size
()
self
.
size_op
=
self
.
queue
.
size
()
self
.
daemon
=
True
add_moving_summary
(
tf
.
cast
(
self
.
size_op
,
tf
.
float32
,
name
=
'input_queue_size'
))
def
run
(
self
):
def
run
(
self
):
self
.
dataflow
.
reset_state
()
self
.
dataflow
.
reset_state
()
...
@@ -155,7 +159,9 @@ class QueueInputTrainer(Trainer):
...
@@ -155,7 +159,9 @@ class QueueInputTrainer(Trainer):
self
.
input_queue
=
input_queue
self
.
input_queue
=
input_queue
# by default, use the first training gpu for prediction
# by default, use the first training gpu for prediction
self
.
predict_tower
=
predict_tower
or
[
0
]
predict_tower
=
predict_tower
or
[
0
]
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
predict_tower
)
self
.
dequed_inputs
=
None
self
.
dequed_inputs
=
None
def
_get_dequeued_inputs
(
self
):
def
_get_dequeued_inputs
(
self
):
...
@@ -171,8 +177,6 @@ class QueueInputTrainer(Trainer):
...
@@ -171,8 +177,6 @@ class QueueInputTrainer(Trainer):
def
_single_tower_grad
(
self
):
def
_single_tower_grad
(
self
):
""" Get grad and cost for single-tower"""
""" Get grad and cost for single-tower"""
self
.
dequed_inputs
=
model_inputs
=
self
.
_get_dequeued_inputs
()
self
.
dequed_inputs
=
model_inputs
=
self
.
_get_dequeued_inputs
()
add_moving_summary
(
tf
.
cast
(
self
.
input_queue
.
size
(),
tf
.
float32
,
name
=
'input-queue-size'
))
# test the overhead of queue
# test the overhead of queue
#with tf.device('/gpu:0'):
#with tf.device('/gpu:0'):
...
@@ -192,24 +196,22 @@ class QueueInputTrainer(Trainer):
...
@@ -192,24 +196,22 @@ class QueueInputTrainer(Trainer):
self
.
input_th
=
EnqueueThread
(
self
)
self
.
input_th
=
EnqueueThread
(
self
)
self
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
input_th
))
self
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
input_th
))
def
train
(
self
):
def
_setup
(
self
):
assert
len
(
self
.
config
.
tower
)
==
1
,
\
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
self
.
_build_enque_thread
()
self
.
_build_enque_thread
()
grads
=
self
.
_single_tower_grad
()
grads
=
self
.
_single_tower_grad
()
grads
=
self
.
process_grads
(
grads
)
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
describe_model
()
describe_model
()
self
.
train_op
=
tf
.
group
(
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
(),
name
=
'train_op'
)
summary_moving_average
(),
name
=
'train_op'
)
# skip training
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
#self.train_op = tf.group(*self.dequed_inputs)
self
.
main_loop
()
def
run_step
(
self
):
def
run_step
(
self
):
""" Simply run self.train_op"""
""" Simply run self.train_op"""
self
.
sess
.
run
(
self
.
train_op
)
self
.
sess
.
run
(
self
.
train_op
)
...
@@ -236,10 +238,7 @@ class QueueInputTrainer(Trainer):
...
@@ -236,10 +238,7 @@ class QueueInputTrainer(Trainer):
:param tower: return the kth predict_func
:param tower: return the kth predict_func
:returns: an `OnlinePredictor`
:returns: an `OnlinePredictor`
"""
"""
if
not
hasattr
(
self
,
'predictor_factory'
):
return
self
.
_predictor_factory
.
get_predictor
(
input_names
,
output_names
,
tower
)
self
.
predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
self
.
predict_tower
)
return
self
.
predictor_factory
.
get_predictor
(
input_names
,
output_names
,
tower
)
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
return
[
self
.
get_predict_func
(
input_names
,
output_names
,
k
)
for
k
in
range
(
n
)]
return
[
self
.
get_predict_func
(
input_names
,
output_names
,
k
)
for
k
in
range
(
n
)]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment