Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
9227aa8e
Commit
9227aa8e
authored
Feb 11, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix typo
parent
d46d3926
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
4 deletions
+16
-4
tensorpack/callbacks/prof.py
tensorpack/callbacks/prof.py
+1
-1
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+15
-3
No files found.
tensorpack/callbacks/prof.py
View file @
9227aa8e
...
...
@@ -68,7 +68,7 @@ class GPUUtilizationTracker(Callback):
self
.
_evt
.
set
()
stats
=
self
.
_queue
.
get
()
for
idx
,
dev
in
enumerate
(
self
.
_devices
):
self
.
trainer
.
monitors
.
put_scalar
(
'GPUUtil/{
:.2f
}'
.
format
(
dev
),
stats
[
idx
])
self
.
trainer
.
monitors
.
put_scalar
(
'GPUUtil/{}'
.
format
(
dev
),
stats
[
idx
])
def
_after_train
(
self
):
self
.
_stop_evt
.
set
()
...
...
tensorpack/input_source/input_source.py
View file @
9227aa8e
...
...
@@ -9,6 +9,7 @@ try:
except
ImportError
:
pass
from
contextlib
import
contextmanager
from
itertools
import
chain
from
six.moves
import
range
,
zip
import
threading
...
...
@@ -503,14 +504,15 @@ class StagingInput(FeedfreeInput):
self
.
_prefill
()
return
self
.
fetches
def
__init__
(
self
,
input
,
towers
=
None
,
nr_stage
=
1
):
def
__init__
(
self
,
input
,
towers
=
None
,
nr_stage
=
1
,
device
=
None
):
"""
Args:
input (FeedfreeInput):
nr_stage: number of elements to prefetch
on each GPU
.
nr_stage: number of elements to prefetch
into each StagingArea, at the beginning
.
Since enqueue and dequeue are synchronized, prefetching 1
element should be sufficient.
towers: deprecated
device (str or None): if not None, place the StagingArea on a specific device. e.g., '/cpu:0'.
"""
assert
isinstance
(
input
,
FeedfreeInput
),
input
self
.
_input
=
input
...
...
@@ -521,6 +523,7 @@ class StagingInput(FeedfreeInput):
self
.
_areas
=
[]
self
.
_stage_ops
=
[]
self
.
_unstage_ops
=
[]
self
.
_device
=
device
def
_setup
(
self
,
inputs
):
self
.
_input
.
setup
(
inputs
)
...
...
@@ -530,6 +533,7 @@ class StagingInput(FeedfreeInput):
def
_get_callbacks
(
self
):
cbs
=
self
.
_input
.
get_callbacks
()
# this callback has to happen after others, so StagingInput can be stacked together
cbs
.
append
(
StagingInput
.
StagingCallback
(
self
,
self
.
_nr_stage
))
return
cbs
...
...
@@ -537,8 +541,16 @@ class StagingInput(FeedfreeInput):
def
_size
(
self
):
return
self
.
_input
.
size
()
@
contextmanager
def
_device_ctx
(
self
):
if
not
self
.
_device
:
yield
else
:
with
tf
.
device
(
self
.
_device
):
yield
def
_get_input_tensors
(
self
):
with
self
.
cached_name_scope
():
with
self
.
cached_name_scope
()
,
self
.
_device_ctx
()
:
inputs
=
self
.
_input
.
get_input_tensors
()
# Putting variables to stagingarea will cause trouble
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment