Commit dc378b53 authored by Yuxin Wu's avatar Yuxin Wu

global_variables instead of variables

parent 540cdf7c
...@@ -17,9 +17,6 @@ from tensorpack.tfutils.gradproc import * ...@@ -17,9 +17,6 @@ from tensorpack.tfutils.gradproc import *
from tensorpack.utils.lut import LookUpTable from tensorpack.utils.lut import LookUpTable
from tensorpack.utils.globvars import globalns as param from tensorpack.utils.globvars import globalns as param
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import rnn
# some model hyperparams to set # some model hyperparams to set
param.batch_size = 128 param.batch_size = 128
param.rnn_size = 256 param.rnn_size = 256
...@@ -30,7 +27,7 @@ param.vocab_size = None ...@@ -30,7 +27,7 @@ param.vocab_size = None
param.softmax_temprature = 1 param.softmax_temprature = 1
param.corpus = 'input.txt' param.corpus = 'input.txt'
class CharRNNData(DataFlow): class CharRNNData(RNGDataFlow):
def __init__(self, input_file, size): def __init__(self, input_file, size):
self.seq_length = param.seq_len self.seq_length = param.seq_len
self._size = size self._size = size
...@@ -49,9 +46,6 @@ class CharRNNData(DataFlow): ...@@ -49,9 +46,6 @@ class CharRNNData(DataFlow):
self.whole_seq = np.array(list(map(self.lut.get_idx, data)), dtype='int32') self.whole_seq = np.array(list(map(self.lut.get_idx, data)), dtype='int32')
logger.info("Corpus loaded. Vocab size: {}".format(self.vocab_size)) logger.info("Corpus loaded. Vocab size: {}".format(self.vocab_size))
def reset_state(self):
self.rng = get_rng(self)
def size(self): def size(self):
return self._size return self._size
...@@ -71,19 +65,18 @@ class Model(ModelDesc): ...@@ -71,19 +65,18 @@ class Model(ModelDesc):
def _build_graph(self, input_vars): def _build_graph(self, input_vars):
input, nextinput = input_vars input, nextinput = input_vars
cell = rnn_cell.BasicLSTMCell(num_units=param.rnn_size) cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=param.rnn_size)
cell = rnn_cell.MultiRNNCell([cell] * param.num_rnn_layer) cell = tf.nn.rnn_cell.MultiRNNCell([cell] * param.num_rnn_layer)
self.initial = initial = cell.zero_state(tf.shape(input)[0], tf.float32) self.initial = initial = cell.zero_state(tf.shape(input)[0], tf.float32)
embeddingW = tf.get_variable('embedding', [param.vocab_size, param.rnn_size]) embeddingW = tf.get_variable('embedding', [param.vocab_size, param.rnn_size])
input_feature = tf.nn.embedding_lookup(embeddingW, input) # B x seqlen x rnnsize input_feature = tf.nn.embedding_lookup(embeddingW, input) # B x seqlen x rnnsize
input_list = tf.split(1, param.seq_len, input_feature) #seqlen x (Bx1xrnnsize) input_list = tf.unstack(input_feature, axis=1) #seqlen x (Bxrnnsize)
input_list = [tf.squeeze(x, [1]) for x in input_list]
# seqlen is 1 in inference. don't need loop_function # seqlen is 1 in inference. don't need loop_function
outputs, last_state = rnn.rnn(cell, input_list, initial, scope='rnnlm') outputs, last_state = tf.nn.rnn(cell, input_list, initial, scope='rnnlm')
self.last_state = tf.identity(last_state, 'last_state') self.last_state = tf.identity(last_state, 'last_state')
# seqlen x (Bxrnnsize) # seqlen x (Bxrnnsize)
......
...@@ -18,7 +18,7 @@ class ModelSaver(Callback): ...@@ -18,7 +18,7 @@ class ModelSaver(Callback):
Save the model to logger directory. Save the model to logger directory.
""" """
def __init__(self, keep_recent=10, keep_freq=0.5, def __init__(self, keep_recent=10, keep_freq=0.5,
var_collections=tf.GraphKeys.VARIABLES): var_collections=tf.GraphKeys.GLOBAL_VARIABLES):
""" """
:param keep_recent: see `tf.train.Saver` documentation. :param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation. :param keep_freq: see `tf.train.Saver` documentation.
......
...@@ -15,7 +15,7 @@ from ..tfutils import get_op_var_name ...@@ -15,7 +15,7 @@ from ..tfutils import get_op_var_name
__all__ = ['HyperParamSetter', 'HumanHyperParamSetter', __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter', 'ScheduledHyperParamSetter',
'StatMonitorParamSetter', 'StatMonitorParamSetter', 'HyperParamSetterWithFunc',
'HyperParam', 'GraphVarParam', 'ObjAttrParam'] 'HyperParam', 'GraphVarParam', 'ObjAttrParam']
class HyperParam(object): class HyperParam(object):
...@@ -197,15 +197,24 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -197,15 +197,24 @@ class ScheduledHyperParamSetter(HyperParamSetter):
v = (self.epoch_num - laste) * 1. / (e - laste) * (v - lastv) + lastv v = (self.epoch_num - laste) * 1. / (e - laste) * (v - lastv) + lastv
return v return v
class HyperParamSetterWithFunc(HyperParamSetter):
def __init__(self, param, func):
"""Set hyperparameter by a func
new_value = f(epoch_num, old_value)
"""
super(StatMonitorParamSetter, self).__init__(param)
self.f = func
def _get_value_to_set(self):
return self.f(self.epoch_num, self.get_current_value())
class StatMonitorParamSetter(HyperParamSetter): class StatMonitorParamSetter(HyperParamSetter):
"""
Set hyperparameter by a func, when a specific stat wasn't
decreasing/increasing enough in the last $k$ epochs
"""
def __init__(self, param, stat_name, value_func, threshold, def __init__(self, param, stat_name, value_func, threshold,
last_k, reverse=False last_k, reverse=False
): ):
""" """
Set hyperparameter by a func, when a specific stat wasn't
decreasing/increasing enough in the last $k$ epochs.
Change param by `new_value = value_func(old_value)`, Change param by `new_value = value_func(old_value)`,
if : if :
min(stats) >= stats[0] - threshold, where min(stats) >= stats[0] - threshold, where
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: ptb.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import numpy as np
from ...utils import logger, get_dataset_path
from ...utils.fs import download
from ...utils.argtools import memoized_ignoreargs
from ..base import RNGDataFlow
try:
import tensorflow
from tensorflow.models.rnn.ptb import reader as tfreader
except ImportError:
logger.warn_dependency('PennTreeBank', 'tensorflow')
__all__ = []
else:
__all__ = ['PennTreeBank']
TRAIN_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt'
VALID_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt'
TEST_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt'
@memoized_ignoreargs
def get_raw_data(data_dir):
if not os.path.isfile(os.path.join(data_dir, 'ptb.train.txt')):
download(TRAIN_URL, data_dir)
download(VALID_URL, data_dir)
download(TEST_URL, data_dir)
# TODO these functions in TF might not be available in the future
word_to_id = tfreader._build_vocab(os.path.join(data_dir, 'ptb.train.txt'))
data3 = [tfreader._file_to_word_ids(os.path.join(data_dir, fname), word_to_id)
for fname in ['ptb.train.txt', 'ptb.valid.txt', 'ptb.test.txt']]
return data3, word_to_id
class PennTreeBank(RNGDataFlow):
def __init__(self, name, step_size, data_dir=None, shuffle=True):
"""
Generate PTB word sequences.
:param name: one of 'train', 'val', 'test'
"""
super(PennTreeBank, self).__init__()
if data_dir is None:
data_dir = get_dataset_path('ptb_data')
assert os.path.isdir(data_dir)
data3, word_to_id = get_raw_data(data_dir)
self.word_to_id = word_to_id
self.data = np.asarray(
data3[['train', 'val', 'test'].index(name)], dtype='int32')
self.step_size = step_size
self.shuffle = shuffle
def size(self):
return (self.data.shape[0] - 1) // self.step_size
def get_data(self):
sz = self.size()
if not self.shuffle:
starts = np.arange(self.data.shape[0] - 1)[::self.step_size]
assert starts.shape[0] >= sz
starts = starts[:sz]
else:
starts = self.rng.randint(0,
self.data.shape[0] - 1 - self.step_size,
size=(sz,))
for st in starts:
seq = self.data[st:st+self.step_size+1]
yield [seq[:-1],seq[1:]]
@staticmethod
def word_to_id():
data3, wti = get_raw_data()
return wti
if __name__ == '__main__':
D = PennTreeBank('train', 50)
D.reset_state()
for k in D.get_data():
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os import os
import random
import numpy as np import numpy as np
from six.moves import range from six.moves import range
......
...@@ -93,7 +93,7 @@ class ModelFromMetaGraph(ModelDesc): ...@@ -93,7 +93,7 @@ class ModelFromMetaGraph(ModelDesc):
tf.train.import_meta_graph(filename) tf.train.import_meta_graph(filename)
all_coll = tf.get_default_graph().get_all_collection_keys() all_coll = tf.get_default_graph().get_all_collection_keys()
for k in [INPUT_VARS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES, for k in [INPUT_VARS_KEY, tf.GraphKeys.TRAINABLE_VARIABLES,
tf.GraphKeys.VARIABLES]: tf.GraphKeys.GLOBAL_VARIABLES]:
assert k in all_coll, \ assert k in all_coll, \
"Collection {} not found in metagraph!".format(k) "Collection {} not found in metagraph!".format(k)
......
...@@ -113,7 +113,7 @@ class SaverRestore(SessionInit): ...@@ -113,7 +113,7 @@ class SaverRestore(SessionInit):
:param vars_available: varaible names available in the checkpoint, for existence checking :param vars_available: varaible names available in the checkpoint, for existence checking
:returns: a dict of {var_name: [var, var]} to restore :returns: a dict of {var_name: [var, var]} to restore
""" """
vars_to_restore = tf.all_variables() vars_to_restore = tf.global_variables()
var_dict = defaultdict(list) var_dict = defaultdict(list)
chkpt_vars_used = set() chkpt_vars_used = set()
for v in vars_to_restore: for v in vars_to_restore:
...@@ -150,7 +150,7 @@ class ParamRestore(SessionInit): ...@@ -150,7 +150,7 @@ class ParamRestore(SessionInit):
self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)} self.prms = {get_op_var_name(n)[1]: v for n, v in six.iteritems(param_dict)}
def _init(self, sess): def _init(self, sess):
variables = tf.get_collection(tf.GraphKeys.VARIABLES) variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
variable_names = set([get_savename_from_varname(k.name) for k in variables]) variable_names = set([get_savename_from_varname(k.name) for k in variables])
param_names = set(six.iterkeys(self.prms)) param_names = set(six.iterkeys(self.prms))
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
import inspect, six, functools import inspect, six, functools
import collections import collections
__all__ = [ 'map_arg', 'memoized', 'shape2d'] __all__ = [ 'map_arg', 'memoized', 'shape2d', 'memoized_ignoreargs']
def map_arg(**maps): def map_arg(**maps):
""" """
...@@ -54,6 +54,16 @@ class memoized(object): ...@@ -54,6 +54,16 @@ class memoized(object):
'''Support instance methods.''' '''Support instance methods.'''
return functools.partial(self.__call__, obj) return functools.partial(self.__call__, obj)
_MEMOIZED_NOARGS = {}
def memoized_ignoreargs(func):
h = hash(func) # make sure it is hashable. is it necessary?
def wrapper(*args):
if func not in _MEMOIZED_NOARGS:
res = func(*args)
_MEMOIZED_NOARGS[func] = res
return res
return _MEMOIZED_NOARGS[func]
return wrapper
#_GLOBAL_MEMOIZED_CACHE = dict() #_GLOBAL_MEMOIZED_CACHE = dict()
#def global_memoized(func): #def global_memoized(func):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment