Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
2cd41e99
Commit
2cd41e99
authored
Jul 25, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix towers
parent
adb45736
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
1 deletion
+11
-1
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+0
-1
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+1
-0
tensorpack/train/config.py
tensorpack/train/config.py
+10
-0
No files found.
examples/Atari2600/DQN.py
View file @
2cd41e99
...
@@ -111,7 +111,6 @@ class Model(ModelDesc):
...
@@ -111,7 +111,6 @@ class Model(ModelDesc):
predict_onehot
=
tf
.
one_hot
(
self
.
greedy_choice
,
NUM_ACTIONS
,
1.0
,
0.0
)
predict_onehot
=
tf
.
one_hot
(
self
.
greedy_choice
,
NUM_ACTIONS
,
1.0
,
0.0
)
best_v
=
tf
.
reduce_sum
(
targetQ_predict_value
*
predict_onehot
,
1
)
best_v
=
tf
.
reduce_sum
(
targetQ_predict_value
*
predict_onehot
,
1
)
target
=
reward
+
(
1.0
-
tf
.
cast
(
isOver
,
tf
.
float32
))
*
GAMMA
*
tf
.
stop_gradient
(
best_v
)
target
=
reward
+
(
1.0
-
tf
.
cast
(
isOver
,
tf
.
float32
))
*
GAMMA
*
tf
.
stop_gradient
(
best_v
)
sqrcost
=
tf
.
square
(
target
-
pred_action_value
)
sqrcost
=
tf
.
square
(
target
-
pred_action_value
)
...
...
tensorpack/tfutils/sessinit.py
View file @
2cd41e99
...
@@ -54,6 +54,7 @@ class SaverRestore(SessionInit):
...
@@ -54,6 +54,7 @@ class SaverRestore(SessionInit):
def
__init__
(
self
,
model_path
,
prefix
=
None
):
def
__init__
(
self
,
model_path
,
prefix
=
None
):
"""
"""
:param model_path: a model file or a ``checkpoint`` file.
:param model_path: a model file or a ``checkpoint`` file.
:param prefix: add a `prefix/` for every variable in this checkpoint
"""
"""
assert
os
.
path
.
isfile
(
model_path
)
assert
os
.
path
.
isfile
(
model_path
)
if
os
.
path
.
basename
(
model_path
)
==
'checkpoint'
:
if
os
.
path
.
basename
(
model_path
)
==
'checkpoint'
:
...
...
tensorpack/train/config.py
View file @
2cd41e99
...
@@ -57,6 +57,8 @@ class TrainConfig(object):
...
@@ -57,6 +57,8 @@ class TrainConfig(object):
if
'nr_tower'
in
kwargs
or
'tower'
in
kwargs
:
if
'nr_tower'
in
kwargs
or
'tower'
in
kwargs
:
self
.
set_tower
(
**
kwargs
)
self
.
set_tower
(
**
kwargs
)
else
:
self
.
tower
=
[
0
]
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
...
@@ -72,3 +74,11 @@ class TrainConfig(object):
...
@@ -72,3 +74,11 @@ class TrainConfig(object):
tower
=
list
(
range
(
tower
))
tower
=
list
(
range
(
tower
))
self
.
tower
=
tower
self
.
tower
=
tower
assert
isinstance
(
self
.
tower
,
list
)
assert
isinstance
(
self
.
tower
,
list
)
@
property
def
nr_tower
(
self
):
return
len
(
self
.
tower
)
@
nr_tower
.
setter
def
nr_tower
(
self
,
value
):
self
.
tower
=
list
(
range
(
value
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment