Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
49675590
Commit
49675590
authored
Feb 11, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix use of nr_tower (fix #1077)
parent
1097672b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
6 additions
and
7 deletions
+6
-7
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+2
-1
examples/boilerplate.py
examples/boilerplate.py
+0
-2
tensorpack/train/base.py
tensorpack/train/base.py
+3
-3
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+1
-1
No files found.
examples/A3C-Gym/train-atari.py
View file @
49675590
...
...
@@ -206,6 +206,7 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def
train
():
assert
tf
.
test
.
is_gpu_available
(),
"Training requires GPUs!"
dirname
=
os
.
path
.
join
(
'train_log'
,
'train-atari-{}'
.
format
(
ENV_NAME
))
logger
.
set_logger_dir
(
dirname
)
...
...
@@ -259,7 +260,7 @@ def train():
session_init
=
get_model_loader
(
args
.
load
)
if
args
.
load
else
None
,
max_epoch
=
1000
,
)
trainer
=
SimpleTrainer
()
if
config
.
nr_tower
==
1
else
AsyncMultiGPUTrainer
(
train_tower
)
trainer
=
SimpleTrainer
()
if
num_gpu
==
1
else
AsyncMultiGPUTrainer
(
train_tower
)
launch_train_with_config
(
config
,
trainer
)
...
...
examples/boilerplate.py
View file @
49675590
...
...
@@ -71,8 +71,6 @@ if __name__ == '__main__':
config
=
get_config
()
if
args
.
gpu
:
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
...
...
tensorpack/train/base.py
View file @
49675590
...
...
@@ -104,12 +104,12 @@ class Trainer(object):
The ``tf.Session`` object the trainer is using.
Available after :meth:`initialize()`.
Using ``trainer.sess.run`` to evaluate tensors that depend on the
inputs
can lead to
unexpected effect:
Using ``trainer.sess.run`` to evaluate tensors that depend on the
training
``InputSource`` may have
unexpected effect:
For example, if you use ``trainer.sess.run`` to evaluate a tensor that depends on the
inputs coming from a ``StagingArea``,
this
will take a datapoint from the ``StagingArea``, making the ``StagingArea`` empty, and as a result
it
will take a datapoint from the ``StagingArea``, making the ``StagingArea`` empty, and as a result
make the training hang.
"""
...
...
tensorpack/train/trainers.py
View file @
49675590
...
...
@@ -137,7 +137,7 @@ class AsyncMultiGPUTrainer(SingleCostTrainer):
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
if
len
(
self
.
devices
)
>
1
:
assert
isinstance
(
input
,
FeedfreeInput
),
input
tower_fn
=
self
.
_make_get_grad_fn
(
input
,
get_cost_fn
,
get_opt_fn
)
,
tower_fn
=
self
.
_make_get_grad_fn
(
input
,
get_cost_fn
,
get_opt_fn
)
grad_list
=
self
.
_builder
.
call_for_each_tower
(
tower_fn
)
self
.
train_op
=
self
.
_builder
.
build
(
grad_list
,
get_opt_fn
)
return
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment