Commit 34c12fc9 authored by ppwwyyxx's avatar ppwwyyxx

saver accept checkpoint path

parent 0bdc9985
......@@ -39,6 +39,7 @@ with tf.Graph().as_default() as G:
ds = ImageFromFile(args.images, 3, resize=(227, 227))
predictor = DatasetPredictor(config, ds, batch=128)
res = predictor.get_all_result()
res = [k[1] for k in res]
if args.output_type == 'label':
for r in res:
......
......@@ -7,7 +7,7 @@ import tensorflow as tf
from itertools import count
import argparse
import numpy as np
import tqdm
from tqdm import tqdm
from utils import *
from utils.modelutils import describe_model
......@@ -118,7 +118,7 @@ class DatasetPredictor(object):
""" a generator to return prediction for each data"""
with tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data():
yield self.predict_func(dp)
yield [dp, self.predict_func(dp)]
pbar.update()
def get_all_result(self):
......
......@@ -3,6 +3,7 @@
# File: sessinit.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os
from abc import abstractmethod, ABCMeta
import numpy as np
import tensorflow as tf
......@@ -21,6 +22,11 @@ class NewSession(SessionInit):
class SaverRestore(SessionInit):
def __init__(self, model_path):
assert os.path.isfile(model_path)
if os.path.basename(model_path) == 'checkpoint':
model_path = tf.train.get_checkpoint_state(
os.path.dirname(model_path)).model_checkpoint_path
assert os.path.isfile(model_path)
self.set_path(model_path)
def init(self, sess):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment