Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
21a6984c
Commit
21a6984c
authored
Apr 16, 2020
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[a3c] specify dir to save train logs
parent
963e5100
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+5
-5
No files found.
examples/A3C-Gym/train-atari.py
View file @
21a6984c
...
...
@@ -215,10 +215,6 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def
train
():
assert
tf
.
test
.
is_gpu_available
(),
"Training requires GPUs!"
dirname
=
os
.
path
.
join
(
'train_log'
,
'train-atari-{}'
.
format
(
ENV_NAME
))
logger
.
set_logger_dir
(
dirname
)
# assign GPUs for training & inference
num_gpu
=
get_num_gpu
()
global
PREDICTOR_THREAD
...
...
@@ -275,9 +271,11 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--env'
,
help
=
'env'
,
required
=
True
)
parser
.
add_argument
(
'--task'
,
help
=
'task to perform'
,
choices
=
[
'play'
,
'eval'
,
'train'
,
'dump_video'
],
default
=
'train'
)
parser
.
add_argument
(
'--output'
,
help
=
'output directory for
submission'
,
default
=
'output_dir
'
)
parser
.
add_argument
(
'--output'
,
help
=
'output directory for
logs and videos
'
)
parser
.
add_argument
(
'--episode'
,
help
=
'number of episode to eval'
,
default
=
100
,
type
=
int
)
args
=
parser
.
parse_args
()
if
args
.
output
is
None
:
args
.
output
=
os
.
path
.
join
(
'train_log'
,
'train-atari-{}'
.
format
(
args
.
env
))
ENV_NAME
=
args
.
env
NUM_ACTIONS
=
get_player
()
.
action_space
.
n
...
...
@@ -303,4 +301,6 @@ if __name__ == '__main__':
get_player
(
train
=
False
,
dumpdir
=
args
.
output
),
pred
,
args
.
episode
)
else
:
assert
tf
.
test
.
is_gpu_available
(),
"Training requires GPUs!"
logger
.
set_logger_dir
(
args
.
output
)
train
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment