Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c280473d
Commit
c280473d
authored
May 03, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
InferenceRunner select tower from TrainConfig (#249)
parent
9f056711
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
7 deletions
+18
-7
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+14
-5
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+4
-2
No files found.
examples/A3C-Gym/train-atari.py
View file @
c280473d
...
@@ -6,18 +6,23 @@
...
@@ -6,18 +6,23 @@
import
numpy
as
np
import
numpy
as
np
import
os
import
os
import
sys
import
sys
import
re
import
time
import
time
import
random
import
random
import
uuid
import
uuid
import
argparse
import
argparse
import
multiprocessing
import
multiprocessing
import
threading
import
threading
import
cv2
import
cv2
from
collections
import
deque
import
tensorflow
as
tf
import
six
import
six
from
six.moves
import
queue
from
six.moves
import
queue
import
tensorflow
as
tf
if
six
.
PY3
:
from
concurrent
import
futures
# py3
CancelledError
=
futures
.
CancelledError
else
:
CancelledError
=
Exception
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.utils.concurrency
import
*
from
tensorpack.utils.concurrency
import
*
...
@@ -42,7 +47,7 @@ STEPS_PER_EPOCH = 6000
...
@@ -42,7 +47,7 @@ STEPS_PER_EPOCH = 6000
EVAL_EPISODE
=
50
EVAL_EPISODE
=
50
BATCH_SIZE
=
128
BATCH_SIZE
=
128
SIMULATOR_PROC
=
50
SIMULATOR_PROC
=
50
PREDICTOR_THREAD_PER_GPU
=
2
PREDICTOR_THREAD_PER_GPU
=
3
PREDICTOR_THREAD
=
None
PREDICTOR_THREAD
=
None
EVALUATE_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
20
)
EVALUATE_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
20
)
...
@@ -156,7 +161,11 @@ class MySimulatorMaster(SimulatorMaster, Callback):
...
@@ -156,7 +161,11 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def
_on_state
(
self
,
state
,
ident
):
def
_on_state
(
self
,
state
,
ident
):
def
cb
(
outputs
):
def
cb
(
outputs
):
distrib
,
value
=
outputs
.
result
()
try
:
distrib
,
value
=
outputs
.
result
()
except
CancelledError
:
logger
.
info
(
"Client {} cancelled."
.
format
(
ident
))
return
assert
np
.
all
(
np
.
isfinite
(
distrib
)),
distrib
assert
np
.
all
(
np
.
isfinite
(
distrib
)),
distrib
action
=
np
.
random
.
choice
(
len
(
distrib
),
p
=
distrib
)
action
=
np
.
random
.
choice
(
len
(
distrib
),
p
=
distrib
)
client
=
self
.
clients
[
ident
]
client
=
self
.
clients
[
ident
]
...
...
tensorpack/callbacks/inference_runner.py
View file @
c280473d
...
@@ -104,13 +104,15 @@ class InferenceRunnerBase(Callback):
...
@@ -104,13 +104,15 @@ class InferenceRunnerBase(Callback):
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_input_data
.
setup
(
self
.
trainer
.
model
)
self
.
_input_data
.
setup
(
self
.
trainer
.
model
)
self
.
_setup_input_names
()
self
.
_setup_input_names
()
# Use predict_tower in train config. either gpuid or -1
self
.
_predict_tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
in_tensors
=
self
.
_find_input_tensors
()
in_tensors
=
self
.
_find_input_tensors
()
assert
isinstance
(
in_tensors
,
list
),
in_tensors
assert
isinstance
(
in_tensors
,
list
),
in_tensors
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
def
fn
(
_
):
def
fn
(
_
):
self
.
trainer
.
model
.
build_graph
(
in_tensors
)
self
.
trainer
.
model
.
build_graph
(
in_tensors
)
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
0
)
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
self
.
_feed_tensors
=
self
.
_find_feed_tensors
()
self
.
_feed_tensors
=
self
.
_find_feed_tensors
()
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
...
@@ -122,7 +124,7 @@ class InferenceRunnerBase(Callback):
...
@@ -122,7 +124,7 @@ class InferenceRunnerBase(Callback):
def
_get_tensors_maybe_in_tower
(
self
,
names
):
def
_get_tensors_maybe_in_tower
(
self
,
names
):
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
trainer
.
model
.
get_inputs_desc
()])
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
trainer
.
model
.
get_inputs_desc
()])
get_tensor_fn
=
PredictorTowerBuilder
.
get_tensors_maybe_in_tower
get_tensor_fn
=
PredictorTowerBuilder
.
get_tensors_maybe_in_tower
return
get_tensor_fn
(
placeholder_names
,
names
,
0
,
prefix
=
self
.
_prefix
)
return
get_tensor_fn
(
placeholder_names
,
names
,
self
.
_predict_tower_id
,
prefix
=
self
.
_prefix
)
def
_find_input_tensors
(
self
):
def
_find_input_tensors
(
self
):
pass
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment