Commit 7877a7f7 authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent c712e8dd
......@@ -19,6 +19,7 @@ For any unexpected problems, __PLEASE ALWAYS INCLUDE__:
+ TF version: `python -c 'import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)'`.
+ Tensorpack version: `python -c 'import tensorpack; print(tensorpack.__version__)'`.
You can install Tensorpack master by `pip install -U git+https://github.com/ppwwyyxx/tensorpack.git`.:
+ Hardware information, if relevant.
5. About efficiency, PLEASE first read http://tensorpack.readthedocs.io/en/latest/tutorial/performance-tuning.html
Feature Requests:
......
......@@ -18,8 +18,6 @@ from config import config as cfg
__all__ = ['COCODetection', 'COCOMeta']
COCO_NUM_CATEGORY = 80
class _COCOMeta(object):
INSTANCE_TO_BASEDIR = {
......@@ -39,7 +37,7 @@ class _COCOMeta(object):
cat_names: list of names
"""
assert not self.valid()
assert len(cat_ids) == COCO_NUM_CATEGORY and len(cat_names) == COCO_NUM_CATEGORY
assert len(cat_ids) == cfg.DATA.NUM_CATEGORY and len(cat_names) == cfg.DATA.NUM_CATEGORY
self.cat_names = cat_names
self.class_names = ['BG'] + self.cat_names
......
......@@ -63,7 +63,7 @@ _C.MODE_FPN = False
_C.DATA.BASEDIR = '/path/to/your/COCO/DIR'
_C.DATA.TRAIN = ['train2014', 'valminusminival2014'] # i.e., trainval35k
_C.DATA.VAL = 'minival2014' # For now, only support evaluation on single dataset
_C.DATA.NUM_CATEGORY = 80 # 80 categories
_C.DATA.NUM_CATEGORY = 80 # 80 categories.
_C.DATA.CLASS_NAMES = [] # NUM_CLASS strings. Needs to be populated later by data loader
# basemodel ----------------------
......
......@@ -399,7 +399,6 @@ class ResNetFPNModel(DetectionModel):
def visualize(model, model_path, nr_visualize=100, output_dir='output'):
"""
Visualize some intermediate results (proposals, raw predictions) inside the pipeline.
Does not support FPN.
"""
df = get_train_dataflow() # we don't visualize mask stuff
df.reset_state()
......
......@@ -57,7 +57,7 @@ def fbresnet_augmentor(isTrain):
if isTrain:
augmentors = [
GoogleNetResize(),
# It's OK to remove these augs if your CPU is not fast enough.
# It's OK to remove the following augs if your CPU is not fast enough.
# Removing brightness/contrast/saturation does not have a significant effect on accuracy.
# Removing lighting leads to a tiny drop in accuracy.
imgaug.RandomOrderAug(
......
......@@ -48,7 +48,14 @@ class PlaceholderInput(InputSource):
class FeedInput(InputSource):
""" Input by iterating over a DataFlow and feed datapoints. """
"""
Input by iterating over a DataFlow and feed datapoints.
Note:
If `get_input_tensors()` is called more than one time, it will return the same placeholders (i.e. feed points)
as the first time.
Therefore you can't use it for data-parallel training.
"""
class _FeedCallback(Callback):
def __init__(self, ds, placeholders):
......
......@@ -14,7 +14,7 @@ from ..utils.develop import HIDE_DOC
from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..tfutils.tower import TrainTowerContext
from ..input_source import QueueInput
from ..input_source import QueueInput, FeedfreeInput
from ..graph_builder.training import (
SyncMultiGPUParameterServerBuilder,
......@@ -59,7 +59,7 @@ class SimpleTrainer(SingleCostTrainer):
# Only exists for type check & back-compatibility
class QueueInputTrainer(SimpleTrainer):
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
assert isinstance(input, QueueInput)
assert isinstance(input, QueueInput), input
return super(QueueInputTrainer, self)._setup_graph(input, get_cost_fn, get_opt_fn)
......@@ -87,6 +87,8 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
super(SyncMultiGPUTrainerParameterServer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
if len(self.devices) > 1:
assert isinstance(input, FeedfreeInput), input
self.train_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
return []
......@@ -124,6 +126,8 @@ class AsyncMultiGPUTrainer(SingleCostTrainer):
super(AsyncMultiGPUTrainer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
if len(self.devices) > 1:
assert isinstance(input, FeedfreeInput), input
self.train_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
return []
......@@ -162,6 +166,8 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
super(SyncMultiGPUTrainerReplicated, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
if len(self.devices) > 1:
assert isinstance(input, FeedfreeInput), input
self.train_op, post_init_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
......@@ -220,6 +226,7 @@ class DistributedTrainerParameterServer(DistributedTrainerBase):
self.is_chief = self._builder.is_chief
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
assert isinstance(input, FeedfreeInput), input
self.train_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
return []
......@@ -255,6 +262,7 @@ class DistributedTrainerReplicated(DistributedTrainerBase):
return input.setup(inputs_desc)
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
assert isinstance(input, FeedfreeInput), input
self.train_op, initial_sync_op, model_sync_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment