Commit 0a6dd4ae authored by Yuxin Wu's avatar Yuxin Wu

Add `self.training` to ModelDesc

parent cb1419e8
......@@ -39,16 +39,16 @@ for epoch_num in range(starting_epoch, max_epoch):
run_step() # do something
```
In other words, the assumptions are:
1. Training is **running some iterations**.
Tensorpack base trainer implements the logic of __running the iteration__.
Users or derived trainers should implement __what the iteration is__.
Tensorpack base trainer implements the logic of __running the iterations__.
Users or derived trainers should implement __what the iterations are__.
2. Trainer assumes the existence of __"epoch"__, i.e. that the iterations run in double for-loops.
`steps_per_epoch` can be any number you set
2. The concept of __"epoch"__, i.e. we assume that the iterations run in nested for-loops.
In fact, the steps per epoch can be any number
and it only affects the [schedule of callbacks](callback.html).
In other words, an "epoch" in tensorpack is the __default period to run
callbacks__ (validation, summary, checkpoint, etc.).
It has nothing to do with your dataset.
callbacks__ (validation, summary, checkpoint, etc.). It has nothing to do with your dataset.
### Built-in Trainers
......@@ -76,8 +76,8 @@ It takes only one line of code change to use them, e.g. `trainer=SyncMultiGPUTra
Note some __common confusions__ when using these trainers:
1. In each iteration, instead of taking one input tensor for all GPUs and split,
all GPUs take tensors from the `InputSource`.
So the total batch size across all GPUs is ``(batch size of InputSource) * #GPU``.
tensorpack trainers let all GPUs take tensors from the input.
Therefore, the total batch size across all GPUs is ``(batch size of input source) * #GPU``.
You may want to change `steps_per_epoch` or learing rate appropriately according
to the total batch size.
......@@ -92,11 +92,11 @@ Note some __common confusions__ when using these trainers:
```
2. The tower function (your model code) will get called once on each GPU.
You must follow some [rules of tower function](extend/trainer.html#rules-of-tower-function).
So you must follow some [rules of tower function](extend/trainer.html#rules-of-tower-function).
### Distributed Trainers
Distributed training needs the [horovod](https://github.com/uber/horovod) library which offers high-performance allreduce implementation.
Distributed training needs the [horovod](https://github.com/horovod/horovod) library which offers high-performance allreduce implementation.
To run distributed training, first install horovod properly, then refer to the
documentation of [HorovodTrainer](../modules/train.html#tensorpack.train.HorovodTrainer).
......
......@@ -100,8 +100,7 @@ class Model(ModelDesc):
logits, value = self._get_NN_prediction(state)
value = tf.squeeze(value, [1], name='pred_value') # (B,)
policy = tf.nn.softmax(logits, name='policy')
is_training = get_current_tower_context().is_training
if not is_training:
if not self.training:
return
log_probs = tf.log(policy + 1e-6)
......
......@@ -55,8 +55,7 @@ class Model(ModelDesc):
logits = tf.transpose(logits, [1, 0, 2])
isTrain = get_current_tower_context().is_training
if isTrain:
if self.training:
# beam search is too slow to run in training
predictions = tf.cast(
tf.nn.ctc_greedy_decoder(logits, seqlen)[0][0], tf.int32)
......
......@@ -6,7 +6,7 @@ import abc
import tensorflow as tf
from tensorpack import ModelDesc
from tensorpack.tfutils import get_current_tower_context, gradproc, optimizer, summary, varreplace
from tensorpack.tfutils import gradproc, optimizer, summary, varreplace
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.utils import logger
......@@ -60,7 +60,7 @@ class Model(ModelDesc):
[-1] * (input_rank - 1) + [self.history], name='state')
self.predict_value = self.get_DQN_prediction(state)
if not get_current_tower_context().is_training:
if not self.training:
return
reward = tf.clip_by_value(reward, -1, 1)
......
......@@ -45,8 +45,6 @@ class Model(ModelDesc):
tf.TensorSpec([None], tf.int32, 'label')]
def build_graph(self, image, label):
is_training = get_current_tower_context().is_training
fw, fa, fg = get_dorefa(BITW, BITA, BITG)
# monkey-patch tf.get_variable to apply fw
......@@ -100,7 +98,7 @@ class Model(ModelDesc):
.apply(fg)
.BatchNorm('bn5').apply(activate)
# 5
.Dropout(rate=0.5 if is_training else 0.0)
.Dropout(rate=0.5 if self.training else 0.0)
.Conv2D('conv6', 512, 5, padding='VALID')
.apply(fg).BatchNorm('bn6')
.apply(nonlin)
......
......@@ -7,7 +7,6 @@ from tensorpack import ModelDesc
from tensorpack.models import GlobalAvgPooling, l2_regularizer, regularize_cost
from tensorpack.tfutils import optimizer
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.tower import get_current_tower_context
from config import config as cfg
from data import get_all_anchors, get_all_anchors_fpn
......@@ -31,10 +30,6 @@ class GeneralizedRCNN(ModelDesc):
image = image_preprocess(image, bgr=True)
return tf.transpose(image, [0, 3, 1, 2])
@property
def training(self):
return get_current_tower_context().is_training
def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=0.003, trainable=False)
tf.summary.scalar('learning_rate-summary', lr)
......
......@@ -50,12 +50,11 @@ class Model(ModelDesc):
tf.TensorSpec((None, SEQ_LEN), tf.int32, 'nextinput')]
def build_graph(self, input, nextinput):
is_training = get_current_tower_context().is_training
initializer = tf.random_uniform_initializer(-0.05, 0.05)
def get_basic_cell():
cell = rnn.BasicLSTMCell(num_units=HIDDEN_SIZE, forget_bias=0.0, reuse=tf.get_variable_scope().reuse)
if is_training:
if self.training:
cell = rnn.DropoutWrapper(cell, output_keep_prob=1 - DROPOUT)
return cell
......
......@@ -55,7 +55,6 @@ class Model(GANModelDesc):
def build_graph(self, Ilr, Ihr):
Ilr, Ihr = Ilr / 255.0, Ihr / 255.0
ctx = get_current_tower_context()
Ibicubic = tf.image.resize_bicubic(
Ilr, [4 * self.height, 4 * self.width], align_corners=True,
name='bicubic_baseline') # (0,1)
......@@ -182,7 +181,7 @@ class Model(GANModelDesc):
tf.multiply(fake_hr, 255.0, name='prediction')
if ctx.is_training:
if self.training:
with tf.variable_scope('discrim'):
real_score = discriminator(real_hr)
fake_score = discriminator(fake_hr)
......
......@@ -33,10 +33,9 @@ class Model(ModelDesc):
tf.TensorSpec((None,), tf.int32, 'label')]
def build_graph(self, image, label):
is_training = get_current_tower_context().is_training
drop_rate = tf.constant(0.5 if is_training else 0.0)
drop_rate = tf.constant(0.5 if self.training else 0.0)
if is_training:
if self.training:
tf.summary.image("train_image", image, 10)
if tf.test.is_gpu_available():
image = tf.transpose(image, [0, 3, 1, 2])
......
......@@ -6,7 +6,7 @@ import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils import get_current_tower_context, summary
from tensorpack.tfutils import summary
"""
MNIST ConvNet example using tf.layers
......@@ -50,8 +50,7 @@ class Model(ModelDesc):
l = tf.layers.conv2d(l, 32, 3, name='conv3')
l = tf.layers.flatten(l)
l = tf.layers.dense(l, 512, activation=tf.nn.relu, name='fc0')
l = tf.layers.dropout(l, rate=0.5,
training=get_current_tower_context().is_training)
l = tf.layers.dropout(l, rate=0.5, training=self.training)
logits = tf.layers.dense(l, 10, activation=tf.identity, name='fc1')
# a vector of length B with loss of each sample
......
......@@ -30,7 +30,6 @@ class Model(ModelDesc):
image = image * 2 - 1
is_training = get_current_tower_context().is_training
with slim.arg_scope([slim.layers.fully_connected],
weights_regularizer=slim.l2_regularizer(1e-5)):
l = slim.layers.conv2d(image, 32, [3, 3], scope='conv0')
......@@ -41,7 +40,7 @@ class Model(ModelDesc):
l = slim.layers.conv2d(l, 32, [3, 3], scope='conv3')
l = slim.layers.flatten(l, scope='flatten')
l = slim.layers.fully_connected(l, 512, scope='fc0')
l = slim.layers.dropout(l, is_training=is_training)
l = slim.layers.dropout(l, is_training=self.training)
logits = slim.layers.fully_connected(l, 10, activation_fn=None, scope='fc1')
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
......
......@@ -7,6 +7,7 @@ import tensorflow as tf
from ..utils.argtools import memoized_method
from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
from ..compat import backport_tensor_spec, tfv1
TensorSpec = backport_tensor_spec()
......@@ -137,6 +138,14 @@ class ModelDescBase(object):
"""
raise NotImplementedError()
@property
def training(self):
"""
Returns:
bool: whether the caller is under a training context or not.
"""
return get_current_tower_context().is_training
class ModelDesc(ModelDescBase):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment