Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8abdaf77
Commit
8abdaf77
authored
Dec 08, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
bug fix
parent
e592271b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
2 deletions
+6
-2
examples/OpenAIGym/train-atari.py
examples/OpenAIGym/train-atari.py
+3
-1
tensorpack/train/base.py
tensorpack/train/base.py
+1
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+2
-0
No files found.
examples/OpenAIGym/train-atari.py
View file @
8abdaf77
...
@@ -246,14 +246,16 @@ if __name__ == '__main__':
...
@@ -246,14 +246,16 @@ if __name__ == '__main__':
train_tower
=
range
(
nr_gpu
)[:
-
nr_gpu
//
2
]
or
[
0
]
train_tower
=
range
(
nr_gpu
)[:
-
nr_gpu
//
2
]
or
[
0
]
logger
.
info
(
"[BA3C] Train on gpu {} and infer on gpu {}"
.
format
(
logger
.
info
(
"[BA3C] Train on gpu {} and infer on gpu {}"
.
format
(
','
.
join
(
map
(
str
,
train_tower
)),
','
.
join
(
map
(
str
,
predict_tower
))))
','
.
join
(
map
(
str
,
train_tower
)),
','
.
join
(
map
(
str
,
predict_tower
))))
trainer
=
AsyncMultiGPUTrainer
else
:
else
:
logger
.
warn
(
"Without GPU this model will never learn! CPU is only useful for debug."
)
logger
.
warn
(
"Without GPU this model will never learn! CPU is only useful for debug."
)
nr_gpu
=
0
nr_gpu
=
0
PREDICTOR_THREAD
=
1
PREDICTOR_THREAD
=
1
predict_tower
=
[
0
]
predict_tower
=
[
0
]
train_tower
=
[
0
]
train_tower
=
[
0
]
trainer
=
QueueInputTrainer
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
tower
=
train_tower
config
.
tower
=
train_tower
AsyncMultiGPUT
rainer
(
config
,
predict_tower
=
predict_tower
)
.
train
()
t
rainer
(
config
,
predict_tower
=
predict_tower
)
.
train
()
tensorpack/train/base.py
View file @
8abdaf77
...
@@ -93,7 +93,7 @@ class Trainer(object):
...
@@ -93,7 +93,7 @@ class Trainer(object):
val
.
tag
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
val
.
tag
)
# TODO move to subclasses
val
.
tag
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
val
.
tag
)
# TODO move to subclasses
suffix
=
'-summary'
# issue#6150
suffix
=
'-summary'
# issue#6150
if
val
.
tag
.
endswith
(
suffix
):
if
val
.
tag
.
endswith
(
suffix
):
val
.
tag
=
va
.
tag
[:
-
len
(
suffix
)]
val
.
tag
=
va
l
.
tag
[:
-
len
(
suffix
)]
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
summary_writer
.
add_summary
(
summary
,
get_global_step
())
self
.
summary_writer
.
add_summary
(
summary
,
get_global_step
())
...
...
tensorpack/train/multigpu.py
View file @
8abdaf77
...
@@ -56,6 +56,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -56,6 +56,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_setup_predictor_factory
(
predict_tower
)
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
tf
.
test
.
is_gpu_available
()
@
staticmethod
@
staticmethod
...
@@ -110,6 +111,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -110,6 +111,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_average_gradient
=
average_gradient
self
.
_average_gradient
=
average_gradient
assert
tf
.
test
.
is_gpu_available
()
def
_setup
(
self
):
def
_setup
(
self
):
super
(
AsyncMultiGPUTrainer
,
self
)
.
_setup
()
super
(
AsyncMultiGPUTrainer
,
self
)
.
_setup
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment