Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
3e9de2ae
Commit
3e9de2ae
authored
May 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
a faster DummyInput
parent
2b4f7b14
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
11 deletions
+34
-11
examples/PennTreebank/reader.py
examples/PennTreebank/reader.py
+1
-1
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+6
-0
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+27
-10
No files found.
examples/PennTreebank/reader.py
View file @
3e9de2ae
...
...
@@ -27,7 +27,7 @@ import tensorflow as tf
def
_read_words
(
filename
):
with
tf
.
gfile
.
GFile
(
filename
,
"r"
)
as
f
:
with
tf
.
gfile
.
GFile
(
filename
,
"r
b
"
)
as
f
:
return
f
.
read
()
.
decode
(
"utf-8"
)
.
replace
(
"
\n
"
,
"<eos>"
)
.
split
()
...
...
tensorpack/tfutils/tower.py
View file @
3e9de2ae
...
...
@@ -42,6 +42,12 @@ class TowerContext(object):
def
name
(
self
):
return
self
.
_name
@
property
def
index
(
self
):
if
self
.
_name
==
''
:
return
0
return
int
(
self
.
_name
[
-
1
])
def
get_variable_on_tower
(
self
,
*
args
,
**
kwargs
):
"""
Get a variable for this tower specifically, without reusing, even if
...
...
tensorpack/train/input_data.py
View file @
3e9de2ae
...
...
@@ -254,18 +254,35 @@ class DummyConstantInput(FeedfreeInput):
"""
self
.
shapes
=
shapes
logger
.
warn
(
"Using dummy input for debug!"
)
self
.
_cnt
=
0
def
setup
(
self
,
model
):
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
def
get_input_tensors
(
self
):
def
setup_training
(
self
,
trainer
):
super
(
DummyConstantInput
,
self
)
.
setup_training
(
trainer
)
nr_tower
=
trainer
.
config
.
nr_tower
placehdrs
=
self
.
input_placehdrs
assert
len
(
self
.
shapes
)
==
len
(
placehdrs
)
ret
=
[]
for
idx
,
p
in
enumerate
(
placehdrs
):
ret
.
append
(
tf
.
get_variable
(
'dummy-'
+
p
.
op
.
name
,
shape
=
self
.
shapes
[
idx
],
dtype
=
p
.
dtype
,
trainable
=
False
))
self
.
tensors
=
[]
# don't share variables
for
tower
in
range
(
nr_tower
):
tlist
=
[]
# TODO. keep device info in tower
with
tf
.
device
(
'/gpu:{}'
.
format
(
tower
)):
for
idx
,
p
in
enumerate
(
placehdrs
):
tlist
.
append
(
tf
.
get_variable
(
'dummy-{}-{}'
.
format
(
p
.
op
.
name
,
tower
),
shape
=
self
.
shapes
[
idx
],
dtype
=
p
.
dtype
,
trainable
=
False
))
self
.
tensors
.
append
(
tlist
)
def
get_input_tensors
(
self
):
# TODO XXX call with tower index
ret
=
self
.
tensors
[
self
.
_cnt
]
self
.
_cnt
+=
1
return
ret
...
...
@@ -318,12 +335,10 @@ class ZMQInput(FeedfreeInput):
class
StagingInputWrapper
(
FeedfreeInput
):
class
StagingCallback
(
Callback
):
def
__init__
(
self
,
stage_op
,
unstage_op
,
nr_stage
):
self
.
nr_stage
=
nr_stage
self
.
stage_op
=
stage_op
# TODO make sure both stage/unstage are run, to avoid OOM
self
.
fetches
=
tf
.
train
.
SessionRunArgs
(
fetches
=
[
stage_op
,
unstage_op
])
...
...
@@ -335,13 +350,15 @@ class StagingInputWrapper(FeedfreeInput):
def
_before_run
(
self
,
ctx
):
return
self
.
fetches
def
__init__
(
self
,
input
,
devices
):
def
__init__
(
self
,
input
,
devices
,
nr_stage
=
5
):
self
.
_input
=
input
assert
isinstance
(
input
,
FeedfreeInput
)
self
.
_devices
=
devices
self
.
_nr_stage
=
nr_stage
self
.
_areas
=
[]
self
.
_stage_ops
=
[]
self
.
_unstage_ops
=
[]
self
.
_cnt_unstage
=
0
def
setup
(
self
,
model
):
...
...
@@ -354,7 +371,7 @@ class StagingInputWrapper(FeedfreeInput):
trainer
.
register_callback
(
StagingInputWrapper
.
StagingCallback
(
self
.
get_stage_op
(),
self
.
get_unstage_op
(),
5
))
self
.
get_stage_op
(),
self
.
get_unstage_op
(),
self
.
_nr_stage
))
def
setup_staging_areas
(
self
):
for
idx
,
device
in
enumerate
(
self
.
_devices
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment