Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
99c70935
Commit
99c70935
authored
Dec 15, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
move predict_tower into trainconfig
parent
48ef46aa
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
68 additions
and
22 deletions
+68
-22
examples/OpenAIGym/train-atari.py
examples/OpenAIGym/train-atari.py
+2
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+41
-11
tensorpack/train/config.py
tensorpack/train/config.py
+6
-2
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+6
-3
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+13
-5
No files found.
examples/OpenAIGym/train-atari.py
View file @
99c70935
...
...
@@ -258,4 +258,5 @@ if __name__ == '__main__':
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
tower
=
train_tower
trainer
(
config
,
predict_tower
=
predict_tower
)
.
train
()
config
.
predict_tower
=
predict_tower
trainer
(
config
)
.
train
()
tensorpack/callbacks/inference_runner.py
View file @
99c70935
...
...
@@ -14,9 +14,21 @@ from .inference import Inferencer
from
.dispatcher
import
OutputTensorDispatcer
from
..tfutils
import
get_op_tensor_name
from
..utils
import
logger
,
get_tqdm
from
..train.input_data
import
FeedfreeInput
__all__
=
[
'InferenceRunner'
]
def
summary_inferencer
(
trainer
,
infs
):
for
inf
in
infs
:
ret
=
inf
.
after_inference
()
for
k
,
v
in
six
.
iteritems
(
ret
):
try
:
v
=
float
(
v
)
except
:
logger
.
warn
(
"{} returns a non-scalar statistics!"
.
format
(
type
(
inf
)
.
__name__
))
continue
trainer
.
write_scalar_summary
(
k
,
v
)
class
InferenceRunner
(
Callback
):
"""
A callback that runs different kinds of inferencer.
...
...
@@ -31,14 +43,14 @@ class InferenceRunner(Callback):
:param input_tensor_names: list of tensors to feed the dataflow to.
default to all the input placeholders.
"""
assert
isinstance
(
ds
,
DataFlow
),
type
(
ds
)
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
ds
=
ds
if
not
isinstance
(
infs
,
list
):
self
.
infs
=
[
infs
]
else
:
self
.
infs
=
infs
for
v
in
self
.
infs
:
assert
isinstance
(
v
,
Inferencer
),
str
(
v
)
assert
isinstance
(
v
,
Inferencer
),
v
self
.
input_tensors
=
input_tensors
def
_setup_graph
(
self
):
...
...
@@ -96,12 +108,30 @@ class InferenceRunner(Callback):
self
.
_write_summary_after_inference
()
def
_write_summary_after_inference
(
self
):
for
inf
in
self
.
infs
:
ret
=
inf
.
after_inference
()
for
k
,
v
in
six
.
iteritems
(
ret
):
try
:
v
=
float
(
v
)
except
:
logger
.
warn
(
"{} returns a non-scalar statistics!"
.
format
(
type
(
inf
)
.
__name__
))
continue
self
.
trainer
.
write_scalar_summary
(
k
,
v
)
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
class
FeedfreeInferenceRunner
(
Callback
):
IOTensor
=
namedtuple
(
'IOTensor'
,
[
'index'
,
'isOutput'
])
def
__init__
(
self
,
input
,
infs
,
input_tensors
=
None
):
assert
isinstance
(
input
,
FeedfreeInput
),
input
self
.
_input_data
=
input
if
not
isinstance
(
infs
,
list
):
self
.
infs
=
[
infs
]
else
:
self
.
infs
=
infs
for
v
in
self
.
infs
:
assert
isinstance
(
v
,
Inferencer
),
v
self
.
input_tensor_names
=
input_tensors
def
_setup_graph
(
self
):
self
.
_input_data
.
_setup
(
self
.
trainer
)
# only 1 prediction tower will be used for inference
self
.
_input_tensors
=
self
.
_input_data
.
get_input_tensors
()
# TODO filter by names
self
.
_find_output_tensors
()
def
_find_output_tensors
(
self
):
pass
tensorpack/train/config.py
View file @
99c70935
...
...
@@ -4,12 +4,12 @@
import
tensorflow
as
tf
from
..callbacks
import
Callbacks
from
..callbacks.group
import
Callbacks
from
..dataflow.base
import
DataFlow
from
..models
import
ModelDesc
from
..utils
import
logger
from
..tfutils
import
(
JustCurrentSession
,
get_default_sess_config
,
SessionInit
)
from
..dataflow
import
DataFlow
from
.input_data
import
InputData
__all__
=
[
'TrainConfig'
]
...
...
@@ -35,6 +35,7 @@ class TrainConfig(object):
:param max_epoch: maximum number of epoch to run training. default to inf
:param nr_tower: int. number of training towers. default to 1.
:param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given.
:param predict_tower: list of prediction tower in their relative gpu id. Defaults to [0]
"""
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
...
...
@@ -81,6 +82,9 @@ class TrainConfig(object):
self
.
tower
=
kwargs
.
pop
(
'tower'
)
else
:
self
.
tower
=
[
0
]
self
.
predict_tower
=
kwargs
.
pop
(
'predict_tower'
,
[
0
])
if
isinstance
(
self
.
predict_tower
,
int
):
self
.
predict_tower
=
[
self
.
predict_tower
]
# TODO deprecated @Dec20
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
...
...
tensorpack/train/feedfree.py
View file @
99c70935
...
...
@@ -63,7 +63,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
class
SimpleFeedfreeTrainer
(
MultiPredictorTowerTrainer
,
SingleCostFeedfreeTrainer
):
def
__init__
(
self
,
config
,
predict_tower
=
None
):
def
__init__
(
self
,
config
):
"""
A trainer with single cost, single training tower and feed-free input
config.data must exists
...
...
@@ -71,7 +71,7 @@ class SimpleFeedfreeTrainer(
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
self
.
_input_method
super
(
SimpleFeedfreeTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"SimpleFeedfreeTrainer doesn't support multigpu!"
...
...
@@ -99,6 +99,9 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
Use -1 for cpu.
"""
config
.
data
=
QueueInput
(
config
.
dataset
,
input_queue
)
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
assert
len
(
config
.
tower
)
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
,
predict_tower
)
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
tensorpack/train/multigpu.py
View file @
99c70935
...
...
@@ -53,9 +53,13 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
else
:
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
QueueInput
)
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
predict_tower
)
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
tf
.
test
.
is_gpu_available
()
...
...
@@ -101,8 +105,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
MultiPredictorTowerTrainer
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
Non
e
,
average_gradient
=
Tru
e
):
average_gradient
=
Tru
e
,
predict_tower
=
Non
e
):
if
hasattr
(
config
,
'dataset'
):
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
else
:
...
...
@@ -110,7 +114,11 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
assert
isinstance
(
self
.
_input_method
,
QueueInput
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
predict_tower
)
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
self
.
_average_gradient
=
average_gradient
assert
tf
.
test
.
is_gpu_available
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment