Commit 1f02847d authored by Yuxin Wu's avatar Yuxin Wu

fix inputvar problems with py3, and lmdb bug

parent 88ed2c24
......@@ -33,13 +33,13 @@ Describe your training task with three components:
+ Print some variables of interest
+ Run inference on a test dataset
+ Run some operations once a while
+ Send the accuracy to your phone
+ Send loss to your phone
With the above components defined, tensorpack trainer will run the training iterations for you.
Multi-GPU training is off-the-shelf by simply switching the trainer.
You can also define your own trainer for non-standard training (e.g. GAN).
The components are designed to be independent. You can use only Model or DataFlow in your project.
The components are designed to be independent. You can use Model or DataFlow in other projects as well.
## Dependencies:
......
......@@ -20,8 +20,9 @@ from GAN import GANTrainer, RandomZData, build_GAN_losses
To train:
./Image2Image.py --data /path/to/datadir --mode {AtoB,BtoA}
# datadir should contain 512x256 images formed by A and B
# training visualization will appear be in tensorboard
To visualize:
To visualize on test set:
./Image2Image.py --sample --data /path/to/test/datadir --mode {AtoB,BtoA} --load pretrained.model
"""
......
......@@ -128,7 +128,7 @@ class LMDBDataDecoder(LMDBData):
class LMDBDataPoint(LMDBDataDecoder):
""" Read a LMDB file where each value is a serialized datapoint"""
def __init__(self, lmdb_path, shuffle=True):
super(SimpleLMDBLoader, self).__init__(
super(LMDBDataPoint, self).__init__(
lmdb_path, decoder=lambda k, v: loads(v), shuffle=shuffle)
class CaffeLMDB(LMDBDataDecoder):
......
......@@ -149,7 +149,7 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta,
epsilon=epsilon, is_training=ctx.is_training)
epsilon=epsilon, is_training=True)
if ctx.is_training:
# maintain EMA if training
update_op1 = moving_averages.assign_moving_average(
......
......@@ -17,12 +17,13 @@ from ..tfutils.tower import get_current_tower_context
__all__ = ['ModelDesc', 'InputVar', 'ModelFromMetaGraph' ]
_InputVar = namedtuple('InputVar', ['type', 'shape', 'name', 'sparse'])
class InputVar(_InputVar):
#_InputVar = namedtuple('InputVar', ['type', 'shape', 'name', 'sparse'])
class InputVar(object):
def __init__(self, type, shape, name, sparse=False):
super(InputVar, self).__init__(type, shape, name, sparse)
def __new__(cls, type, shape, name, sparse=False):
return super(InputVar, cls).__new__(cls, type, shape, name, sparse)
self.type = type
self.shape = shape
self.name = name
self.sparse = sparse
def dumps(self):
return pickle.dumps(self)
@staticmethod
......
......@@ -97,7 +97,7 @@ class CaffeLayerProcessor(object):
def load_caffe(model_desc, model_file):
"""
return a dict of params
:return: a dict of params
"""
with change_env('GLOG_minloglevel', '2'):
import caffe
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment