Commit f0273bee authored by Yuxin Wu's avatar Yuxin Wu

use pickle to dump inputvars

parent 6b85a1f1
...@@ -160,7 +160,7 @@ def get_config(): ...@@ -160,7 +160,7 @@ def get_config():
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)]) ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
]), ]),
model=Model(), model=Model(),
step_per_epoch=300, step_per_epoch=dataset.size(),
max_epoch=300, max_epoch=300,
) )
......
...@@ -66,7 +66,6 @@ class Model(ModelDesc): ...@@ -66,7 +66,6 @@ class Model(ModelDesc):
tf.concat(3, [b1, b2, b3, b4, b5]), 1, 1, tf.concat(3, [b1, b2, b3, b4, b5]), 1, 1,
W_init=tf.constant_initializer(0.2), W_init=tf.constant_initializer(0.2),
use_bias=False, nl=tf.identity) use_bias=False, nl=tf.identity)
final_map = tf.squeeze(final_map, [3], name='predmap')
costs = [] costs = []
for idx, b in enumerate([b1, b2, b3, b4, b5, final_map]): for idx, b in enumerate([b1, b2, b3, b4, b5, final_map]):
output = tf.nn.sigmoid(b, name='output{}'.format(idx+1)) output = tf.nn.sigmoid(b, name='output{}'.format(idx+1))
......
...@@ -93,7 +93,7 @@ class InferenceRunner(Callback): ...@@ -93,7 +93,7 @@ class InferenceRunner(Callback):
def _find_input_tensors(self): def _find_input_tensors(self):
if self.input_tensors is None: if self.input_tensors is None:
input_vars = self.trainer.model.reuse_input_vars() input_vars = self.trainer.model.get_input_vars()
self.input_tensors = [x.name for x in input_vars] self.input_tensors = [x.name for x in input_vars]
def _find_output_tensors(self): def _find_output_tensors(self):
......
...@@ -19,7 +19,7 @@ class ImageFromFile(RNGDataFlow): ...@@ -19,7 +19,7 @@ class ImageFromFile(RNGDataFlow):
:param channel: 1 or 3 channel :param channel: 1 or 3 channel
:param resize: a (h, w) tuple. If given, will force a resize :param resize: a (h, w) tuple. If given, will force a resize
""" """
assert len(files) assert len(files), "No Image Files!"
self.files = files self.files = files
self.channel = int(channel) self.channel = int(channel)
self.imread_mode = cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR self.imread_mode = cv2.IMREAD_GRAYSCALE if self.channel == 1 else cv2.IMREAD_COLOR
......
...@@ -154,8 +154,9 @@ class RandomCropRandomShape(ImageAugmentor): ...@@ -154,8 +154,9 @@ class RandomCropRandomShape(ImageAugmentor):
h = self.rng.randint(self.hmin, hmax+1) h = self.rng.randint(self.hmin, hmax+1)
w = self.rng.randint(self.wmin, wmax+1) w = self.rng.randint(self.wmin, wmax+1)
diffh = img.shape[0] - h diffh = img.shape[0] - h
y0 = 0 if diffh == 0 else self.rng.randint(diffh)
diffw = img.shape[1] - w diffw = img.shape[1] - w
assert diffh >= 0 and diffw >= 0
y0 = 0 if diffh == 0 else self.rng.randint(diffh)
x0 = 0 if diffw == 0 else self.rng.randint(diffw) x0 = 0 if diffw == 0 else self.rng.randint(diffw)
return (y0,x0,h,w) return (y0,x0,h,w)
......
...@@ -8,6 +8,7 @@ import re ...@@ -8,6 +8,7 @@ import re
import tensorflow as tf import tensorflow as tf
from collections import namedtuple from collections import namedtuple
import inspect import inspect
import pickle
from ..utils import logger, INPUT_VARS_KEY from ..utils import logger, INPUT_VARS_KEY
from ..tfutils.common import get_tensors_by_names from ..tfutils.common import get_tensors_by_names
...@@ -16,7 +17,13 @@ from ..tfutils.tower import get_current_tower_context ...@@ -16,7 +17,13 @@ from ..tfutils.tower import get_current_tower_context
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph' ] __all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph' ]
InputVar = namedtuple('InputVar', ['type', 'shape', 'name']) _InputVar = namedtuple('InputVar', ['type', 'shape', 'name'])
class InputVar(_InputVar):
def dumps(self):
return pickle.dumps(self)
@staticmethod
def loads(buf):
return pickle.loads(buf)
class ModelDesc(object): class ModelDesc(object):
""" Base class for a model description """ """ Base class for a model description """
...@@ -29,17 +36,17 @@ class ModelDesc(object): ...@@ -29,17 +36,17 @@ class ModelDesc(object):
:returns: the list of raw input vars in the graph :returns: the list of raw input vars in the graph
""" """
try: try:
return self.reuse_input_vars() return self._reuse_input_vars()
except KeyError: except KeyError:
pass pass
ret = self.get_placeholders() return self.get_placeholders()
for v in ret:
tf.add_to_collection(INPUT_VARS_KEY, v)
return ret
def get_placeholders(self, prefix=''): def get_placeholders(self, prefix=''):
""" build placeholders with optional prefix, for each InputVar""" """ build placeholders with optional prefix, for each InputVar
"""
input_vars = self._get_input_vars() input_vars = self._get_input_vars()
for v in input_vars:
tf.add_to_collection(INPUT_VARS_KEY, v.dumps())
ret = [] ret = []
for v in input_vars: for v in input_vars:
ret.append(tf.placeholder( ret.append(tf.placeholder(
...@@ -47,7 +54,7 @@ class ModelDesc(object): ...@@ -47,7 +54,7 @@ class ModelDesc(object):
name=prefix + v.name)) name=prefix + v.name))
return ret return ret
def reuse_input_vars(self): def _reuse_input_vars(self):
""" Find and return already-defined input_vars in default graph""" """ Find and return already-defined input_vars in default graph"""
input_var_names = [k.name for k in self._get_input_vars()] input_var_names = [k.name for k in self._get_input_vars()]
return get_tensors_by_names(input_var_names) return get_tensors_by_names(input_var_names)
...@@ -104,11 +111,10 @@ class ModelFromMetaGraph(ModelDesc): ...@@ -104,11 +111,10 @@ class ModelFromMetaGraph(ModelDesc):
assert k in all_coll, \ assert k in all_coll, \
"Collection {} not found in metagraph!".format(k) "Collection {} not found in metagraph!".format(k)
def get_input_vars(self):
return tf.get_collection(INPUT_VARS_KEY)
def _get_input_vars(self): def _get_input_vars(self):
raise NotImplementedError("Shouldn't call here") col = tf.get_collection(INPUT_VARS_KEY)
col = [InputVar.loads(v) for v in col]
return col
def _build_graph(self, _, __): def _build_graph(self, _, __):
""" Do nothing. Graph was imported already """ """ Do nothing. Graph was imported already """
......
...@@ -105,7 +105,7 @@ def add_moving_summary(v, *args): ...@@ -105,7 +105,7 @@ def add_moving_summary(v, *args):
@memoized @memoized
def summary_moving_average(tensors=None): def summary_moving_average(tensors=None):
""" """
Create a MovingAverage op and summary for tensors Create a MovingAverage op and add summary for tensors
:param tensors: list of tf.Tensor to summary. default to the collection MOVING_SUMMARY_VARS_KEY :param tensors: list of tf.Tensor to summary. default to the collection MOVING_SUMMARY_VARS_KEY
:returns: a op to maintain these average. :returns: a op to maintain these average.
""" """
......
...@@ -8,7 +8,7 @@ import os.path ...@@ -8,7 +8,7 @@ import os.path
def global_import(name): def global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else []
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
del globals()[name] del globals()[name]
......
...@@ -24,20 +24,25 @@ class StopTraining(BaseException): ...@@ -24,20 +24,25 @@ class StopTraining(BaseException):
pass pass
class Trainer(object): class Trainer(object):
""" """ Base class for a trainer."""
Base class for a trainer.
Available Attritbutes:
stat_holder: a `StatHolder` instance
summary_writer: a `tf.SummaryWriter`
summary_op: a `tf.Operation` which returns summary string
config: a `TrainConfig`
model: a `ModelDesc`
sess: a `tf.Session`
coord: a `tf.train.Coordinator`
"""
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
"""a `StatHolder` instance"""
stat_holder = None
"""`tf.SummaryWriter`"""
summary_writer = None
"""a tf.Tensor which returns summary string"""
summary_op = None
""" TrainConfig """
config = None
""" a ModelDesc"""
model = None
""" the current session"""
sess = None
""" the `tf.train.Coordinator` """
coord = None
def __init__(self, config): def __init__(self, config):
""" """
:param config: a `TrainConfig` instance :param config: a `TrainConfig` instance
......
...@@ -147,10 +147,11 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase, ...@@ -147,10 +147,11 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase,
for th in self.training_threads: for th in self.training_threads:
th.pause() th.pause()
try: try:
async_step_total_cnt = int(re.findall( if self.config.tower > 1:
'[0-9]+', self.async_step_counter.__str__())[0]) async_step_total_cnt = int(re.findall(
self.write_scalar_summary( '[0-9]+', self.async_step_counter.__str__())[0])
'async_global_step', async_step_total_cnt) self.write_scalar_summary(
'async_global_step', async_step_total_cnt)
except: except:
logger.exception("Cannot log async_global_step") logger.exception("Cannot log async_global_step")
super(AsyncMultiGPUTrainer, self)._trigger_epoch() super(AsyncMultiGPUTrainer, self)._trigger_epoch()
...@@ -63,9 +63,11 @@ class QueueInputTrainerBase(FeedlessTrainer): ...@@ -63,9 +63,11 @@ class QueueInputTrainerBase(FeedlessTrainer):
def _build_enque_thread(self, input_queue=None): def _build_enque_thread(self, input_queue=None):
""" create a thread that keeps filling the queue """ """ create a thread that keeps filling the queue """
self.input_vars = self.model.get_input_vars() self.input_vars = self.model.get_input_vars()
assert len(self.input_vars) > 0, "QueueInput can only be used with input placeholders!"
if input_queue is None: if input_queue is None:
self.input_queue = tf.FIFOQueue( self.input_queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_vars], name='input_queue') 50, [x.dtype for x in self.input_vars],
name='input_queue')
else: else:
self.input_queue = input_queue self.input_queue = input_queue
input_th = EnqueueThread(self) input_th = EnqueueThread(self)
......
...@@ -125,6 +125,7 @@ class FeedlessTrainer(Trainer): ...@@ -125,6 +125,7 @@ class FeedlessTrainer(Trainer):
""" return a list of actual input tensors. """ return a list of actual input tensors.
Always return new tensors (for multi tower) if called mutliple times. Always return new tensors (for multi tower) if called mutliple times.
""" """
pass
class SingleCostFeedlessTrainer(FeedlessTrainer): class SingleCostFeedlessTrainer(FeedlessTrainer):
def _get_cost_and_grad(self): def _get_cost_and_grad(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment