Commit e2be301b authored by Yuxin Wu's avatar Yuxin Wu

fix import

parent e0e29779
......@@ -8,9 +8,14 @@ import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer
import tensorpack as tp
from tensorpack import imgaug
from tensorpack.tfutils import argscope
from tensorpack.models import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.models import (
Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm, BNReLU,
LinearWrap)
from tensorpack.predict import PredictConfig, SimpleDatasetPredictor
class GoogleNetResize(imgaug.ImageAugmentor):
......@@ -137,12 +142,12 @@ def resnet_backbone(image, num_blocks, block_func):
def eval_on_ILSVRC12(model, model_file, dataflow):
pred_config = PredictConfig(
model=model,
session_init=get_model_loader(model_file),
session_init=tp.get_model_loader(model_file),
input_names=['input', 'label'],
output_names=['wrong-top1', 'wrong-top5']
)
pred = SimpleDatasetPredictor(pred_config, dataflow)
acc1, acc5 = RatioCounter(), RatioCounter()
acc1, acc5 = tp.RatioCounter(), tp.RatioCounter()
for o in pred.get_result():
batch_size = o[0].shape[0]
acc1.feed(o[0].sum(), batch_size)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment