Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
453b7c63
Commit
453b7c63
authored
Jan 08, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove ptb, because ptb data been removed from tensorflow recently
parent
cb53e6c0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
48 deletions
+15
-48
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+15
-10
tensorpack/dataflow/dataset/ptb.py
tensorpack/dataflow/dataset/ptb.py
+0
-38
No files found.
tensorpack/callbacks/inference_runner.py
View file @
453b7c63
...
@@ -77,7 +77,7 @@ class InferenceRunner(Callback):
...
@@ -77,7 +77,7 @@ class InferenceRunner(Callback):
self
.
infs
=
infs
self
.
infs
=
infs
for
v
in
self
.
infs
:
for
v
in
self
.
infs
:
assert
isinstance
(
v
,
Inferencer
),
v
assert
isinstance
(
v
,
Inferencer
),
v
self
.
input_tensors
=
input_tensors
self
.
input_tensors
=
input_tensors
# names actually
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_find_input_tensors
()
# these are all tensor names
self
.
_find_input_tensors
()
# these are all tensor names
...
@@ -141,7 +141,7 @@ class InferenceRunner(Callback):
...
@@ -141,7 +141,7 @@ class InferenceRunner(Callback):
class
FeedfreeInferenceRunner
(
Callback
):
class
FeedfreeInferenceRunner
(
Callback
):
IOTensor
=
namedtuple
(
'IOTensor'
,
[
'index'
,
'isOutput'
])
IOTensor
=
namedtuple
(
'IOTensor'
,
[
'index'
,
'isOutput'
])
def
__init__
(
self
,
input
,
infs
,
input_
tensor
s
=
None
):
def
__init__
(
self
,
input
,
infs
,
input_
name
s
=
None
):
assert
isinstance
(
input
,
FeedfreeInput
),
input
assert
isinstance
(
input
,
FeedfreeInput
),
input
self
.
_input_data
=
input
self
.
_input_data
=
input
if
not
isinstance
(
infs
,
list
):
if
not
isinstance
(
infs
,
list
):
...
@@ -150,7 +150,9 @@ class FeedfreeInferenceRunner(Callback):
...
@@ -150,7 +150,9 @@ class FeedfreeInferenceRunner(Callback):
self
.
infs
=
infs
self
.
infs
=
infs
for
v
in
self
.
infs
:
for
v
in
self
.
infs
:
assert
isinstance
(
v
,
Inferencer
),
v
assert
isinstance
(
v
,
Inferencer
),
v
self
.
input_tensor_names
=
input_tensors
if
input_names
is
not
None
:
assert
isinstance
(
input_names
,
list
)
self
.
_input_names
=
input_names
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_find_input_tensors
()
# tensors
self
.
_find_input_tensors
()
# tensors
...
@@ -162,17 +164,20 @@ class FeedfreeInferenceRunner(Callback):
...
@@ -162,17 +164,20 @@ class FeedfreeInferenceRunner(Callback):
# only 1 prediction tower will be used for inference
# only 1 prediction tower will be used for inference
self
.
_input_tensors
=
self
.
_input_data
.
get_input_tensors
()
self
.
_input_tensors
=
self
.
_input_data
.
get_input_tensors
()
model_placehdrs
=
self
.
trainer
.
model
.
get_reuse_placehdrs
()
model_placehdrs
=
self
.
trainer
.
model
.
get_reuse_placehdrs
()
if
self
.
input_names
is
not
None
:
assert
len
(
self
.
input_names
)
==
len
(
self
.
_input_tensors
),
\
"[FeedfreeInferenceRunner] input_names must have the same length as the input data."
# XXX incorrect
self
.
_input_tensors
=
[
k
for
idx
,
k
in
enumerate
(
self
.
_input_tensors
)
if
model_placehdrs
[
idx
]
.
name
in
self
.
input_names
]
assert
len
(
self
.
_input_tensors
)
==
len
(
self
.
input_names
),
\
"[FeedfreeInferenceRunner] all input_tensors must be defined as InputVar in the Model!"
assert
len
(
self
.
_input_tensors
)
==
len
(
model_placehdrs
),
\
assert
len
(
self
.
_input_tensors
)
==
len
(
model_placehdrs
),
\
"FeedfreeInput doesn't produce correct number of output tensors"
"FeedfreeInput doesn't produce correct number of output tensors"
if
self
.
input_tensor_names
is
not
None
:
assert
isinstance
(
self
.
input_tensor_names
,
list
)
self
.
_input_tensors
=
[
k
for
idx
,
k
in
enumerate
(
self
.
_input_tensors
)
if
model_placehdrs
[
idx
]
.
name
in
self
.
input_tensor_names
]
assert
len
(
self
.
_input_tensors
)
==
len
(
self
.
input_tensor_names
),
\
"names of input tensors are not defined in the Model"
def
_find_output_tensors
(
self
):
def
_find_output_tensors
(
self
):
# doesn't support output an input tensor
#
TODO
doesn't support output an input tensor
dispatcer
=
OutputTensorDispatcer
()
dispatcer
=
OutputTensorDispatcer
()
for
inf
in
self
.
infs
:
for
inf
in
self
.
infs
:
dispatcer
.
add_entry
(
inf
.
get_output_tensors
())
dispatcer
.
add_entry
(
inf
.
get_output_tensors
())
...
...
tensorpack/dataflow/dataset/ptb.py
deleted
100644 → 0
View file @
cb53e6c0
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: ptb.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
import
numpy
as
np
from
...utils
import
logger
,
get_dataset_path
from
...utils.fs
import
download
from
...utils.argtools
import
memoized_ignoreargs
try
:
from
tensorflow.models.rnn.ptb
import
reader
as
tfreader
except
ImportError
:
logger
.
warn_dependency
(
'PennTreeBank'
,
'tensorflow.models.rnn.ptb.reader'
)
__all__
=
[]
else
:
__all__
=
[
'get_PennTreeBank'
]
TRAIN_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt'
VALID_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt'
TEST_URL
=
'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt'
@
memoized_ignoreargs
def
get_PennTreeBank
(
data_dir
=
None
):
if
data_dir
is
None
:
data_dir
=
get_dataset_path
(
'ptb_data'
)
if
not
os
.
path
.
isfile
(
os
.
path
.
join
(
data_dir
,
'ptb.train.txt'
)):
download
(
TRAIN_URL
,
data_dir
)
download
(
VALID_URL
,
data_dir
)
download
(
TEST_URL
,
data_dir
)
# TODO these functions in TF might not be available in the future
word_to_id
=
tfreader
.
_build_vocab
(
os
.
path
.
join
(
data_dir
,
'ptb.train.txt'
))
data3
=
[
np
.
asarray
(
tfreader
.
_file_to_word_ids
(
os
.
path
.
join
(
data_dir
,
fname
),
word_to_id
))
for
fname
in
[
'ptb.train.txt'
,
'ptb.valid.txt'
,
'ptb.test.txt'
]]
return
data3
,
word_to_id
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment