Commit 7ce3d7ab authored by Yuxin Wu's avatar Yuxin Wu

update docs. use INTEPR_LINEAR by default.

parent 72c97317
This diff is collapsed.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: ptb-lstm.py
# File: PTB-LSTM.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
......@@ -53,11 +53,10 @@ class Model(ModelDesc):
input, nextinput = input_vars
initializer = tf.random_uniform_initializer(-0.05, 0.05)
with tf.variable_scope('LSTM', initializer=initializer):
cell = rnn.BasicLSTMCell(num_units=HIDDEN_SIZE, forget_bias=0.0)
if is_training:
cell = rnn.DropoutWrapper(cell, output_keep_prob=DROPOUT)
cell = rnn.MultiRNNCell([cell] * NUM_LAYER)
cell = rnn.BasicLSTMCell(num_units=HIDDEN_SIZE, forget_bias=0.0)
if is_training:
cell = rnn.DropoutWrapper(cell, output_keep_prob=DROPOUT)
cell = rnn.MultiRNNCell([cell] * NUM_LAYER)
def get_v(n):
return tf.get_variable(n, [BATCH, HIDDEN_SIZE],
......@@ -71,13 +70,13 @@ class Model(ModelDesc):
input_feature = tf.nn.embedding_lookup(embeddingW, input) # B x seqlen x hiddensize
input_feature = Dropout(input_feature, DROPOUT)
input_list = tf.unstack(input_feature, num=SEQ_LEN, axis=1) # seqlen x (Bxhidden)
outputs, last_state = rnn.static_rnn(cell, input_list, state_var, scope='rnn')
with tf.variable_scope('LSTM', initializer=initializer):
input_list = tf.unstack(input_feature, num=SEQ_LEN, axis=1) # seqlen x (Bxhidden)
outputs, last_state = rnn.static_rnn(cell, input_list, state_var, scope='rnn')
# seqlen x (Bxrnnsize)
output = tf.reshape(tf.concat_v2(outputs, 1), [-1, HIDDEN_SIZE]) # (Bxseqlen) x hidden
logits = FullyConnected('fc', output, VOCAB_SIZE, nl=tf.identity,
W_init=initializer)
logits = FullyConnected('fc', output, VOCAB_SIZE, nl=tf.identity, W_init=initializer, b_init=initializer)
xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=symbolic_functions.flatten(nextinput))
......
......@@ -33,4 +33,4 @@ Note to contributors:
Example needs to satisfy one of the following:
+ Reproduce performance of a published or well-known paper.
+ Get state-of-the-art performance on some task.
+ Illustrate a new way of using the library that are currently not covered.
+ Illustrate a new way of using the library that is currently not covered.
......@@ -499,12 +499,12 @@ class PrintData(ProxyDataFlow):
.. code-block:: none
[0110 09:22:21 @common.py:589] DataFlow Info:
datapoint 0<2 with 4 elements consists of
datapoint 0<2 with 4 components consists of
dp 0: is float of shape () with range [0.0816501893251]
dp 1: is ndarray of shape (64, 64) with range [0.1300, 0.6895]
dp 2: is ndarray of shape (64, 64) with range [-1.2248, 1.2177]
dp 3: is ndarray of shape (9, 9) with range [-0.6045, 0.6045]
datapoint 1<2 with 4 elements consists of
datapoint 1<2 with 4 components consists of
dp 0: is float of shape () with range [5.88252075399]
dp 1: is ndarray of shape (64, 64) with range [0.0072, 0.9371]
dp 2: is ndarray of shape (64, 64) with range [-0.9011, 0.8491]
......@@ -539,7 +539,7 @@ class PrintData(ProxyDataFlow):
string: debug message
"""
if isinstance(el, list):
return "%s is list of %i elements " % (" " * (depth * 2), len(el))
return "%s is list of %i elements" % (" " * (depth * 2), len(el))
else:
el_type = el.__class__.__name__
......@@ -593,7 +593,7 @@ class PrintData(ProxyDataFlow):
msg = [""]
for i, dummy in enumerate(cutoff(ds.get_data(), self.num)):
if isinstance(dummy, list):
msg.append("datapoint %i<%i with %i elements consists of" % (i, self.num, len(dummy)))
msg.append("datapoint %i<%i with %i components consists of" % (i, self.num, len(dummy)))
for k, entry in enumerate(dummy):
msg.append(self._analyze_input_data(entry, k))
label = "" if self.label is "" else " (" + self.label + ")"
......
......@@ -17,7 +17,7 @@ class ImageFromFile(RNGDataFlow):
"""
Args:
files (list): list of file paths.
channel (int): 1 or 3. Produce RGB images if channel==3.
channel (int): 1 or 3. Will convert grayscale to RGB images if channel==3.
resize (tuple): (h, w). If given, resize the image.
"""
assert len(files), "No image files given to ImageFromFile!"
......
......@@ -14,7 +14,7 @@ class Rotation(ImageAugmentor):
""" Random rotate the image w.r.t a random center"""
def __init__(self, max_deg, center_range=(0, 1),
interp=cv2.INTER_CUBIC,
interp=cv2.INTER_LINEAR,
border=cv2.BORDER_REPLICATE):
"""
Args:
......@@ -43,7 +43,7 @@ class RotationAndCropValid(ImageAugmentor):
Note that this will produce images of different shapes.
"""
def __init__(self, max_deg, interp=cv2.INTER_CUBIC):
def __init__(self, max_deg, interp=cv2.INTER_LINEAR):
"""
Args:
max_deg (float): max abs value of the rotation degree (in angle).
......
......@@ -49,7 +49,7 @@ class Flip(ImageAugmentor):
class Resize(ImageAugmentor):
""" Resize image to a target size"""
def __init__(self, shape, interp=cv2.INTER_CUBIC):
def __init__(self, shape, interp=cv2.INTER_LINEAR):
"""
Args:
shape: (h, w) tuple or a int
......@@ -85,7 +85,7 @@ class ResizeShortestEdge(ImageAugmentor):
h, w = img.shape[:2]
scale = self.size / min(h, w)
desSize = map(int, [scale * w, scale * h])
ret = cv2.resize(img, tuple(desSize), interpolation=cv2.INTER_CUBIC)
ret = cv2.resize(img, tuple(desSize), interpolation=cv2.INTER_LINEAR)
if img.ndim == 3 and ret.ndim == 2:
ret = ret[:, :, np.newaxis]
return ret
......@@ -95,7 +95,7 @@ class RandomResize(ImageAugmentor):
""" Randomly rescale w and h of the image"""
def __init__(self, xrange, yrange, minimum=(0, 0), aspect_ratio_thres=0.15,
interp=cv2.INTER_CUBIC):
interp=cv2.INTER_LINEAR):
"""
Args:
xrange (tuple): (min, max) range of scaling ratio for w
......
......@@ -20,6 +20,11 @@ def replace_get_variable(fn):
Returns:
a context where ``tf.get_variable`` and
``variable_scope.get_variable`` are replaced with ``fn``.
Note that originally ``tf.get_variable ==
tensorflow.python.ops.variable_scope.get_variable``. But some code such as
some in `rnn_cell/`, uses the latter one to get variable, therefore both
need to be replaced.
"""
old_getv = tf.get_variable
old_vars_getv = variable_scope.get_variable
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment