Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6e562112
Commit
6e562112
authored
May 15, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
param setter object attr
parent
704bee73
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
32 deletions
+47
-32
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+4
-3
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+36
-25
tensorpack/predict.py
tensorpack/predict.py
+7
-4
No files found.
tensorpack/callbacks/base.py
View file @
6e562112
...
...
@@ -46,7 +46,8 @@ class Callback(object):
"""
self
.
trainer
=
trainer
self
.
graph
=
tf
.
get_default_graph
()
self
.
epoch_num
=
self
.
trainer
.
config
.
starting_epoch
self
.
epoch_num
=
self
.
trainer
.
config
.
starting_epoch
-
1
# self.epoch_num is always the number of epochs that finished updating parameters.
self
.
_setup_graph
()
def
_setup_graph
(
self
):
...
...
@@ -81,8 +82,8 @@ class Callback(object):
In this function, self.epoch_num would be the number of epoch finished.
"""
self
.
_trigger_epoch
()
self
.
epoch_num
+=
1
self
.
_trigger_epoch
()
def
_trigger_epoch
(
self
):
pass
...
...
@@ -117,7 +118,7 @@ class PeriodicCallback(ProxyCallback):
self
.
period
=
int
(
period
)
def
_trigger_epoch
(
self
):
self
.
cb
.
epoch_num
=
self
.
epoch_num
-
1
if
self
.
epoch_num
%
self
.
period
==
0
:
self
.
cb
.
epoch_num
=
self
.
epoch_num
-
1
self
.
cb
.
trigger_epoch
()
tensorpack/callbacks/param.py
View file @
6e562112
...
...
@@ -20,17 +20,26 @@ class HyperParamSetter(Callback):
"""
__metaclass__
=
ABCMeta
# TODO maybe support InputVar?
def
__init__
(
self
,
var_name
,
shape
=
[]):
TF_VAR
=
0
OBJ_ATTR
=
1
def
__init__
(
self
,
param
,
shape
=
[]):
"""
:param
var_name: name of the variab
le
:param shape: shape of the
variable
:param
param: either a name of the variable in the graph, or a (object, attribute) tup
le
:param shape: shape of the
param
"""
self
.
op_name
,
self
.
var_name
=
get_op_var_name
(
var_name
)
if
isinstance
(
param
,
tuple
):
self
.
param_type
=
HyperParamSetter
.
OBJ_ATTR
self
.
obj_attr
=
param
self
.
readable_name
=
param
[
1
]
else
:
self
.
param_type
=
HyperParamSetter
.
TF_VAR
self
.
readable_name
,
self
.
var_name
=
get_op_var_name
(
param
)
self
.
shape
=
shape
self
.
last_value
=
None
def
_setup_graph
(
self
):
if
self
.
param_type
==
HyperParamSetter
.
TF_VAR
:
all_vars
=
tf
.
all_variables
()
for
v
in
all_vars
:
if
v
.
name
==
self
.
var_name
:
...
...
@@ -40,7 +49,7 @@ class HyperParamSetter(Callback):
raise
ValueError
(
"{} is not a VARIABLE in the graph!"
.
format
(
self
.
var_name
))
self
.
val_holder
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
self
.
shape
,
name
=
self
.
op
_name
+
'_feed'
)
name
=
self
.
readable
_name
+
'_feed'
)
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
def
get_current_value
(
self
):
...
...
@@ -50,7 +59,7 @@ class HyperParamSetter(Callback):
ret
=
self
.
_get_current_value
()
if
ret
is
not
None
and
ret
!=
self
.
last_value
:
logger
.
info
(
"{} at epoch {} will change to {}"
.
format
(
self
.
op
_name
,
self
.
epoch_num
+
1
,
ret
))
self
.
readable
_name
,
self
.
epoch_num
+
1
,
ret
))
self
.
last_value
=
ret
return
ret
...
...
@@ -67,19 +76,21 @@ class HyperParamSetter(Callback):
def
_set_param
(
self
):
v
=
self
.
get_current_value
()
if
v
is
not
None
:
if
self
.
param_type
==
HyperParamSetter
.
TF_VAR
:
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
else
:
setattr
(
self
.
obj_attr
[
0
],
self
.
obj_attr
[
1
],
v
)
class
HumanHyperParamSetter
(
HyperParamSetter
):
"""
Set hyperparameters manually by modifying a file.
"""
def
__init__
(
self
,
var_name
,
file_name
):
def
__init__
(
self
,
param
,
file_name
):
"""
:param var_name: name of the variable.
:param file_name: a file containing the value of the variable. Each line in the file is a k:v pair
"""
self
.
file_name
=
file_name
super
(
HumanHyperParamSetter
,
self
)
.
__init__
(
var_name
)
super
(
HumanHyperParamSetter
,
self
)
.
__init__
(
param
)
def
_get_current_value
(
self
):
try
:
...
...
@@ -87,25 +98,25 @@ class HumanHyperParamSetter(HyperParamSetter):
lines
=
f
.
readlines
()
lines
=
[
s
.
strip
()
.
split
(
':'
)
for
s
in
lines
]
dic
=
{
str
(
k
):
float
(
v
)
for
k
,
v
in
lines
}
ret
=
dic
[
self
.
op
_name
]
ret
=
dic
[
self
.
readable
_name
]
return
ret
except
:
logger
.
warn
(
"Failed to parse {} in {}"
.
format
(
self
.
op
_name
,
self
.
file_name
))
self
.
readable
_name
,
self
.
file_name
))
return
None
class
ScheduledHyperParamSetter
(
HyperParamSetter
):
"""
Set hyperparameters by a predefined schedule.
"""
def
__init__
(
self
,
var_name
,
schedule
):
def
__init__
(
self
,
param
,
schedule
):
"""
:param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
"""
schedule
=
[(
int
(
a
),
float
(
b
))
for
a
,
b
in
schedule
]
self
.
schedule
=
sorted
(
schedule
,
key
=
operator
.
itemgetter
(
0
))
super
(
ScheduledHyperParamSetter
,
self
)
.
__init__
(
var_name
)
super
(
ScheduledHyperParamSetter
,
self
)
.
__init__
(
param
)
def
_get_current_value
(
self
):
for
e
,
v
in
self
.
schedule
:
...
...
tensorpack/predict.py
View file @
6e562112
...
...
@@ -102,7 +102,7 @@ class PredictWorker(multiprocessing.Process):
def
__init__
(
self
,
idx
,
gpuid
,
inqueue
,
outqueue
,
config
):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used
:param gpuid: id of the GPU to be used
. set to -1 to use CPU.
:param inqueue: input queue to get data point
:param outqueue: output queue put result
:param config: a `PredictConfig`
...
...
@@ -115,10 +115,13 @@ class PredictWorker(multiprocessing.Process):
self
.
config
=
config
def
run
(
self
):
logger
.
info
(
"Worker {} use GPU {}"
.
format
(
self
.
idx
,
self
.
gpuid
))
if
self
.
gpuid
>=
0
:
logger
.
info
(
"Worker {} uses GPU {}"
.
format
(
self
.
idx
,
self
.
gpuid
))
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
self
.
gpuid
else
:
logger
.
info
(
"Worker {} uses CPU"
.
format
(
self
.
idx
))
G
=
tf
.
Graph
()
# build a graph for each process, because they don't need to share anything
with
G
.
as_default
(),
tf
.
device
(
'/gpu:0'
):
with
G
.
as_default
(),
tf
.
device
(
'/gpu:0'
if
self
.
gpuid
>=
0
else
'/cpu:0'
):
if
self
.
idx
!=
0
:
from
tensorpack.models._common
import
disable_layer_logging
disable_layer_logging
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment