Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
99c70935
Commit
99c70935
authored
Dec 15, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
move predict_tower into trainconfig
parent
48ef46aa
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
68 additions
and
22 deletions
+68
-22
examples/OpenAIGym/train-atari.py
examples/OpenAIGym/train-atari.py
+2
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+41
-11
tensorpack/train/config.py
tensorpack/train/config.py
+6
-2
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+6
-3
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+13
-5
No files found.
examples/OpenAIGym/train-atari.py
View file @
99c70935
...
@@ -258,4 +258,5 @@ if __name__ == '__main__':
...
@@ -258,4 +258,5 @@ if __name__ == '__main__':
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
tower
=
train_tower
config
.
tower
=
train_tower
trainer
(
config
,
predict_tower
=
predict_tower
)
.
train
()
config
.
predict_tower
=
predict_tower
trainer
(
config
)
.
train
()
tensorpack/callbacks/inference_runner.py
View file @
99c70935
...
@@ -14,9 +14,21 @@ from .inference import Inferencer
...
@@ -14,9 +14,21 @@ from .inference import Inferencer
from
.dispatcher
import
OutputTensorDispatcer
from
.dispatcher
import
OutputTensorDispatcer
from
..tfutils
import
get_op_tensor_name
from
..tfutils
import
get_op_tensor_name
from
..utils
import
logger
,
get_tqdm
from
..utils
import
logger
,
get_tqdm
from
..train.input_data
import
FeedfreeInput
__all__
=
[
'InferenceRunner'
]
__all__
=
[
'InferenceRunner'
]
def
summary_inferencer
(
trainer
,
infs
):
for
inf
in
infs
:
ret
=
inf
.
after_inference
()
for
k
,
v
in
six
.
iteritems
(
ret
):
try
:
v
=
float
(
v
)
except
:
logger
.
warn
(
"{} returns a non-scalar statistics!"
.
format
(
type
(
inf
)
.
__name__
))
continue
trainer
.
write_scalar_summary
(
k
,
v
)
class
InferenceRunner
(
Callback
):
class
InferenceRunner
(
Callback
):
"""
"""
A callback that runs different kinds of inferencer.
A callback that runs different kinds of inferencer.
...
@@ -31,14 +43,14 @@ class InferenceRunner(Callback):
...
@@ -31,14 +43,14 @@ class InferenceRunner(Callback):
:param input_tensor_names: list of tensors to feed the dataflow to.
:param input_tensor_names: list of tensors to feed the dataflow to.
default to all the input placeholders.
default to all the input placeholders.
"""
"""
assert
isinstance
(
ds
,
DataFlow
),
type
(
ds
)
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
ds
=
ds
self
.
ds
=
ds
if
not
isinstance
(
infs
,
list
):
if
not
isinstance
(
infs
,
list
):
self
.
infs
=
[
infs
]
self
.
infs
=
[
infs
]
else
:
else
:
self
.
infs
=
infs
self
.
infs
=
infs
for
v
in
self
.
infs
:
for
v
in
self
.
infs
:
assert
isinstance
(
v
,
Inferencer
),
str
(
v
)
assert
isinstance
(
v
,
Inferencer
),
v
self
.
input_tensors
=
input_tensors
self
.
input_tensors
=
input_tensors
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
...
@@ -96,12 +108,30 @@ class InferenceRunner(Callback):
...
@@ -96,12 +108,30 @@ class InferenceRunner(Callback):
self
.
_write_summary_after_inference
()
self
.
_write_summary_after_inference
()
def
_write_summary_after_inference
(
self
):
def
_write_summary_after_inference
(
self
):
for
inf
in
self
.
infs
:
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
ret
=
inf
.
after_inference
()
for
k
,
v
in
six
.
iteritems
(
ret
):
class
FeedfreeInferenceRunner
(
Callback
):
try
:
IOTensor
=
namedtuple
(
'IOTensor'
,
[
'index'
,
'isOutput'
])
v
=
float
(
v
)
except
:
def
__init__
(
self
,
input
,
infs
,
input_tensors
=
None
):
logger
.
warn
(
"{} returns a non-scalar statistics!"
.
format
(
type
(
inf
)
.
__name__
))
assert
isinstance
(
input
,
FeedfreeInput
),
input
continue
self
.
_input_data
=
input
self
.
trainer
.
write_scalar_summary
(
k
,
v
)
if
not
isinstance
(
infs
,
list
):
self
.
infs
=
[
infs
]
else
:
self
.
infs
=
infs
for
v
in
self
.
infs
:
assert
isinstance
(
v
,
Inferencer
),
v
self
.
input_tensor_names
=
input_tensors
def
_setup_graph
(
self
):
self
.
_input_data
.
_setup
(
self
.
trainer
)
# only 1 prediction tower will be used for inference
self
.
_input_tensors
=
self
.
_input_data
.
get_input_tensors
()
# TODO filter by names
self
.
_find_output_tensors
()
def
_find_output_tensors
(
self
):
pass
tensorpack/train/config.py
View file @
99c70935
...
@@ -4,12 +4,12 @@
...
@@ -4,12 +4,12 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..callbacks
import
Callbacks
from
..callbacks.group
import
Callbacks
from
..dataflow.base
import
DataFlow
from
..models
import
ModelDesc
from
..models
import
ModelDesc
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils
import
(
JustCurrentSession
,
from
..tfutils
import
(
JustCurrentSession
,
get_default_sess_config
,
SessionInit
)
get_default_sess_config
,
SessionInit
)
from
..dataflow
import
DataFlow
from
.input_data
import
InputData
from
.input_data
import
InputData
__all__
=
[
'TrainConfig'
]
__all__
=
[
'TrainConfig'
]
...
@@ -35,6 +35,7 @@ class TrainConfig(object):
...
@@ -35,6 +35,7 @@ class TrainConfig(object):
:param max_epoch: maximum number of epoch to run training. default to inf
:param max_epoch: maximum number of epoch to run training. default to inf
:param nr_tower: int. number of training towers. default to 1.
:param nr_tower: int. number of training towers. default to 1.
:param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given.
:param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given.
:param predict_tower: list of prediction tower in their relative gpu id. Defaults to [0]
"""
"""
def
assert_type
(
v
,
tp
):
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
assert
isinstance
(
v
,
tp
),
v
.
__class__
...
@@ -81,6 +82,9 @@ class TrainConfig(object):
...
@@ -81,6 +82,9 @@ class TrainConfig(object):
self
.
tower
=
kwargs
.
pop
(
'tower'
)
self
.
tower
=
kwargs
.
pop
(
'tower'
)
else
:
else
:
self
.
tower
=
[
0
]
self
.
tower
=
[
0
]
self
.
predict_tower
=
kwargs
.
pop
(
'predict_tower'
,
[
0
])
if
isinstance
(
self
.
predict_tower
,
int
):
self
.
predict_tower
=
[
self
.
predict_tower
]
# TODO deprecated @Dec20
# TODO deprecated @Dec20
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
...
...
tensorpack/train/feedfree.py
View file @
99c70935
...
@@ -63,7 +63,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
...
@@ -63,7 +63,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
class
SimpleFeedfreeTrainer
(
class
SimpleFeedfreeTrainer
(
MultiPredictorTowerTrainer
,
MultiPredictorTowerTrainer
,
SingleCostFeedfreeTrainer
):
SingleCostFeedfreeTrainer
):
def
__init__
(
self
,
config
,
predict_tower
=
None
):
def
__init__
(
self
,
config
):
"""
"""
A trainer with single cost, single training tower and feed-free input
A trainer with single cost, single training tower and feed-free input
config.data must exists
config.data must exists
...
@@ -71,7 +71,7 @@ class SimpleFeedfreeTrainer(
...
@@ -71,7 +71,7 @@ class SimpleFeedfreeTrainer(
self
.
_input_method
=
config
.
data
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
self
.
_input_method
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
self
.
_input_method
super
(
SimpleFeedfreeTrainer
,
self
)
.
__init__
(
config
)
super
(
SimpleFeedfreeTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
assert
len
(
self
.
config
.
tower
)
==
1
,
\
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"SimpleFeedfreeTrainer doesn't support multigpu!"
"SimpleFeedfreeTrainer doesn't support multigpu!"
...
@@ -99,6 +99,9 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
...
@@ -99,6 +99,9 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
Use -1 for cpu.
Use -1 for cpu.
"""
"""
config
.
data
=
QueueInput
(
config
.
dataset
,
input_queue
)
config
.
data
=
QueueInput
(
config
.
dataset
,
input_queue
)
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
assert
len
(
config
.
tower
)
==
1
,
\
assert
len
(
config
.
tower
)
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
,
predict_tower
)
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
tensorpack/train/multigpu.py
View file @
99c70935
...
@@ -53,9 +53,13 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -53,9 +53,13 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
else
:
else
:
self
.
_input_method
=
config
.
data
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
QueueInput
)
assert
isinstance
(
self
.
_input_method
,
QueueInput
)
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
predict_tower
)
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
tf
.
test
.
is_gpu_available
()
assert
tf
.
test
.
is_gpu_available
()
...
@@ -101,8 +105,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -101,8 +105,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
MultiPredictorTowerTrainer
):
MultiPredictorTowerTrainer
):
def
__init__
(
self
,
config
,
def
__init__
(
self
,
config
,
input_queue
=
None
,
input_queue
=
None
,
predict_tower
=
Non
e
,
average_gradient
=
Tru
e
,
average_gradient
=
Tru
e
):
predict_tower
=
Non
e
):
if
hasattr
(
config
,
'dataset'
):
if
hasattr
(
config
,
'dataset'
):
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
else
:
else
:
...
@@ -110,7 +114,11 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -110,7 +114,11 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
assert
isinstance
(
self
.
_input_method
,
QueueInput
)
assert
isinstance
(
self
.
_input_method
,
QueueInput
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
predict_tower
)
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
self
.
_average_gradient
=
average_gradient
self
.
_average_gradient
=
average_gradient
assert
tf
.
test
.
is_gpu_available
()
assert
tf
.
test
.
is_gpu_available
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment