Commit 0ada5796 authored by Yuxin Wu's avatar Yuxin Wu

remove duplicated docs in slim example. update readme

parent c8a35615
...@@ -28,14 +28,14 @@ config = TrainConfig( ...@@ -28,14 +28,14 @@ config = TrainConfig(
callbacks=[...] callbacks=[...]
) )
# start training: # start training (with a slow trainer. See 'tutorials - Input Sources' for details):
# SimpleTrainer(config).train() # SimpleTrainer(config).train()
# start training with queue prefetch: # start training with queue prefetch:
# QueueInputTrainer(config).train() QueueInputTrainer(config).train()
# start multi-GPU training with a synchronous update: # start multi-GPU training with a synchronous update:
SyncMultiGPUTrainer(config).train() # SyncMultiGPUTrainer(config).train()
``` ```
Trainers just run some iterations, so there is no limit to where the data come from Trainers just run some iterations, so there is no limit to where the data come from
......
...@@ -3,8 +3,12 @@ ...@@ -3,8 +3,12 @@
Training examples with __reproducible__ and meaningful performance. Training examples with __reproducible__ and meaningful performance.
## Vision: ## Getting Started:
+ [An illustrative mnist example with explanation of the framework](mnist-convnet.py) + [An illustrative mnist example with explanation of the framework](mnist-convnet.py)
+ The same mnist example using [tf-slim](mnist-tfslim.py), [Keras](mnist-keras.py), and [with weights visualizations](mnist-visualizations.py)
+ [A boilerplate file to start with, for your own tasks](boilerplate.py)
## Vision:
+ [A tiny SVHN ConvNet with 97.8% accuracy](svhn-digit-convnet.py) + [A tiny SVHN ConvNet with 97.8% accuracy](svhn-digit-convnet.py)
+ [DoReFa-Net: training binary / low-bitwidth CNN on ImageNet](DoReFa-Net) + [DoReFa-Net: training binary / low-bitwidth CNN on ImageNet](DoReFa-Net)
+ [Train ResNet for ImageNet/Cifar10/SVHN](ResNet) + [Train ResNet for ImageNet/Cifar10/SVHN](ResNet)
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: mnist-convnet.py # File: mnist-convnet.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np import numpy as np
import os import os
...@@ -82,9 +81,7 @@ class Model(ModelDesc): ...@@ -82,9 +81,7 @@ class Model(ModelDesc):
summary.add_moving_summary(cost, wd_cost, self.cost) summary.add_moving_summary(cost, wd_cost, self.cost)
# monitor histogram of all weight (of conv and fc layers) in tensorboard # monitor histogram of all weight (of conv and fc layers) in tensorboard
summary.add_param_summary(('.*/W', ['histogram', 'rms']), summary.add_param_summary(('.*/W', ['histogram', 'rms']))
('.*/weights', ['histogram', 'rms']) # to also work with slim
)
def _get_optimizer(self): def _get_optimizer(self):
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
...@@ -141,4 +138,6 @@ if __name__ == '__main__': ...@@ -141,4 +138,6 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
# SimpleTrainer is slow, this is just a demo.
SimpleTrainer(config).train() SimpleTrainer(config).train()
# You can use QueueInputTrainer instead
...@@ -6,13 +6,14 @@ import numpy as np ...@@ -6,13 +6,14 @@ import numpy as np
import os import os
import sys import sys
import argparse import argparse
""" """
MNIST ConvNet example. MNIST ConvNet example using TensorFlow-slim.
about 0.6% validation error after 30 epochs. Mostly the same as 'mnist-convnet.py',
the only differences are:
1. use slim.layers, slim.arg_scope, etc
2. use slim names to summarize weights
""" """
# Just import everything into current namespace
from tensorpack import * from tensorpack import *
import tensorflow as tf import tensorflow as tf
import tensorflow.contrib.slim as slim import tensorflow.contrib.slim as slim
...@@ -22,24 +23,14 @@ IMAGE_SIZE = 28 ...@@ -22,24 +23,14 @@ IMAGE_SIZE = 28
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
"""
Define all the inputs (with type, shape, name) that
the graph will need.
"""
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'), return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label')] InputDesc(tf.int32, (None,), 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
"""This function should build the model which takes the input variables
and define self.cost at the end"""
# inputs contains a list of input variables defined above
image, label = inputs image, label = inputs
# In tensorflow, inputs to convolution function are assumed to be
# NHWC. Add a single channel here.
image = tf.expand_dims(image, 3) image = tf.expand_dims(image, 3)
image = image * 2 - 1 # center the pixels values at zero image = image * 2 - 1
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
with slim.arg_scope([slim.layers.fully_connected], with slim.arg_scope([slim.layers.fully_connected],
...@@ -55,30 +46,19 @@ class Model(ModelDesc): ...@@ -55,30 +46,19 @@ class Model(ModelDesc):
l = slim.layers.dropout(l, is_training=is_training) l = slim.layers.dropout(l, is_training=is_training)
logits = slim.layers.fully_connected(l, 10, activation_fn=None, scope='fc1') logits = slim.layers.fully_connected(l, 10, activation_fn=None, scope='fc1')
prob = tf.nn.softmax(logits, name='prob') # a Bx10 with probabilities prob = tf.nn.softmax(logits, name='prob')
# a vector of length B with loss of each sample
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss cost = tf.reduce_mean(cost, name='cross_entropy_loss')
# compute the "incorrect vector", for the callback ClassificationError to use at validation time
wrong = symbolic_functions.prediction_incorrect(logits, label, name='incorrect') wrong = symbolic_functions.prediction_incorrect(logits, label, name='incorrect')
# This will monitor training error (in a moving_average fashion):
# 1. write the value to tensosrboard
# 2. write the value to stat.json
# 3. print the value after each epoch
train_error = tf.reduce_mean(wrong, name='train_error') train_error = tf.reduce_mean(wrong, name='train_error')
summary.add_moving_summary(train_error) summary.add_moving_summary(train_error)
# slim already adds regularization to a collection, no extra handling
self.cost = cost self.cost = cost
summary.add_moving_summary(cost) summary.add_moving_summary(cost)
summary.add_param_summary(('.*/weights', ['histogram', 'rms'])) # slim uses different variable names
# monitor histogram of all weight (of conv and fc layers) in tensorboard
summary.add_param_summary(('.*/W', ['histogram', 'rms']),
('.*/weights', ['histogram', 'rms']) # to also work with slim
)
def _get_optimizer(self): def _get_optimizer(self):
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
...@@ -86,8 +66,6 @@ class Model(ModelDesc): ...@@ -86,8 +66,6 @@ class Model(ModelDesc):
global_step=get_global_step_var(), global_step=get_global_step_var(),
decay_steps=468 * 10, decay_steps=468 * 10,
decay_rate=0.3, staircase=True, name='learning_rate') decay_rate=0.3, staircase=True, name='learning_rate')
# This will also put the summary in tensorboard, stat.json and print in terminal
# but this time without moving average
tf.summary.scalar('lr', lr) tf.summary.scalar('lr', lr)
return tf.train.AdamOptimizer(lr) return tf.train.AdamOptimizer(lr)
...@@ -99,26 +77,17 @@ def get_data(): ...@@ -99,26 +77,17 @@ def get_data():
def get_config(): def get_config():
# automatically setup the directory train_log/mnist-convnet for logging
logger.auto_set_dir() logger.auto_set_dir()
dataset_train, dataset_test = get_data() dataset_train, dataset_test = get_data()
# How many iterations you want in each epoch.
# This is the default value, don't actually need to set it in the config
steps_per_epoch = dataset_train.size()
# get the config which contains everything necessary in a training
return TrainConfig( return TrainConfig(
model=Model(), model=Model(),
dataflow=dataset_train, # the DataFlow instance for training dataflow=dataset_train,
callbacks=[ callbacks=[
ModelSaver(), # save the model after every epoch ModelSaver(),
InferenceRunner( # run inference(for validation) after every epoch InferenceRunner(
dataset_test, # the DataFlow instance used for validation dataset_test,
# Calculate both the cost and the error for this DataFlow
[ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')]), [ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')]),
], ],
steps_per_epoch=steps_per_epoch,
max_epoch=100, max_epoch=100,
) )
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: mnist-convnet.py # File: mnist-visualizations.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np import numpy as np
import os import os
...@@ -9,11 +8,9 @@ import sys ...@@ -9,11 +8,9 @@ import sys
import argparse import argparse
""" """
MNIST ConvNet example. MNIST ConvNet example with weights/activations visualization.
about 0.6% validation error after 30 epochs.
""" """
# Just import everything into current namespace
from tensorpack import * from tensorpack import *
import tensorflow as tf import tensorflow as tf
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
...@@ -73,10 +70,6 @@ def visualize_conv_activations(activation, name): ...@@ -73,10 +70,6 @@ def visualize_conv_activations(activation, name):
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
"""
Define all the inputs (with type, shape, name) that
the graph will need.
"""
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'), return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label')] InputDesc(tf.int32, (None,), 'label')]
...@@ -111,7 +104,6 @@ class Model(ModelDesc): ...@@ -111,7 +104,6 @@ class Model(ModelDesc):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
# compute the "incorrect vector", for the callback ClassificationError to use at validation time
wrong = symbf.prediction_incorrect(logits, label, name='incorrect') wrong = symbf.prediction_incorrect(logits, label, name='incorrect')
accuracy = symbf.accuracy(logits, label) accuracy = symbf.accuracy(logits, label)
...@@ -121,9 +113,7 @@ class Model(ModelDesc): ...@@ -121,9 +113,7 @@ class Model(ModelDesc):
self.cost = tf.add_n([wd_cost, cost], name='total_cost') self.cost = tf.add_n([wd_cost, cost], name='total_cost')
summary.add_moving_summary(cost, wd_cost, self.cost, accuracy) summary.add_moving_summary(cost, wd_cost, self.cost, accuracy)
summary.add_param_summary(('.*/W', ['histogram', 'rms']), summary.add_param_summary(('.*/W', ['histogram', 'rms']))
('.*/weights', ['histogram', 'rms']) # to also work with slim
)
def _get_optimizer(self): def _get_optimizer(self):
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from .base import Trainer from .base import Trainer
from ..utils import logger
from ..tfutils import TowerContext from ..tfutils import TowerContext
from .input_source import FeedInput from .input_source import FeedInput
...@@ -25,6 +26,7 @@ class SimpleTrainer(Trainer): ...@@ -25,6 +26,7 @@ class SimpleTrainer(Trainer):
assert isinstance(self._input_source, FeedInput), type(self._input_source) assert isinstance(self._input_source, FeedInput), type(self._input_source)
else: else:
self._input_source = FeedInput(config.dataflow) self._input_source = FeedInput(config.dataflow)
logger.warn("SimpleTrainer is slow! Do you really want to use it?")
def run_step(self): def run_step(self):
""" Feed data into the graph and run the updates. """ """ Feed data into the graph and run the updates. """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment