Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e6493857
Commit
e6493857
authored
Oct 18, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
StagingInputWrapper takes a list of int
parent
4c5cdf9b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
22 additions
and
15 deletions
+22
-15
examples/GAN/GAN.py
examples/GAN/GAN.py
+1
-1
tensorpack/__init__.py
tensorpack/__init__.py
+6
-1
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+14
-6
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+1
-2
tensorpack/train/simple.py
tensorpack/train/simple.py
+0
-5
No files found.
examples/GAN/GAN.py
View file @
e6493857
...
...
@@ -136,7 +136,7 @@ class MultiGPUGANTrainer(Trainer):
raw_devices
=
[
'/gpu:{}'
.
format
(
k
)
for
k
in
config
.
tower
]
# setup input
input
=
StagingInputWrapper
(
QueueInput
(
config
.
dataflow
),
raw_devices
)
input
=
StagingInputWrapper
(
QueueInput
(
config
.
dataflow
),
config
.
tower
)
model
=
config
.
model
cbs
=
input
.
setup
(
model
.
get_inputs_desc
())
config
.
callbacks
.
extend
(
cbs
)
...
...
tensorpack/__init__.py
View file @
e6493857
...
...
@@ -2,6 +2,7 @@
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
os
as
_os
from
tensorpack.libinfo
import
__version__
,
_HAS_TF
...
...
@@ -15,7 +16,11 @@ if _HAS_TF:
from
tensorpack.callbacks
import
*
from
tensorpack.tfutils
import
*
from
tensorpack.train
import
*
# In development. Default to v1
if
_os
.
environ
.
get
(
'TENSORPACK_TRAIN_API'
,
'v1'
)
==
'v2'
:
from
tensorpack.trainv2
import
*
else
:
from
tensorpack.train
import
*
from
tensorpack.graph_builder
import
*
from
tensorpack.input_source
import
*
from
tensorpack.predict
import
*
tensorpack/input_source/input_source.py
View file @
e6493857
...
...
@@ -19,6 +19,7 @@ from ..tfutils.common import get_op_tensor_name
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
..utils.concurrency
import
ShareSessionThread
from
..utils.develop
import
log_deprecated
from
..callbacks.base
import
Callback
from
..callbacks.graph
import
RunOp
...
...
@@ -457,7 +458,8 @@ class TFDatasetInput(FeedfreeInput):
class
StagingInputWrapper
(
FeedfreeInput
):
"""
A wrapper around a feedfree input, to prefetch it in StagingArea (usually on GPUs).
A wrapper around a feedfree input,
to prefetch the input in StagingArea (on GPUs).
"""
class
StagingCallback
(
Callback
):
"""
...
...
@@ -478,16 +480,22 @@ class StagingInputWrapper(FeedfreeInput):
def
_before_run
(
self
,
ctx
):
return
self
.
fetches
def
__init__
(
self
,
input
,
device
s
,
nr_stage
=
5
):
def
__init__
(
self
,
input
,
tower
s
,
nr_stage
=
5
):
"""
Args:
input
: a :class:`FeedfreeInput`
devices: list of devices to be used for each training tower
nr_stage: number of elements to prefetch
input
(FeedfreeInput):
towers ([int]): list of GPU ids to prefetch on.
nr_stage: number of elements to prefetch
on each GPU.
"""
assert
isinstance
(
input
,
FeedfreeInput
),
input
self
.
_input
=
input
self
.
_devices
=
devices
if
not
isinstance
(
towers
[
0
],
int
):
# API changed
log_deprecated
(
"StagingInputWrapper(devices=)"
,
"Use (towers=) instead!"
,
"2018-01-31"
)
self
.
_devices
=
towers
else
:
self
.
_devices
=
[
'/gpu:{}'
.
format
(
k
)
for
k
in
towers
]
self
.
_nr_stage
=
nr_stage
self
.
_areas
=
[]
self
.
_stage_ops
=
[]
...
...
tensorpack/train/multigpu.py
View file @
e6493857
...
...
@@ -44,8 +44,7 @@ def apply_prefetch_policy(config, gpu_prefetch=True):
# seem to only improve on >1 GPUs
if
not
isinstance
(
config
.
data
,
(
StagingInputWrapper
,
DummyConstantInput
)):
devices
=
[
'/gpu:{}'
.
format
(
k
)
for
k
in
config
.
tower
]
config
.
data
=
StagingInputWrapper
(
config
.
data
,
devices
)
config
.
data
=
StagingInputWrapper
(
config
.
data
,
config
.
tower
)
class
SyncMultiGPUTrainerParameterServer
(
Trainer
):
...
...
tensorpack/train/simple.py
View file @
e6493857
...
...
@@ -62,9 +62,4 @@ def QueueInputTrainer(config, input_queue=None):
else
:
config
.
data
=
QueueInput
(
config
.
dataflow
,
input_queue
)
config
.
dataflow
=
None
# debug
# from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[128,224,224,3], [128]])
return
SimpleTrainer
(
config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment