Commit f0273bee authored by Yuxin Wu's avatar Yuxin Wu

use pickle to dump inputvars

parent 6b85a1f1
......@@ -160,7 +160,7 @@ def get_config():
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
]),
model=Model(),
step_per_epoch=300,
step_per_epoch=dataset.size(),
max_epoch=300,
)
......
......@@ -66,7 +66,6 @@ class Model(ModelDesc):
tf.concat(3, [b1, b2, b3, b4, b5]), 1, 1,
W_init=tf.constant_initializer(0.2),
use_bias=False, nl=tf.identity)
final_map = tf.squeeze(final_map, [3], name='predmap')
costs = []
for idx, b in enumerate([b1, b2, b3, b4, b5, final_map]):
output = tf.nn.sigmoid(b, name='output{}'.format(idx+1))
......
......@@ -93,7 +93,7 @@ class InferenceRunner(Callback):
def _find_input_tensors(self):
if self.input_tensors is None:
input_vars = self.trainer.model.reuse_input_vars()
input_vars = self.trainer.model.get_input_vars()
self.input_tensors = [x.name for x in input_vars]
def _find_output_tensors(self):
......
......@@ -19,7 +19,7 @@ class ImageFromFile(RNGDataFlow):
:param channel: 1 or 3 channel
:param resize: a (h, w) tuple. If given, will force a resize
"""
assert len(files)
assert len(files), "No Image Files!"
self.files = files
self.channel = int(channel)
self.imread_mode = cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR
......
......@@ -154,8 +154,9 @@ class RandomCropRandomShape(ImageAugmentor):
h = self.rng.randint(self.hmin, hmax+1)
w = self.rng.randint(self.wmin, wmax+1)
diffh = img.shape[0] - h
y0 = 0 if diffh == 0 else self.rng.randint(diffh)
diffw = img.shape[1] - w
assert diffh >= 0 and diffw >= 0
y0 = 0 if diffh == 0 else self.rng.randint(diffh)
x0 = 0 if diffw == 0 else self.rng.randint(diffw)
return (y0,x0,h,w)
......
......@@ -8,6 +8,7 @@ import re
import tensorflow as tf
from collections import namedtuple
import inspect
import pickle
from ..utils import logger, INPUT_VARS_KEY
from ..tfutils.common import get_tensors_by_names
......@@ -16,7 +17,13 @@ from ..tfutils.tower import get_current_tower_context
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph' ]
InputVar = namedtuple('InputVar', ['type', 'shape', 'name'])
_InputVar = namedtuple('InputVar', ['type', 'shape', 'name'])
class InputVar(_InputVar):
def dumps(self):
return pickle.dumps(self)
@staticmethod
def loads(buf):
return pickle.loads(buf)
class ModelDesc(object):
""" Base class for a model description """
......@@ -29,17 +36,17 @@ class ModelDesc(object):
:returns: the list of raw input vars in the graph
"""
try:
return self.reuse_input_vars()
return self._reuse_input_vars()
except KeyError:
pass
ret = self.get_placeholders()
for v in ret:
tf.add_to_collection(INPUT_VARS_KEY, v)
return ret
return self.get_placeholders()
def get_placeholders(self, prefix=''):
""" build placeholders with optional prefix, for each InputVar"""
""" build placeholders with optional prefix, for each InputVar
"""
input_vars = self._get_input_vars()
for v in input_vars:
tf.add_to_collection(INPUT_VARS_KEY, v.dumps())
ret = []
for v in input_vars:
ret.append(tf.placeholder(
......@@ -47,7 +54,7 @@ class ModelDesc(object):
name=prefix + v.name))
return ret
def reuse_input_vars(self):
def _reuse_input_vars(self):
""" Find and return already-defined input_vars in default graph"""
input_var_names = [k.name for k in self._get_input_vars()]
return get_tensors_by_names(input_var_names)
......@@ -104,11 +111,10 @@ class ModelFromMetaGraph(ModelDesc):
assert k in all_coll, \
"Collection {} not found in metagraph!".format(k)
def get_input_vars(self):
return tf.get_collection(INPUT_VARS_KEY)
def _get_input_vars(self):
raise NotImplementedError("Shouldn't call here")
col = tf.get_collection(INPUT_VARS_KEY)
col = [InputVar.loads(v) for v in col]
return col
def _build_graph(self, _, __):
""" Do nothing. Graph was imported already """
......
......@@ -105,7 +105,7 @@ def add_moving_summary(v, *args):
@memoized
def summary_moving_average(tensors=None):
"""
Create a MovingAverage op and summary for tensors
Create a MovingAverage op and add summary for tensors
:param tensors: list of tf.Tensor to summary. default to the collection MOVING_SUMMARY_VARS_KEY
:returns: a op to maintain these average.
"""
......
......@@ -8,7 +8,7 @@ import os.path
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
lst = p.__all__ if '__all__' in dir(p) else []
for k in lst:
globals()[k] = p.__dict__[k]
del globals()[name]
......
......@@ -24,20 +24,25 @@ class StopTraining(BaseException):
pass
class Trainer(object):
"""
Base class for a trainer.
Available Attritbutes:
stat_holder: a `StatHolder` instance
summary_writer: a `tf.SummaryWriter`
summary_op: a `tf.Operation` which returns summary string
config: a `TrainConfig`
model: a `ModelDesc`
sess: a `tf.Session`
coord: a `tf.train.Coordinator`
"""
""" Base class for a trainer."""
__metaclass__ = ABCMeta
"""a `StatHolder` instance"""
stat_holder = None
"""`tf.SummaryWriter`"""
summary_writer = None
"""a tf.Tensor which returns summary string"""
summary_op = None
""" TrainConfig """
config = None
""" a ModelDesc"""
model = None
""" the current session"""
sess = None
""" the `tf.train.Coordinator` """
coord = None
def __init__(self, config):
"""
:param config: a `TrainConfig` instance
......
......@@ -147,6 +147,7 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase,
for th in self.training_threads:
th.pause()
try:
if self.config.tower > 1:
async_step_total_cnt = int(re.findall(
'[0-9]+', self.async_step_counter.__str__())[0])
self.write_scalar_summary(
......
......@@ -63,9 +63,11 @@ class QueueInputTrainerBase(FeedlessTrainer):
def _build_enque_thread(self, input_queue=None):
""" create a thread that keeps filling the queue """
self.input_vars = self.model.get_input_vars()
assert len(self.input_vars) > 0, "QueueInput can only be used with input placeholders!"
if input_queue is None:
self.input_queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_vars], name='input_queue')
50, [x.dtype for x in self.input_vars],
name='input_queue')
else:
self.input_queue = input_queue
input_th = EnqueueThread(self)
......
......@@ -125,6 +125,7 @@ class FeedlessTrainer(Trainer):
""" return a list of actual input tensors.
Always return new tensors (for multi tower) if called mutliple times.
"""
pass
class SingleCostFeedlessTrainer(FeedlessTrainer):
def _get_cost_and_grad(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment