Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
27d73303
Commit
27d73303
authored
Mar 20, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix 'logits' naming in A3C (fix #197)
parent
7e2be137
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
15 deletions
+15
-15
examples/A3C-Gym/run-atari.py
examples/A3C-Gym/run-atari.py
+2
-2
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+10
-10
examples/GAN/GAN.py
examples/GAN/GAN.py
+3
-3
No files found.
examples/A3C-Gym/run-atari.py
View file @
27d73303
...
...
@@ -64,7 +64,7 @@ class Model(ModelDesc):
def
_build_graph
(
self
,
inputs
):
state
,
action
,
futurereward
=
inputs
policy
=
self
.
_get_NN_prediction
(
state
)
self
.
logits
=
tf
.
nn
.
softmax
(
policy
,
name
=
'logits
'
)
policy
=
tf
.
nn
.
softmax
(
policy
,
name
=
'policy
'
)
def
run_submission
(
cfg
,
output
,
nr
):
...
...
@@ -105,5 +105,5 @@ if __name__ == '__main__':
model
=
Model
(),
session_init
=
SaverRestore
(
args
.
load
),
input_names
=
[
'state'
],
output_names
=
[
'
logits
'
])
output_names
=
[
'
policy
'
])
run_submission
(
cfg
,
args
.
output
,
args
.
episode
)
examples/A3C-Gym/train-atari.py
View file @
27d73303
...
...
@@ -96,30 +96,30 @@ class Model(ModelDesc):
l
=
FullyConnected
(
'fc0'
,
l
,
512
,
nl
=
tf
.
identity
)
l
=
PReLU
(
'prelu'
,
l
)
policy
=
FullyConnected
(
'fc-pi'
,
l
,
out_dim
=
NUM_ACTIONS
,
nl
=
tf
.
identity
)
logits
=
FullyConnected
(
'fc-pi'
,
l
,
out_dim
=
NUM_ACTIONS
,
nl
=
tf
.
identity
)
# unnormalized policy
value
=
FullyConnected
(
'fc-v'
,
l
,
1
,
nl
=
tf
.
identity
)
return
policy
,
value
return
logits
,
value
def
_build_graph
(
self
,
inputs
):
state
,
action
,
futurereward
=
inputs
policy
,
self
.
value
=
self
.
_get_NN_prediction
(
state
)
logits
,
self
.
value
=
self
.
_get_NN_prediction
(
state
)
self
.
value
=
tf
.
squeeze
(
self
.
value
,
[
1
],
name
=
'pred_value'
)
# (B,)
self
.
logits
=
tf
.
nn
.
softmax
(
policy
,
name
=
'logits
'
)
self
.
policy
=
tf
.
nn
.
softmax
(
logits
,
name
=
'policy
'
)
expf
=
tf
.
get_variable
(
'explore_factor'
,
shape
=
[],
initializer
=
tf
.
constant_initializer
(
1
),
trainable
=
False
)
logitsT
=
tf
.
nn
.
softmax
(
policy
*
expf
,
name
=
'logitsT
'
)
policy_explore
=
tf
.
nn
.
softmax
(
logits
*
expf
,
name
=
'policy_explore
'
)
is_training
=
get_current_tower_context
()
.
is_training
if
not
is_training
:
return
log_probs
=
tf
.
log
(
self
.
logits
+
1e-6
)
log_probs
=
tf
.
log
(
self
.
policy
+
1e-6
)
log_pi_a_given_s
=
tf
.
reduce_sum
(
log_probs
*
tf
.
one_hot
(
action
,
NUM_ACTIONS
),
1
)
advantage
=
tf
.
subtract
(
tf
.
stop_gradient
(
self
.
value
),
futurereward
,
name
=
'advantage'
)
policy_loss
=
tf
.
reduce_sum
(
log_pi_a_given_s
*
advantage
,
name
=
'policy_loss'
)
xentropy_loss
=
tf
.
reduce_sum
(
self
.
logits
*
log_probs
,
name
=
'xentropy_loss'
)
self
.
policy
*
log_probs
,
name
=
'xentropy_loss'
)
value_loss
=
tf
.
nn
.
l2_loss
(
self
.
value
-
futurereward
,
name
=
'value_loss'
)
pred_reward
=
tf
.
reduce_mean
(
self
.
value
,
name
=
'predict_reward'
)
...
...
@@ -151,7 +151,7 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def
_setup_graph
(
self
):
self
.
async_predictor
=
MultiThreadAsyncPredictor
(
self
.
trainer
.
get_predictors
([
'state'
],
[
'
logitsT
'
,
'pred_value'
],
self
.
trainer
.
get_predictors
([
'state'
],
[
'
policy_explore
'
,
'pred_value'
],
PREDICTOR_THREAD
),
batch_size
=
15
)
def
_before_train
(
self
):
...
...
@@ -220,7 +220,7 @@ def get_config():
[(
80
,
2
),
(
100
,
3
),
(
120
,
4
),
(
140
,
5
)]),
master
,
StartProcOrThread
(
master
),
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'
logits
'
]),
2
),
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'
policy
'
]),
2
),
],
session_creator
=
sesscreate
.
NewSessionCreator
(
config
=
get_default_sess_config
(
0.5
)),
...
...
@@ -254,7 +254,7 @@ if __name__ == '__main__':
model
=
Model
(),
session_init
=
SaverRestore
(
args
.
load
),
input_names
=
[
'state'
],
output_names
=
[
'
logits
'
])
output_names
=
[
'
policy
'
])
if
args
.
task
==
'play'
:
play_model
(
cfg
)
elif
args
.
task
==
'eval'
:
...
...
examples/GAN/GAN.py
View file @
27d73303
...
...
@@ -43,15 +43,15 @@ class GANModelDesc(ModelDesc):
d_pos_acc
=
tf
.
reduce_mean
(
tf
.
cast
(
score_real
>
0.5
,
tf
.
float32
),
name
=
'accuracy_real'
)
d_neg_acc
=
tf
.
reduce_mean
(
tf
.
cast
(
score_fake
<
0.5
,
tf
.
float32
),
name
=
'accuracy_fake'
)
self
.
d_accuracy
=
tf
.
add
(
.5
*
d_pos_acc
,
.5
*
d_neg_acc
,
name
=
'accuracy'
)
d_accuracy
=
tf
.
add
(
.5
*
d_pos_acc
,
.5
*
d_neg_acc
,
name
=
'accuracy'
)
self
.
d_loss
=
tf
.
add
(
.5
*
d_loss_pos
,
.5
*
d_loss_neg
,
name
=
'loss'
)
with
tf
.
name_scope
(
"gen"
):
self
.
g_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
logits_fake
,
labels
=
tf
.
ones_like
(
logits_fake
)),
name
=
'loss'
)
self
.
g_accuracy
=
tf
.
reduce_mean
(
tf
.
cast
(
score_fake
>
0.5
,
tf
.
float32
),
name
=
'accuracy'
)
g_accuracy
=
tf
.
reduce_mean
(
tf
.
cast
(
score_fake
>
0.5
,
tf
.
float32
),
name
=
'accuracy'
)
add_moving_summary
(
self
.
g_loss
,
self
.
d_loss
,
self
.
d_accuracy
,
self
.
g_accuracy
)
add_moving_summary
(
self
.
g_loss
,
self
.
d_loss
,
d_accuracy
,
g_accuracy
)
class
GANTrainer
(
FeedfreeTrainerBase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment