Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b7ee409b
Commit
b7ee409b
authored
Jun 19, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
small changes in __main__
parent
bc551406
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
25 additions
and
32 deletions
+25
-32
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+10
-14
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+5
-8
examples/cifar-convnet.py
examples/cifar-convnet.py
+2
-5
examples/mnist-convnet.py
examples/mnist-convnet.py
+4
-4
examples/svhn-digit-convnet.py
examples/svhn-digit-convnet.py
+1
-1
tensorpack/RL/history.py
tensorpack/RL/history.py
+3
-0
No files found.
examples/A3C-Gym/train-atari.py
View file @
b7ee409b
...
@@ -60,10 +60,6 @@ ENV_NAME = None
...
@@ -60,10 +60,6 @@ ENV_NAME = None
def
get_player
(
viz
=
False
,
train
=
False
,
dumpdir
=
None
):
def
get_player
(
viz
=
False
,
train
=
False
,
dumpdir
=
None
):
pl
=
GymEnv
(
ENV_NAME
,
viz
=
viz
,
dumpdir
=
dumpdir
)
pl
=
GymEnv
(
ENV_NAME
,
viz
=
viz
,
dumpdir
=
dumpdir
)
pl
=
MapPlayerState
(
pl
,
lambda
img
:
cv2
.
resize
(
img
,
IMAGE_SIZE
[::
-
1
]))
pl
=
MapPlayerState
(
pl
,
lambda
img
:
cv2
.
resize
(
img
,
IMAGE_SIZE
[::
-
1
]))
global
NUM_ACTIONS
NUM_ACTIONS
=
pl
.
get_action_space
()
.
num_actions
()
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
if
not
train
:
if
not
train
:
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
...
@@ -201,8 +197,6 @@ class MySimulatorMaster(SimulatorMaster, Callback):
...
@@ -201,8 +197,6 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def
get_config
():
def
get_config
():
dirname
=
os
.
path
.
join
(
'train_log'
,
'train-atari-{}'
.
format
(
ENV_NAME
))
logger
.
set_logger_dir
(
dirname
)
M
=
Model
()
M
=
Model
()
name_base
=
str
(
uuid
.
uuid1
())[:
6
]
name_base
=
str
(
uuid
.
uuid1
())[:
6
]
...
@@ -251,17 +245,15 @@ if __name__ == '__main__':
...
@@ -251,17 +245,15 @@ if __name__ == '__main__':
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
ENV_NAME
=
args
.
env
ENV_NAME
=
args
.
env
assert
ENV_NAME
logger
.
info
(
"Environment Name: {}"
.
format
(
ENV_NAME
))
logger
.
info
(
"Environment Name: {}"
.
format
(
ENV_NAME
))
p
=
get_player
()
NUM_ACTIONS
=
get_player
()
.
get_action_space
()
.
num_actions
()
del
p
# set NUM_ACTIONS
logger
.
info
(
"Number of actions: {}"
.
format
(
NUM_ACTIONS
))
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
if
args
.
task
!=
'train'
:
assert
args
.
load
is
not
None
if
args
.
task
!=
'train'
:
if
args
.
task
!=
'train'
:
assert
args
.
load
is
not
None
cfg
=
PredictConfig
(
cfg
=
PredictConfig
(
model
=
Model
(),
model
=
Model
(),
session_init
=
get_model_loader
(
args
.
load
),
session_init
=
get_model_loader
(
args
.
load
),
...
@@ -277,7 +269,11 @@ if __name__ == '__main__':
...
@@ -277,7 +269,11 @@ if __name__ == '__main__':
OfflinePredictor
(
cfg
),
args
.
episode
)
OfflinePredictor
(
cfg
),
args
.
episode
)
# gym.upload(output, api_key='xxx')
# gym.upload(output, api_key='xxx')
else
:
else
:
dirname
=
os
.
path
.
join
(
'train_log'
,
'train-atari-{}'
.
format
(
ENV_NAME
))
logger
.
set_logger_dir
(
dirname
)
nr_gpu
=
get_nr_gpu
()
nr_gpu
=
get_nr_gpu
()
trainer
=
QueueInputTrainer
if
nr_gpu
>
0
:
if
nr_gpu
>
0
:
if
nr_gpu
>
1
:
if
nr_gpu
>
1
:
predict_tower
=
list
(
range
(
nr_gpu
))[
-
nr_gpu
//
2
:]
predict_tower
=
list
(
range
(
nr_gpu
))[
-
nr_gpu
//
2
:]
...
@@ -285,12 +281,12 @@ if __name__ == '__main__':
...
@@ -285,12 +281,12 @@ if __name__ == '__main__':
predict_tower
=
[
0
]
predict_tower
=
[
0
]
PREDICTOR_THREAD
=
len
(
predict_tower
)
*
PREDICTOR_THREAD_PER_GPU
PREDICTOR_THREAD
=
len
(
predict_tower
)
*
PREDICTOR_THREAD_PER_GPU
train_tower
=
list
(
range
(
nr_gpu
))[:
-
nr_gpu
//
2
]
or
[
0
]
train_tower
=
list
(
range
(
nr_gpu
))[:
-
nr_gpu
//
2
]
or
[
0
]
logger
.
info
(
"[BA3C] Train on gpu {} and infer on gpu {}"
.
format
(
logger
.
info
(
"[B
atch-
A3C] Train on gpu {} and infer on gpu {}"
.
format
(
','
.
join
(
map
(
str
,
train_tower
)),
','
.
join
(
map
(
str
,
predict_tower
))))
','
.
join
(
map
(
str
,
train_tower
)),
','
.
join
(
map
(
str
,
predict_tower
))))
trainer
=
AsyncMultiGPUTrainer
if
len
(
train_tower
)
>
1
:
trainer
=
AsyncMultiGPUTrainer
else
:
else
:
logger
.
warn
(
"Without GPU this model will never learn! CPU is only useful for debug."
)
logger
.
warn
(
"Without GPU this model will never learn! CPU is only useful for debug."
)
nr_gpu
=
0
PREDICTOR_THREAD
=
1
PREDICTOR_THREAD
=
1
predict_tower
,
train_tower
=
[
0
],
[
0
]
predict_tower
,
train_tower
=
[
0
],
[
0
]
trainer
=
QueueInputTrainer
trainer
=
QueueInputTrainer
...
...
examples/DeepQNetwork/DQN.py
View file @
b7ee409b
...
@@ -149,17 +149,14 @@ if __name__ == '__main__':
...
@@ -149,17 +149,14 @@ if __name__ == '__main__':
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
if
args
.
task
!=
'train'
:
assert
args
.
load
is
not
None
ROM_FILE
=
args
.
rom
ROM_FILE
=
args
.
rom
METHOD
=
args
.
algo
METHOD
=
args
.
algo
# set num_actions
# set num_actions
pl
=
AtariPlayer
(
ROM_FILE
,
viz
=
False
)
NUM_ACTIONS
=
AtariPlayer
(
ROM_FILE
)
.
get_action_space
()
.
num_actions
()
NUM_ACTIONS
=
pl
.
get_action_space
()
.
num_actions
()
logger
.
info
(
"ROM: {}, Num Actions: {}"
.
format
(
ROM_FILE
,
NUM_ACTIONS
))
del
pl
if
args
.
task
!=
'train'
:
if
args
.
task
!=
'train'
:
assert
args
.
load
is
not
None
cfg
=
PredictConfig
(
cfg
=
PredictConfig
(
model
=
Model
(),
model
=
Model
(),
session_init
=
get_model_loader
(
args
.
load
),
session_init
=
get_model_loader
(
args
.
load
),
...
@@ -171,8 +168,8 @@ if __name__ == '__main__':
...
@@ -171,8 +168,8 @@ if __name__ == '__main__':
eval_model_multithread
(
cfg
,
EVAL_EPISODE
,
get_player
)
eval_model_multithread
(
cfg
,
EVAL_EPISODE
,
get_player
)
else
:
else
:
logger
.
set_logger_dir
(
logger
.
set_logger_dir
(
'train_log/
DQN-{}'
.
format
(
os
.
path
.
join
(
'train_log'
,
'
DQN-{}'
.
format
(
os
.
path
.
basename
(
ROM_FILE
)
.
split
(
'.'
)[
0
]))
os
.
path
.
basename
(
ROM_FILE
)
.
split
(
'.'
)[
0
]))
)
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
...
...
examples/cifar-convnet.py
View file @
b7ee409b
...
@@ -112,8 +112,6 @@ def get_data(train_or_test, cifar_classnum):
...
@@ -112,8 +112,6 @@ def get_data(train_or_test, cifar_classnum):
def
get_config
(
cifar_classnum
):
def
get_config
(
cifar_classnum
):
logger
.
auto_set_dir
()
# prepare dataset
# prepare dataset
dataset_train
=
get_data
(
'train'
,
cifar_classnum
)
dataset_train
=
get_data
(
'train'
,
cifar_classnum
)
dataset_test
=
get_data
(
'test'
,
cifar_classnum
)
dataset_test
=
get_data
(
'test'
,
cifar_classnum
)
...
@@ -145,10 +143,9 @@ if __name__ == '__main__':
...
@@ -145,10 +143,9 @@ if __name__ == '__main__':
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
else
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
with
tf
.
Graph
()
.
as_default
():
with
tf
.
Graph
()
.
as_default
():
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
'cifar'
+
str
(
args
.
classnum
)))
config
=
get_config
(
args
.
classnum
)
config
=
get_config
(
args
.
classnum
)
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
...
@@ -156,7 +153,7 @@ if __name__ == '__main__':
...
@@ -156,7 +153,7 @@ if __name__ == '__main__':
if
args
.
gpu
:
if
args
.
gpu
:
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
nr_gpu
=
get_nr_gpu
()
nr_gpu
=
get_nr_gpu
()
if
nr_gpu
=
=
1
:
if
nr_gpu
<
=
1
:
QueueInputTrainer
(
config
)
.
train
()
QueueInputTrainer
(
config
)
.
train
()
else
:
else
:
SyncMultiGPUTrainer
(
config
)
.
train
()
SyncMultiGPUTrainer
(
config
)
.
train
()
examples/mnist-convnet.py
View file @
b7ee409b
...
@@ -102,9 +102,6 @@ def get_data():
...
@@ -102,9 +102,6 @@ def get_data():
def
get_config
():
def
get_config
():
# automatically setup the directory train_log/mnist-convnet for logging
logger
.
auto_set_dir
()
dataset_train
,
dataset_test
=
get_data
()
dataset_train
,
dataset_test
=
get_data
()
# How many iterations you want in each epoch.
# How many iterations you want in each epoch.
# This is the default value, don't actually need to set it in the config
# This is the default value, don't actually need to set it in the config
...
@@ -136,9 +133,12 @@ if __name__ == '__main__':
...
@@ -136,9 +133,12 @@ if __name__ == '__main__':
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
# automatically setup the directory train_log/mnist-convnet for logging
logger
.
auto_set_dir
()
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
# SimpleTrainer is slow, this is just a demo.
# SimpleTrainer is slow, this is just a demo.
SimpleTrainer
(
config
)
.
train
()
# You can use QueueInputTrainer instead
# You can use QueueInputTrainer instead
SimpleTrainer
(
config
)
.
train
()
examples/svhn-digit-convnet.py
View file @
b7ee409b
...
@@ -94,7 +94,6 @@ def get_data():
...
@@ -94,7 +94,6 @@ def get_data():
def
get_config
():
def
get_config
():
logger
.
auto_set_dir
()
data_train
,
data_test
=
get_data
()
data_train
,
data_test
=
get_data
()
return
TrainConfig
(
return
TrainConfig
(
...
@@ -120,6 +119,7 @@ if __name__ == '__main__':
...
@@ -120,6 +119,7 @@ if __name__ == '__main__':
else
:
else
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
logger
.
auto_set_dir
()
with
tf
.
Graph
()
.
as_default
():
with
tf
.
Graph
()
.
as_default
():
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
...
...
tensorpack/RL/history.py
View file @
b7ee409b
...
@@ -40,6 +40,9 @@ class HistoryBuffer(object):
...
@@ -40,6 +40,9 @@ class HistoryBuffer(object):
class
HistoryFramePlayer
(
ProxyPlayer
):
class
HistoryFramePlayer
(
ProxyPlayer
):
""" Include history frames in state, or use black images.
""" Include history frames in state, or use black images.
It assumes the underlying player will do auto-restart.
It assumes the underlying player will do auto-restart.
Map the original frames into (H, W, HIST x channels).
Oldest frames first.
"""
"""
def
__init__
(
self
,
player
,
hist_len
):
def
__init__
(
self
,
player
,
hist_len
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment