Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
175bc41a
Commit
175bc41a
authored
May 18, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
some updates to trainer
parent
de6d5502
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
28 additions
and
9 deletions
+28
-9
scripts/dump_model_params.py
scripts/dump_model_params.py
+1
-1
tensorpack/dataflow/dataset/atari.py
tensorpack/dataflow/dataset/atari.py
+5
-1
tensorpack/dataflow/dataset/rlenv.py
tensorpack/dataflow/dataset/rlenv.py
+13
-1
tensorpack/train/base.py
tensorpack/train/base.py
+6
-2
tensorpack/train/config.py
tensorpack/train/config.py
+2
-0
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-4
No files found.
scripts/dump_model_params.py
View file @
175bc41a
...
...
@@ -22,7 +22,7 @@ get_config_func = imp.load_source('config_script', args.config).get_config
with
tf
.
Graph
()
.
as_default
()
as
G
:
config
=
get_config_func
()
config
.
model
.
get_cost
(
config
.
model
.
get_input_vars
(),
is_training
=
False
)
config
.
model
.
build_graph
(
config
.
model
.
get_input_vars
(),
is_training
=
False
)
init
=
sessinit
.
SaverRestore
(
args
.
model
)
sess
=
tf
.
Session
()
init
.
init
(
sess
)
...
...
tensorpack/dataflow/dataset/atari.py
View file @
175bc41a
...
...
@@ -9,7 +9,7 @@ import os
import
cv2
from
collections
import
deque
from
...utils
import
get_rng
from
.
import
RLEnvironment
from
.
rlenv
import
RLEnvironment
__all__
=
[
'AtariDriver'
,
'AtariPlayer'
]
...
...
@@ -32,6 +32,8 @@ class AtariDriver(object):
self
.
width
,
self
.
height
=
self
.
ale
.
getScreenDims
()
self
.
actions
=
self
.
ale
.
getMinimalActionSet
()
if
isinstance
(
viz
,
int
):
viz
=
float
(
viz
)
self
.
viz
=
viz
self
.
romname
=
os
.
path
.
basename
(
rom_file
)
if
self
.
viz
and
isinstance
(
self
.
viz
,
float
):
...
...
@@ -64,6 +66,7 @@ class AtariDriver(object):
cv2
.
imwrite
(
"{}/{:06d}.jpg"
.
format
(
self
.
viz
,
self
.
framenum
),
ret
)
self
.
framenum
+=
1
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_BGR2YUV
)[:,:,
0
]
ret
=
ret
[
36
:
204
,:]
# several online repos all use this
return
ret
def
get_num_actions
(
self
):
...
...
@@ -109,6 +112,7 @@ class AtariPlayer(RLEnvironment):
"""
self
.
frames
.
clear
()
s
=
self
.
driver
.
grab_image
()
s
=
cv2
.
resize
(
s
,
self
.
image_shape
)
for
_
in
range
(
self
.
hist_len
):
self
.
frames
.
append
(
s
)
...
...
tensorpack/dataflow/dataset/rlenv.py
View file @
175bc41a
...
...
@@ -5,7 +5,7 @@
from
abc
import
abstractmethod
,
ABCMeta
__all__
=
[
'RLEnvironment'
]
__all__
=
[
'RLEnvironment'
,
'NaiveRLEnvironment'
]
class
RLEnvironment
(
object
):
__meta__
=
ABCMeta
...
...
@@ -23,3 +23,15 @@ class RLEnvironment(object):
:params act: the action
:returns: (reward, isOver)
"""
class
NaiveRLEnvironment
(
RLEnvironment
):
def
__init__
(
self
):
self
.
k
=
0
def
current_state
(
self
):
self
.
k
+=
1
return
self
.
k
def
action
(
self
,
act
):
self
.
k
=
act
return
(
self
.
k
,
self
.
k
>
10
)
tensorpack/train/base.py
View file @
175bc41a
...
...
@@ -36,6 +36,7 @@ class Trainer(object):
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
self
.
config
=
config
self
.
model
=
config
.
model
self
.
extra_threads_procs
=
config
.
extra_threads_procs
@
abstractmethod
def
train
(
self
):
...
...
@@ -84,7 +85,7 @@ class Trainer(object):
callbacks
.
setup_graph
(
self
)
self
.
config
.
session_init
.
init
(
self
.
sess
)
tf
.
get_default_graph
()
.
finalize
()
self
.
_start_
all_threads
()
self
.
_start_
concurrency
()
with
self
.
sess
.
as_default
():
try
:
...
...
@@ -121,12 +122,15 @@ class Trainer(object):
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
coord
=
tf
.
train
.
Coordinator
()
def
_start_
all_threads
(
self
):
def
_start_
concurrency
(
self
):
"""
Run all threads before starting training
"""
tf
.
train
.
start_queue_runners
(
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
for
k
in
self
.
extra_threads_procs
:
k
.
start
()
def
process_grads
(
self
,
grads
):
g
=
[]
...
...
tensorpack/train/config.py
View file @
175bc41a
...
...
@@ -32,6 +32,7 @@ class TrainConfig(object):
:param step_per_epoch: the number of steps (SGD updates) to perform in each epoch.
:param max_epoch: maximum number of epoch to run training. default to 100
:param nr_tower: int. number of towers. default to 1.
:param extra_threads_procs: list of `Startable` threads or processes
"""
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
...
...
@@ -53,5 +54,6 @@ class TrainConfig(object):
self
.
max_epoch
=
int
(
kwargs
.
pop
(
'max_epoch'
,
100
))
assert
self
.
step_per_epoch
>
0
and
self
.
max_epoch
>
0
self
.
nr_tower
=
int
(
kwargs
.
pop
(
'nr_tower'
,
1
))
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
tensorpack/train/trainer.py
View file @
175bc41a
...
...
@@ -209,12 +209,9 @@ class QueueInputTrainer(Trainer):
self
.
init_session_and_coord
()
# create a thread that keeps filling the queue
self
.
input_th
=
EnqueueThread
(
self
,
self
.
input_queue
,
enqueue_op
,
self
.
input_vars
)
self
.
extra_threads_procs
.
append
(
self
.
input_th
)
self
.
main_loop
()
def
_start_all_threads
(
self
):
super
(
QueueInputTrainer
,
self
)
.
_start_all_threads
()
self
.
input_th
.
start
()
def
run_step
(
self
):
if
self
.
async
:
if
not
self
.
async_running
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment