Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e1edc9a4
Commit
e1edc9a4
authored
Jan 10, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update sphinx docs & add ptb example
parent
ff0f2cf7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
297 additions
and
7 deletions
+297
-7
docs/conf.py
docs/conf.py
+6
-1
examples/PennTreebank/PTB-LSTM.py
examples/PennTreebank/PTB-LSTM.py
+159
-0
examples/PennTreebank/reader.py
examples/PennTreebank/reader.py
+122
-0
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+10
-6
No files found.
docs/conf.py
View file @
e1edc9a4
...
@@ -52,6 +52,7 @@ needs_sphinx = '1.4'
...
@@ -52,6 +52,7 @@ needs_sphinx = '1.4'
# ones.
# ones.
extensions
=
[
extensions
=
[
'sphinx.ext.autodoc'
,
'sphinx.ext.autodoc'
,
'sphinx.ext.todo'
,
'sphinx.ext.napoleon'
,
'sphinx.ext.napoleon'
,
# 'sphinx.ext.coverage',
# 'sphinx.ext.coverage',
'sphinx.ext.mathjax'
,
'sphinx.ext.mathjax'
,
...
@@ -66,7 +67,11 @@ napoleon_include_special_with_doc = True
...
@@ -66,7 +67,11 @@ napoleon_include_special_with_doc = True
napoleon_numpy_docstring
=
False
napoleon_numpy_docstring
=
False
napoleon_use_rtype
=
False
napoleon_use_rtype
=
False
intersphinx_timeout
=
0.1
if
os
.
environ
.
get
(
'READTHEDOCS'
)
==
'True'
:
intersphinx_timeout
=
10
else
:
# skip this when building locally
intersphinx_timeout
=
0.1
intersphinx_mapping
=
{
'python'
:
(
'https://docs.python.org/3.4'
,
None
)}
intersphinx_mapping
=
{
'python'
:
(
'https://docs.python.org/3.4'
,
None
)}
# -------------------------
# -------------------------
...
...
examples/PennTreebank/PTB-LSTM.py
0 → 100755
View file @
e1edc9a4
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: ptb-lstm.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
numpy
as
np
import
os
import
argparse
from
tensorpack
import
*
from
tensorpack.tfutils.gradproc
import
*
from
tensorpack.utils
import
logger
,
get_dataset_path
from
tensorpack.utils.fs
import
download
from
tensorpack.utils.argtools
import
memoized_ignoreargs
import
reader
as
tfreader
from
reader
import
ptb_producer
rnn
=
tf
.
contrib
.
rnn
SEQ_LEN
=
35
HIDDEN_SIZE
=
650
NUM_LAYER
=
2
BATCH
=
20
DROPOUT
=
0.5
VOCAB_SIZE
=
None
TRAIN_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt'
VALID_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt'
TEST_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt'
@
memoized_ignoreargs
def
get_PennTreeBank
(
data_dir
=
None
):
if
data_dir
is
None
:
data_dir
=
get_dataset_path
(
'ptb_data'
)
if
not
os
.
path
.
isfile
(
os
.
path
.
join
(
data_dir
,
'ptb.train.txt'
)):
download
(
TRAIN_URL
,
data_dir
)
download
(
VALID_URL
,
data_dir
)
download
(
TEST_URL
,
data_dir
)
word_to_id
=
tfreader
.
_build_vocab
(
os
.
path
.
join
(
data_dir
,
'ptb.train.txt'
))
data3
=
[
np
.
asarray
(
tfreader
.
_file_to_word_ids
(
os
.
path
.
join
(
data_dir
,
fname
),
word_to_id
))
for
fname
in
[
'ptb.train.txt'
,
'ptb.valid.txt'
,
'ptb.test.txt'
]]
return
data3
,
word_to_id
class
Model
(
ModelDesc
):
def
_get_input_vars
(
self
):
return
[
InputVar
(
tf
.
int32
,
(
None
,
SEQ_LEN
),
'input'
),
InputVar
(
tf
.
int32
,
(
None
,
SEQ_LEN
),
'nextinput'
)]
def
_build_graph
(
self
,
input_vars
):
is_training
=
get_current_tower_context
()
.
is_training
input
,
nextinput
=
input_vars
initializer
=
tf
.
random_uniform_initializer
(
-
0.05
,
0.05
)
with
tf
.
variable_scope
(
'LSTM'
,
initializer
=
initializer
):
cell
=
rnn
.
BasicLSTMCell
(
num_units
=
HIDDEN_SIZE
,
forget_bias
=
0.0
)
if
is_training
:
cell
=
rnn
.
DropoutWrapper
(
cell
,
output_keep_prob
=
DROPOUT
)
cell
=
rnn
.
MultiRNNCell
([
cell
]
*
NUM_LAYER
)
def
get_v
(
n
):
return
tf
.
get_variable
(
n
,
[
BATCH
,
HIDDEN_SIZE
],
trainable
=
False
,
initializer
=
tf
.
constant_initializer
())
self
.
state
=
state_var
=
\
(
rnn
.
LSTMStateTuple
(
get_v
(
'c0'
),
get_v
(
'h0'
)),
rnn
.
LSTMStateTuple
(
get_v
(
'c1'
),
get_v
(
'h1'
)))
embeddingW
=
tf
.
get_variable
(
'embedding'
,
[
VOCAB_SIZE
,
HIDDEN_SIZE
],
initializer
=
initializer
)
input_feature
=
tf
.
nn
.
embedding_lookup
(
embeddingW
,
input
)
# B x seqlen x hiddensize
input_feature
=
Dropout
(
input_feature
,
DROPOUT
)
input_list
=
tf
.
unstack
(
input_feature
,
num
=
SEQ_LEN
,
axis
=
1
)
# seqlen x (Bxhidden)
outputs
,
last_state
=
rnn
.
static_rnn
(
cell
,
input_list
,
state_var
,
scope
=
'rnn'
)
# seqlen x (Bxrnnsize)
output
=
tf
.
reshape
(
tf
.
concat_v2
(
outputs
,
1
),
[
-
1
,
HIDDEN_SIZE
])
# (Bxseqlen) x hidden
logits
=
FullyConnected
(
'fc'
,
output
,
VOCAB_SIZE
,
nl
=
tf
.
identity
,
W_init
=
initializer
)
xent_loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
symbolic_functions
.
flatten
(
nextinput
))
update_state_op
=
tf
.
group
(
tf
.
assign
(
state_var
[
0
]
.
c
,
last_state
[
0
]
.
c
),
tf
.
assign
(
state_var
[
0
]
.
h
,
last_state
[
0
]
.
h
),
tf
.
assign
(
state_var
[
1
]
.
c
,
last_state
[
1
]
.
c
),
tf
.
assign
(
state_var
[
1
]
.
h
,
last_state
[
1
]
.
h
),
name
=
'update_state'
)
with
tf
.
control_dependencies
([
update_state_op
]):
self
.
cost
=
tf
.
truediv
(
tf
.
reduce_sum
(
xent_loss
),
tf
.
cast
(
BATCH
,
tf
.
float32
),
name
=
'cost'
)
# log-perplexity
perpl
=
tf
.
exp
(
self
.
cost
/
SEQ_LEN
,
name
=
'perplexity'
)
summary
.
add_moving_summary
(
perpl
)
def
reset_lstm_state
(
self
):
s
=
self
.
state
z
=
tf
.
zeros_like
(
s
[
0
]
.
c
)
return
tf
.
group
(
s
[
0
]
.
c
.
assign
(
z
),
s
[
0
]
.
h
.
assign
(
z
),
s
[
1
]
.
c
.
assign
(
z
),
s
[
1
]
.
h
.
assign
(
z
))
def
get_gradient_processor
(
self
):
return
[
GlobalNormClip
(
5
)]
def
get_config
():
logger
.
auto_set_dir
()
data3
,
wd2id
=
get_PennTreeBank
()
global
VOCAB_SIZE
VOCAB_SIZE
=
len
(
wd2id
)
step_per_epoch
=
(
data3
[
0
]
.
shape
[
0
]
//
BATCH
-
1
)
//
SEQ_LEN
train_data
=
TensorInput
(
lambda
:
ptb_producer
(
data3
[
0
],
BATCH
,
SEQ_LEN
),
step_per_epoch
)
val_data
=
TensorInput
(
lambda
:
ptb_producer
(
data3
[
1
],
BATCH
,
SEQ_LEN
),
(
data3
[
1
]
.
shape
[
0
]
//
BATCH
-
1
)
//
SEQ_LEN
)
M
=
Model
()
lr
=
symbolic_functions
.
get_scalar_var
(
'learning_rate'
,
1
,
summary
=
True
)
return
TrainConfig
(
data
=
train_data
,
model
=
M
,
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
lr
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
HyperParamSetterWithFunc
(
'learning_rate'
,
lambda
e
,
x
:
x
*
0.80
if
e
>
6
else
x
),
RunOp
(
lambda
:
M
.
reset_lstm_state
()),
FeedfreeInferenceRunner
(
val_data
,
[
ScalarStats
([
'cost'
])]),
CallbackFactory
(
trigger_epoch
=
lambda
self
:
self
.
trainer
.
write_scalar_summary
(
'validation_perplexity'
,
np
.
exp
(
self
.
trainer
.
stat_holder
.
get_stat_now
(
'validation_cost'
)
/
SEQ_LEN
))),
RunOp
(
lambda
:
M
.
reset_lstm_state
()),
]),
max_epoch
=
70
,
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
args
=
parser
.
parse_args
()
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
SimpleFeedfreeTrainer
(
config
)
.
train
()
examples/PennTreebank/reader.py
0 → 100644
View file @
e1edc9a4
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for parsing PTB text files."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
os
import
tensorflow
as
tf
def
_read_words
(
filename
):
with
tf
.
gfile
.
GFile
(
filename
,
"r"
)
as
f
:
return
f
.
read
()
.
decode
(
"utf-8"
)
.
replace
(
"
\n
"
,
"<eos>"
)
.
split
()
def
_build_vocab
(
filename
):
data
=
_read_words
(
filename
)
counter
=
collections
.
Counter
(
data
)
count_pairs
=
sorted
(
counter
.
items
(),
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
words
,
_
=
list
(
zip
(
*
count_pairs
))
word_to_id
=
dict
(
zip
(
words
,
range
(
len
(
words
))))
return
word_to_id
def
_file_to_word_ids
(
filename
,
word_to_id
):
data
=
_read_words
(
filename
)
return
[
word_to_id
[
word
]
for
word
in
data
if
word
in
word_to_id
]
def
ptb_raw_data
(
data_path
=
None
):
"""Load PTB raw data from data directory "data_path".
Reads PTB text files, converts strings to integer ids,
and performs mini-batching of the inputs.
The PTB dataset comes from Tomas Mikolov's webpage:
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
Args:
data_path: string path to the directory where simple-examples.tgz has
been extracted.
Returns:
tuple (train_data, valid_data, test_data, vocabulary)
where each of the data objects can be passed to PTBIterator.
"""
train_path
=
os
.
path
.
join
(
data_path
,
"ptb.train.txt"
)
valid_path
=
os
.
path
.
join
(
data_path
,
"ptb.valid.txt"
)
test_path
=
os
.
path
.
join
(
data_path
,
"ptb.test.txt"
)
word_to_id
=
_build_vocab
(
train_path
)
train_data
=
_file_to_word_ids
(
train_path
,
word_to_id
)
valid_data
=
_file_to_word_ids
(
valid_path
,
word_to_id
)
test_data
=
_file_to_word_ids
(
test_path
,
word_to_id
)
vocabulary
=
len
(
word_to_id
)
return
train_data
,
valid_data
,
test_data
,
vocabulary
def
ptb_producer
(
raw_data
,
batch_size
,
num_steps
,
name
=
None
):
"""Iterate on the raw PTB data.
This chunks up raw_data into batches of examples and returns Tensors that
are drawn from these batches.
Args:
raw_data: one of the raw data outputs from ptb_raw_data.
batch_size: int, the batch size.
num_steps: int, the number of unrolls.
name: the name of this operation (optional).
Returns:
A pair of Tensors, each shaped [batch_size, num_steps]. The second element
of the tuple is the same data time-shifted to the right by one.
Raises:
tf.errors.InvalidArgumentError: if batch_size or num_steps are too high.
"""
with
tf
.
name_scope
(
name
,
"PTBProducer"
,
[
raw_data
,
batch_size
,
num_steps
]):
raw_data
=
tf
.
convert_to_tensor
(
raw_data
,
name
=
"raw_data"
,
dtype
=
tf
.
int32
)
data_len
=
tf
.
size
(
raw_data
)
batch_len
=
data_len
//
batch_size
data
=
tf
.
reshape
(
raw_data
[
0
:
batch_size
*
batch_len
],
[
batch_size
,
batch_len
])
epoch_size
=
(
batch_len
-
1
)
//
num_steps
assertion
=
tf
.
assert_positive
(
epoch_size
,
message
=
"epoch_size == 0, decrease batch_size or num_steps"
)
with
tf
.
control_dependencies
([
assertion
]):
epoch_size
=
tf
.
identity
(
epoch_size
,
name
=
"epoch_size"
)
i
=
tf
.
train
.
range_input_producer
(
epoch_size
,
shuffle
=
False
)
.
dequeue
()
x
=
tf
.
strided_slice
(
data
,
[
0
,
i
*
num_steps
],
[
batch_size
,
(
i
+
1
)
*
num_steps
])
x
.
set_shape
([
batch_size
,
num_steps
])
y
=
tf
.
strided_slice
(
data
,
[
0
,
i
*
num_steps
+
1
],
[
batch_size
,
(
i
+
1
)
*
num_steps
+
1
])
y
.
set_shape
([
batch_size
,
num_steps
])
return
x
,
y
tensorpack/dataflow/common.py
View file @
e1edc9a4
...
@@ -10,10 +10,10 @@ from six.moves import range, map
...
@@ -10,10 +10,10 @@ from six.moves import range, map
from
.base
import
DataFlow
,
ProxyDataFlow
,
RNGDataFlow
from
.base
import
DataFlow
,
ProxyDataFlow
,
RNGDataFlow
from
..utils
import
logger
,
get_tqdm
from
..utils
import
logger
,
get_tqdm
__all__
=
[
'TestDataSpeed'
,
'BatchData'
,
'BatchDataByShape'
,
'FixedSizeData'
,
'MapData'
,
__all__
=
[
'TestDataSpeed'
,
'
PrintData'
,
'
BatchData'
,
'BatchDataByShape'
,
'FixedSizeData'
,
'MapData'
,
'MapDataComponent'
,
'RepeatedData'
,
'RandomChooseData'
,
'MapDataComponent'
,
'RepeatedData'
,
'RandomChooseData'
,
'RandomMixData'
,
'JoinData'
,
'ConcatData'
,
'SelectComponent'
,
'RandomMixData'
,
'JoinData'
,
'ConcatData'
,
'SelectComponent'
,
'LocallyShuffleData'
,
'PrintData'
]
'LocallyShuffleData'
]
class
TestDataSpeed
(
ProxyDataFlow
):
class
TestDataSpeed
(
ProxyDataFlow
):
...
@@ -479,12 +479,14 @@ class PrintData(ProxyDataFlow):
...
@@ -479,12 +479,14 @@ class PrintData(ProxyDataFlow):
Behave like an identity mapping but print shapes of produced datapoints once during construction.
Behave like an identity mapping but print shapes of produced datapoints once during construction.
Attributes:
Attributes:
label (str): label to identify the data when using this debugging on multiple places
label (str): label to identify the data when using this debugging on multiple places
.
num (int): number of iterations
num (int): number of iterations
Example:
Example:
To enable this debugging output, you should place it somewhere in your dataflow like
To enable this debugging output, you should place it somewhere in your dataflow like
.. code-block:: python
def get_data():
def get_data():
ds = CaffeLMDB('path/to/lmdb')
ds = CaffeLMDB('path/to/lmdb')
ds = SomeInscrutableMappings(ds)
ds = SomeInscrutableMappings(ds)
...
@@ -494,6 +496,8 @@ class PrintData(ProxyDataFlow):
...
@@ -494,6 +496,8 @@ class PrintData(ProxyDataFlow):
The output looks like:
The output looks like:
.. code-block:: none
[0110 09:22:21 @common.py:589] DataFlow Info:
[0110 09:22:21 @common.py:589] DataFlow Info:
datapoint 0<2 with 4 elements consists of
datapoint 0<2 with 4 elements consists of
dp 0: is float of shape () with range [0.0816501893251]
dp 0: is float of shape () with range [0.0816501893251]
...
@@ -511,7 +515,7 @@ class PrintData(ProxyDataFlow):
...
@@ -511,7 +515,7 @@ class PrintData(ProxyDataFlow):
"""
"""
Args:
Args:
ds (DataFlow): input DataFlow.
ds (DataFlow): input DataFlow.
num (int): number of dataflow points.
num (int): number of dataflow points
to print
.
label (str, optional): label to identify this call, when using multiple times
label (str, optional): label to identify this call, when using multiple times
"""
"""
super
(
PrintData
,
self
)
.
__init__
(
ds
)
super
(
PrintData
,
self
)
.
__init__
(
ds
)
...
@@ -519,7 +523,7 @@ class PrintData(ProxyDataFlow):
...
@@ -519,7 +523,7 @@ class PrintData(ProxyDataFlow):
self
.
label
=
label
self
.
label
=
label
self
.
print_info
()
self
.
print_info
()
def
analyze_input_data
(
self
,
el
,
k
,
depth
=
1
):
def
_
analyze_input_data
(
self
,
el
,
k
,
depth
=
1
):
"""
"""
Gather useful debug information from a datapoint.
Gather useful debug information from a datapoint.
...
@@ -591,7 +595,7 @@ class PrintData(ProxyDataFlow):
...
@@ -591,7 +595,7 @@ class PrintData(ProxyDataFlow):
if
isinstance
(
dummy
,
list
):
if
isinstance
(
dummy
,
list
):
msg
.
append
(
"datapoint
%
i<
%
i with
%
i elements consists of"
%
(
i
,
self
.
num
,
len
(
dummy
)))
msg
.
append
(
"datapoint
%
i<
%
i with
%
i elements consists of"
%
(
i
,
self
.
num
,
len
(
dummy
)))
for
k
,
entry
in
enumerate
(
dummy
):
for
k
,
entry
in
enumerate
(
dummy
):
msg
.
append
(
self
.
analyze_input_data
(
entry
,
k
))
msg
.
append
(
self
.
_
analyze_input_data
(
entry
,
k
))
label
=
""
if
self
.
label
is
""
else
" ("
+
self
.
label
+
")"
label
=
""
if
self
.
label
is
""
else
" ("
+
self
.
label
+
")"
logger
.
info
(
colored
(
"DataFlow Info
%
s:"
%
label
,
'cyan'
)
+
'
\n
'
.
join
(
msg
))
logger
.
info
(
colored
(
"DataFlow Info
%
s:"
%
label
,
'cyan'
)
+
'
\n
'
.
join
(
msg
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment