Commit f43309f0 authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent 47de91da
...@@ -14,7 +14,7 @@ There are two ways to do inference during training. ...@@ -14,7 +14,7 @@ There are two ways to do inference during training.
"evaluate some tensors for each input, and aggregate the results in the end". "evaluate some tensors for each input, and aggregate the results in the end".
You can use the `InferenceRunner` interface with some `Inferencer`. You can use the `InferenceRunner` interface with some `Inferencer`.
This will further support prefetch & data-parallel inference. This will further support prefetch & data-parallel inference.
Currently this lacks documentation, but you can refer to examples Currently this lacks documentation, but you can refer to examples
that uses `InferenceRunner` or custom `Inferencer` to learn more. that uses `InferenceRunner` or custom `Inferencer` to learn more.
...@@ -55,8 +55,8 @@ predictor = OfflinePredictor(pred_config) ...@@ -55,8 +55,8 @@ predictor = OfflinePredictor(pred_config)
output1_array, output2_array = predictor(input1_array, input2_array) output1_array, output2_array = predictor(input1_array, input2_array)
``` ```
It's __common to use a different graph for inference__, It's __common to use a different graph for inference__,
e.g., use NHWC format, support encoded image format, etc. e.g., use NHWC format, support encoded image format, etc.
You can make these changes inside the `model` or `tower_func` in your `PredictConfig`. You can make these changes inside the `model` or `tower_func` in your `PredictConfig`.
The example in [examples/basics/export-model.py](../examples/basics/export-model.py) demonstrates such an altered inference graph. The example in [examples/basics/export-model.py](../examples/basics/export-model.py) demonstrates such an altered inference graph.
...@@ -90,7 +90,7 @@ you can also save your models into other formats after training, so it may be mo ...@@ -90,7 +90,7 @@ you can also save your models into other formats after training, so it may be mo
- Removes all unnecessary operations (training-only ops, e.g., learning-rate) to compress the graph. - Removes all unnecessary operations (training-only ops, e.g., learning-rate) to compress the graph.
This creates a self-contained graph which includes all necessary information to run inference. This creates a self-contained graph which includes all necessary information to run inference.
To load the saved graph, you can simply: To load the saved graph, you can simply:
```python ```python
graph_def = tf.GraphDef() graph_def = tf.GraphDef()
...@@ -116,7 +116,7 @@ training: ...@@ -116,7 +116,7 @@ training:
1. The model (the graph): you've already written it yourself with TF symbolic functions. 1. The model (the graph): you've already written it yourself with TF symbolic functions.
Nothing about it is related to the tensorpack interface. Nothing about it is related to the tensorpack interface.
If you use tensorpack layers, they are mainly just wrappers around `tf.layers`. If you use tensorpack layers, they are not so different from `tf.layers`.
2. The trained parameters: tensorpack saves them in standard TF checkpoint format. 2. The trained parameters: tensorpack saves them in standard TF checkpoint format.
Nothing about the format is related to tensorpack. Nothing about the format is related to tensorpack.
...@@ -139,14 +139,16 @@ with TowerContext('', is_training=False): ...@@ -139,14 +139,16 @@ with TowerContext('', is_training=False):
```eval_rst ```eval_rst
.. note:: **Do not use metagraph for inference!** .. note:: **Do not use metagraph for inference!**
Metagraph is the wrong abstraction for a "model". Tensorpack saves a metagraph during training. Users should not try to load it for inference.
Metagraph is the wrong abstraction for a "model".
It stores the entire graph which contains not only the mathematical model, but also all the It stores the entire graph which contains not only the mathematical model, but also all the
training settings (queues, iterators, summaries, evaluations, multi-gpu replications). training settings (queues, iterators, summaries, evaluations, multi-gpu replications).
Therefore it is usually wrong to import a training metagraph for inference. Therefore it is usually wrong to import a training metagraph for inference.
It's especially error-prone to load a metagraph on top of a non-empty graph. It's especially error-prone to load a metagraph on top of a non-empty graph.
The potential name conflicts between the current graph and the nodes in the The potential name conflicts between the current graph and the nodes in the
metagraph can lead to esoteric bugs or sometimes completely ruin the model. metagraph can lead to esoteric bugs or sometimes completely ruin the model.
It's also very common to change the graph for inference. It's also very common to change the graph for inference.
For example, you may need a different data layout for CPU inference, For example, you may need a different data layout for CPU inference,
...@@ -161,7 +163,7 @@ with TowerContext('', is_training=False): ...@@ -161,7 +163,7 @@ with TowerContext('', is_training=False):
You can just use `tf.train.Saver` for all the work. You can just use `tf.train.Saver` for all the work.
Alternatively, use tensorpack's `get_model_loader(path).init(tf.get_default_session())` Alternatively, use tensorpack's `get_model_loader(path).init(tf.get_default_session())`
Now, you've already built a graph for inference, and the checkpoint is also loaded. Now, you've already built a graph for inference, and the checkpoint is also loaded.
You may now: You may now:
1. use `sess.run` to do inference 1. use `sess.run` to do inference
......
...@@ -28,7 +28,7 @@ Some practicical notes: ...@@ -28,7 +28,7 @@ Some practicical notes:
### To test a model: ### To test a model:
Download models from [model zoo](http://models.tensorpack.com/OpenAIGym/). Download models from [model zoo](http://models.tensorpack.com/#OpenAIGym).
Watch the agent play: Watch the agent play:
`./train-atari.py --task play --env Breakout-v0 --load Breakout-v0.npz` `./train-atari.py --task play --env Breakout-v0 --load Breakout-v0.npz`
......
Example code to convert, load and run inference of some Caffe models. Example code to convert, load and run inference of some Caffe models.
Require caffe python bindings to be installed. Require caffe python bindings to be installed.
Converted models can also be found at [tensorpack model zoo](http://models.tensorpack.com). Converted models can also be found at [tensorpack model zoo](http://models.tensorpack.com/#Caffe-Converted).
## AlexNet: ## AlexNet:
Download: https://github.com/BVLC/caffe/tree/master/models/bvlc_alexnet Download: https://github.com/BVLC/caffe/tree/master/models/bvlc_alexnet
......
...@@ -46,7 +46,7 @@ In this implementation, quantized operations are all performed through `tf.float ...@@ -46,7 +46,7 @@ In this implementation, quantized operations are all performed through `tf.float
+ Look at the docstring in `*-dorefa.py` to see detailed usage and performance. + Look at the docstring in `*-dorefa.py` to see detailed usage and performance.
Pretrained model for (1,4,32)-ResNet18 and several AlexNet are available at Pretrained model for (1,4,32)-ResNet18 and several AlexNet are available at
[tensorpack model zoo](http://models.tensorpack.com/DoReFa-Net/). [tensorpack model zoo](http://models.tensorpack.com/#DoReFa-Net).
They're provided in the format of numpy dictionary. They're provided in the format of numpy dictionary.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation accuracy. The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation accuracy.
......
...@@ -18,8 +18,8 @@ This is likely the best-performing open source TensorFlow reimplementation of th ...@@ -18,8 +18,8 @@ This is likely the best-performing open source TensorFlow reimplementation of th
## Dependencies ## Dependencies
+ Python 3.3+; OpenCV + Python 3.3+; OpenCV
+ TensorFlow ≥ 1.6 + TensorFlow ≥ 1.6
+ pycocotools: `for i in cython 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'; do pip install $i; done` + pycocotools/scipy: `for i in cython 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' scipy; do pip install $i; done`
+ Pre-trained [ImageNet ResNet model](http://models.tensorpack.com/FasterRCNN/) + Pre-trained [ImageNet ResNet model](http://models.tensorpack.com/#FasterRCNN)
from tensorpack model zoo from tensorpack model zoo
+ [COCO data](http://cocodataset.org/#download). It needs to have the following directory structure: + [COCO data](http://cocodataset.org/#download). It needs to have the following directory structure:
``` ```
...@@ -83,24 +83,24 @@ prediction have to be run with the corresponding configs used in training. ...@@ -83,24 +83,24 @@ prediction have to be run with the corresponding configs used in training.
These models are trained on train2017 and evaluated on val2017 using mAP@IoU=0.50:0.95. These models are trained on train2017 and evaluated on val2017 using mAP@IoU=0.50:0.95.
Unless otherwise noted, all models are fine-tuned from ImageNet pre-trained R50/R101 models in Unless otherwise noted, all models are fine-tuned from ImageNet pre-trained R50/R101 models in
[tensorpack model zoo](http://models.tensorpack.com/FasterRCNN/), [tensorpack model zoo](http://models.tensorpack.com/#FasterRCNN),
using 8 NVIDIA V100s. using 8 NVIDIA V100s.
Performance in [Detectron](https://github.com/facebookresearch/Detectron/) can be reproduced. Performance in [Detectron](https://github.com/facebookresearch/Detectron/) can be reproduced.
| Backbone | mAP<br/>(box;mask) | Detectron mAP <sup>[1](#ft1)</sup><br/> (box;mask) | Time <br/>(on 8 V100s) | Configurations <br/> (click to expand) | | Backbone | mAP<br/>(box;mask) | Detectron mAP <sup>[1](#ft1)</sup><br/> (box;mask) | Time <br/>(on 8 V100s) | Configurations <br/> (click to expand) |
| - | - | - | - | - | | - | - | - | - | - |
| R50-C4 | 34.1 | | 7.5h | <details><summary>super quick</summary>`MODE_MASK=False FRCNN.BATCH_PER_IM=64`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=600 PREPROC.MAX_SIZE=1024`<br/>`TRAIN.LR_SCHEDULE=[140000,180000,200000]` </details> | | R50-C4 | 34.1 | | 7h | <details><summary>super quick</summary>`MODE_MASK=False FRCNN.BATCH_PER_IM=64`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=600 PREPROC.MAX_SIZE=1024`<br/>`TRAIN.LR_SCHEDULE=[140000,180000,200000]` </details> |
| R50-C4 | 35.6 | 34.8 | 23h | <details><summary>standard</summary>`MODE_MASK=False` </details> | | R50-C4 | 35.6 | 34.8 | 22.5h | <details><summary>standard</summary>`MODE_MASK=False` </details> |
| R50-FPN | 37.5 | 36.7 | 11h | <details><summary>standard</summary>`MODE_MASK=False MODE_FPN=True` </details> | | R50-FPN | 37.5 | 36.7 | 10.5h | <details><summary>standard</summary>`MODE_MASK=False MODE_FPN=True` </details> |
| R50-C4 | 36.2;31.8 [:arrow_down:][R50C41x] | 35.8;31.4 | 23.5h | <details><summary>standard</summary>this is the default, no changes in config needed </details> | | R50-C4 | 36.2;31.8 [:arrow_down:][R50C41x] | 35.8;31.4 | 23h | <details><summary>standard</summary>this is the default, no changes in config needed </details> |
| R50-FPN | 38.2;34.8 | 37.7;33.9 | 13.5h | <details><summary>standard</summary>`MODE_FPN=True` </details> | | R50-FPN | 38.2;34.8 | 37.7;33.9 | 12.5h | <details><summary>standard</summary>`MODE_FPN=True` </details> |
| R50-FPN | 38.9;35.4 [:arrow_down:][R50FPN2x] | 38.6;34.5 | 25h | <details><summary>2x</summary>`MODE_FPN=True`<br/>`TRAIN.LR_SCHEDULE=2x` </details> | | R50-FPN | 38.9;35.4 [:arrow_down:][R50FPN2x] | 38.6;34.5 | 24h | <details><summary>2x</summary>`MODE_FPN=True`<br/>`TRAIN.LR_SCHEDULE=2x` </details> |
| R50-FPN-GN | 40.4;36.3 [:arrow_down:][R50FPN2xGN] | 40.3;35.7 | 31h | <details><summary>2x+GN</summary>`MODE_FPN=True`<br/>`FPN.NORM=GN BACKBONE.NORM=GN`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`<br/>`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head` <br/>`TRAIN.LR_SCHEDULE=2x` | | R50-FPN-GN | 40.4;36.3 [:arrow_down:][R50FPN2xGN] | 40.3;35.7 | 29h | <details><summary>2x+GN</summary>`MODE_FPN=True`<br/>`FPN.NORM=GN BACKBONE.NORM=GN`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`<br/>`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head` <br/>`TRAIN.LR_SCHEDULE=2x` |
| R50-FPN | 41.7;36.2 | | 17h | <details><summary>+Cascade</summary>`MODE_FPN=True FPN.CASCADE=True` </details> | | R50-FPN | 41.7;36.2 | | 16h | <details><summary>+Cascade</summary>`MODE_FPN=True FPN.CASCADE=True` </details> |
| R101-C4 | 40.1;34.6 [:arrow_down:][R101C41x] | | 28h | <details><summary>standard</summary>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]` </details> | | R101-C4 | 40.1;34.6 [:arrow_down:][R101C41x] | | 27h | <details><summary>standard</summary>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]` </details> |
| R101-FPN | 40.7;36.8 [:arrow_down:][R101FPN1x] | 40.0;35.9 | 18h | <details><summary>standard</summary>`MODE_FPN=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]` </details> | | R101-FPN | 40.7;36.8 [:arrow_down:][R101FPN1x] | 40.0;35.9 | 17h | <details><summary>standard</summary>`MODE_FPN=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]` </details> |
| R101-FPN | 46.6;40.3 [:arrow_down:][R101FPN3xCasAug] <sup>[2](#ft2)</sup> | | 69h | <details><summary>3x+Cascade+TrainAug</summary>`MODE_FPN=True FPN.CASCADE=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]`<br/>`TEST.RESULT_SCORE_THRESH=1e-4`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=[640,800]`<br/>`TRAIN.LR_SCHEDULE=3x` </details> | | R101-FPN | 46.6;40.3 [:arrow_down:][R101FPN3xCasAug] <sup>[2](#ft2)</sup> | | 64h | <details><summary>3x+Cascade+TrainAug</summary>`MODE_FPN=True FPN.CASCADE=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]`<br/>`TEST.RESULT_SCORE_THRESH=1e-4`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=[640,800]`<br/>`TRAIN.LR_SCHEDULE=3x` </details> |
| R101-FPN-GN<br/>(From Scratch) | 47.7;41.7 [:arrow_down:][R101FPN9xGNCasAugScratch] <sup>[3](#ft3)</sup> | 47.4;40.5 | 28h (on 64 V100s) | <details><summary>9x+GN+Cascade+TrainAug</summary>`MODE_FPN=True FPN.CASCADE=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]`<br/>`FPN.NORM=GN BACKBONE.NORM=GN`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`<br/>`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=[640,800]`<br/>`TRAIN.LR_SCHEDULE=9x`<br/>`BACKBONE.FREEZE_AT=0`</details> | | R101-FPN-GN<br/>(From Scratch) | 47.7;41.7 [:arrow_down:][R101FPN9xGNCasAugScratch] <sup>[3](#ft3)</sup> | 47.4;40.5 | 28h (on 64 V100s) | <details><summary>9x+GN+Cascade+TrainAug</summary>`MODE_FPN=True FPN.CASCADE=True`<br/>`BACKBONE.RESNET_NUM_BLOCKS=[3,4,23,3]`<br/>`FPN.NORM=GN BACKBONE.NORM=GN`<br/>`FPN.FRCNN_HEAD_FUNC=fastrcnn_4conv1fc_gn_head`<br/>`FPN.MRCNN_HEAD_FUNC=maskrcnn_up4conv_gn_head`<br/>`PREPROC.TRAIN_SHORT_EDGE_SIZE=[640,800]`<br/>`TRAIN.LR_SCHEDULE=9x`<br/>`BACKBONE.FREEZE_AT=0`</details> |
[R50C41x]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R50C41x.npz [R50C41x]: http://models.tensorpack.com/FasterRCNN/COCO-MaskRCNN-R50C41x.npz
......
...@@ -140,7 +140,7 @@ _C.TRAIN.STARTING_EPOCH = 1 # the first epoch to start with, useful to continue ...@@ -140,7 +140,7 @@ _C.TRAIN.STARTING_EPOCH = 1 # the first epoch to start with, useful to continue
# Therefore, there is *no need* to modify the config if you only change the number of GPUs. # Therefore, there is *no need* to modify the config if you only change the number of GPUs.
_C.TRAIN.LR_SCHEDULE = "1x" # "1x" schedule in detectron _C.TRAIN.LR_SCHEDULE = "1x" # "1x" schedule in detectron
_C.TRAIN.EVAL_PERIOD = 25 # period (epochs) to run evaluation _C.TRAIN.EVAL_PERIOD = 50 # period (epochs) to run evaluation
_C.TRAIN.CHECKPOINT_PERIOD = 20 # period (epochs) to save model _C.TRAIN.CHECKPOINT_PERIOD = 20 # period (epochs) to save model
# preprocessing -------------------- # preprocessing --------------------
......
...@@ -17,7 +17,7 @@ from GAN import GANModelDesc, GANTrainer ...@@ -17,7 +17,7 @@ from GAN import GANModelDesc, GANTrainer
Boundary Equilibrium GAN. Boundary Equilibrium GAN.
See the docstring in DCGAN.py for usage. See the docstring in DCGAN.py for usage.
A pretrained model on CelebA is at http://models.tensorpack.com/GAN/ A pretrained model on CelebA is at http://models.tensorpack.com/#GAN
""" """
......
...@@ -30,7 +30,7 @@ from GAN import GANModelDesc, GANTrainer, RandomZData ...@@ -30,7 +30,7 @@ from GAN import GANModelDesc, GANTrainer, RandomZData
You can also train on other images (just use any directory of jpg files in You can also train on other images (just use any directory of jpg files in
`--data`). But you may need to change the preprocessing. `--data`). But you may need to change the preprocessing.
A pretrained model on CelebA is at http://models.tensorpack.com/GAN/ A pretrained model on CelebA is at http://models.tensorpack.com/#GAN
""" """
......
...@@ -24,7 +24,7 @@ To train: ...@@ -24,7 +24,7 @@ To train:
To visualize: To visualize:
./InfoGAN-mnist.py --sample --load path/to/model ./InfoGAN-mnist.py --sample --load path/to/model
A pretrained model is at http://models.tensorpack.com/GAN/ A pretrained model is at http://models.tensorpack.com/#GAN
""" """
BATCH = 128 BATCH = 128
......
...@@ -33,4 +33,4 @@ To inference (produce a heatmap at each level at out*.png): ...@@ -33,4 +33,4 @@ To inference (produce a heatmap at each level at out*.png):
```bash ```bash
./hed.py --load pretrained.model --run a.jpg ./hed.py --load pretrained.model --run a.jpg
``` ```
Models I trained can be downloaded [here](http://models.tensorpack.com/HED/). Models I trained can be downloaded [here](http://models.tensorpack.com/#HED).
...@@ -4,7 +4,7 @@ ImageNet training code of ResNet, ShuffleNet, DoReFa-Net, AlexNet, Inception, VG ...@@ -4,7 +4,7 @@ ImageNet training code of ResNet, ShuffleNet, DoReFa-Net, AlexNet, Inception, VG
To train any of the models, just do `./{model}.py --data /path/to/ilsvrc`. To train any of the models, just do `./{model}.py --data /path/to/ilsvrc`.
More options are available in `./{model}.py --help`. More options are available in `./{model}.py --help`.
Expected format of data directory is described in [docs](http://tensorpack.readthedocs.io/modules/dataflow.dataset.html#tensorpack.dataflow.dataset.ILSVRC12). Expected format of data directory is described in [docs](http://tensorpack.readthedocs.io/modules/dataflow.dataset.html#tensorpack.dataflow.dataset.ILSVRC12).
Some pretrained models can be downloaded at [tensorpack model zoo](http://models.tensorpack.com/). Some pretrained models can be downloaded at [tensorpack model zoo](http://models.tensorpack.com/#ImageNetModels).
### ShuffleNet ### ShuffleNet
......
...@@ -39,7 +39,7 @@ Usage: ...@@ -39,7 +39,7 @@ Usage:
./CAM-resnet.py --data /path/to/imagenet [--load ImageNet-ResNet18-Preact.npz] [--gpu 0,1,2,3] ./CAM-resnet.py --data /path/to/imagenet [--load ImageNet-ResNet18-Preact.npz] [--gpu 0,1,2,3]
``` ```
Pretrained and fine-tuned ResNet can be downloaded Pretrained and fine-tuned ResNet can be downloaded
in the [model zoo](http://models.tensorpack.com/). in the [model zoo](http://models.tensorpack.com/#Visualization).
2. Generate CAM on ImageNet validation set: 2. Generate CAM on ImageNet validation set:
```bash ```bash
......
...@@ -20,7 +20,7 @@ To train (takes about 300 epochs to reach 8.8% error): ...@@ -20,7 +20,7 @@ To train (takes about 300 epochs to reach 8.8% error):
./mnist-addition.py ./mnist-addition.py
``` ```
To draw the above visualization with [pretrained model](http://models.tensorpack.com/SpatialTransformer/): To draw the above visualization with [pretrained model](http://models.tensorpack.com/#SpatialTransformer):
```bash ```bash
./mnist-addition.py --load mnist-addition.npz --view ./mnist-addition.py --load mnist-addition.npz --view
``` ```
...@@ -35,7 +35,7 @@ python enet-pat.py --vgg19 /path/to/vgg19.npz --data train2017.lmdb ...@@ -35,7 +35,7 @@ python enet-pat.py --vgg19 /path/to/vgg19.npz --data train2017.lmdb
Training is highly unstable and does not often give good results. Training is highly unstable and does not often give good results.
The pretrained model may also fail on different types of images. The pretrained model may also fail on different types of images.
You can download and play with the pretrained model [here](http://models.tensorpack.com/SuperResolution/). You can download and play with the pretrained model [here](http://models.tensorpack.com/#SuperResolution).
3. Inference on an image and output in current directory: 3. Inference on an image and output in current directory:
......
...@@ -304,7 +304,7 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -304,7 +304,7 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
with tf.name_scope('sync_variables'): with tf.name_scope('sync_variables'):
post_init_op = SyncMultiGPUReplicatedBuilder.get_post_init_ops() post_init_op = SyncMultiGPUReplicatedBuilder.get_post_init_ops()
else: else:
post_init_op = tf.no_op(name='empty_sync_variables') post_init_op = None
return train_op, post_init_op return train_op, post_init_op
# Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py # Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
......
...@@ -190,13 +190,16 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer): ...@@ -190,13 +190,16 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
grad_list = self._builder.call_for_each_tower(tower_fn) grad_list = self._builder.call_for_each_tower(tower_fn)
self.train_op, post_init_op = self._builder.build(grad_list, get_opt_fn) self.train_op, post_init_op = self._builder.build(grad_list, get_opt_fn)
cb = RunOp( if post_init_op is not None:
post_init_op, cb = RunOp(
run_before=True, post_init_op,
run_as_trigger=self.BROADCAST_EVERY_EPOCH, run_before=True,
verbose=True) run_as_trigger=self.BROADCAST_EVERY_EPOCH,
cb.name_scope = "SyncVariables" verbose=True)
return [cb] cb.name_scope = "SyncVariables"
return [cb]
else:
return []
class DistributedTrainerBase(SingleCostTrainer): class DistributedTrainerBase(SingleCostTrainer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment