Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
890df78f
Commit
890df78f
authored
May 27, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
small fix in RandomApplyAug & A3C
parent
a0247332
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
22 additions
and
11 deletions
+22
-11
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+6
-6
examples/DeepQNetwork/common.py
examples/DeepQNetwork/common.py
+1
-1
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+4
-1
tensorpack/dataflow/imgaug/meta.py
tensorpack/dataflow/imgaug/meta.py
+8
-0
tensorpack/models/fc.py
tensorpack/models/fc.py
+1
-1
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+2
-2
No files found.
examples/A3C-Gym/train-atari.py
View file @
890df78f
...
...
@@ -80,12 +80,12 @@ class MySimulatorWorker(SimulatorProcess):
class
Model
(
ModelDesc
):
def
_get_inputs
(
self
):
assert
NUM_ACTIONS
is
not
None
return
[
InputDesc
(
tf
.
float32
,
(
None
,)
+
IMAGE_SHAPE3
,
'state'
),
return
[
InputDesc
(
tf
.
uint8
,
(
None
,)
+
IMAGE_SHAPE3
,
'state'
),
InputDesc
(
tf
.
int64
,
(
None
,),
'action'
),
InputDesc
(
tf
.
float32
,
(
None
,),
'futurereward'
)]
def
_get_NN_prediction
(
self
,
image
):
image
=
image
/
255.0
image
=
tf
.
cast
(
image
,
tf
.
float32
)
/
255.0
with
argscope
(
Conv2D
,
nl
=
tf
.
nn
.
relu
):
l
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
)
l
=
MaxPooling
(
'pool0'
,
l
,
2
)
...
...
@@ -220,7 +220,7 @@ def get_config():
dataflow
=
dataflow
,
callbacks
=
[
ModelSaver
(),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
8
0
,
0.0003
),
(
120
,
0.0001
)]),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
2
0
,
0.0003
),
(
120
,
0.0001
)]),
ScheduledHyperParamSetter
(
'entropy_beta'
,
[(
80
,
0.005
)]),
ScheduledHyperParamSetter
(
'explore_factor'
,
[(
80
,
2
),
(
100
,
3
),
(
120
,
4
),
(
140
,
5
)]),
...
...
@@ -230,7 +230,7 @@ def get_config():
StartProcOrThread
(
master
),
PeriodicTrigger
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'policy'
],
get_player
),
every_k_epochs
=
2
),
every_k_epochs
=
3
),
],
session_creator
=
sesscreate
.
NewSessionCreator
(
config
=
get_default_sess_config
(
0.5
)),
...
...
@@ -264,7 +264,7 @@ if __name__ == '__main__':
if
args
.
task
!=
'train'
:
cfg
=
PredictConfig
(
model
=
Model
(),
session_init
=
SaverRestore
(
args
.
load
),
session_init
=
get_model_loader
(
args
.
load
),
input_names
=
[
'state'
],
output_names
=
[
'policy'
])
if
args
.
task
==
'play'
:
...
...
@@ -296,7 +296,7 @@ if __name__ == '__main__':
trainer
=
QueueInputTrainer
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
get_model_loader
(
args
.
load
)
config
.
tower
=
train_tower
config
.
predict_tower
=
predict_tower
trainer
(
config
)
.
train
()
examples/DeepQNetwork/common.py
View file @
890df78f
...
...
@@ -120,4 +120,4 @@ def play_n_episodes(player, predfunc, nr):
if
k
!=
0
:
player
.
restart_episode
()
score
=
play_one_episode
(
player
,
predfunc
)
print
(
"
Score:"
,
score
)
print
(
"
{}/{}, score="
,
k
,
nr
,
score
)
tensorpack/dataflow/image.py
View file @
890df78f
...
...
@@ -109,7 +109,10 @@ class AugmentImageComponents(MapData):
to keep the original images not modified.
Turn it off to save time when you know it's OK.
"""
self
.
augs
=
AugmentorList
(
augmentors
)
if
isinstance
(
augmentors
,
AugmentorList
):
self
.
augs
=
augmentors
else
:
self
.
augs
=
AugmentorList
(
augmentors
)
self
.
ds
=
ds
self
.
_nr_error
=
0
...
...
tensorpack/dataflow/imgaug/meta.py
View file @
890df78f
...
...
@@ -38,6 +38,14 @@ class RandomApplyAug(ImageAugmentor):
else
:
return
(
False
,
None
)
def
_augment_return_params
(
self
,
img
):
p
=
self
.
rng
.
rand
()
if
p
<
self
.
prob
:
img
,
prms
=
self
.
aug
.
_augment_return_params
(
img
)
return
img
,
(
True
,
prms
)
else
:
return
img
,
(
False
,
None
)
def
reset_state
(
self
):
super
(
RandomApplyAug
,
self
)
.
reset_state
()
self
.
aug
.
reset_state
()
...
...
tensorpack/models/fc.py
View file @
890df78f
...
...
@@ -31,7 +31,7 @@ def FullyConnected(x, out_dim,
Variable Names:
* ``W``: weights
* ``W``: weights
of shape [in_dim, out_dim]
* ``b``: bias
"""
x
=
symbf
.
batch_flatten
(
x
)
...
...
tensorpack/predict/concurrency.py
View file @
890df78f
...
...
@@ -99,9 +99,9 @@ class PredictorWorkerThread(StoppableThread, ShareSessionThread):
# self.id, len(futures), self.queue.qsize())
# debug, for speed testing
# if not hasattr(self, 'xxx'):
#
self.xxx = outputs = self.func(batched)
#
self.xxx = outputs = self.func(batched)
# else:
#
outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
#
outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
for
idx
,
f
in
enumerate
(
futures
):
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment