Commit 7b232c98 authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent 62d54f68
......@@ -30,13 +30,13 @@ Examples are not only for demonstration of the framework -- you can train them a
## Features:
It's Yet Another TF wrapper, but different in:
1. Not focus on models.
1. It's not a model wrapper.
+ There are already too many symbolic function wrappers.
Tensorpack includes only a few common models,
but you can use any other wrappers within tensorpack, such as sonnet/Keras/slim/tflearn/tensorlayer/....
but you can use any other model wrappers within tensorpack, such as sonnet/Keras/slim/tflearn/tensorlayer/....
2. Focus on __training speed__.
+ Speed comes for free with tensorpack -- it uses TensorFlow in the correct way.
+ Speed comes for free with tensorpack -- it uses TensorFlow in the __correct way__.
Even on a tiny CNN example, the training runs [1.6x faster](https://gist.github.com/ppwwyyxx/8d95da79f8d97036a7d67c2416c851b6) than the equivalent Keras code.
+ Data-parallel multi-GPU training is off-the-shelf to use. It is as fast as Google's [official benchmark](https://www.tensorflow.org/performance/benchmarks).
......
......@@ -9,7 +9,6 @@ import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer
import tensorpack as tp
from tensorpack import imgaug, dataset
from tensorpack.dataflow import (
AugmentImageComponent, PrefetchDataZMQ,
......@@ -134,6 +133,9 @@ def apply_preactivation(l, preact):
def get_bn(zero_init=False):
"""
Zero init gamma is good for resnet. See https://arxiv.org/abs/1706.02677.
"""
if zero_init:
return lambda x, name: BatchNorm('bn', x, gamma_init=tf.zeros_initializer())
else:
......@@ -220,10 +222,10 @@ def eval_on_ILSVRC12(model, sessinit, dataflow):
)
pred = SimpleDatasetPredictor(pred_config, dataflow)
acc1, acc5 = RatioCounter(), RatioCounter()
for o in pred.get_result():
batch_size = o[0].shape[0]
acc1.feed(o[0].sum(), batch_size)
acc5.feed(o[1].sum(), batch_size)
for top1, top5 in pred.get_result():
batch_size = top1.shape[0]
acc1.feed(top1.sum(), batch_size)
acc5.feed(top5.sum(), batch_size)
print("Top1 Error: {}".format(acc1.ratio))
print("Top5 Error: {}".format(acc5.ratio))
......
......@@ -78,9 +78,10 @@ class Model(ModelDesc):
return tf.train.AdamOptimizer(lr)
# Keras needs an extra input if learning_phase is needed
# Keras needs an extra input if learning_phase is used by the model
class KerasCallback(Callback):
def __init__(self, isTrain):
assert isinstance(isTrain, bool), isTrain
self._isTrain = isTrain
self._learning_phase = KB.learning_phase()
......
......@@ -11,8 +11,7 @@ from ..graph_builder.input_source import QueueInput, FeedfreeInput
from .simple import SimpleTrainer
from .base import Trainer
__all__ = ['FeedfreeTrainerBase', 'SingleCostFeedfreeTrainer',
'SimpleFeedfreeTrainer', 'QueueInputTrainer']
__all__ = ['FeedfreeTrainerBase', 'SingleCostFeedfreeTrainer', 'QueueInputTrainer']
# TODO deprecate it some time
......@@ -32,9 +31,9 @@ class FeedfreeTrainerBase(Trainer):
self.config.callbacks.extend(cbs)
# deprecated
class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """
@deprecated("", "2017-11-21")
def __init__(self, *args, **kwargs):
super(SingleCostFeedfreeTrainer, self).__init__(*args, **kwargs)
logger.warn("SingleCostFeedfreeTrainer was deprecated!")
......@@ -45,12 +44,6 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
return self.model.get_cost_and_grad()
@deprecated("Use SimpleTrainer with config.data is the same!", "2017-09-13")
def SimpleFeedfreeTrainer(config):
assert isinstance(config.data, FeedfreeInput), config.data
return SimpleTrainer(config)
def QueueInputTrainer(config, input_queue=None):
"""
A wrapper trainer which automatically wraps ``config.dataflow`` by a :class:`QueueInput`.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment