Commit 266ce353 authored by Yuxin Wu's avatar Yuxin Wu

update setup.py

parent 880b20e9
...@@ -150,7 +150,7 @@ class ImageNetModel(ModelDesc): ...@@ -150,7 +150,7 @@ class ImageNetModel(ModelDesc):
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
image = self.image_preprocess(image, bgr=True) image = ImageNetModel.image_preprocess(image, bgr=True)
if self.data_format == 'NCHW': if self.data_format == 'NCHW':
image = tf.transpose(image, [0, 3, 1, 2]) image = tf.transpose(image, [0, 3, 1, 2])
...@@ -181,7 +181,8 @@ class ImageNetModel(ModelDesc): ...@@ -181,7 +181,8 @@ class ImageNetModel(ModelDesc):
tf.summary.scalar('learning_rate', lr) tf.summary.scalar('learning_rate', lr)
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True) return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
def image_preprocess(self, image, bgr=True): @staticmethod
def image_preprocess(image, bgr=True):
with tf.name_scope('image_preprocess'): with tf.name_scope('image_preprocess'):
if image.dtype.base_dtype != tf.float32: if image.dtype.base_dtype != tf.float32:
image = tf.cast(image, tf.float32) image = tf.cast(image, tf.float32)
......
...@@ -35,6 +35,6 @@ setup( ...@@ -35,6 +35,6 @@ setup(
'all: python_version < "3.0"': ['tornado'] 'all: python_version < "3.0"': ['tornado']
}, },
include_package_data=True, #include_package_data=True,
package_data={'tensorpack': ['user_ops/Makefile', 'user_ops/*.cc', 'user_ops/*.h']}, #package_data={'tensorpack': []},
) )
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment