Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
35833026
Commit
35833026
authored
Feb 20, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
make MultiTowerPredictorTrainer a default + some doc change. (#156)
parent
2b076314
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
119 additions
and
120 deletions
+119
-120
docs/tutorial/efficient-dataflow.md
docs/tutorial/efficient-dataflow.md
+4
-2
examples/PennTreebank/PTB-LSTM.py
examples/PennTreebank/PTB-LSTM.py
+10
-8
examples/PennTreebank/README.md
examples/PennTreebank/README.md
+3
-1
examples/ResNet/README.md
examples/ResNet/README.md
+2
-0
setup.cfg
setup.cfg
+0
-1
setup.py
setup.py
+1
-0
tensorpack/train/base.py
tensorpack/train/base.py
+22
-13
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+7
-7
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+2
-7
tensorpack/train/predict.py
tensorpack/train/predict.py
+65
-0
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+3
-81
No files found.
docs/tutorial/efficient-dataflow.md
View file @
35833026
...
@@ -19,6 +19,9 @@ memory (for caching) and CPU (for data processing).
...
@@ -19,6 +19,9 @@ memory (for caching) and CPU (for data processing).
You'll definitely need to tune the parameters (#processes, #threads, size of buffer, etc.)
You'll definitely need to tune the parameters (#processes, #threads, size of buffer, etc.)
or change the pipeline for new tasks and new machines to achieve best performance.
or change the pipeline for new tasks and new machines to achieve best performance.
This tutorial is quite complicated, because you do need these knowledge of hardware & system to run fast on ImageNet-sized dataset.
However, for small datasets (e.g., several GBs), a proper prefetch should work well enough.
## Random Read
## Random Read
We start from a simple DataFlow:
We start from a simple DataFlow:
...
@@ -36,7 +39,6 @@ will concatenate the data into an `numpy.ndarray`, but since images are original
...
@@ -36,7 +39,6 @@ will concatenate the data into an `numpy.ndarray`, but since images are original
On an SSD you probably can already observe good speed here (e.g. 5 it/s, that is 1280 samples/s), but on HDD the speed may be just 1 it/s,
On an SSD you probably can already observe good speed here (e.g. 5 it/s, that is 1280 samples/s), but on HDD the speed may be just 1 it/s,
because we're doing heavy random read on the filesystem (regardless of whether
`shuffle`
is True).
because we're doing heavy random read on the filesystem (regardless of whether
`shuffle`
is True).
Note that for smaller datasets, random read + prefetching is usually enough.
We'll now add the cheapest pre-processing now to get an ndarray in the end instead of a list
We'll now add the cheapest pre-processing now to get an ndarray in the end instead of a list
(because TensorFlow will need ndarray eventually):
(because TensorFlow will need ndarray eventually):
...
@@ -187,7 +189,7 @@ Let me summarize what the above DataFlow does:
...
@@ -187,7 +189,7 @@ Let me summarize what the above DataFlow does:
how the
`Trainer`
is implemented.
how the
`Trainer`
is implemented.
The above DataFlow can run at a speed of 5~10 batches per second, if you have good CPUs, RAM, disks and augmentors.
The above DataFlow can run at a speed of 5~10 batches per second, if you have good CPUs, RAM, disks and augmentors.
As a reference, tensorpack can train ResNet-18 (a shallow ResNet) at 4.
4
batches (of 256 samples) per second on 4 old TitanX.
As a reference, tensorpack can train ResNet-18 (a shallow ResNet) at 4.
5
batches (of 256 samples) per second on 4 old TitanX.
So DataFlow won't be a serious bottleneck if configured properly.
So DataFlow won't be a serious bottleneck if configured properly.
## Larger Datasets?
## Larger Datasets?
...
...
examples/PennTreebank/PTB-LSTM.py
View file @
35833026
...
@@ -74,23 +74,25 @@ class Model(ModelDesc):
...
@@ -74,23 +74,25 @@ class Model(ModelDesc):
input_list
=
tf
.
unstack
(
input_feature
,
num
=
SEQ_LEN
,
axis
=
1
)
# seqlen x (Bxhidden)
input_list
=
tf
.
unstack
(
input_feature
,
num
=
SEQ_LEN
,
axis
=
1
)
# seqlen x (Bxhidden)
outputs
,
last_state
=
rnn
.
static_rnn
(
cell
,
input_list
,
state_var
,
scope
=
'rnn'
)
outputs
,
last_state
=
rnn
.
static_rnn
(
cell
,
input_list
,
state_var
,
scope
=
'rnn'
)
# update the hidden state after a rnn loop completes
update_state_ops
=
[
tf
.
assign
(
state_var
[
0
]
.
c
,
last_state
[
0
]
.
c
),
tf
.
assign
(
state_var
[
0
]
.
h
,
last_state
[
0
]
.
h
),
tf
.
assign
(
state_var
[
1
]
.
c
,
last_state
[
1
]
.
c
),
tf
.
assign
(
state_var
[
1
]
.
h
,
last_state
[
1
]
.
h
)]
# seqlen x (Bxrnnsize)
# seqlen x (Bxrnnsize)
output
=
tf
.
reshape
(
tf
.
concat
(
outputs
,
1
),
[
-
1
,
HIDDEN_SIZE
])
# (Bxseqlen) x hidden
output
=
tf
.
reshape
(
tf
.
concat
(
outputs
,
1
),
[
-
1
,
HIDDEN_SIZE
])
# (Bxseqlen) x hidden
logits
=
FullyConnected
(
'fc'
,
output
,
VOCAB_SIZE
,
nl
=
tf
.
identity
,
W_init
=
initializer
,
b_init
=
initializer
)
logits
=
FullyConnected
(
'fc'
,
output
,
VOCAB_SIZE
,
nl
=
tf
.
identity
,
W_init
=
initializer
,
b_init
=
initializer
)
xent_loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
xent_loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
tf
.
reshape
(
nextinput
,
[
-
1
]))
logits
=
logits
,
labels
=
tf
.
reshape
(
nextinput
,
[
-
1
]))
update_state_op
=
tf
.
group
(
with
tf
.
control_dependencies
(
update_state_ops
):
tf
.
assign
(
state_var
[
0
]
.
c
,
last_state
[
0
]
.
c
),
tf
.
assign
(
state_var
[
0
]
.
h
,
last_state
[
0
]
.
h
),
tf
.
assign
(
state_var
[
1
]
.
c
,
last_state
[
1
]
.
c
),
tf
.
assign
(
state_var
[
1
]
.
h
,
last_state
[
1
]
.
h
),
name
=
'update_state'
)
with
tf
.
control_dependencies
([
update_state_op
]):
self
.
cost
=
tf
.
truediv
(
tf
.
reduce_sum
(
xent_loss
),
self
.
cost
=
tf
.
truediv
(
tf
.
reduce_sum
(
xent_loss
),
tf
.
cast
(
BATCH
,
tf
.
float32
),
name
=
'cost'
)
# log-perplexity
tf
.
cast
(
BATCH
,
tf
.
float32
),
name
=
'cost'
)
# log-perplexity
perpl
=
tf
.
exp
(
self
.
cost
/
SEQ_LEN
,
name
=
'perplexity'
)
perpl
=
tf
.
exp
(
self
.
cost
/
SEQ_LEN
,
name
=
'perplexity'
)
summary
.
add_moving_summary
(
perpl
)
summary
.
add_moving_summary
(
perpl
,
self
.
cost
)
def
reset_lstm_state
(
self
):
def
reset_lstm_state
(
self
):
s
=
self
.
state
s
=
self
.
state
...
@@ -98,7 +100,7 @@ class Model(ModelDesc):
...
@@ -98,7 +100,7 @@ class Model(ModelDesc):
return
tf
.
group
(
s
[
0
]
.
c
.
assign
(
z
),
return
tf
.
group
(
s
[
0
]
.
c
.
assign
(
z
),
s
[
0
]
.
h
.
assign
(
z
),
s
[
0
]
.
h
.
assign
(
z
),
s
[
1
]
.
c
.
assign
(
z
),
s
[
1
]
.
c
.
assign
(
z
),
s
[
1
]
.
h
.
assign
(
z
))
s
[
1
]
.
h
.
assign
(
z
)
,
name
=
'reset_lstm_state'
)
def
_get_optimizer
(
self
):
def
_get_optimizer
(
self
):
lr
=
symbolic_functions
.
get_scalar_var
(
'learning_rate'
,
1
,
summary
=
True
)
lr
=
symbolic_functions
.
get_scalar_var
(
'learning_rate'
,
1
,
summary
=
True
)
...
...
examples/PennTreebank/README.md
View file @
35833026
...
@@ -4,10 +4,12 @@
...
@@ -4,10 +4,12 @@
This example is mainly to demonstrate:
This example is mainly to demonstrate:
1.
How to train an RNN with persistent state between iterations.
1.
How to train an RNN with persistent state between iterations.
Here it simply manages the state inside the graph.
`state_saving_rnn`
can be used for more complicated use case.
2.
How to use a TF reader pipeline instead of a DataFlow, for both training & inference.
2.
How to use a TF reader pipeline instead of a DataFlow, for both training & inference.
It trains an language model on PTB dataset, basically an equivalent of the PTB example
It trains an language model on PTB dataset, basically an equivalent of the PTB example
in
[
tensorflow/models
](
https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb
)
.
in
[
tensorflow/models
](
https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb
)
with its "medium" config.
It has the same performance & speed as the original example as well.
It has the same performance & speed as the original example as well.
Note that the data pipeline is completely copied from the tensorflow example.
Note that the data pipeline is completely copied from the tensorflow example.
...
...
examples/ResNet/README.md
View file @
35833026
...
@@ -16,6 +16,8 @@ To train, just run:
...
@@ -16,6 +16,8 @@ To train, just run:
```
bash
```
bash
./imagenet-resnet.py
--data
/path/to/original/ILSVRC
--gpu
0,1,2,3
-d
18
./imagenet-resnet.py
--data
/path/to/original/ILSVRC
--gpu
0,1,2,3
-d
18
```
```
The speed is 1860 samples/s on 4 TitanX Pascal, and 1160 it/s on 4 old TitanX, provided that your data is fast
enough. See the
[
tutorial
](
http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html
)
on how to speed up your data.


...
...
setup.cfg
View file @
35833026
[metadata]
[metadata]
name = tensorpack
author = TensorPack contributors
author = TensorPack contributors
author-email = ppwwyyxxc@gmail.com
author-email = ppwwyyxxc@gmail.com
url = https://github.com/ppwwyyxx/tensorpack
url = https://github.com/ppwwyyxx/tensorpack
...
...
setup.py
View file @
35833026
...
@@ -51,6 +51,7 @@ for s in scripts:
...
@@ -51,6 +51,7 @@ for s in scripts:
scripts_to_install
.
append
(
newname
)
scripts_to_install
.
append
(
newname
)
setup
(
setup
(
name
=
'tensorpack'
,
version
=
__version__
,
version
=
__version__
,
description
=
'Neural Network Toolbox on TensorFlow'
,
description
=
'Neural Network Toolbox on TensorFlow'
,
long_description
=
long_description
,
long_description
=
long_description
,
...
...
tensorpack/train/base.py
View file @
35833026
...
@@ -10,8 +10,9 @@ import six
...
@@ -10,8 +10,9 @@ import six
from
six.moves
import
range
from
six.moves
import
range
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.predict
import
PredictorFactory
from
.config
import
TrainConfig
from
.config
import
TrainConfig
from
..utils
import
logger
from
..utils
import
logger
,
deprecated
from
..callbacks
import
StatHolder
from
..callbacks
import
StatHolder
from
..tfutils
import
get_global_step_value
from
..tfutils
import
get_global_step_value
from
..tfutils.modelutils
import
describe_model
from
..tfutils.modelutils
import
describe_model
...
@@ -62,7 +63,8 @@ class Trainer(object):
...
@@ -62,7 +63,8 @@ class Trainer(object):
@
abstractmethod
@
abstractmethod
def
run_step
(
self
):
def
run_step
(
self
):
""" Abstract method. Run one iteration. """
""" Abstract method: run one iteration. Subclass should define what is "iteration".
"""
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
"""
"""
...
@@ -102,7 +104,7 @@ class Trainer(object):
...
@@ -102,7 +104,7 @@ class Trainer(object):
def
add_scalar_summary
(
self
,
name
,
val
):
def
add_scalar_summary
(
self
,
name
,
val
):
"""
"""
Add a scalar sumary to both TF events file and StatHolder.
Add a scalar sum
m
ary to both TF events file and StatHolder.
"""
"""
self
.
add_summary
(
create_scalar_summary
(
name
,
val
))
self
.
add_summary
(
create_scalar_summary
(
name
,
val
))
...
@@ -187,22 +189,29 @@ class Trainer(object):
...
@@ -187,22 +189,29 @@ class Trainer(object):
self
.
summary_writer
.
close
()
self
.
summary_writer
.
close
()
self
.
monitored_sess
.
close
()
self
.
monitored_sess
.
close
()
def
get_predict_func
(
self
,
input_names
,
output_names
):
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
"""
Args:
Args:
input_names (list), output_names(list): list of names
input_names (list), output_names(list): list of names
tower (int): return the predictor on the kth tower, defined by ``config.predict_tower``.
Returns:
Returns:
an
OnlinePredictor
an
:class:`OnlinePredictor`.
"""
"""
raise
NotImplementedError
()
if
not
hasattr
(
self
,
'_predictor_factory'
):
self
.
_predictor_factory
=
PredictorFactory
(
self
.
model
,
self
.
config
.
predict_tower
)
return
self
.
_predictor_factory
.
get_predictor
(
input_names
,
output_names
,
tower
)
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
""" Return n predictors.
""" Return n predictors. """
Can be overwritten by subclasses to exploit more
nr_tower
=
len
(
self
.
config
.
predict_tower
)
parallelism among predictors.
if
nr_tower
<
n
:
"""
if
len
(
self
.
config
.
predict_tower
)
>
1
:
logger
.
warn
(
logger
.
warn
(
"[Speed] Have set multiple predict_tower, but only have naive `get_predict_funcs` implementation"
)
"Requested {} predictor but only have {} predict towers! "
return
[
self
.
get_predict_func
(
input_names
,
output_names
)
for
k
in
range
(
n
)]
"Predictors will be assigned to GPUs in round-robin."
.
format
(
n
,
nr_tower
))
return
[
self
.
get_predict_func
(
input_names
,
output_names
,
k
%
nr_tower
)
for
k
in
range
(
n
)]
@
deprecated
(
"Don't need to call it any more!"
,
"2017-03-20"
)
def
_setup_predictor_factory
(
self
):
pass
tensorpack/train/feedfree.py
View file @
35833026
...
@@ -10,7 +10,6 @@ from ..tfutils.tower import TowerContext, get_current_tower_context
...
@@ -10,7 +10,6 @@ from ..tfutils.tower import TowerContext, get_current_tower_context
from
.input_data
import
QueueInput
,
FeedfreeInput
from
.input_data
import
QueueInput
,
FeedfreeInput
from
.base
import
Trainer
from
.base
import
Trainer
from
.trainer
import
MultiPredictorTowerTrainer
__all__
=
[
'FeedfreeTrainerBase'
,
'SingleCostFeedfreeTrainer'
,
__all__
=
[
'FeedfreeTrainerBase'
,
'SingleCostFeedfreeTrainer'
,
'SimpleFeedfreeTrainer'
,
'QueueInputTrainer'
]
'SimpleFeedfreeTrainer'
,
'QueueInputTrainer'
]
...
@@ -40,7 +39,7 @@ class FeedfreeTrainerBase(Trainer):
...
@@ -40,7 +39,7 @@ class FeedfreeTrainerBase(Trainer):
self
.
_input_method
.
_setup
(
self
)
self
.
_input_method
.
_setup
(
self
)
def
run_step
(
self
):
def
run_step
(
self
):
""" Simply run ``self.train_op``
, which minimizes the cost
."""
""" Simply run ``self.train_op``."""
self
.
hooked_sess
.
run
(
self
.
train_op
)
self
.
hooked_sess
.
run
(
self
.
train_op
)
# if not hasattr(self, 'cnt'):
# if not hasattr(self, 'cnt'):
# self.cnt = 0
# self.cnt = 0
...
@@ -75,9 +74,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
...
@@ -75,9 +74,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
return
cost
,
grads
return
cost
,
grads
class
SimpleFeedfreeTrainer
(
class
SimpleFeedfreeTrainer
(
SingleCostFeedfreeTrainer
):
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
"""
"""
A trainer with single cost, single training tower, any number of
A trainer with single cost, single training tower, any number of
prediction tower, and feed-free input.
prediction tower, and feed-free input.
...
@@ -92,7 +89,6 @@ class SimpleFeedfreeTrainer(
...
@@ -92,7 +89,6 @@ class SimpleFeedfreeTrainer(
self
.
_input_method
=
config
.
data
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
self
.
_input_method
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
self
.
_input_method
super
(
SimpleFeedfreeTrainer
,
self
)
.
__init__
(
config
)
super
(
SimpleFeedfreeTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
()
assert
len
(
self
.
config
.
tower
)
==
1
,
\
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"SimpleFeedfreeTrainer doesn't support multigpu!"
"SimpleFeedfreeTrainer doesn't support multigpu!"
...
@@ -116,7 +112,11 @@ def QueueInputTrainer(config, input_queue=None, predict_tower=None):
...
@@ -116,7 +112,11 @@ def QueueInputTrainer(config, input_queue=None, predict_tower=None):
input_queue (tf.QueueBase): an input queue. Defaults to the
input_queue (tf.QueueBase): an input queue. Defaults to the
:class:`QueueInput` default.
:class:`QueueInput` default.
"""
"""
config
.
data
=
QueueInput
(
config
.
dataflow
,
input_queue
)
if
config
.
dataflow
is
not
None
:
config
.
data
=
QueueInput
(
config
.
dataflow
,
input_queue
)
else
:
assert
isinstance
(
config
.
data
,
QueueInput
),
config
.
data
if
predict_tower
is
not
None
:
if
predict_tower
is
not
None
:
log_deprecated
(
"Argument `predict_tower` in trainer"
,
"Use TrainConfig(predict_tower=...) instead!"
)
log_deprecated
(
"Argument `predict_tower` in trainer"
,
"Use TrainConfig(predict_tower=...) instead!"
)
config
.
predict_tower
=
predict_tower
config
.
predict_tower
=
predict_tower
...
...
tensorpack/train/multigpu.py
View file @
35833026
...
@@ -16,7 +16,6 @@ from ..tfutils.collection import backup_collection, restore_collection
...
@@ -16,7 +16,6 @@ from ..tfutils.collection import backup_collection, restore_collection
from
..tfutils.gradproc
import
apply_grad_processors
,
ScaleGradient
from
..tfutils.gradproc
import
apply_grad_processors
,
ScaleGradient
from
.base
import
Trainer
from
.base
import
Trainer
from
.trainer
import
MultiPredictorTowerTrainer
from
.feedfree
import
SingleCostFeedfreeTrainer
from
.feedfree
import
SingleCostFeedfreeTrainer
from
.input_data
import
QueueInput
from
.input_data
import
QueueInput
...
@@ -68,8 +67,7 @@ class MultiGPUTrainer(Trainer):
...
@@ -68,8 +67,7 @@ class MultiGPUTrainer(Trainer):
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
,
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
,
SingleCostFeedfreeTrainer
,
SingleCostFeedfreeTrainer
):
MultiPredictorTowerTrainer
):
"""
"""
A multi-tower multi-GPU trainer which synchronoizes the gradients computed
A multi-tower multi-GPU trainer which synchronoizes the gradients computed
from each tower and averages them.
from each tower and averages them.
...
@@ -97,7 +95,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -97,7 +95,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
config
.
predict_tower
=
predict_tower
config
.
predict_tower
=
predict_tower
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
()
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
tf
.
test
.
is_gpu_available
()
assert
tf
.
test
.
is_gpu_available
()
self
.
average_cost
=
average_cost
self
.
average_cost
=
average_cost
...
@@ -158,8 +155,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -158,8 +155,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
,
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
,
SingleCostFeedfreeTrainer
,
SingleCostFeedfreeTrainer
):
MultiPredictorTowerTrainer
):
"""
"""
A multi-tower multi-GPU trainer where each tower independently
A multi-tower multi-GPU trainer where each tower independently
asynchronously updates the model without locking.
asynchronously updates the model without locking.
...
@@ -187,7 +183,6 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -187,7 +183,6 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
log_deprecated
(
"Argument `predict_tower` in trainer"
,
"Use TrainConfig.predict_tower instead!"
)
log_deprecated
(
"Argument `predict_tower` in trainer"
,
"Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
config
.
predict_tower
=
predict_tower
self
.
_setup_predictor_factory
()
self
.
_scale_gradient
=
scale_gradient
self
.
_scale_gradient
=
scale_gradient
assert
tf
.
test
.
is_gpu_available
()
assert
tf
.
test
.
is_gpu_available
()
...
...
tensorpack/train/predict.py
0 → 100644
View file @
35833026
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: predict.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
..utils
import
SUMMARY_BACKUP_KEYS
,
PREDICT_TOWER
from
..tfutils.collection
import
freeze_collection
from
..tfutils
import
get_tensors_by_names
,
get_op_tensor_name
from
..predict
import
OnlinePredictor
,
build_prediction_graph
__all__
=
[
'PredictorFactory'
]
class
PredictorFactory
(
object
):
""" Make predictors for a trainer"""
def
__init__
(
self
,
model
,
towers
):
"""
Args:
towers (list[int]): list of gpu id
"""
self
.
model
=
model
self
.
towers
=
towers
self
.
tower_built
=
False
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
):
"""
Args:
tower (int): need the kth tower (not the gpu id)
Returns:
an online predictor (which has to be used under a default session)
"""
if
not
self
.
tower_built
:
self
.
_build_predict_tower
()
tower
=
self
.
towers
[
tower
]
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
model
.
get_inputs_desc
()])
def
get_name_in_tower
(
name
):
return
PREDICT_TOWER
+
str
(
tower
)
+
'/'
+
name
def
maybe_inside_tower
(
name
):
name
=
get_op_tensor_name
(
name
)[
0
]
if
name
in
placeholder_names
:
return
name
else
:
return
get_name_in_tower
(
name
)
input_names
=
map
(
maybe_inside_tower
,
input_names
)
raw_input_vars
=
get_tensors_by_names
(
input_names
)
output_names
=
map
(
get_name_in_tower
,
output_names
)
output_vars
=
get_tensors_by_names
(
output_names
)
return
OnlinePredictor
(
raw_input_vars
,
output_vars
)
def
_build_predict_tower
(
self
):
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
),
\
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
def
fn
(
_
):
self
.
model
.
build_graph
(
self
.
model
.
get_reused_placehdrs
())
build_prediction_graph
(
fn
,
self
.
towers
)
self
.
tower_built
=
True
tensorpack/train/trainer.py
View file @
35833026
...
@@ -2,69 +2,13 @@
...
@@ -2,69 +2,13 @@
# File: trainer.py
# File: trainer.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
.base
import
Trainer
from
.base
import
Trainer
from
..utils
import
SUMMARY_BACKUP_KEYS
,
PREDICT_TOWER
from
..tfutils
import
TowerContext
from
..tfutils
import
get_tensors_by_names
,
TowerContext
,
get_op_tensor_name
from
..tfutils.collection
import
freeze_collection
from
..predict
import
OnlinePredictor
,
build_prediction_graph
from
.input_data
import
FeedInput
from
.input_data
import
FeedInput
from
.predict
import
PredictorFactory
__all__
=
[
'SimpleTrainer'
,
'MultiPredictorTowerTrainer'
]
__all__
=
[
'SimpleTrainer'
]
class
PredictorFactory
(
object
):
""" Make predictors for a trainer"""
def
__init__
(
self
,
model
,
towers
):
"""
:param towers: list of gpu relative id
"""
self
.
model
=
model
self
.
towers
=
towers
self
.
tower_built
=
False
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
):
"""
Args:
tower: need the kth tower (not the gpu id)
Returns:
an online predictor (which has to be used under a default session)
"""
if
not
self
.
tower_built
:
self
.
_build_predict_tower
()
tower
=
self
.
towers
[
tower
%
len
(
self
.
towers
)]
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
model
.
get_inputs_desc
()])
def
get_name_in_tower
(
name
):
return
PREDICT_TOWER
+
str
(
tower
)
+
'/'
+
name
def
maybe_inside_tower
(
name
):
name
=
get_op_tensor_name
(
name
)[
0
]
if
name
in
placeholder_names
:
return
name
else
:
return
get_name_in_tower
(
name
)
input_names
=
map
(
maybe_inside_tower
,
input_names
)
raw_input_vars
=
get_tensors_by_names
(
input_names
)
output_names
=
map
(
get_name_in_tower
,
output_names
)
output_vars
=
get_tensors_by_names
(
output_names
)
return
OnlinePredictor
(
raw_input_vars
,
output_vars
)
def
_build_predict_tower
(
self
):
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
),
\
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
def
fn
(
_
):
self
.
model
.
build_graph
(
self
.
model
.
get_reused_placehdrs
())
build_prediction_graph
(
fn
,
self
.
towers
)
self
.
tower_built
=
True
class
SimpleTrainer
(
Trainer
):
class
SimpleTrainer
(
Trainer
):
...
@@ -103,25 +47,3 @@ class SimpleTrainer(Trainer):
...
@@ -103,25 +47,3 @@ class SimpleTrainer(Trainer):
def
get_predict_func
(
self
,
input_names
,
output_names
):
def
get_predict_func
(
self
,
input_names
,
output_names
):
return
self
.
_predictor_factory
.
get_predictor
(
input_names
,
output_names
,
0
)
return
self
.
_predictor_factory
.
get_predictor
(
input_names
,
output_names
,
0
)
class
MultiPredictorTowerTrainer
(
Trainer
):
""" A trainer with possibly multiple prediction tower """
def
_setup_predictor_factory
(
self
):
# by default, use the first training gpu for prediction
self
.
_predictor_factory
=
PredictorFactory
(
self
.
model
,
self
.
config
.
predict_tower
)
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
Args:
tower (int): return the kth predict_func
Returns:
an OnlinePredictor instance
"""
return
self
.
_predictor_factory
.
get_predictor
(
input_names
,
output_names
,
tower
)
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
return
[
self
.
get_predict_func
(
input_names
,
output_names
,
k
)
for
k
in
range
(
n
)]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment