Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d7a85f44
Commit
d7a85f44
authored
May 27, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
misc update on framework
parent
11c46a71
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
172 additions
and
124 deletions
+172
-124
.gitignore
.gitignore
+2
-0
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+47
-78
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+38
-0
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+80
-44
tensorpack/dataflow/RL.py
tensorpack/dataflow/RL.py
+0
-0
tensorpack/dataflow/dataset/.gitignore
tensorpack/dataflow/dataset/.gitignore
+1
-0
tensorpack/train/config.py
tensorpack/train/config.py
+2
-2
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+2
-0
No files found.
.gitignore
View file @
d7a85f44
...
@@ -60,3 +60,5 @@ docs/_build/
...
@@ -60,3 +60,5 @@ docs/_build/
# PyBuilder
# PyBuilder
target/
target/
*.dat
*.dat
*.bin
examples/Atari2600/DQN.py
View file @
d7a85f44
...
@@ -16,7 +16,7 @@ from collections import deque
...
@@ -16,7 +16,7 @@ from collections import deque
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.models
import
*
from
tensorpack.models
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils.concurrency
import
ensure_proc_terminate
from
tensorpack.utils.concurrency
import
ensure_proc_terminate
,
subproc_call
from
tensorpack.utils.stat
import
*
from
tensorpack.utils.stat
import
*
from
tensorpack.predict
import
PredictConfig
,
get_predict_func
,
ParallelPredictWorker
from
tensorpack.predict
import
PredictConfig
,
get_predict_func
,
ParallelPredictWorker
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
...
@@ -33,7 +33,6 @@ for atari games
...
@@ -33,7 +33,6 @@ for atari games
BATCH_SIZE
=
32
BATCH_SIZE
=
32
IMAGE_SIZE
=
84
IMAGE_SIZE
=
84
NUM_ACTIONS
=
None
FRAME_HISTORY
=
4
FRAME_HISTORY
=
4
ACTION_REPEAT
=
4
ACTION_REPEAT
=
4
HEIGHT_RANGE
=
(
36
,
204
)
# for breakout
HEIGHT_RANGE
=
(
36
,
204
)
# for breakout
...
@@ -49,6 +48,15 @@ INIT_MEMORY_SIZE = 50000
...
@@ -49,6 +48,15 @@ INIT_MEMORY_SIZE = 50000
STEP_PER_EPOCH
=
10000
STEP_PER_EPOCH
=
10000
EVAL_EPISODE
=
100
EVAL_EPISODE
=
100
NUM_ACTIONS
=
None
ROM_FILE
=
None
def
get_player
(
viz
=
False
):
pl
=
AtariPlayer
(
ROM_FILE
,
viz
=
viz
,
height_range
=
HEIGHT_RANGE
,
frame_skip
=
ACTION_REPEAT
)
global
NUM_ACTIONS
NUM_ACTIONS
=
pl
.
get_num_actions
()
return
pl
class
Model
(
ModelDesc
):
class
Model
(
ModelDesc
):
def
_get_input_vars
(
self
):
def
_get_input_vars
(
self
):
assert
NUM_ACTIONS
is
not
None
assert
NUM_ACTIONS
is
not
None
...
@@ -56,8 +64,7 @@ class Model(ModelDesc):
...
@@ -56,8 +64,7 @@ class Model(ModelDesc):
InputVar
(
tf
.
int32
,
(
None
,),
'action'
),
InputVar
(
tf
.
int32
,
(
None
,),
'action'
),
InputVar
(
tf
.
float32
,
(
None
,),
'reward'
),
InputVar
(
tf
.
float32
,
(
None
,),
'reward'
),
InputVar
(
tf
.
float32
,
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
,
FRAME_HISTORY
),
'next_state'
),
InputVar
(
tf
.
float32
,
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
,
FRAME_HISTORY
),
'next_state'
),
InputVar
(
tf
.
bool
,
(
None
,),
'isOver'
)
InputVar
(
tf
.
bool
,
(
None
,),
'isOver'
)
]
]
def
_get_DQN_prediction
(
self
,
image
,
is_training
):
def
_get_DQN_prediction
(
self
,
image
,
is_training
):
""" image: [0,255]"""
""" image: [0,255]"""
...
@@ -89,7 +96,7 @@ class Model(ModelDesc):
...
@@ -89,7 +96,7 @@ class Model(ModelDesc):
with
tf
.
variable_scope
(
'target'
):
with
tf
.
variable_scope
(
'target'
):
targetQ_predict_value
=
tf
.
stop_gradient
(
targetQ_predict_value
=
tf
.
stop_gradient
(
self
.
_get_DQN_prediction
(
next_state
,
False
))
# NxA
self
.
_get_DQN_prediction
(
next_state
,
False
))
# NxA
target
=
reward
+
(
1
-
tf
.
cast
(
isOver
,
tf
.
int32
))
*
target
=
reward
+
(
1
.0
-
tf
.
cast
(
isOver
,
tf
.
float32
))
*
\
GAMMA
*
tf
.
reduce_max
(
targetQ_predict_value
,
1
)
# Nx1
GAMMA
*
tf
.
reduce_max
(
targetQ_predict_value
,
1
)
# Nx1
sqrcost
=
tf
.
square
(
target
-
pred_action_value
)
sqrcost
=
tf
.
square
(
target
-
pred_action_value
)
...
@@ -108,7 +115,7 @@ class Model(ModelDesc):
...
@@ -108,7 +115,7 @@ class Model(ModelDesc):
new_name
=
target_name
.
replace
(
'target/'
,
''
)
new_name
=
target_name
.
replace
(
'target/'
,
''
)
logger
.
info
(
"{} <- {}"
.
format
(
target_name
,
new_name
))
logger
.
info
(
"{} <- {}"
.
format
(
target_name
,
new_name
))
ops
.
append
(
v
.
assign
(
tf
.
get_default_graph
()
.
get_tensor_by_name
(
new_name
+
':0'
)))
ops
.
append
(
v
.
assign
(
tf
.
get_default_graph
()
.
get_tensor_by_name
(
new_name
+
':0'
)))
return
tf
.
group
(
*
ops
)
return
tf
.
group
(
*
ops
,
name
=
'update_target_network'
)
def
get_gradient_processor
(
self
):
def
get_gradient_processor
(
self
):
return
[
MapGradient
(
lambda
grad
:
\
return
[
MapGradient
(
lambda
grad
:
\
...
@@ -120,28 +127,11 @@ def current_predictor(state):
...
@@ -120,28 +127,11 @@ def current_predictor(state):
pred
=
pred_var
.
eval
(
feed_dict
=
{
'state:0'
:
[
state
]})
pred
=
pred_var
.
eval
(
feed_dict
=
{
'state:0'
:
[
state
]})
return
pred
[
0
]
return
pred
[
0
]
class
TargetNetworkUpdator
(
Callback
):
def
__init__
(
self
,
M
):
self
.
M
=
M
def
_setup_graph
(
self
):
self
.
update_op
=
self
.
M
.
update_target_param
()
def
_update
(
self
):
logger
.
info
(
"Delayed Predictor updating..."
)
self
.
update_op
.
run
()
def
_before_train
(
self
):
self
.
_update
()
def
_trigger_epoch
(
self
):
self
.
_update
()
def
play_one_episode
(
player
,
func
,
verbose
=
False
):
def
play_one_episode
(
player
,
func
,
verbose
=
False
):
tot_reward
=
0
tot_reward
=
0
que
=
deque
(
maxlen
=
30
)
que
=
deque
(
maxlen
=
30
)
while
True
:
while
True
:
s
=
player
.
current_state
()
# XXX
s
=
player
.
current_state
()
outputs
=
func
([[
s
]])
outputs
=
func
([[
s
]])
action_value
=
outputs
[
0
][
0
]
action_value
=
outputs
[
0
][
0
]
act
=
action_value
.
argmax
()
act
=
action_value
.
argmax
()
...
@@ -160,16 +150,10 @@ def play_one_episode(player, func, verbose=False):
...
@@ -160,16 +150,10 @@ def play_one_episode(player, func, verbose=False):
if
isOver
:
if
isOver
:
return
tot_reward
return
tot_reward
def
play_model
(
model_path
,
romfile
):
def
play_model
(
model_path
):
player
=
HistoryFramePlayer
(
AtariPlayer
(
player
=
HistoryFramePlayer
(
get_player
(
0.01
),
FRAME_HISTORY
)
romfile
,
viz
=
0.01
,
height_range
=
HEIGHT_RANGE
,
frame_skip
=
ACTION_REPEAT
),
FRAME_HISTORY
)
global
NUM_ACTIONS
NUM_ACTIONS
=
player
.
player
.
get_num_actions
()
M
=
Model
()
cfg
=
PredictConfig
(
cfg
=
PredictConfig
(
model
=
M
,
model
=
M
odel
()
,
input_data_mapping
=
[
0
],
input_data_mapping
=
[
0
],
session_init
=
SaverRestore
(
model_path
),
session_init
=
SaverRestore
(
model_path
),
output_var_names
=
[
'fct/output:0'
])
output_var_names
=
[
'fct/output:0'
])
...
@@ -178,7 +162,7 @@ def play_model(model_path, romfile):
...
@@ -178,7 +162,7 @@ def play_model(model_path, romfile):
score
=
play_one_episode
(
player
,
predfunc
)
score
=
play_one_episode
(
player
,
predfunc
)
print
(
"Total:"
,
score
)
print
(
"Total:"
,
score
)
def
eval_model_multiprocess
(
model_path
,
romfile
):
def
eval_model_multiprocess
(
model_path
):
M
=
Model
()
M
=
Model
()
cfg
=
PredictConfig
(
cfg
=
PredictConfig
(
model
=
M
,
model
=
M
,
...
@@ -192,11 +176,7 @@ def eval_model_multiprocess(model_path, romfile):
...
@@ -192,11 +176,7 @@ def eval_model_multiprocess(model_path, romfile):
self
.
outq
=
outqueue
self
.
outq
=
outqueue
def
run
(
self
):
def
run
(
self
):
player
=
HistoryFramePlayer
(
AtariPlayer
(
player
=
HistoryFramePlayer
(
get_player
(),
FRAME_HISTORY
)
romfile
,
viz
=
0
,
height_range
=
HEIGHT_RANGE
,
frame_skip
=
ACTION_REPEAT
),
FRAME_HISTORY
)
global
NUM_ACTIONS
NUM_ACTIONS
=
player
.
player
.
get_num_actions
()
self
.
_init_runtime
()
self
.
_init_runtime
()
while
True
:
while
True
:
score
=
play_one_episode
(
player
,
self
.
func
)
score
=
play_one_episode
(
player
,
self
.
func
)
...
@@ -216,32 +196,32 @@ def eval_model_multiprocess(model_path, romfile):
...
@@ -216,32 +196,32 @@ def eval_model_multiprocess(model_path, romfile):
r
=
q
.
get
()
r
=
q
.
get
()
stat
.
feed
(
r
)
stat
.
feed
(
r
)
finally
:
finally
:
for
p
in
procs
:
p
.
terminate
()
p
.
join
()
if
stat
.
count
()
>
0
:
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
stat
.
average
,
stat
.
max
))
stat
.
average
,
stat
.
max
))
return
(
stat
.
average
,
stat
.
max
)
else
:
return
(
0
,
0
)
class
Evaluator
(
Callback
):
def
_trigger_epoch
(
self
):
logger
.
info
(
"Evaluating..."
)
output
=
subproc_call
(
"CUDA_VISIBLE_DEVICES= {} --task eval --rom {} --load {}"
.
format
(
sys
.
argv
[
0
],
romfile
,
os
.
path
.
join
(
logger
.
LOG_DIR
,
'checkpoint'
)),
timeout
=
10
*
60
)
if
output
:
last
=
output
.
strip
()
.
split
(
'
\n
'
)[
-
1
]
last
=
last
[
last
.
find
(
']'
)
+
1
:]
mean
,
maximum
=
re
.
findall
(
'[0-9
\
.
\
-]+'
,
last
)[
-
2
:]
self
.
trainer
.
write_scalar_summary
(
'mean_score'
,
mean
)
self
.
trainer
.
write_scalar_summary
(
'max_score'
,
maximum
)
def
get_config
(
romfile
):
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
M
=
Model
()
player
=
AtariPlayer
(
romfile
,
height_range
=
HEIGHT_RANGE
,
frame_skip
=
ACTION_REPEAT
)
global
NUM_ACTIONS
NUM_ACTIONS
=
player
.
get_num_actions
()
M
=
Model
()
dataset_train
=
ExpReplay
(
dataset_train
=
ExpReplay
(
predictor
=
current_predictor
,
predictor
=
current_predictor
,
player
=
player
,
player
=
get_player
()
,
num_actions
=
NUM_ACTIONS
,
num_actions
=
NUM_ACTIONS
,
memory_size
=
MEMORY_SIZE
,
memory_size
=
MEMORY_SIZE
,
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
...
@@ -255,18 +235,6 @@ def get_config(romfile):
...
@@ -255,18 +235,6 @@ def get_config(romfile):
lr
=
tf
.
Variable
(
0.00025
,
trainable
=
False
,
name
=
'learning_rate'
)
lr
=
tf
.
Variable
(
0.00025
,
trainable
=
False
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
class
Evaluator
(
Callback
):
def
_trigger_epoch
(
self
):
logger
.
info
(
"Evaluating..."
)
output
=
subprocess
.
check_output
(
"""CUDA_VISIBLE_DEVICES= {} --task eval --rom {} --load {} 2>&1 | grep Average"""
.
format
(
sys
.
argv
[
0
],
romfile
,
os
.
path
.
join
(
logger
.
LOG_DIR
,
'checkpoint'
)),
shell
=
True
)
output
=
output
.
strip
()
output
=
output
[
output
.
find
(
']'
)
+
1
:]
mean
,
maximum
=
re
.
findall
(
'[0-9
\
.
\
-]+'
,
output
)[
-
2
:]
self
.
trainer
.
write_scalar_summary
(
'mean_score'
,
mean
)
self
.
trainer
.
write_scalar_summary
(
'max_score'
,
maximum
)
return
TrainConfig
(
return
TrainConfig
(
dataset
=
dataset_train
,
dataset
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
...
@@ -274,15 +242,13 @@ def get_config(romfile):
...
@@ -274,15 +242,13 @@ def get_config(romfile):
StatPrinter
(),
StatPrinter
(),
ModelSaver
(),
ModelSaver
(),
HumanHyperParamSetter
(
'learning_rate'
,
'hyper.txt'
),
HumanHyperParamSetter
(
'learning_rate'
,
'hyper.txt'
),
HumanHyperParamSetter
((
dataset_train
,
'exploration'
),
'hyper.txt'
),
HumanHyperParamSetter
(
ObjAttrParam
(
dataset_train
,
'exploration'
),
'hyper.txt'
),
TargetNetworkUpdator
(
M
),
RunOp
(
lambda
:
M
.
update_target_param
()
),
dataset_train
,
dataset_train
,
PeriodicCallback
(
Evaluator
(),
2
),
PeriodicCallback
(
Evaluator
(),
2
),
]),
]),
session_config
=
get_default_sess_config
(
0.5
),
model
=
M
,
model
=
M
,
step_per_epoch
=
STEP_PER_EPOCH
,
step_per_epoch
=
STEP_PER_EPOCH
,
max_epoch
=
10000
,
)
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -300,15 +266,18 @@ if __name__ == '__main__':
...
@@ -300,15 +266,18 @@ if __name__ == '__main__':
if
args
.
task
!=
'train'
:
if
args
.
task
!=
'train'
:
assert
args
.
load
is
not
None
assert
args
.
load
is
not
None
global
ROM_FILE
ROM_FILE
=
args
.
rom
if
args
.
task
==
'play'
:
if
args
.
task
==
'play'
:
play_model
(
args
.
load
,
args
.
rom
)
play_model
(
args
.
load
)
sys
.
exit
()
sys
.
exit
()
if
args
.
task
==
'eval'
:
if
args
.
task
==
'eval'
:
eval_model_multiprocess
(
args
.
load
,
args
.
rom
)
eval_model_multiprocess
(
args
.
load
)
sys
.
exit
()
sys
.
exit
()
with
tf
.
Graph
()
.
as_default
():
with
tf
.
Graph
()
.
as_default
():
config
=
get_config
(
args
.
rom
)
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
SimpleTrainer
(
config
)
.
train
()
SimpleTrainer
(
config
)
.
train
()
...
...
tensorpack/callbacks/graph.py
0 → 100644
View file @
d7a85f44
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: graph.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
""" Graph related callbacks"""
from
.base
import
Callback
from
..utils
import
logger
__all__
=
[
'RunOp'
]
class
RunOp
(
Callback
):
""" Run an op periodically"""
def
__init__
(
self
,
setup_func
,
run_before
=
True
,
run_epoch
=
True
):
"""
:param setup_func: a function that returns the op in the graph
:param run_before: run the op before training
:param run_epoch: run the op on every epoch trigger
"""
self
.
setup_func
=
setup_func
self
.
run_before
=
run_before
self
.
run_epoch
=
run_epoch
def
_setup_graph
(
self
):
self
.
_op
=
self
.
setup_func
()
#self._op_name = self._op.name
def
_before_train
(
self
):
if
self
.
run_before
:
self
.
_op
.
run
()
def
_trigger_epoch
(
self
):
if
self
.
run_epoch
:
self
.
_op
.
run
()
#def _log(self):
#logger.info("Running op {} ...".format(self._op_name))
tensorpack/callbacks/param.py
View file @
d7a85f44
...
@@ -4,42 +4,43 @@
...
@@ -4,42 +4,43 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
,
abstractproperty
import
operator
import
operator
import
six
from
.base
import
Callback
from
.base
import
Callback
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils
import
get_op_var_name
from
..tfutils
import
get_op_var_name
__all__
=
[
'HyperParamSetter'
,
'HumanHyperParamSetter'
,
__all__
=
[
'HyperParamSetter'
,
'HumanHyperParamSetter'
,
'ScheduledHyperParamSetter'
]
'ScheduledHyperParamSetter'
,
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
]
class
HyperParamSetter
(
Callback
):
class
HyperParam
(
object
):
"""
""" Base class for a hyper param"""
Base class to set hyperparameters after every epoch.
"""
__metaclass__
=
ABCMeta
__metaclass__
=
ABCMeta
TF_VAR
=
0
def
setup_graph
(
self
):
OBJ_ATTR
=
1
""" setup the graph in `setup_graph` callback stage, if necessary"""
pass
def
__init__
(
self
,
param
,
shape
=
[]):
@
abstractmethod
"""
def
set_value
(
self
,
v
):
:param param: either a name of the variable in the graph, or a (object, attribute) tuple
""" define how the value of the param will be set"""
:param shape: shape of the param
pass
"""
if
isinstance
(
param
,
tuple
):
@
abstractproperty
self
.
param_type
=
HyperParamSetter
.
OBJ_ATTR
def
readable_name
(
self
):
self
.
obj_attr
=
param
pass
self
.
readable_name
=
param
[
1
]
else
:
class
GraphVarParam
(
HyperParam
):
self
.
param_type
=
HyperParamSetter
.
TF_VAR
""" a variable in the graph"""
self
.
readable_name
,
self
.
var_name
=
get_op_var_name
(
param
)
def
__init__
(
self
,
name
,
shape
=
[]):
self
.
name
=
name
self
.
shape
=
shape
self
.
shape
=
shape
self
.
last_value
=
None
self
.
_readable_name
,
self
.
var_name
=
get_op_var_name
(
name
)
def
_setup_graph
(
self
):
def
setup_graph
(
self
):
if
self
.
param_type
==
HyperParamSetter
.
TF_VAR
:
all_vars
=
tf
.
all_variables
()
all_vars
=
tf
.
all_variables
()
for
v
in
all_vars
:
for
v
in
all_vars
:
if
v
.
name
==
self
.
var_name
:
if
v
.
name
==
self
.
var_name
:
...
@@ -49,22 +50,62 @@ class HyperParamSetter(Callback):
...
@@ -49,22 +50,62 @@ class HyperParamSetter(Callback):
raise
ValueError
(
"{} is not a VARIABLE in the graph!"
.
format
(
self
.
var_name
))
raise
ValueError
(
"{} is not a VARIABLE in the graph!"
.
format
(
self
.
var_name
))
self
.
val_holder
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
self
.
shape
,
self
.
val_holder
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
self
.
shape
,
name
=
self
.
readable_name
+
'_feed'
)
name
=
self
.
_
readable_name
+
'_feed'
)
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
def
get_current_value
(
self
):
def
set_value
(
self
,
v
):
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
@
property
def
readable_name
(
self
):
return
self
.
_readable_name
class
ObjAttrParam
(
HyperParam
):
""" an attribute of an object"""
def
__init__
(
self
,
obj
,
attrname
):
self
.
obj
=
obj
self
.
attrname
=
attrname
def
set_value
(
self
,
v
):
setattr
(
self
.
obj
,
self
.
attrname
,
v
)
@
property
def
readable_name
(
self
):
return
self
.
attrname
class
HyperParamSetter
(
Callback
):
"""
Base class to set hyperparameters after every epoch.
"""
__metaclass__
=
ABCMeta
def
__init__
(
self
,
param
):
"""
:param param: a `HyperParam` instance, or a string (assumed to be a scalar `GraphVarParam`)
"""
# if a string, assumed to be a scalar graph variable
if
isinstance
(
param
,
six
.
string_types
):
param
=
GraphVarParam
(
param
)
assert
isinstance
(
param
,
HyperParam
),
type
(
param
)
self
.
param
=
param
self
.
last_value
=
None
def
_setup_graph
(
self
):
self
.
param
.
setup_graph
()
def
get_value_to_set
(
self
):
"""
"""
:returns: the value to assign to the variable now.
:returns: the value to assign to the variable now.
"""
"""
ret
=
self
.
_get_
current_value
()
ret
=
self
.
_get_
value_to_set
()
if
ret
is
not
None
and
ret
!=
self
.
last_value
:
if
ret
is
not
None
and
ret
!=
self
.
last_value
:
logger
.
info
(
"{} at epoch {} will change to {}"
.
format
(
logger
.
info
(
"{} at epoch {} will change to {}"
.
format
(
self
.
readable_name
,
self
.
epoch_num
+
1
,
ret
))
self
.
param
.
readable_name
,
self
.
epoch_num
+
1
,
ret
))
self
.
last_value
=
ret
self
.
last_value
=
ret
return
ret
return
ret
@
abstractmethod
@
abstractmethod
def
_get_
current_value
(
self
):
def
_get_
value_to_set
(
self
):
pass
pass
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
...
@@ -74,12 +115,9 @@ class HyperParamSetter(Callback):
...
@@ -74,12 +115,9 @@ class HyperParamSetter(Callback):
self
.
_set_param
()
self
.
_set_param
()
def
_set_param
(
self
):
def
_set_param
(
self
):
v
=
self
.
get_
current_value
()
v
=
self
.
get_
value_to_set
()
if
v
is
not
None
:
if
v
is
not
None
:
if
self
.
param_type
==
HyperParamSetter
.
TF_VAR
:
self
.
param
.
set_value
(
v
)
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
else
:
setattr
(
self
.
obj_attr
[
0
],
self
.
obj_attr
[
1
],
v
)
class
HumanHyperParamSetter
(
HyperParamSetter
):
class
HumanHyperParamSetter
(
HyperParamSetter
):
"""
"""
...
@@ -92,18 +130,18 @@ class HumanHyperParamSetter(HyperParamSetter):
...
@@ -92,18 +130,18 @@ class HumanHyperParamSetter(HyperParamSetter):
self
.
file_name
=
file_name
self
.
file_name
=
file_name
super
(
HumanHyperParamSetter
,
self
)
.
__init__
(
param
)
super
(
HumanHyperParamSetter
,
self
)
.
__init__
(
param
)
def
_get_
current_value
(
self
):
def
_get_
value_to_set
(
self
):
try
:
try
:
with
open
(
self
.
file_name
)
as
f
:
with
open
(
self
.
file_name
)
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
lines
=
[
s
.
strip
()
.
split
(
':'
)
for
s
in
lines
]
lines
=
[
s
.
strip
()
.
split
(
':'
)
for
s
in
lines
]
dic
=
{
str
(
k
):
float
(
v
)
for
k
,
v
in
lines
}
dic
=
{
str
(
k
):
float
(
v
)
for
k
,
v
in
lines
}
ret
=
dic
[
self
.
readable_name
]
ret
=
dic
[
self
.
param
.
readable_name
]
return
ret
return
ret
except
:
except
:
logger
.
warn
(
logger
.
warn
(
"Failed to parse {} in {}"
.
format
(
"Failed to parse {} in {}"
.
format
(
self
.
readable_name
,
self
.
file_name
))
self
.
param
.
readable_name
,
self
.
file_name
))
return
None
return
None
class
ScheduledHyperParamSetter
(
HyperParamSetter
):
class
ScheduledHyperParamSetter
(
HyperParamSetter
):
...
@@ -118,11 +156,9 @@ class ScheduledHyperParamSetter(HyperParamSetter):
...
@@ -118,11 +156,9 @@ class ScheduledHyperParamSetter(HyperParamSetter):
self
.
schedule
=
sorted
(
schedule
,
key
=
operator
.
itemgetter
(
0
))
self
.
schedule
=
sorted
(
schedule
,
key
=
operator
.
itemgetter
(
0
))
super
(
ScheduledHyperParamSetter
,
self
)
.
__init__
(
param
)
super
(
ScheduledHyperParamSetter
,
self
)
.
__init__
(
param
)
def
_get_
current_value
(
self
):
def
_get_
value_to_set
(
self
):
for
e
,
v
in
self
.
schedule
:
for
e
,
v
in
self
.
schedule
:
if
e
==
self
.
epoch_num
:
if
e
==
self
.
epoch_num
:
return
v
return
v
return
None
return
None
tensorpack/dataflow/RL.py
100755 → 100644
View file @
d7a85f44
File mode changed from 100755 to 100644
tensorpack/dataflow/dataset/.gitignore
View file @
d7a85f44
mnist_data
mnist_data
cifar10_data
cifar10_data
cifar100_data
svhn_data
svhn_data
ilsvrc_metadata
ilsvrc_metadata
bsds500_data
bsds500_data
tensorpack/train/config.py
View file @
d7a85f44
...
@@ -30,7 +30,7 @@ class TrainConfig(object):
...
@@ -30,7 +30,7 @@ class TrainConfig(object):
:param model: a `ModelDesc` instance.j
:param model: a `ModelDesc` instance.j
:param starting_epoch: int. default to be 1.
:param starting_epoch: int. default to be 1.
:param step_per_epoch: the number of steps (SGD updates) to perform in each epoch.
:param step_per_epoch: the number of steps (SGD updates) to perform in each epoch.
:param max_epoch: maximum number of epoch to run training. default to
100
:param max_epoch: maximum number of epoch to run training. default to
inf
:param nr_tower: int. number of towers. default to 1.
:param nr_tower: int. number of towers. default to 1.
:param extra_threads_procs: list of `Startable` threads or processes
:param extra_threads_procs: list of `Startable` threads or processes
"""
"""
...
@@ -51,7 +51,7 @@ class TrainConfig(object):
...
@@ -51,7 +51,7 @@ class TrainConfig(object):
assert_type
(
self
.
session_init
,
SessionInit
)
assert_type
(
self
.
session_init
,
SessionInit
)
self
.
step_per_epoch
=
int
(
kwargs
.
pop
(
'step_per_epoch'
))
self
.
step_per_epoch
=
int
(
kwargs
.
pop
(
'step_per_epoch'
))
self
.
starting_epoch
=
int
(
kwargs
.
pop
(
'starting_epoch'
,
1
))
self
.
starting_epoch
=
int
(
kwargs
.
pop
(
'starting_epoch'
,
1
))
self
.
max_epoch
=
int
(
kwargs
.
pop
(
'max_epoch'
,
100
))
self
.
max_epoch
=
int
(
kwargs
.
pop
(
'max_epoch'
,
99999
))
assert
self
.
step_per_epoch
>
0
and
self
.
max_epoch
>
0
assert
self
.
step_per_epoch
>
0
and
self
.
max_epoch
>
0
self
.
nr_tower
=
int
(
kwargs
.
pop
(
'nr_tower'
,
1
))
self
.
nr_tower
=
int
(
kwargs
.
pop
(
'nr_tower'
,
1
))
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
...
...
tensorpack/utils/concurrency.py
View file @
d7a85f44
...
@@ -14,6 +14,8 @@ if six.PY2:
...
@@ -14,6 +14,8 @@ if six.PY2:
else
:
else
:
import
subprocess
import
subprocess
from
.
import
logger
__all__
=
[
'StoppableThread'
,
'LoopThread'
,
'ensure_proc_terminate'
,
__all__
=
[
'StoppableThread'
,
'LoopThread'
,
'ensure_proc_terminate'
,
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
]
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment