Commit 453b7c63 authored by Yuxin Wu's avatar Yuxin Wu

remove ptb, because ptb data been removed from tensorflow recently

parent cb53e6c0
...@@ -77,7 +77,7 @@ class InferenceRunner(Callback): ...@@ -77,7 +77,7 @@ class InferenceRunner(Callback):
self.infs = infs self.infs = infs
for v in self.infs: for v in self.infs:
assert isinstance(v, Inferencer), v assert isinstance(v, Inferencer), v
self.input_tensors = input_tensors self.input_tensors = input_tensors # names actually
def _setup_graph(self): def _setup_graph(self):
self._find_input_tensors() # these are all tensor names self._find_input_tensors() # these are all tensor names
...@@ -141,7 +141,7 @@ class InferenceRunner(Callback): ...@@ -141,7 +141,7 @@ class InferenceRunner(Callback):
class FeedfreeInferenceRunner(Callback): class FeedfreeInferenceRunner(Callback):
IOTensor = namedtuple('IOTensor', ['index', 'isOutput']) IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, input, infs, input_tensors=None): def __init__(self, input, infs, input_names=None):
assert isinstance(input, FeedfreeInput), input assert isinstance(input, FeedfreeInput), input
self._input_data = input self._input_data = input
if not isinstance(infs, list): if not isinstance(infs, list):
...@@ -150,7 +150,9 @@ class FeedfreeInferenceRunner(Callback): ...@@ -150,7 +150,9 @@ class FeedfreeInferenceRunner(Callback):
self.infs = infs self.infs = infs
for v in self.infs: for v in self.infs:
assert isinstance(v, Inferencer), v assert isinstance(v, Inferencer), v
self.input_tensor_names = input_tensors if input_names is not None:
assert isinstance(input_names, list)
self._input_names = input_names
def _setup_graph(self): def _setup_graph(self):
self._find_input_tensors() # tensors self._find_input_tensors() # tensors
...@@ -162,17 +164,20 @@ class FeedfreeInferenceRunner(Callback): ...@@ -162,17 +164,20 @@ class FeedfreeInferenceRunner(Callback):
# only 1 prediction tower will be used for inference # only 1 prediction tower will be used for inference
self._input_tensors = self._input_data.get_input_tensors() self._input_tensors = self._input_data.get_input_tensors()
model_placehdrs = self.trainer.model.get_reuse_placehdrs() model_placehdrs = self.trainer.model.get_reuse_placehdrs()
if self.input_names is not None:
assert len(self.input_names) == len(self._input_tensors), \
"[FeedfreeInferenceRunner] input_names must have the same length as the input data."
# XXX incorrect
self._input_tensors = [k for idx, k in enumerate(self._input_tensors)
if model_placehdrs[idx].name in self.input_names]
assert len(self._input_tensors) == len(self.input_names), \
"[FeedfreeInferenceRunner] all input_tensors must be defined as InputVar in the Model!"
assert len(self._input_tensors) == len(model_placehdrs), \ assert len(self._input_tensors) == len(model_placehdrs), \
"FeedfreeInput doesn't produce correct number of output tensors" "FeedfreeInput doesn't produce correct number of output tensors"
if self.input_tensor_names is not None:
assert isinstance(self.input_tensor_names, list)
self._input_tensors = [k for idx, k in enumerate(self._input_tensors)
if model_placehdrs[idx].name in self.input_tensor_names]
assert len(self._input_tensors) == len(self.input_tensor_names), \
"names of input tensors are not defined in the Model"
def _find_output_tensors(self): def _find_output_tensors(self):
# doesn't support output an input tensor # TODO doesn't support output an input tensor
dispatcer = OutputTensorDispatcer() dispatcer = OutputTensorDispatcer()
for inf in self.infs: for inf in self.infs:
dispatcer.add_entry(inf.get_output_tensors()) dispatcer.add_entry(inf.get_output_tensors())
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: ptb.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import numpy as np
from ...utils import logger, get_dataset_path
from ...utils.fs import download
from ...utils.argtools import memoized_ignoreargs
try:
from tensorflow.models.rnn.ptb import reader as tfreader
except ImportError:
logger.warn_dependency('PennTreeBank', 'tensorflow.models.rnn.ptb.reader')
__all__ = []
else:
__all__ = ['get_PennTreeBank']
TRAIN_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt'
VALID_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.valid.txt'
TEST_URL = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.test.txt'
@memoized_ignoreargs
def get_PennTreeBank(data_dir=None):
if data_dir is None:
data_dir = get_dataset_path('ptb_data')
if not os.path.isfile(os.path.join(data_dir, 'ptb.train.txt')):
download(TRAIN_URL, data_dir)
download(VALID_URL, data_dir)
download(TEST_URL, data_dir)
# TODO these functions in TF might not be available in the future
word_to_id = tfreader._build_vocab(os.path.join(data_dir, 'ptb.train.txt'))
data3 = [np.asarray(tfreader._file_to_word_ids(os.path.join(data_dir, fname), word_to_id))
for fname in ['ptb.train.txt', 'ptb.valid.txt', 'ptb.test.txt']]
return data3, word_to_id
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment