Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
fd21c3b1
Commit
fd21c3b1
authored
Apr 24, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix async training late-binding bug
parent
b6a775f4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
3 deletions
+5
-3
docs/conf.py
docs/conf.py
+2
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+3
-2
No files found.
docs/conf.py
View file @
fd21c3b1
...
...
@@ -21,7 +21,8 @@ import os
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
'../'
))
import
mock
MOCK_MODULES
=
[
'numpy'
,
'scipy'
,
'tensorflow'
,
'scipy.misc'
,
'h5py'
,
'nltk'
,
'cv2'
]
MOCK_MODULES
=
[
'numpy'
,
'scipy'
,
'tensorflow'
,
'scipy.misc'
,
'h5py'
,
'nltk'
,
'cv2'
,
'scipy.io'
]
for
mod_name
in
MOCK_MODULES
:
sys
.
modules
[
mod_name
]
=
mock
.
Mock
()
...
...
tensorpack/train/trainer.py
View file @
fd21c3b1
...
...
@@ -6,6 +6,7 @@ import tensorflow as tf
import
threading
import
copy
import
re
import
functools
from
six.moves
import
zip
from
.base
import
Trainer
...
...
@@ -175,7 +176,7 @@ class QueueInputTrainer(Trainer):
else
:
grad_list
=
[
self
.
process_grads
(
g
)
for
g
in
grad_list
]
# pretend to average the grads, in order to make async and
# sync have consistent
semantics
# sync have consistent
effective learning rate
def
scale
(
grads
):
return
[(
grad
/
self
.
config
.
nr_tower
,
var
)
for
grad
,
var
in
grads
]
grad_list
=
map
(
scale
,
grad_list
)
...
...
@@ -192,7 +193,7 @@ class QueueInputTrainer(Trainer):
self
.
threads
=
[]
for
k
in
range
(
1
,
self
.
config
.
nr_tower
):
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
f
=
lambda
:
self
.
sess
.
run
([
train_op
])
f
=
lambda
op
=
train_op
:
self
.
sess
.
run
([
op
])
# avoid late-binding
th
=
LoopThread
(
f
)
th
.
pause
()
th
.
start
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment