Commit 34c12fc9 authored by ppwwyyxx's avatar ppwwyyxx

saver accept checkpoint path

parent 0bdc9985
...@@ -39,6 +39,7 @@ with tf.Graph().as_default() as G: ...@@ -39,6 +39,7 @@ with tf.Graph().as_default() as G:
ds = ImageFromFile(args.images, 3, resize=(227, 227)) ds = ImageFromFile(args.images, 3, resize=(227, 227))
predictor = DatasetPredictor(config, ds, batch=128) predictor = DatasetPredictor(config, ds, batch=128)
res = predictor.get_all_result() res = predictor.get_all_result()
res = [k[1] for k in res]
if args.output_type == 'label': if args.output_type == 'label':
for r in res: for r in res:
......
...@@ -7,7 +7,7 @@ import tensorflow as tf ...@@ -7,7 +7,7 @@ import tensorflow as tf
from itertools import count from itertools import count
import argparse import argparse
import numpy as np import numpy as np
import tqdm from tqdm import tqdm
from utils import * from utils import *
from utils.modelutils import describe_model from utils.modelutils import describe_model
...@@ -118,7 +118,7 @@ class DatasetPredictor(object): ...@@ -118,7 +118,7 @@ class DatasetPredictor(object):
""" a generator to return prediction for each data""" """ a generator to return prediction for each data"""
with tqdm(total=self.ds.size()) as pbar: with tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
yield self.predict_func(dp) yield [dp, self.predict_func(dp)]
pbar.update() pbar.update()
def get_all_result(self): def get_all_result(self):
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# File: sessinit.py # File: sessinit.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -21,6 +22,11 @@ class NewSession(SessionInit): ...@@ -21,6 +22,11 @@ class NewSession(SessionInit):
class SaverRestore(SessionInit): class SaverRestore(SessionInit):
def __init__(self, model_path): def __init__(self, model_path):
assert os.path.isfile(model_path)
if os.path.basename(model_path) == 'checkpoint':
model_path = tf.train.get_checkpoint_state(
os.path.dirname(model_path)).model_checkpoint_path
assert os.path.isfile(model_path)
self.set_path(model_path) self.set_path(model_path)
def init(self, sess): def init(self, sess):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment