Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4e644290
Commit
4e644290
authored
Oct 19, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add a function to apply general prefetch policies
parent
fff3f2d3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
27 deletions
+33
-27
tensorpack/trainv2/base.py
tensorpack/trainv2/base.py
+4
-2
tensorpack/trainv2/interface.py
tensorpack/trainv2/interface.py
+29
-25
No files found.
tensorpack/trainv2/base.py
View file @
4e644290
...
...
@@ -239,8 +239,10 @@ class SingleCostTrainer(Trainer):
Args:
inputs_desc ([InputDesc]):
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor.
Might get called multiple times for data-parallel training or inference.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
optimizer. Will only be called once.
Returns:
[Callback]: a (possibly empty) list of callbacks needed for training.
...
...
tensorpack/trainv2/interface.py
View file @
4e644290
...
...
@@ -5,18 +5,35 @@
import
tensorflow
as
tf
from
..input_source
import
(
FeedInput
,
QueueInput
,
StagingInputWrapper
,
DummyConstantInput
)
InputSource
,
FeedInput
,
QueueInput
,
StagingInputWrapper
,
DummyConstantInput
)
from
..train.config
import
TrainConfig
from
.base
import
SingleCostTrainer
from
.trainers
import
SimpleTrainer
,
DistributedTrainerReplicated
__all__
=
[
'launch_train_with_config'
,
'TrainConfig'
]
__all__
=
[
'launch_train_with_config'
,
'TrainConfig'
,
'apply_default_prefetch'
]
def
_maybe_gpu_prefetch
(
input
,
towers
,
gpu_prefetch
):
# seem to only improve on >1 GPUs
if
len
(
towers
)
>
1
and
gpu_prefetch
:
def
apply_default_prefetch
(
input_source_or_dataflow
,
trainer
,
towers
):
"""
Apply a set of default rules to make a fast :class:`InputSource`.
Args:
input_source_or_dataflow(InputSource | DataFlow):
trainer (Trainer):
towers ([int]): list of GPU ids.
"""
if
not
isinstance
(
input_source_or_dataflow
,
InputSource
):
# to mimic same behavior of the old trainer interface
if
type
(
trainer
)
==
SimpleTrainer
:
input
=
FeedInput
(
input_source_or_dataflow
)
else
:
input
=
QueueInput
(
input_source_or_dataflow
)
else
:
input
=
input_source_or_dataflow
if
len
(
towers
)
>
1
:
# seem to only improve on >1 GPUs
assert
not
isinstance
(
trainer
,
SimpleTrainer
)
assert
tf
.
test
.
is_gpu_available
()
if
not
isinstance
(
input
,
(
StagingInputWrapper
,
DummyConstantInput
)):
...
...
@@ -26,7 +43,8 @@ def _maybe_gpu_prefetch(input, towers, gpu_prefetch):
def
launch_train_with_config
(
config
,
trainer
):
"""
To mimic the old training interface, with a trainer and a config.
Train with a :class:`TrainConfig` and a new version of :class:`Trainer`, to
mimic the old training interface.
Args:
config (TrainConfig):
...
...
@@ -49,18 +67,8 @@ def launch_train_with_config(config, trainer):
model
=
config
.
model
inputs_desc
=
model
.
get_inputs_desc
()
input
=
config
.
data
# some check & input wrappers to mimic same behavior of the old trainer interface
if
input
is
None
:
if
type
(
trainer
)
==
SimpleTrainer
:
input
=
FeedInput
(
config
.
dataflow
)
else
:
input
=
QueueInput
(
config
.
dataflow
)
if
config
.
nr_tower
>
1
:
assert
not
isinstance
(
trainer
,
SimpleTrainer
)
input
=
_maybe_gpu_prefetch
(
input
,
config
.
tower
,
True
)
input
=
config
.
data
or
config
.
dataflow
input
=
apply_default_prefetch
(
input
,
trainer
,
config
.
tower
)
if
isinstance
(
trainer
,
DistributedTrainerReplicated
)
and
\
config
.
session_config
is
not
None
:
...
...
@@ -72,10 +80,6 @@ def launch_train_with_config(config, trainer):
inputs_desc
,
input
,
model
.
build_graph_get_cost
,
model
.
get_optimizer
)
trainer
.
train
(
config
.
callbacks
,
config
.
monitors
,
config
.
session_creator
,
config
.
session_init
,
config
.
steps_per_epoch
,
config
.
starting_epoch
,
config
.
max_epoch
)
config
.
callbacks
,
config
.
monitors
,
config
.
session_creator
,
config
.
session_init
,
config
.
steps_per_epoch
,
config
.
starting_epoch
,
config
.
max_epoch
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment