Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
48ef46aa
Commit
48ef46aa
authored
Dec 15, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
delete PTB
parent
b6b1adae
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
48 deletions
+5
-48
tensorpack/dataflow/dataset/ptb.py
tensorpack/dataflow/dataset/ptb.py
+5
-48
No files found.
tensorpack/dataflow/dataset/ptb.py
View file @
48ef46aa
...
...
@@ -17,7 +17,7 @@ except ImportError:
logger
.
warn_dependency
(
'PennTreeBank'
,
'tensorflow'
)
__all__
=
[]
else
:
__all__
=
[
'PennTreeBank'
]
__all__
=
[
'
get_
PennTreeBank'
]
TRAIN_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt'
...
...
@@ -25,59 +25,16 @@ VALID_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.val
TEST_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt'
@
memoized_ignoreargs
def
get_raw_data
(
data_dir
):
def
get_PennTreeBank
(
data_dir
=
None
):
if
data_dir
is
None
:
data_dir
=
get_dataset_path
(
'ptb_data'
)
if
not
os
.
path
.
isfile
(
os
.
path
.
join
(
data_dir
,
'ptb.train.txt'
)):
download
(
TRAIN_URL
,
data_dir
)
download
(
VALID_URL
,
data_dir
)
download
(
TEST_URL
,
data_dir
)
# TODO these functions in TF might not be available in the future
word_to_id
=
tfreader
.
_build_vocab
(
os
.
path
.
join
(
data_dir
,
'ptb.train.txt'
))
data3
=
[
tfreader
.
_file_to_word_ids
(
os
.
path
.
join
(
data_dir
,
fname
),
word_to_id
)
data3
=
[
np
.
asarray
(
tfreader
.
_file_to_word_ids
(
os
.
path
.
join
(
data_dir
,
fname
),
word_to_id
)
)
for
fname
in
[
'ptb.train.txt'
,
'ptb.valid.txt'
,
'ptb.test.txt'
]]
return
data3
,
word_to_id
class
PennTreeBank
(
RNGDataFlow
):
def
__init__
(
self
,
name
,
step_size
,
data_dir
=
None
,
shuffle
=
True
):
"""
Generate PTB word sequences.
:param name: one of 'train', 'val', 'test'
"""
super
(
PennTreeBank
,
self
)
.
__init__
()
if
data_dir
is
None
:
data_dir
=
get_dataset_path
(
'ptb_data'
)
data3
,
word_to_id
=
get_raw_data
(
data_dir
)
self
.
word_to_id
=
word_to_id
self
.
data
=
np
.
asarray
(
data3
[[
'train'
,
'val'
,
'test'
]
.
index
(
name
)],
dtype
=
'int32'
)
self
.
step_size
=
step_size
self
.
shuffle
=
shuffle
def
size
(
self
):
return
(
self
.
data
.
shape
[
0
]
-
1
)
//
self
.
step_size
def
get_data
(
self
):
sz
=
self
.
size
()
if
not
self
.
shuffle
:
starts
=
np
.
arange
(
self
.
data
.
shape
[
0
]
-
1
)[::
self
.
step_size
]
assert
starts
.
shape
[
0
]
>=
sz
starts
=
starts
[:
sz
]
else
:
starts
=
self
.
rng
.
randint
(
0
,
self
.
data
.
shape
[
0
]
-
1
-
self
.
step_size
,
size
=
(
sz
,))
for
st
in
starts
:
seq
=
self
.
data
[
st
:
st
+
self
.
step_size
+
1
]
yield
[
seq
[:
-
1
],
seq
[
1
:]]
@
staticmethod
def
word_to_id
():
data3
,
wti
=
get_raw_data
()
return
wti
if
__name__
==
'__main__'
:
D
=
PennTreeBank
(
'train'
,
50
)
D
.
reset_state
()
for
k
in
D
.
get_data
():
import
IPython
as
IP
;
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment