Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
dc378b53
Commit
dc378b53
authored
Nov 26, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
global_variables instead of variables
parent
540cdf7c
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
118 additions
and
23 deletions
+118
-23
examples/char-rnn/char-rnn.py
examples/char-rnn/char-rnn.py
+5
-12
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+1
-1
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+14
-5
tensorpack/dataflow/dataset/ptb.py
tensorpack/dataflow/dataset/ptb.py
+84
-0
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+0
-1
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+1
-1
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+2
-2
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+11
-1
No files found.
examples/char-rnn/char-rnn.py
View file @
dc378b53
...
@@ -17,9 +17,6 @@ from tensorpack.tfutils.gradproc import *
...
@@ -17,9 +17,6 @@ from tensorpack.tfutils.gradproc import *
from
tensorpack.utils.lut
import
LookUpTable
from
tensorpack.utils.lut
import
LookUpTable
from
tensorpack.utils.globvars
import
globalns
as
param
from
tensorpack.utils.globvars
import
globalns
as
param
from
tensorflow.python.ops
import
rnn_cell
from
tensorflow.python.ops
import
rnn
# some model hyperparams to set
# some model hyperparams to set
param
.
batch_size
=
128
param
.
batch_size
=
128
param
.
rnn_size
=
256
param
.
rnn_size
=
256
...
@@ -30,7 +27,7 @@ param.vocab_size = None
...
@@ -30,7 +27,7 @@ param.vocab_size = None
param
.
softmax_temprature
=
1
param
.
softmax_temprature
=
1
param
.
corpus
=
'input.txt'
param
.
corpus
=
'input.txt'
class
CharRNNData
(
DataFlow
):
class
CharRNNData
(
RNG
DataFlow
):
def
__init__
(
self
,
input_file
,
size
):
def
__init__
(
self
,
input_file
,
size
):
self
.
seq_length
=
param
.
seq_len
self
.
seq_length
=
param
.
seq_len
self
.
_size
=
size
self
.
_size
=
size
...
@@ -49,9 +46,6 @@ class CharRNNData(DataFlow):
...
@@ -49,9 +46,6 @@ class CharRNNData(DataFlow):
self
.
whole_seq
=
np
.
array
(
list
(
map
(
self
.
lut
.
get_idx
,
data
)),
dtype
=
'int32'
)
self
.
whole_seq
=
np
.
array
(
list
(
map
(
self
.
lut
.
get_idx
,
data
)),
dtype
=
'int32'
)
logger
.
info
(
"Corpus loaded. Vocab size: {}"
.
format
(
self
.
vocab_size
))
logger
.
info
(
"Corpus loaded. Vocab size: {}"
.
format
(
self
.
vocab_size
))
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
def
size
(
self
):
def
size
(
self
):
return
self
.
_size
return
self
.
_size
...
@@ -71,19 +65,18 @@ class Model(ModelDesc):
...
@@ -71,19 +65,18 @@ class Model(ModelDesc):
def
_build_graph
(
self
,
input_vars
):
def
_build_graph
(
self
,
input_vars
):
input
,
nextinput
=
input_vars
input
,
nextinput
=
input_vars
cell
=
rnn_cell
.
BasicLSTMCell
(
num_units
=
param
.
rnn_size
)
cell
=
tf
.
nn
.
rnn_cell
.
BasicLSTMCell
(
num_units
=
param
.
rnn_size
)
cell
=
rnn_cell
.
MultiRNNCell
([
cell
]
*
param
.
num_rnn_layer
)
cell
=
tf
.
nn
.
rnn_cell
.
MultiRNNCell
([
cell
]
*
param
.
num_rnn_layer
)
self
.
initial
=
initial
=
cell
.
zero_state
(
tf
.
shape
(
input
)[
0
],
tf
.
float32
)
self
.
initial
=
initial
=
cell
.
zero_state
(
tf
.
shape
(
input
)[
0
],
tf
.
float32
)
embeddingW
=
tf
.
get_variable
(
'embedding'
,
[
param
.
vocab_size
,
param
.
rnn_size
])
embeddingW
=
tf
.
get_variable
(
'embedding'
,
[
param
.
vocab_size
,
param
.
rnn_size
])
input_feature
=
tf
.
nn
.
embedding_lookup
(
embeddingW
,
input
)
# B x seqlen x rnnsize
input_feature
=
tf
.
nn
.
embedding_lookup
(
embeddingW
,
input
)
# B x seqlen x rnnsize
input_list
=
tf
.
split
(
1
,
param
.
seq_len
,
input_feature
)
#seqlen x (Bx1xrnnsize)
input_list
=
tf
.
unstack
(
input_feature
,
axis
=
1
)
#seqlen x (Bxrnnsize)
input_list
=
[
tf
.
squeeze
(
x
,
[
1
])
for
x
in
input_list
]
# seqlen is 1 in inference. don't need loop_function
# seqlen is 1 in inference. don't need loop_function
outputs
,
last_state
=
r
nn
.
rnn
(
cell
,
input_list
,
initial
,
scope
=
'rnnlm'
)
outputs
,
last_state
=
tf
.
nn
.
rnn
(
cell
,
input_list
,
initial
,
scope
=
'rnnlm'
)
self
.
last_state
=
tf
.
identity
(
last_state
,
'last_state'
)
self
.
last_state
=
tf
.
identity
(
last_state
,
'last_state'
)
# seqlen x (Bxrnnsize)
# seqlen x (Bxrnnsize)
...
...
tensorpack/callbacks/common.py
View file @
dc378b53
...
@@ -18,7 +18,7 @@ class ModelSaver(Callback):
...
@@ -18,7 +18,7 @@ class ModelSaver(Callback):
Save the model to logger directory.
Save the model to logger directory.
"""
"""
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
var_collections
=
tf
.
GraphKeys
.
VARIABLES
):
var_collections
=
tf
.
GraphKeys
.
GLOBAL_
VARIABLES
):
"""
"""
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
...
...
tensorpack/callbacks/param.py
View file @
dc378b53
...
@@ -15,7 +15,7 @@ from ..tfutils import get_op_var_name
...
@@ -15,7 +15,7 @@ from ..tfutils import get_op_var_name
__all__
=
[
'HyperParamSetter'
,
'HumanHyperParamSetter'
,
__all__
=
[
'HyperParamSetter'
,
'HumanHyperParamSetter'
,
'ScheduledHyperParamSetter'
,
'ScheduledHyperParamSetter'
,
'StatMonitorParamSetter'
,
'StatMonitorParamSetter'
,
'HyperParamSetterWithFunc'
,
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
]
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
]
class
HyperParam
(
object
):
class
HyperParam
(
object
):
...
@@ -197,15 +197,24 @@ class ScheduledHyperParamSetter(HyperParamSetter):
...
@@ -197,15 +197,24 @@ class ScheduledHyperParamSetter(HyperParamSetter):
v
=
(
self
.
epoch_num
-
laste
)
*
1.
/
(
e
-
laste
)
*
(
v
-
lastv
)
+
lastv
v
=
(
self
.
epoch_num
-
laste
)
*
1.
/
(
e
-
laste
)
*
(
v
-
lastv
)
+
lastv
return
v
return
v
class
HyperParamSetterWithFunc
(
HyperParamSetter
):
def
__init__
(
self
,
param
,
func
):
"""Set hyperparameter by a func
new_value = f(epoch_num, old_value)
"""
super
(
StatMonitorParamSetter
,
self
)
.
__init__
(
param
)
self
.
f
=
func
def
_get_value_to_set
(
self
):
return
self
.
f
(
self
.
epoch_num
,
self
.
get_current_value
())
class
StatMonitorParamSetter
(
HyperParamSetter
):
class
StatMonitorParamSetter
(
HyperParamSetter
):
"""
Set hyperparameter by a func, when a specific stat wasn't
decreasing/increasing enough in the last $k$ epochs
"""
def
__init__
(
self
,
param
,
stat_name
,
value_func
,
threshold
,
def
__init__
(
self
,
param
,
stat_name
,
value_func
,
threshold
,
last_k
,
reverse
=
False
last_k
,
reverse
=
False
):
):
"""
"""
Set hyperparameter by a func, when a specific stat wasn't
decreasing/increasing enough in the last $k$ epochs.
Change param by `new_value = value_func(old_value)`,
Change param by `new_value = value_func(old_value)`,
if :
if :
min(stats) >= stats[0] - threshold, where
min(stats) >= stats[0] - threshold, where
...
...
tensorpack/dataflow/dataset/ptb.py
0 → 100644
View file @
dc378b53
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: ptb.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
import
numpy
as
np
from
...utils
import
logger
,
get_dataset_path
from
...utils.fs
import
download
from
...utils.argtools
import
memoized_ignoreargs
from
..base
import
RNGDataFlow
try
:
import
tensorflow
from
tensorflow.models.rnn.ptb
import
reader
as
tfreader
except
ImportError
:
logger
.
warn_dependency
(
'PennTreeBank'
,
'tensorflow'
)
__all__
=
[]
else
:
__all__
=
[
'PennTreeBank'
]
TRAIN_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt'
VALID_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt'
TEST_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt'
@
memoized_ignoreargs
def
get_raw_data
(
data_dir
):
if
not
os
.
path
.
isfile
(
os
.
path
.
join
(
data_dir
,
'ptb.train.txt'
)):
download
(
TRAIN_URL
,
data_dir
)
download
(
VALID_URL
,
data_dir
)
download
(
TEST_URL
,
data_dir
)
# TODO these functions in TF might not be available in the future
word_to_id
=
tfreader
.
_build_vocab
(
os
.
path
.
join
(
data_dir
,
'ptb.train.txt'
))
data3
=
[
tfreader
.
_file_to_word_ids
(
os
.
path
.
join
(
data_dir
,
fname
),
word_to_id
)
for
fname
in
[
'ptb.train.txt'
,
'ptb.valid.txt'
,
'ptb.test.txt'
]]
return
data3
,
word_to_id
class
PennTreeBank
(
RNGDataFlow
):
def
__init__
(
self
,
name
,
step_size
,
data_dir
=
None
,
shuffle
=
True
):
"""
Generate PTB word sequences.
:param name: one of 'train', 'val', 'test'
"""
super
(
PennTreeBank
,
self
)
.
__init__
()
if
data_dir
is
None
:
data_dir
=
get_dataset_path
(
'ptb_data'
)
assert
os
.
path
.
isdir
(
data_dir
)
data3
,
word_to_id
=
get_raw_data
(
data_dir
)
self
.
word_to_id
=
word_to_id
self
.
data
=
np
.
asarray
(
data3
[[
'train'
,
'val'
,
'test'
]
.
index
(
name
)],
dtype
=
'int32'
)
self
.
step_size
=
step_size
self
.
shuffle
=
shuffle
def
size
(
self
):
return
(
self
.
data
.
shape
[
0
]
-
1
)
//
self
.
step_size
def
get_data
(
self
):
sz
=
self
.
size
()
if
not
self
.
shuffle
:
starts
=
np
.
arange
(
self
.
data
.
shape
[
0
]
-
1
)[::
self
.
step_size
]
assert
starts
.
shape
[
0
]
>=
sz
starts
=
starts
[:
sz
]
else
:
starts
=
self
.
rng
.
randint
(
0
,
self
.
data
.
shape
[
0
]
-
1
-
self
.
step_size
,
size
=
(
sz
,))
for
st
in
starts
:
seq
=
self
.
data
[
st
:
st
+
self
.
step_size
+
1
]
yield
[
seq
[:
-
1
],
seq
[
1
:]]
@
staticmethod
def
word_to_id
():
data3
,
wti
=
get_raw_data
()
return
wti
if
__name__
==
'__main__'
:
D
=
PennTreeBank
(
'train'
,
50
)
D
.
reset_state
()
for
k
in
D
.
get_data
():
import
IPython
as
IP
;
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
tensorpack/dataflow/dataset/svhn.py
View file @
dc378b53
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
os
import
os
import
random
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
range
from
six.moves
import
range
...
...
tensorpack/models/model_desc.py
View file @
dc378b53
...
@@ -93,7 +93,7 @@ class ModelFromMetaGraph(ModelDesc):
...
@@ -93,7 +93,7 @@ class ModelFromMetaGraph(ModelDesc):
tf
.
train
.
import_meta_graph
(
filename
)
tf
.
train
.
import_meta_graph
(
filename
)
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
for
k
in
[
INPUT_VARS_KEY
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
for
k
in
[
INPUT_VARS_KEY
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
tf
.
GraphKeys
.
VARIABLES
]:
tf
.
GraphKeys
.
GLOBAL_
VARIABLES
]:
assert
k
in
all_coll
,
\
assert
k
in
all_coll
,
\
"Collection {} not found in metagraph!"
.
format
(
k
)
"Collection {} not found in metagraph!"
.
format
(
k
)
...
...
tensorpack/tfutils/sessinit.py
View file @
dc378b53
...
@@ -113,7 +113,7 @@ class SaverRestore(SessionInit):
...
@@ -113,7 +113,7 @@ class SaverRestore(SessionInit):
:param vars_available: varaible names available in the checkpoint, for existence checking
:param vars_available: varaible names available in the checkpoint, for existence checking
:returns: a dict of {var_name: [var, var]} to restore
:returns: a dict of {var_name: [var, var]} to restore
"""
"""
vars_to_restore
=
tf
.
al
l_variables
()
vars_to_restore
=
tf
.
globa
l_variables
()
var_dict
=
defaultdict
(
list
)
var_dict
=
defaultdict
(
list
)
chkpt_vars_used
=
set
()
chkpt_vars_used
=
set
()
for
v
in
vars_to_restore
:
for
v
in
vars_to_restore
:
...
@@ -150,7 +150,7 @@ class ParamRestore(SessionInit):
...
@@ -150,7 +150,7 @@ class ParamRestore(SessionInit):
self
.
prms
=
{
get_op_var_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
self
.
prms
=
{
get_op_var_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
VARIABLES
)
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_
VARIABLES
)
variable_names
=
set
([
get_savename_from_varname
(
k
.
name
)
for
k
in
variables
])
variable_names
=
set
([
get_savename_from_varname
(
k
.
name
)
for
k
in
variables
])
param_names
=
set
(
six
.
iterkeys
(
self
.
prms
))
param_names
=
set
(
six
.
iterkeys
(
self
.
prms
))
...
...
tensorpack/utils/argtools.py
View file @
dc378b53
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
import
inspect
,
six
,
functools
import
inspect
,
six
,
functools
import
collections
import
collections
__all__
=
[
'map_arg'
,
'memoized'
,
'shape2d'
]
__all__
=
[
'map_arg'
,
'memoized'
,
'shape2d'
,
'memoized_ignoreargs'
]
def
map_arg
(
**
maps
):
def
map_arg
(
**
maps
):
"""
"""
...
@@ -54,6 +54,16 @@ class memoized(object):
...
@@ -54,6 +54,16 @@ class memoized(object):
'''Support instance methods.'''
'''Support instance methods.'''
return
functools
.
partial
(
self
.
__call__
,
obj
)
return
functools
.
partial
(
self
.
__call__
,
obj
)
_MEMOIZED_NOARGS
=
{}
def
memoized_ignoreargs
(
func
):
h
=
hash
(
func
)
# make sure it is hashable. is it necessary?
def
wrapper
(
*
args
):
if
func
not
in
_MEMOIZED_NOARGS
:
res
=
func
(
*
args
)
_MEMOIZED_NOARGS
[
func
]
=
res
return
res
return
_MEMOIZED_NOARGS
[
func
]
return
wrapper
#_GLOBAL_MEMOIZED_CACHE = dict()
#_GLOBAL_MEMOIZED_CACHE = dict()
#def global_memoized(func):
#def global_memoized(func):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment