Commit 0ada5796 authored by Yuxin Wu's avatar Yuxin Wu

remove duplicated docs in slim example. update readme

parent c8a35615
......@@ -28,14 +28,14 @@ config = TrainConfig(
callbacks=[...]
)
# start training:
# start training (with a slow trainer. See 'tutorials - Input Sources' for details):
# SimpleTrainer(config).train()
# start training with queue prefetch:
# QueueInputTrainer(config).train()
QueueInputTrainer(config).train()
# start multi-GPU training with a synchronous update:
SyncMultiGPUTrainer(config).train()
# SyncMultiGPUTrainer(config).train()
```
Trainers just run some iterations, so there is no limit to where the data come from
......
......@@ -3,8 +3,12 @@
Training examples with __reproducible__ and meaningful performance.
## Vision:
## Getting Started:
+ [An illustrative mnist example with explanation of the framework](mnist-convnet.py)
+ The same mnist example using [tf-slim](mnist-tfslim.py), [Keras](mnist-keras.py), and [with weights visualizations](mnist-visualizations.py)
+ [A boilerplate file to start with, for your own tasks](boilerplate.py)
## Vision:
+ [A tiny SVHN ConvNet with 97.8% accuracy](svhn-digit-convnet.py)
+ [DoReFa-Net: training binary / low-bitwidth CNN on ImageNet](DoReFa-Net)
+ [Train ResNet for ImageNet/Cifar10/SVHN](ResNet)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: mnist-convnet.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
import os
......@@ -82,9 +81,7 @@ class Model(ModelDesc):
summary.add_moving_summary(cost, wd_cost, self.cost)
# monitor histogram of all weight (of conv and fc layers) in tensorboard
summary.add_param_summary(('.*/W', ['histogram', 'rms']),
('.*/weights', ['histogram', 'rms']) # to also work with slim
)
summary.add_param_summary(('.*/W', ['histogram', 'rms']))
def _get_optimizer(self):
lr = tf.train.exponential_decay(
......@@ -141,4 +138,6 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
# SimpleTrainer is slow, this is just a demo.
SimpleTrainer(config).train()
# You can use QueueInputTrainer instead
......@@ -6,13 +6,14 @@ import numpy as np
import os
import sys
import argparse
"""
MNIST ConvNet example.
about 0.6% validation error after 30 epochs.
MNIST ConvNet example using TensorFlow-slim.
Mostly the same as 'mnist-convnet.py',
the only differences are:
1. use slim.layers, slim.arg_scope, etc
2. use slim names to summarize weights
"""
# Just import everything into current namespace
from tensorpack import *
import tensorflow as tf
import tensorflow.contrib.slim as slim
......@@ -22,24 +23,14 @@ IMAGE_SIZE = 28
class Model(ModelDesc):
def _get_inputs(self):
"""
Define all the inputs (with type, shape, name) that
the graph will need.
"""
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label')]
def _build_graph(self, inputs):
"""This function should build the model which takes the input variables
and define self.cost at the end"""
# inputs contains a list of input variables defined above
image, label = inputs
# In tensorflow, inputs to convolution function are assumed to be
# NHWC. Add a single channel here.
image = tf.expand_dims(image, 3)
image = image * 2 - 1 # center the pixels values at zero
image = image * 2 - 1
is_training = get_current_tower_context().is_training
with slim.arg_scope([slim.layers.fully_connected],
......@@ -55,30 +46,19 @@ class Model(ModelDesc):
l = slim.layers.dropout(l, is_training=is_training)
logits = slim.layers.fully_connected(l, 10, activation_fn=None, scope='fc1')
prob = tf.nn.softmax(logits, name='prob') # a Bx10 with probabilities
prob = tf.nn.softmax(logits, name='prob')
# a vector of length B with loss of each sample
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
# compute the "incorrect vector", for the callback ClassificationError to use at validation time
wrong = symbolic_functions.prediction_incorrect(logits, label, name='incorrect')
# This will monitor training error (in a moving_average fashion):
# 1. write the value to tensosrboard
# 2. write the value to stat.json
# 3. print the value after each epoch
train_error = tf.reduce_mean(wrong, name='train_error')
summary.add_moving_summary(train_error)
# slim already adds regularization to a collection, no extra handling
self.cost = cost
summary.add_moving_summary(cost)
# monitor histogram of all weight (of conv and fc layers) in tensorboard
summary.add_param_summary(('.*/W', ['histogram', 'rms']),
('.*/weights', ['histogram', 'rms']) # to also work with slim
)
summary.add_param_summary(('.*/weights', ['histogram', 'rms'])) # slim uses different variable names
def _get_optimizer(self):
lr = tf.train.exponential_decay(
......@@ -86,8 +66,6 @@ class Model(ModelDesc):
global_step=get_global_step_var(),
decay_steps=468 * 10,
decay_rate=0.3, staircase=True, name='learning_rate')
# This will also put the summary in tensorboard, stat.json and print in terminal
# but this time without moving average
tf.summary.scalar('lr', lr)
return tf.train.AdamOptimizer(lr)
......@@ -99,26 +77,17 @@ def get_data():
def get_config():
# automatically setup the directory train_log/mnist-convnet for logging
logger.auto_set_dir()
dataset_train, dataset_test = get_data()
# How many iterations you want in each epoch.
# This is the default value, don't actually need to set it in the config
steps_per_epoch = dataset_train.size()
# get the config which contains everything necessary in a training
return TrainConfig(
model=Model(),
dataflow=dataset_train, # the DataFlow instance for training
dataflow=dataset_train,
callbacks=[
ModelSaver(), # save the model after every epoch
InferenceRunner( # run inference(for validation) after every epoch
dataset_test, # the DataFlow instance used for validation
# Calculate both the cost and the error for this DataFlow
ModelSaver(),
InferenceRunner(
dataset_test,
[ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')]),
],
steps_per_epoch=steps_per_epoch,
max_epoch=100,
)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: mnist-convnet.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# File: mnist-visualizations.py
import numpy as np
import os
......@@ -9,11 +8,9 @@ import sys
import argparse
"""
MNIST ConvNet example.
about 0.6% validation error after 30 epochs.
MNIST ConvNet example with weights/activations visualization.
"""
# Just import everything into current namespace
from tensorpack import *
import tensorflow as tf
import tensorpack.tfutils.symbolic_functions as symbf
......@@ -73,10 +70,6 @@ def visualize_conv_activations(activation, name):
class Model(ModelDesc):
def _get_inputs(self):
"""
Define all the inputs (with type, shape, name) that
the graph will need.
"""
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label')]
......@@ -111,7 +104,6 @@ class Model(ModelDesc):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
# compute the "incorrect vector", for the callback ClassificationError to use at validation time
wrong = symbf.prediction_incorrect(logits, label, name='incorrect')
accuracy = symbf.accuracy(logits, label)
......@@ -121,9 +113,7 @@ class Model(ModelDesc):
self.cost = tf.add_n([wd_cost, cost], name='total_cost')
summary.add_moving_summary(cost, wd_cost, self.cost, accuracy)
summary.add_param_summary(('.*/W', ['histogram', 'rms']),
('.*/weights', ['histogram', 'rms']) # to also work with slim
)
summary.add_param_summary(('.*/W', ['histogram', 'rms']))
def _get_optimizer(self):
lr = tf.train.exponential_decay(
......
......@@ -4,6 +4,7 @@
from .base import Trainer
from ..utils import logger
from ..tfutils import TowerContext
from .input_source import FeedInput
......@@ -25,6 +26,7 @@ class SimpleTrainer(Trainer):
assert isinstance(self._input_source, FeedInput), type(self._input_source)
else:
self._input_source = FeedInput(config.dataflow)
logger.warn("SimpleTrainer is slow! Do you really want to use it?")
def run_step(self):
""" Feed data into the graph and run the updates. """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment