Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0dbfe237
Commit
0dbfe237
authored
Mar 25, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
hyperparam setter
parent
9387c653
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
1 deletion
+74
-1
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+66
-0
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+8
-1
No files found.
tensorpack/callbacks/param.py
0 → 100644
View file @
0dbfe237
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: param.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
abc
import
abstractmethod
,
ABCMeta
from
.base
import
Callback
from
..utils
import
logger
,
get_op_var_name
__all__
=
[
'HyperParamSetter'
,
'HumanHyperParamSetter'
]
class
HyperParamSetter
(
Callback
):
__metaclass__
=
ABCMeta
# TODO maybe support InputVar?
def
__init__
(
self
,
var_name
,
shape
=
[]):
self
.
op_name
,
self
.
var_name
=
get_op_var_name
(
var_name
)
self
.
shape
=
shape
self
.
last_value
=
None
def
_before_train
(
self
):
all_vars
=
tf
.
all_variables
()
for
v
in
all_vars
:
print
v
.
name
if
v
.
name
==
self
.
var_name
:
self
.
var
=
v
break
else
:
raise
ValueError
(
"{} is not a VARIABLE in the graph!"
.
format
(
self
.
var_name
))
self
.
val_holder
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
self
.
shape
,
name
=
self
.
op_name
+
'_feed'
)
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
def
get_current_value
(
self
):
ret
=
self
.
_get_current_value
()
if
ret
!=
self
.
last_value
:
logger
.
info
(
"{} at epoch {} is changed to {}"
.
format
(
self
.
var_name
,
self
.
epoch_num
,
ret
))
self
.
last_value
=
ret
return
ret
@
abstractmethod
def
_get_current_value
(
self
):
pass
def
_trigger_epoch
(
self
):
v
=
self
.
get_current_value
()
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
class
HumanHyperParamSetter
(
HyperParamSetter
):
def
__init__
(
self
,
var_name
,
file_name
):
"""
read value from file_name.
file_name: each line in the file is a k:v pair
"""
self
.
file_name
=
file_name
super
(
HumanHyperParamSetter
,
self
)
.
__init__
(
var_name
)
def
_get_current_value
(
self
):
with
open
(
self
.
file_name
)
as
f
:
lines
=
f
.
readlines
()
lines
=
[
s
.
strip
()
.
split
(
':'
)
for
s
in
lines
]
dic
=
{
str
(
k
):
float
(
v
)
for
k
,
v
in
lines
}
return
dic
[
self
.
op_name
]
tensorpack/utils/utils.py
View file @
0dbfe237
...
...
@@ -10,7 +10,8 @@ import numpy as np
from
.
import
logger
__all__
=
[
'timed_operation'
,
'change_env'
,
'get_rng'
,
'memoized'
]
__all__
=
[
'timed_operation'
,
'change_env'
,
'get_rng'
,
'memoized'
,
'get_op_var_name'
]
#def expand_dim_if_necessary(var, dp):
# """
...
...
@@ -77,3 +78,9 @@ class memoized(object):
def
get_rng
(
self
):
seed
=
(
id
(
self
)
+
os
.
getpid
())
%
4294967295
return
np
.
random
.
RandomState
(
seed
)
def
get_op_var_name
(
name
):
if
name
.
endswith
(
':0'
):
return
name
[:
-
2
],
name
else
:
return
name
,
name
+
':0'
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment