Commit c6c9a4ba authored by Yuxin Wu's avatar Yuxin Wu

add epoch_num stat before each epoch

parent 5e2c7309
......@@ -8,7 +8,7 @@ You can actually train them and reproduce the performance... not just to see how
+ [DoReFa-Net: training binary / low bitwidth CNN](examples/DoReFa-Net)
+ [InceptionV3 on ImageNet](examples/Inception/inceptionv3.py)
+ [ResNet for Cifar10 classification](examples/ResNet)
+ [ResNet for ImageNet/Cifar10 classification](examples/ResNet)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED)
+ [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer)
+ [Double DQN plays Atari games](examples/Atari2600)
......
......@@ -2,9 +2,10 @@
## imagenet-resnet.py
ImageNet training code of pre-activation ResNet. It follows the setup in
[fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) and get similar performance (with much fewer lines of code),
[fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) and gets similar performance (with much fewer lines of code).
More results to come.
| Model (WIP) | Top 5 Error | Top 1 Error |
| Model | Top 5 Error | Top 1 Error |
|:-------------------|-------------|------------:|
| ResNet 18 | 10.67% | 29.50% |
| ResNet 50 | 7.13% | 24.12% |
......
......@@ -116,7 +116,6 @@ class Model(ModelDesc):
wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5'))
# weight decay on all W of fc layers
wd_w = 1e-4
wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='l2_regularize_loss')
add_moving_summary(loss, wd_cost)
......
......@@ -107,11 +107,14 @@ class StatPrinter(Callback):
self._stat_holder.set_print_tag(self.print_tag)
self._stat_holder.add_blacklist_tag(['global_step', 'epoch_num'])
# just try to add this stat earlier so SendStat can use
self._stat_holder.add_stat('epoch_num', self.epoch_num + 1)
def _trigger_epoch(self):
# by default, add this two stat
self._stat_holder.add_stat('global_step', get_global_step())
self._stat_holder.add_stat('epoch_num', self.epoch_num)
self._stat_holder.finalize()
self._stat_holder.add_stat('epoch_num', self.epoch_num + 1)
class SendStat(Callback):
"""
......
......@@ -27,7 +27,6 @@ class PredictConfig(object):
:param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session.
:param input_var_names: a list of input variable names.
:param input_data_mapping: deprecated. used to select `input_var_names` from the `InputVars` of the model.
:param model: a `ModelDesc` instance
:param output_var_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph.
......@@ -47,13 +46,7 @@ class PredictConfig(object):
# inputs & outputs
self.input_var_names = kwargs.pop('input_var_names', None)
input_mapping = kwargs.pop('input_data_mapping', None)
if input_mapping:
raw_vars = self.model.get_input_vars_desc()
self.input_var_names = [raw_vars[k].name for k in input_mapping]
logger.warn('The option `input_data_mapping` was deprecated. \
Use \'input_var_names=[{}]\' instead'.format(', '.join(self.input_var_names)))
elif self.input_var_names is None:
if self.input_var_names is None:
# neither options is set, assume all inputs
raw_vars = self.model.get_input_vars_desc()
self.input_var_names = [k.name for k in raw_vars]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment