Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8db5bcd3
Commit
8db5bcd3
authored
Dec 02, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
separate out input method
parent
fdab3db2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
188 additions
and
122 deletions
+188
-122
examples/GAN/GAN.py
examples/GAN/GAN.py
+6
-5
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+1
-1
tensorpack/models/fc.py
tensorpack/models/fc.py
+2
-2
tensorpack/train/inputmethod.py
tensorpack/train/inputmethod.py
+134
-0
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+12
-11
tensorpack/train/queue.py
tensorpack/train/queue.py
+7
-93
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+26
-10
No files found.
examples/GAN/GAN.py
View file @
8db5bcd3
...
@@ -5,15 +5,15 @@
...
@@ -5,15 +5,15 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
import
numpy
as
np
import
numpy
as
np
from
tensorpack
import
(
QueueInputTrainer
Base
,
TowerContext
,
from
tensorpack
import
(
QueueInputTrainer
,
TowerContext
,
get_global_step_var
)
get_global_step_var
,
QueueInput
)
from
tensorpack.tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
tensorpack.tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
tensorpack.dataflow
import
DataFlow
from
tensorpack.dataflow
import
DataFlow
class
GANTrainer
(
QueueInputTrainer
Base
):
class
GANTrainer
(
QueueInputTrainer
):
def
__init__
(
self
,
config
,
g_vs_d
=
1
):
def
__init__
(
self
,
config
,
g_vs_d
=
1
):
super
(
GANTrainer
,
self
)
.
__init__
(
config
)
super
(
GANTrainer
,
self
)
.
__init__
(
config
)
self
.
_
build_enque_thread
(
)
self
.
_
input_method
=
QueueInput
(
config
.
dataset
)
if
g_vs_d
>
1
:
if
g_vs_d
>
1
:
self
.
_opt_g
=
g_vs_d
self
.
_opt_g
=
g_vs_d
self
.
_opt_d
=
1
self
.
_opt_d
=
1
...
@@ -22,8 +22,9 @@ class GANTrainer(QueueInputTrainerBase):
...
@@ -22,8 +22,9 @@ class GANTrainer(QueueInputTrainerBase):
self
.
_opt_d
=
int
(
1.0
/
g_vs_d
)
self
.
_opt_d
=
int
(
1.0
/
g_vs_d
)
def
_setup
(
self
):
def
_setup
(
self
):
super
(
GANTrainer
,
self
)
.
_setup
()
with
TowerContext
(
''
):
with
TowerContext
(
''
):
actual_inputs
=
self
.
_get_input_tensors
_noreuse
()
actual_inputs
=
self
.
_get_input_tensors
()
self
.
model
.
build_graph
(
actual_inputs
)
self
.
model
.
build_graph
(
actual_inputs
)
self
.
g_min
=
self
.
config
.
optimizer
.
minimize
(
self
.
model
.
g_loss
,
self
.
g_min
=
self
.
config
.
optimizer
.
minimize
(
self
.
model
.
g_loss
,
var_list
=
self
.
model
.
g_vars
,
name
=
'g_op'
)
var_list
=
self
.
model
.
g_vars
,
name
=
'g_op'
)
...
...
tensorpack/models/conv2d.py
View file @
8db5bcd3
...
@@ -44,7 +44,7 @@ def Conv2D(x, out_channel, kernel_shape,
...
@@ -44,7 +44,7 @@ def Conv2D(x, out_channel, kernel_shape,
stride
=
shape4d
(
stride
)
stride
=
shape4d
(
stride
)
if
W_init
is
None
:
if
W_init
is
None
:
W_init
=
tf
.
contrib
.
layers
.
xavier_initializer_conv2d
()
W_init
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
()
if
b_init
is
None
:
if
b_init
is
None
:
b_init
=
tf
.
constant_initializer
()
b_init
=
tf
.
constant_initializer
()
...
...
tensorpack/models/fc.py
View file @
8db5bcd3
...
@@ -30,8 +30,8 @@ def FullyConnected(x, out_dim,
...
@@ -30,8 +30,8 @@ def FullyConnected(x, out_dim,
in_dim
=
x
.
get_shape
()
.
as_list
()[
1
]
in_dim
=
x
.
get_shape
()
.
as_list
()[
1
]
if
W_init
is
None
:
if
W_init
is
None
:
#W_init = tf.
truncated_normal_initializer(stddev=1 / math.sqrt(float(in_dim))
)
#W_init = tf.
uniform_unit_scaling_initializer(factor=1.43
)
W_init
=
tf
.
uniform_unit_scaling_initializer
(
factor
=
1.43
)
W_init
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
(
)
if
b_init
is
None
:
if
b_init
is
None
:
b_init
=
tf
.
constant_initializer
()
b_init
=
tf
.
constant_initializer
()
...
...
tensorpack/train/inputmethod.py
0 → 100644
View file @
8db5bcd3
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: inputmethod.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
threading
from
abc
import
ABCMeta
,
abstractmethod
from
..tfutils.summary
import
add_moving_summary
from
..utils
import
logger
from
..callbacks.concurrency
import
StartProcOrThread
__all__
=
[
'QueueInput'
,
'FeedfreeInput'
,
'TensorInput'
]
class
InputMethod
(
object
):
__metaclass__
=
ABCMeta
pass
class
FeedInput
(
InputMethod
):
def
__init__
(
self
,
ds
):
self
.
ds
=
ds
def
size
(
self
):
return
self
.
ds
.
size
()
def
_setup
(
self
,
trainer
):
self
.
input_vars
=
trainer
.
model
.
get_input_vars
()
class
FeedfreeInput
(
InputMethod
):
def
get_input_tensors
(
self
):
return
self
.
_get_input_tensors
()
@
abstractmethod
def
_get_input_tensors
(
self
):
"""
always create and return a list of new input tensors
"""
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
trainer
,
queue
,
ds
,
input_placehdrs
):
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
name
=
'EnqueueThread'
self
.
daemon
=
True
self
.
dataflow
=
ds
self
.
queue
=
queue
self
.
sess
=
trainer
.
sess
self
.
coord
=
trainer
.
coord
self
.
placehdrs
=
input_placehdrs
self
.
op
=
self
.
queue
.
enqueue
(
self
.
placehdrs
)
self
.
close_op
=
self
.
queue
.
close
(
cancel_pending_enqueues
=
True
)
self
.
size_op
=
self
.
queue
.
size
()
add_moving_summary
(
tf
.
cast
(
self
.
size_op
,
tf
.
float32
,
name
=
'input_queue_size'
))
def
run
(
self
):
self
.
dataflow
.
reset_state
()
with
self
.
sess
.
as_default
():
try
:
while
True
:
for
dp
in
self
.
dataflow
.
get_data
():
if
self
.
coord
.
should_stop
():
return
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
))
#print 'TFQ:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self
.
op
.
run
(
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
as
e
:
pass
except
Exception
:
logger
.
exception
(
"Exception in EnqueueThread:"
)
finally
:
self
.
coord
.
request_stop
()
try
:
self
.
sess
.
run
(
self
.
close_op
)
except
RuntimeError
:
# session already closed
pass
logger
.
info
(
"Enqueue Thread Exited."
)
class
QueueInput
(
FeedfreeInput
):
def
__init__
(
self
,
ds
,
queue
=
None
):
self
.
queue
=
queue
self
.
ds
=
ds
def
size
(
self
):
return
self
.
ds
.
size
()
def
_setup
(
self
,
trainer
):
self
.
input_placehdrs
=
trainer
.
model
.
get_input_vars
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"QueueInput can only be used with input placeholders!"
if
self
.
queue
is
None
:
self
.
queue
=
tf
.
FIFOQueue
(
50
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
],
name
=
'input_queue'
)
self
.
thread
=
EnqueueThread
(
trainer
,
self
.
queue
,
self
.
ds
,
self
.
input_placehdrs
)
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
def
_get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
qv
.
set_shape
(
v
.
get_shape
())
# test the overhead of queue
#with tf.device('/gpu:0'):
#ret = [tf.Variable(tf.random_normal([128,224,224,3],
#dtype=tf.float32), trainable=False),
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
return
ret
class
TensorInput
(
FeedfreeInput
):
def
__init__
(
self
,
get_tensor_fn
,
size
=
None
):
self
.
get_tensor_fn
=
get_tensor_fn
self
.
_size
=
size
def
size
(
self
):
if
self
.
_size
is
None
:
raise
ValueError
(
"size of TensorInput is None!"
)
return
self
.
_size
def
_setup
(
self
,
trainer
):
pass
def
_get_input_tensors
(
self
):
return
self
.
get_tensor_fn
()
class
SplitTensorInput
(
FeedfreeInput
):
pass
tensorpack/train/multigpu.py
View file @
8db5bcd3
...
@@ -15,12 +15,13 @@ from ..tfutils import (backup_collection, restore_collection,
...
@@ -15,12 +15,13 @@ from ..tfutils import (backup_collection, restore_collection,
get_global_step_var
,
TowerContext
)
get_global_step_var
,
TowerContext
)
from
..tfutils.gradproc
import
apply_grad_processors
,
ScaleGradient
from
..tfutils.gradproc
import
apply_grad_processors
,
ScaleGradient
from
.trainer
import
FeedlessTrainer
,
SingleCostFeedlessTrainer
,
MultiPredictorTowerTrainer
from
.trainer
import
FeedfreeTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
from
.queue
import
QueueInputTrainer
,
QueueInputTrainerBase
from
.queue
import
QueueInputTrainer
from
.inputmethod
import
QueueInput
__all__
=
[
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
__all__
=
[
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
class
MultiGPUTrainer
(
Feed
less
Trainer
):
class
MultiGPUTrainer
(
Feed
free
Trainer
):
""" Base class for multi-gpu training"""
""" Base class for multi-gpu training"""
@
staticmethod
@
staticmethod
def
_multi_tower_grads
(
towers
,
get_tower_grad_func
):
def
_multi_tower_grads
(
towers
,
get_tower_grad_func
):
...
@@ -42,15 +43,14 @@ class MultiGPUTrainer(FeedlessTrainer):
...
@@ -42,15 +43,14 @@ class MultiGPUTrainer(FeedlessTrainer):
restore_collection
(
backup
)
restore_collection
(
backup
)
return
grad_list
return
grad_list
class
SyncMultiGPUTrainer
(
QueueInputTrainerBase
,
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
,
MultiGPUTrainer
,
SingleCostFeedfreeTrainer
,
SingleCostFeedlessTrainer
,
MultiPredictorTowerTrainer
):
MultiPredictorTowerTrainer
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_
build_enque_thread
(
input_queue
)
self
.
_
input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
@
staticmethod
@
staticmethod
def
_average_grads
(
tower_grads
):
def
_average_grads
(
tower_grads
):
...
@@ -75,6 +75,7 @@ class SyncMultiGPUTrainer(QueueInputTrainerBase,
...
@@ -75,6 +75,7 @@ class SyncMultiGPUTrainer(QueueInputTrainerBase,
return
ret
return
ret
def
_setup
(
self
):
def
_setup
(
self
):
super
(
SyncMultiGPUTrainer
,
self
)
.
_setup
()
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
grads
=
SyncMultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
SyncMultiGPUTrainer
.
_average_grads
(
grad_list
)
...
@@ -87,9 +88,8 @@ class SyncMultiGPUTrainer(QueueInputTrainerBase,
...
@@ -87,9 +88,8 @@ class SyncMultiGPUTrainer(QueueInputTrainerBase,
def
run_step
(
self
):
def
run_step
(
self
):
self
.
sess
.
run
(
self
.
train_op
)
self
.
sess
.
run
(
self
.
train_op
)
class
AsyncMultiGPUTrainer
(
QueueInputTrainerBase
,
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
,
MultiGPUTrainer
,
SingleCostFeedfreeTrainer
,
SingleCostFeedlessTrainer
,
MultiPredictorTowerTrainer
):
MultiPredictorTowerTrainer
):
def
__init__
(
self
,
config
,
def
__init__
(
self
,
config
,
input_queue
=
None
,
input_queue
=
None
,
...
@@ -97,10 +97,11 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase,
...
@@ -97,10 +97,11 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase,
average_gradient
=
True
):
average_gradient
=
True
):
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_
build_enque_thread
(
input_queue
)
self
.
_
input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
self
.
average_gradient
=
average_gradient
self
.
average_gradient
=
average_gradient
def
_setup
(
self
):
def
_setup
(
self
):
super
(
SyncMultiGPUTrainer
,
self
)
.
_setup
()
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
gradprocs
=
self
.
model
.
get_gradient_processor
()
gradprocs
=
self
.
model
.
get_gradient_processor
()
...
...
tensorpack/train/queue.py
View file @
8db5bcd3
...
@@ -12,86 +12,14 @@ from ..tfutils import get_global_step_var, TowerContext
...
@@ -12,86 +12,14 @@ from ..tfutils import get_global_step_var, TowerContext
from
..utils
import
logger
from
..utils
import
logger
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.concurrency
import
StartProcOrThread
from
..tfutils.gradproc
import
apply_grad_processors
from
..tfutils.gradproc
import
apply_grad_processors
from
.inputmethod
import
QueueInput
from
.trainer
import
(
Feed
less
Trainer
,
MultiPredictorTowerTrainer
,
from
.trainer
import
(
Feed
free
Trainer
,
MultiPredictorTowerTrainer
,
SingleCostFeed
less
Trainer
)
SingleCostFeed
free
Trainer
)
__all__
=
[
'QueueInputTrainer
Base'
,
'QueueInputTrainer
'
]
__all__
=
[
'QueueInputTrainer'
]
class
EnqueueThread
(
threading
.
Thread
):
class
QueueInputTrainer
(
MultiPredictorTowerTrainer
,
SingleCostFeedfreeTrainer
):
def
__init__
(
self
,
trainer
):
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
name
=
'EnqueueThread'
self
.
daemon
=
True
self
.
sess
=
trainer
.
sess
self
.
coord
=
trainer
.
coord
self
.
dataflow
=
RepeatedData
(
trainer
.
config
.
dataset
,
-
1
)
self
.
input_vars
=
trainer
.
input_vars
self
.
queue
=
trainer
.
input_queue
self
.
op
=
self
.
queue
.
enqueue
(
self
.
input_vars
)
self
.
close_op
=
self
.
queue
.
close
(
cancel_pending_enqueues
=
True
)
self
.
size_op
=
self
.
queue
.
size
()
add_moving_summary
(
tf
.
cast
(
self
.
size_op
,
tf
.
float32
,
name
=
'input_queue_size'
))
def
run
(
self
):
self
.
dataflow
.
reset_state
()
with
self
.
sess
.
as_default
():
try
:
while
True
:
for
dp
in
self
.
dataflow
.
get_data
():
if
self
.
coord
.
should_stop
():
return
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
#print 'TFQ:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self
.
op
.
run
(
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
as
e
:
pass
except
Exception
:
logger
.
exception
(
"Exception in EnqueueThread:"
)
finally
:
self
.
coord
.
request_stop
()
try
:
self
.
sess
.
run
(
self
.
close_op
)
except
RuntimeError
:
# session already closed
pass
logger
.
info
(
"Enqueue Thread Exited."
)
class
QueueInputTrainerBase
(
FeedlessTrainer
):
def
_build_enque_thread
(
self
,
input_queue
=
None
):
""" create a thread that keeps filling the queue """
self
.
input_vars
=
self
.
model
.
get_input_vars
()
assert
len
(
self
.
input_vars
)
>
0
,
"QueueInput can only be used with input placeholders!"
if
input_queue
is
None
:
self
.
input_queue
=
tf
.
FIFOQueue
(
50
,
[
x
.
dtype
for
x
in
self
.
input_vars
],
name
=
'input_queue'
)
else
:
self
.
input_queue
=
input_queue
input_th
=
EnqueueThread
(
self
)
self
.
config
.
callbacks
.
append
(
StartProcOrThread
(
input_th
))
def
_get_input_tensors_noreuse
(
self
):
""" Dequeue a datapoint from input_queue and return.
Can be called multiple times.
"""
ret
=
self
.
input_queue
.
dequeue
(
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_vars
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_vars
):
qv
.
set_shape
(
v
.
get_shape
())
# test the overhead of queue
#with tf.device('/gpu:0'):
#ret = [tf.Variable(tf.random_normal([128,224,224,3],
#dtype=tf.float32), trainable=False),
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
return
ret
class
QueueInputTrainer
(
MultiPredictorTowerTrainer
,
QueueInputTrainerBase
,
SingleCostFeedlessTrainer
):
""" Single GPU Trainer, takes input from a queue"""
""" Single GPU Trainer, takes input from a queue"""
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
...
@@ -104,9 +32,10 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase, Singl
...
@@ -104,9 +32,10 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase, Singl
"""
"""
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_
build_enque_thread
(
input_queue
)
self
.
_
input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
def
_setup
(
self
):
def
_setup
(
self
):
super
(
QueueInputTrainer
,
self
)
.
_setup
()
assert
len
(
self
.
config
.
tower
)
==
1
,
\
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
with
TowerContext
(
''
):
with
TowerContext
(
''
):
...
@@ -119,18 +48,3 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase, Singl
...
@@ -119,18 +48,3 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, QueueInputTrainerBase, Singl
# skip training
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
#self.train_op = tf.group(*self.dequed_inputs)
def
run_step
(
self
):
""" Simply run self.train_op"""
self
.
sess
.
run
(
self
.
train_op
)
# debug-benchmark code:
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
tensorpack/train/trainer.py
View file @
8db5bcd3
...
@@ -16,9 +16,10 @@ from ..tfutils import (get_tensors_by_names, freeze_collection,
...
@@ -16,9 +16,10 @@ from ..tfutils import (get_tensors_by_names, freeze_collection,
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
from
..tfutils.gradproc
import
apply_grad_processors
from
..tfutils.gradproc
import
apply_grad_processors
from
.inputmethod
import
FeedfreeInput
__all__
=
[
'SimpleTrainer'
,
'Feed
less
Trainer'
,
'MultiPredictorTowerTrainer'
,
__all__
=
[
'SimpleTrainer'
,
'Feed
free
Trainer'
,
'MultiPredictorTowerTrainer'
,
'SingleCostFeed
less
Trainer'
]
'SingleCostFeed
free
Trainer'
]
class
PredictorFactory
(
object
):
class
PredictorFactory
(
object
):
""" Make predictors for a trainer"""
""" Make predictors for a trainer"""
...
@@ -112,7 +113,7 @@ class MultiPredictorTowerTrainer(Trainer):
...
@@ -112,7 +113,7 @@ class MultiPredictorTowerTrainer(Trainer):
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
return
[
self
.
get_predict_func
(
input_names
,
output_names
,
k
)
for
k
in
range
(
n
)]
return
[
self
.
get_predict_func
(
input_names
,
output_names
,
k
)
for
k
in
range
(
n
)]
class
Feed
less
Trainer
(
Trainer
):
class
Feed
free
Trainer
(
Trainer
):
""" A trainer which runs iteration without feed_dict (therefore faster) """
""" A trainer which runs iteration without feed_dict (therefore faster) """
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
# need to run summary_op every epoch
# need to run summary_op every epoch
...
@@ -121,16 +122,17 @@ class FeedlessTrainer(Trainer):
...
@@ -121,16 +122,17 @@ class FeedlessTrainer(Trainer):
summary_str
=
self
.
summary_op
.
eval
()
summary_str
=
self
.
summary_op
.
eval
()
self
.
_process_summary
(
summary_str
)
self
.
_process_summary
(
summary_str
)
def
_get_input_tensors_noreuse
(
self
):
def
_get_input_tensors
(
self
):
""" return a list of actual input tensors.
return
self
.
_input_method
.
get_input_tensors
()
Always return new tensors (for multi tower) if called mutliple times.
"""
def
_setup
(
self
):
pass
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
)
self
.
_input_method
.
_setup
(
self
)
class
SingleCostFeed
lessTrainer
(
Feedless
Trainer
):
class
SingleCostFeed
freeTrainer
(
Feedfree
Trainer
):
def
_get_cost_and_grad
(
self
):
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient on a new tower"""
""" get the cost and gradient on a new tower"""
actual_inputs
=
self
.
_get_input_tensors
_noreuse
()
actual_inputs
=
self
.
_get_input_tensors
()
self
.
model
.
build_graph
(
actual_inputs
)
self
.
model
.
build_graph
(
actual_inputs
)
cost_var
=
self
.
model
.
get_cost
()
cost_var
=
self
.
model
.
get_cost
()
# GATE_NONE faster?
# GATE_NONE faster?
...
@@ -139,3 +141,17 @@ class SingleCostFeedlessTrainer(FeedlessTrainer):
...
@@ -139,3 +141,17 @@ class SingleCostFeedlessTrainer(FeedlessTrainer):
add_moving_summary
(
cost_var
)
add_moving_summary
(
cost_var
)
return
cost_var
,
grads
return
cost_var
,
grads
def
run_step
(
self
):
""" Simply run self.train_op"""
self
.
sess
.
run
(
self
.
train_op
)
# debug-benchmark code:
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment