Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
adb45736
Commit
adb45736
authored
Jul 25, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix trainconfig
parent
acd7f798
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
6 deletions
+10
-6
tensorpack/train/config.py
tensorpack/train/config.py
+8
-4
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+2
-2
No files found.
tensorpack/train/config.py
View file @
adb45736
...
...
@@ -55,6 +55,13 @@ class TrainConfig(object):
self
.
max_epoch
=
int
(
kwargs
.
pop
(
'max_epoch'
,
99999
))
assert
self
.
step_per_epoch
>
0
and
self
.
max_epoch
>
0
if
'nr_tower'
in
kwargs
or
'tower'
in
kwargs
:
self
.
set_tower
(
**
kwargs
)
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
set_tower
(
self
,
**
kwargs
):
nr_tower
=
kwargs
.
pop
(
'nr_tower'
,
None
)
tower
=
kwargs
.
pop
(
'tower'
,
None
)
assert
nr_tower
is
None
or
tower
is
None
,
"Cannot set both nr_tower and tower!"
...
...
@@ -64,7 +71,4 @@ class TrainConfig(object):
if
isinstance
(
tower
,
int
):
tower
=
list
(
range
(
tower
))
self
.
tower
=
tower
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
assert
isinstance
(
self
.
tower
,
list
)
tensorpack/train/multigpu.py
View file @
adb45736
...
...
@@ -93,7 +93,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
# sync have consistent effective learning rate
def
scale
(
grads
):
with
tf
.
name_scope
(
'async_scale_grad'
):
return
[(
grad
/
self
.
config
.
nr_tower
if
grad
is
not
None
else
None
,
var
)
return
[(
grad
/
len
(
self
.
config
.
tower
)
if
grad
is
not
None
else
None
,
var
)
for
grad
,
var
in
grads
]
grad_list
=
map
(
scale
,
grad_list
)
grad_list
=
[
self
.
process_grads
(
g
)
for
g
in
grad_list
]
...
...
@@ -113,7 +113,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
# itertools.count is atomic w.r.t. python threads
self
.
async_step_counter
=
itertools
.
count
()
self
.
training_threads
=
[]
for
k
in
range
(
1
,
self
.
config
.
nr_tower
):
for
k
in
range
(
1
,
len
(
self
.
config
.
tower
)
):
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
def
f
(
op
=
train_op
):
# avoid late-binding
self
.
sess
.
run
([
op
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment