Commit 2d539fea authored by Yuxin Wu's avatar Yuxin Wu

update docs/logging/deprecation

parent 7b4bc9cf
...@@ -18,14 +18,6 @@ tensorpack.utils.concurrency module ...@@ -18,14 +18,6 @@ tensorpack.utils.concurrency module
:show-inheritance: :show-inheritance:
tensorpack.utils.discretize module
----------------------------------
.. automodule:: tensorpack.utils.discretize
:members:
:undoc-members:
:show-inheritance:
tensorpack.utils.fs module tensorpack.utils.fs module
-------------------------- --------------------------
...@@ -58,14 +50,6 @@ tensorpack.utils.logger module ...@@ -58,14 +50,6 @@ tensorpack.utils.logger module
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
tensorpack.utils.lut module
---------------------------
.. automodule:: tensorpack.utils.lut
:members:
:undoc-members:
:show-inheritance:
tensorpack.utils.rect module tensorpack.utils.rect module
---------------------------- ----------------------------
......
...@@ -14,7 +14,6 @@ from six.moves import map, range ...@@ -14,7 +14,6 @@ from six.moves import map, range
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.gradproc import GlobalNormClip from tensorpack.tfutils.gradproc import GlobalNormClip
from tensorpack.utils.lut import LookUpTable
from tensorpack.utils.globvars import globalns as param from tensorpack.utils.globvars import globalns as param
import tensorflow as tf import tensorflow as tf
...@@ -50,16 +49,16 @@ class CharRNNData(RNGDataFlow): ...@@ -50,16 +49,16 @@ class CharRNNData(RNGDataFlow):
print(sorted(self.chars)) print(sorted(self.chars))
self.vocab_size = len(self.chars) self.vocab_size = len(self.chars)
param.vocab_size = self.vocab_size param.vocab_size = self.vocab_size
self.lut = LookUpTable(self.chars) char2idx = {c: i for i, c in enumerate(self.chars)}
self.whole_seq = np.array(list(map(self.lut.get_idx, data)), dtype='int32') self.whole_seq = np.array([char2idx[c] for c in data], dtype='int32')
logger.info("Corpus loaded. Vocab size: {}".format(self.vocab_size)) logger.info("Corpus loaded. Vocab size: {}".format(self.vocab_size))
def size(self): def size(self):
return self._size return self._size
def get_data(self): def get_data(self):
random_starts = self.rng.randint(0, random_starts = self.rng.randint(
self.whole_seq.shape[0] - self.seq_length - 1, (self._size,)) 0, self.whole_seq.shape[0] - self.seq_length - 1, (self._size,))
for st in random_starts: for st in random_starts:
seq = self.whole_seq[st:st + self.seq_length + 1] seq = self.whole_seq[st:st + self.seq_length + 1]
yield [seq[:-1], seq[1:]] yield [seq[:-1], seq[1:]]
......
...@@ -131,9 +131,7 @@ class Model(ModelDesc): ...@@ -131,9 +131,7 @@ class Model(ModelDesc):
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True) return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
def get_data(train_or_test, fake=False): def get_data(train_or_test):
if fake:
return FakeData([[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8')
isTrain = train_or_test == 'train' isTrain = train_or_test == 'train'
datadir = args.data datadir = args.data
...@@ -197,8 +195,12 @@ def get_data(train_or_test, fake=False): ...@@ -197,8 +195,12 @@ def get_data(train_or_test, fake=False):
def get_config(fake=False, data_format='NCHW'): def get_config(fake=False, data_format='NCHW'):
dataset_train = get_data('train', fake=fake) if fake:
dataset_val = get_data('val', fake=fake) dataset_train = dataset_val = FakeData(
[[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8')
else:
dataset_train = get_data('train')
dataset_val = get_data('val')
return TrainConfig( return TrainConfig(
model=Model(data_format=data_format), model=Model(data_format=data_format),
...@@ -259,9 +261,11 @@ if __name__ == '__main__': ...@@ -259,9 +261,11 @@ if __name__ == '__main__':
NR_GPU = len(args.gpu.split(',')) NR_GPU = len(args.gpu.split(','))
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU
logger.auto_set_dir() logger.set_logger_dir(
os.path.join('train_log', 'imagenet-resnet-d' + str(DEPTH)))
logger.info("Batch size per GPU: " + str(BATCH_SIZE))
config = get_config(fake=args.fake, data_format=args.data_format) config = get_config(fake=args.fake, data_format=args.data_format)
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
config.nr_tower = NR_GPU config.nr_tower = NR_GPU
SyncMultiGPUTrainer(config).train() SyncMultiGPUTrainerParameterServer(config).train()
...@@ -61,8 +61,9 @@ def np_sample(img, coords): ...@@ -61,8 +61,9 @@ def np_sample(img, coords):
class GaussianDeform(ImageAugmentor): class GaussianDeform(ImageAugmentor):
""" """
Some kind of slow deformation. Some kind of slow deformation I made up. Don't count on it.
""" """
# TODO input/output with different shape # TODO input/output with different shape
def __init__(self, anchors, shape, sigma=0.5, randrange=None): def __init__(self, anchors, shape, sigma=0.5, randrange=None):
......
...@@ -208,7 +208,8 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree ...@@ -208,7 +208,8 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
def SyncMultiGPUTrainer(config): def SyncMultiGPUTrainer(config):
""" """
Alias for ``SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')``, Alias for ``SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')``,
as this is the most commonly used synchronous multigpu trainer. as this is the most commonly used synchronous multigpu trainer (but may
not be more efficient than the other).
""" """
return SyncMultiGPUTrainerParameterServer(config, ps_device='gpu') return SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')
......
...@@ -3,12 +3,14 @@ ...@@ -3,12 +3,14 @@
# File: discretize.py # File: discretize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .argtools import log_once
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import numpy as np import numpy as np
import six import six
from six.moves import range from six.moves import range
from .argtools import log_once
from .develop import log_deprecated
__all__ = ['UniformDiscretizer1D', 'UniformDiscretizerND'] __all__ = ['UniformDiscretizer1D', 'UniformDiscretizerND']
...@@ -40,6 +42,7 @@ class UniformDiscretizer1D(Discretizer1D): ...@@ -40,6 +42,7 @@ class UniformDiscretizer1D(Discretizer1D):
self.maxv = float(maxv) self.maxv = float(maxv)
self.spacing = float(spacing) self.spacing = float(spacing)
self.nr_bin = int(np.ceil((self.maxv - self.minv) / self.spacing)) self.nr_bin = int(np.ceil((self.maxv - self.minv) / self.spacing))
log_deprecated("Discretizer", "It's not related to the library and I'd be surprised if you're using it..")
def get_nr_bin(self): def get_nr_bin(self):
""" """
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import six import six
from .develop import log_deprecated
__all__ = ['LookUpTable'] __all__ = ['LookUpTable']
...@@ -18,6 +19,7 @@ class LookUpTable(object): ...@@ -18,6 +19,7 @@ class LookUpTable(object):
""" """
self.idx2obj = dict(enumerate(objlist)) self.idx2obj = dict(enumerate(objlist))
self.obj2idx = {v: k for k, v in six.iteritems(self.idx2obj)} self.obj2idx = {v: k for k, v in six.iteritems(self.idx2obj)}
log_deprecated("LookUpTable", "It's not related to the library and I'd be surprised if you're using it..")
def size(self): def size(self):
return len(self.idx2obj) return len(self.idx2obj)
......
...@@ -50,10 +50,11 @@ class PythonScript(threading.Thread): ...@@ -50,10 +50,11 @@ class PythonScript(threading.Thread):
self.p.kill() # kill -9 self.p.kill() # kill -9
self.join() self.join()
else: else:
# something unexpected happend here, this script was supposed to survive at leat the timeout # something unexpected happend here, this script was supposed to survive at least the timeout
if len(self.err) is not 0: if len(self.err) is not 0:
stderr = u"\n\n\n\n\n" + self.err.decode('utf-8') output = u"STDOUT: \n\n\n" + self.out.decode('utf-8')
raise AssertionError(stderr) output += u"\n\n\n STDERR: \n\n\n" + self.err.decode('utf-8')
raise AssertionError(output)
class TestPythonScript(unittest.TestCase): class TestPythonScript(unittest.TestCase):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment