Commit c6c9a4ba authored by Yuxin Wu's avatar Yuxin Wu

add epoch_num stat before each epoch

parent 5e2c7309
...@@ -8,7 +8,7 @@ You can actually train them and reproduce the performance... not just to see how ...@@ -8,7 +8,7 @@ You can actually train them and reproduce the performance... not just to see how
+ [DoReFa-Net: training binary / low bitwidth CNN](examples/DoReFa-Net) + [DoReFa-Net: training binary / low bitwidth CNN](examples/DoReFa-Net)
+ [InceptionV3 on ImageNet](examples/Inception/inceptionv3.py) + [InceptionV3 on ImageNet](examples/Inception/inceptionv3.py)
+ [ResNet for Cifar10 classification](examples/ResNet) + [ResNet for ImageNet/Cifar10 classification](examples/ResNet)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED) + [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED)
+ [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer) + [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer)
+ [Double DQN plays Atari games](examples/Atari2600) + [Double DQN plays Atari games](examples/Atari2600)
......
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
## imagenet-resnet.py ## imagenet-resnet.py
ImageNet training code of pre-activation ResNet. It follows the setup in ImageNet training code of pre-activation ResNet. It follows the setup in
[fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) and get similar performance (with much fewer lines of code), [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) and gets similar performance (with much fewer lines of code).
More results to come.
| Model (WIP) | Top 5 Error | Top 1 Error | | Model | Top 5 Error | Top 1 Error |
|:-------------------|-------------|------------:| |:-------------------|-------------|------------:|
| ResNet 18 | 10.67% | 29.50% | | ResNet 18 | 10.67% | 29.50% |
| ResNet 50 | 7.13% | 24.12% | | ResNet 50 | 7.13% | 24.12% |
......
...@@ -116,7 +116,6 @@ class Model(ModelDesc): ...@@ -116,7 +116,6 @@ class Model(ModelDesc):
wrong = prediction_incorrect(logits, label, 5, name='wrong-top5') wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5')) add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5'))
# weight decay on all W of fc layers
wd_w = 1e-4 wd_w = 1e-4
wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='l2_regularize_loss') wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='l2_regularize_loss')
add_moving_summary(loss, wd_cost) add_moving_summary(loss, wd_cost)
......
...@@ -107,11 +107,14 @@ class StatPrinter(Callback): ...@@ -107,11 +107,14 @@ class StatPrinter(Callback):
self._stat_holder.set_print_tag(self.print_tag) self._stat_holder.set_print_tag(self.print_tag)
self._stat_holder.add_blacklist_tag(['global_step', 'epoch_num']) self._stat_holder.add_blacklist_tag(['global_step', 'epoch_num'])
# just try to add this stat earlier so SendStat can use
self._stat_holder.add_stat('epoch_num', self.epoch_num + 1)
def _trigger_epoch(self): def _trigger_epoch(self):
# by default, add this two stat # by default, add this two stat
self._stat_holder.add_stat('global_step', get_global_step()) self._stat_holder.add_stat('global_step', get_global_step())
self._stat_holder.add_stat('epoch_num', self.epoch_num)
self._stat_holder.finalize() self._stat_holder.finalize()
self._stat_holder.add_stat('epoch_num', self.epoch_num + 1)
class SendStat(Callback): class SendStat(Callback):
""" """
......
...@@ -27,7 +27,6 @@ class PredictConfig(object): ...@@ -27,7 +27,6 @@ class PredictConfig(object):
:param session_init: a `utils.sessinit.SessionInit` instance to :param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session. initialize variables of a session.
:param input_var_names: a list of input variable names. :param input_var_names: a list of input variable names.
:param input_data_mapping: deprecated. used to select `input_var_names` from the `InputVars` of the model.
:param model: a `ModelDesc` instance :param model: a `ModelDesc` instance
:param output_var_names: a list of names of the output tensors to predict, the :param output_var_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph. variables can be any computable tensor in the graph.
...@@ -47,13 +46,7 @@ class PredictConfig(object): ...@@ -47,13 +46,7 @@ class PredictConfig(object):
# inputs & outputs # inputs & outputs
self.input_var_names = kwargs.pop('input_var_names', None) self.input_var_names = kwargs.pop('input_var_names', None)
input_mapping = kwargs.pop('input_data_mapping', None) if self.input_var_names is None:
if input_mapping:
raw_vars = self.model.get_input_vars_desc()
self.input_var_names = [raw_vars[k].name for k in input_mapping]
logger.warn('The option `input_data_mapping` was deprecated. \
Use \'input_var_names=[{}]\' instead'.format(', '.join(self.input_var_names)))
elif self.input_var_names is None:
# neither options is set, assume all inputs # neither options is set, assume all inputs
raw_vars = self.model.get_input_vars_desc() raw_vars = self.model.get_input_vars_desc()
self.input_var_names = [k.name for k in raw_vars] self.input_var_names = [k.name for k in raw_vars]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment