Commit 6dce5c4b authored by Yuxin Wu's avatar Yuxin Wu

update docs. cleanup scripts

parent acae3fe5
......@@ -13,7 +13,7 @@ It's Yet Another TF wrapper, but different in:
1. Focus on __training speed__.
+ Speed comes for free with tensorpack -- it uses TensorFlow in the __efficient way__ with no extra overhead.
On various CNNs, it runs 1.5~1.7x faster than the equivalent Keras code.
On various CNNs, it runs 1.1~2x faster than the equivalent Keras code.
+ Data-parallel multi-GPU training is off-the-shelf to use. It runs as fast as Google's [official benchmark](https://www.tensorflow.org/performance/benchmarks).
......
......@@ -12,10 +12,10 @@ for other FCN tasks such as semantic segmentation and detection.
## Usage
This script only needs the original BSDS dataset and applies augmentation on the fly.
This script needs the original BSDS dataset and applies augmentation on the fly.
It will automatically download the dataset to `$TENSORPACK_DATASET/` if not there.
It requires pretrained vgg16 model. See the docs in [examples/load-vgg16.py](../load-vgg16.py)
It requires pretrained vgg16 model. See the docs in [examples/CaffeModels](../CaffeModels)
for instructions to convert from vgg16 caffe model.
To view augmented training images:
......@@ -34,11 +34,3 @@ To inference (produce a heatmap at each level at out*.png):
./hed.py --load pretrained.model --run a.jpg
```
Models I trained can be downloaded [here](http://models.tensorpack.com/HED/).
To view the loss curve:
```bash
cat train_log/hed/stat.json | jq '.[] |
"\(.xentropy1)\t\(.xentropy2)\t\(.xentropy3)\t\(.xentropy4)\t\(.xentropy5)\t\(.xentropy6)"' -r | \
tpk-plot-point --legend 1,2,3,4,5,final --decay 0.8
```
Or just open tensorboard.
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: dump-dataflow.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import argparse
import cv2
import imp
import tqdm
import os
from tensorpack.utils import logger
from tensorpack.utils.fs import mkdir_p
from tensorpack.dataflow import RepeatedData
parser = argparse.ArgumentParser()
parser.add_argument(dest='config')
parser.add_argument('-o', '--output',
help='output directory to dump dataset image. If not given, will not dump images.')
parser.add_argument('-s', '--scale',
help='scale the image data (maybe by 255)', default=1, type=int)
parser.add_argument('--index',
help='index of the image component in datapoint',
default=0, type=int)
parser.add_argument('-n', '--number', help='number of images to dump',
default=10, type=int)
args = parser.parse_args()
logger.auto_set_dir(action='d')
get_config_func = imp.load_source('config_script', args.config).get_config
config = get_config_func()
config.dataset.reset_state()
if args.output:
mkdir_p(args.output)
cnt = 0
index = args.index # TODO: as an argument?
for dp in config.dataset.get_data():
imgbatch = dp[index]
if cnt > args.number:
break
for bi, img in enumerate(imgbatch):
cnt += 1
fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi))
cv2.imwrite(fname, img * args.scale)
NR_DP_TEST = args.number
logger.info("Testing dataflow speed:")
ds = RepeatedData(config.dataset, -1)
with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar:
for idx, dp in enumerate(ds.get_data()):
del dp
if idx > NR_DP_TEST:
break
pbar.update()
This diff is collapsed.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: serve-data.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import argparse
import imp
from tensorpack.dataflow import serve_data
parser = argparse.ArgumentParser()
parser.add_argument(dest='config')
parser.add_argument('-p', '--port', help='port', type=int, required=True)
args = parser.parse_args()
get_config_func = imp.load_source('config_script', args.config).get_config
config = get_config_func()
ds = config.dataset
serve_data(ds, "tcp://*:{}".format(args.port))
......@@ -515,7 +515,8 @@ class StagingInput(FeedfreeInput):
element should be sufficient.
towers: deprecated
device (str or None): if not None, place the StagingArea on a specific device. e.g., '/cpu:0'.
Otherwise, they are placed under where `get_inputs_tensors` gets called.
Otherwise, they are placed under where `get_inputs_tensors`
gets called, which could be unspecified in case of simple trainers.
"""
assert isinstance(input, FeedfreeInput), input
self._input = input
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment