Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
38d26977
Commit
38d26977
authored
May 29, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
backward-compat with tf1.0
parent
3898d354
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
1 deletion
+9
-1
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+7
-1
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+2
-0
No files found.
tensorpack/tfutils/common.py
View file @
38d26977
...
@@ -59,7 +59,13 @@ def get_global_step_var():
...
@@ -59,7 +59,13 @@ def get_global_step_var():
"The global_step variable should be created under the root variable scope!"
"The global_step variable should be created under the root variable scope!"
assert
not
scope
.
reuse
,
\
assert
not
scope
.
reuse
,
\
"The global_step variable shouldn't be called under a reuse variable scope!"
"The global_step variable shouldn't be called under a reuse variable scope!"
var
=
training_util
.
get_or_create_global_step
()
if
get_tf_version_number
()
<=
1.0
:
var
=
tf
.
get_variable
(
'global_step'
,
initializer
=
tf
.
constant
(
0
,
dtype
=
tf
.
int64
),
trainable
=
False
,
dtype
=
tf
.
int64
)
tf
.
add_to_collection
(
tf
.
GraphKeys
.
GLOBAL_STEP
,
var
)
else
:
var
=
training_util
.
get_or_create_global_step
()
return
var
return
var
...
...
tensorpack/train/feedfree.py
View file @
38d26977
...
@@ -64,12 +64,14 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
...
@@ -64,12 +64,14 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" get the cost and gradient"""
""" get the cost and gradient"""
self
.
build_train_tower
()
self
.
build_train_tower
()
cost
=
self
.
model
.
get_cost
()
# assume single cost
cost
=
self
.
model
.
get_cost
()
# assume single cost
# opt may be created under first-tower variable scope (which is '')
opt
=
self
.
model
.
get_optimizer
()
opt
=
self
.
model
.
get_optimizer
()
# GATE_NONE faster?
# GATE_NONE faster?
varlist
=
tf
.
trainable_variables
()
varlist
=
tf
.
trainable_variables
()
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
if
ctx
is
not
None
and
ctx
.
has_own_variables
and
ctx
.
vs_name
:
if
ctx
is
not
None
and
ctx
.
has_own_variables
and
ctx
.
vs_name
:
# only optimize w.r.t vars in this tower
# only optimize w.r.t vars in this tower
# TODO assumption on the first-tower empty variable scope
varlist
=
[
v
for
v
in
varlist
if
v
.
op
.
name
.
startswith
(
ctx
.
vs_name
+
'/'
)]
varlist
=
[
v
for
v
in
varlist
if
v
.
op
.
name
.
startswith
(
ctx
.
vs_name
+
'/'
)]
grads
=
opt
.
compute_gradients
(
grads
=
opt
.
compute_gradients
(
cost
,
cost
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment