Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0be066fe
Commit
0be066fe
authored
May 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
small fix in staging
parent
3f48ed30
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
12 deletions
+10
-12
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+5
-5
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+5
-7
No files found.
tensorpack/train/feedfree.py
View file @
0be066fe
...
@@ -24,8 +24,8 @@ class FeedfreeTrainerBase(Trainer):
...
@@ -24,8 +24,8 @@ class FeedfreeTrainerBase(Trainer):
Get input tensors from `self.input_method` and build the graph.
Get input tensors from `self.input_method` and build the graph.
"""
"""
def
f
():
def
f
():
input
s
=
self
.
_input_method
.
get_input_tensors
()
self
.
_input_tensor
s
=
self
.
_input_method
.
get_input_tensors
()
self
.
model
.
build_graph
(
input
s
)
self
.
model
.
build_graph
(
self
.
_input_tensor
s
)
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
if
ctx
is
None
:
if
ctx
is
None
:
with
TowerContext
(
''
):
with
TowerContext
(
''
):
...
@@ -98,7 +98,7 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
...
@@ -98,7 +98,7 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
cost
,
grads
=
self
.
_get_cost_and_grad
()
cost
,
grads
=
self
.
_get_cost_and_grad
()
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
name
=
'min_op'
)
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
name
=
'min_op'
)
# skip training
# skip training
# self.train_op = tf.group(*self.
dequed_input
s)
# self.train_op = tf.group(*self.
_input_tensor
s)
def
QueueInputTrainer
(
config
,
input_queue
=
None
,
predict_tower
=
None
):
def
QueueInputTrainer
(
config
,
input_queue
=
None
,
predict_tower
=
None
):
...
@@ -117,9 +117,9 @@ def QueueInputTrainer(config, input_queue=None, predict_tower=None):
...
@@ -117,9 +117,9 @@ def QueueInputTrainer(config, input_queue=None, predict_tower=None):
else
:
else
:
assert
isinstance
(
config
.
data
,
QueueInput
),
config
.
data
assert
isinstance
(
config
.
data
,
QueueInput
),
config
.
data
# from tensorpack.train.input_data import
QueueInput, FeedfreeInput,
StagingInputWrapper, DummyConstantInput
# from tensorpack.train.input_data import StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[
64,224,224,3], [64
]])
# config.data = DummyConstantInput([[
128,224,224,3], [128
]])
if
predict_tower
is
not
None
:
if
predict_tower
is
not
None
:
log_deprecated
(
"Argument `predict_tower` in trainer"
,
"Use TrainConfig(predict_tower=...) instead!"
)
log_deprecated
(
"Argument `predict_tower` in trainer"
,
"Use TrainConfig(predict_tower=...) instead!"
)
...
...
tensorpack/train/input_data.py
View file @
0be066fe
...
@@ -169,7 +169,6 @@ class QueueInput(FeedfreeInput):
...
@@ -169,7 +169,6 @@ class QueueInput(FeedfreeInput):
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
#ret[0]= tf.Print(ret[0], [tf.reduce_mean(ret[0])], "asdf")
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
...
@@ -326,7 +325,7 @@ class StagingInputWrapper(FeedfreeInput):
...
@@ -326,7 +325,7 @@ class StagingInputWrapper(FeedfreeInput):
self
.
stage_op
=
stage_op
self
.
stage_op
=
stage_op
# TODO make sure both stage/unstage are run, to avoid OOM
# TODO make sure both stage/unstage are run, to avoid OOM
self
.
fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
fetches
=
tf
.
train
.
SessionRunArgs
(
fetches
=
[
stage_op
])
fetches
=
[
stage_op
,
unstage_op
])
def
_before_train
(
self
):
def
_before_train
(
self
):
# pre-fill the staging area
# pre-fill the staging area
...
@@ -350,8 +349,8 @@ class StagingInputWrapper(FeedfreeInput):
...
@@ -350,8 +349,8 @@ class StagingInputWrapper(FeedfreeInput):
self
.
setup_staging_areas
()
self
.
setup_staging_areas
()
def
setup_training
(
self
,
trainer
):
def
setup_training
(
self
,
trainer
):
super
(
StagingInputWrapper
,
self
)
.
setup_training
(
trainer
)
self
.
_input
.
setup_training
(
trainer
)
self
.
_input
.
setup_training
(
trainer
)
self
.
setup_staging_areas
()
trainer
.
register_callback
(
trainer
.
register_callback
(
StagingInputWrapper
.
StagingCallback
(
StagingInputWrapper
.
StagingCallback
(
...
@@ -359,11 +358,10 @@ class StagingInputWrapper(FeedfreeInput):
...
@@ -359,11 +358,10 @@ class StagingInputWrapper(FeedfreeInput):
def
setup_staging_areas
(
self
):
def
setup_staging_areas
(
self
):
for
idx
,
device
in
enumerate
(
self
.
_devices
):
for
idx
,
device
in
enumerate
(
self
.
_devices
):
inputs
=
self
.
_input
.
get_input_tensors
()
dtypes
=
[
x
.
dtype
for
x
in
inputs
]
with
tf
.
device
(
device
):
with
tf
.
device
(
device
):
stage
=
StagingArea
(
inputs
=
self
.
_input
.
get_input_tensors
()
dtypes
,
shapes
=
None
)
dtypes
=
[
x
.
dtype
for
x
in
inputs
]
stage
=
StagingArea
(
dtypes
,
shapes
=
None
)
self
.
_stage_ops
.
append
(
stage
.
put
(
inputs
))
self
.
_stage_ops
.
append
(
stage
.
put
(
inputs
))
self
.
_areas
.
append
(
stage
)
self
.
_areas
.
append
(
stage
)
outputs
=
stage
.
get
()
outputs
=
stage
.
get
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment