Commit 0c519bdc authored by Yuxin Wu's avatar Yuxin Wu

Let Keras Model accept a real "tower func" with positional args (#739)

parent 3ef33a34
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: utf-8 -*-
# File: imagenet-resnet.py # File: imagenet-resnet.py
import argparse import argparse
......
...@@ -85,8 +85,8 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)) ...@@ -85,8 +85,8 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
return x return x
def resnet50(inputs): def resnet50(image):
input = tf.layers.Input(tensor=inputs[0]) input = Input(tensor=image)
def image_preprocess(image): def image_preprocess(image):
image = ImageNetModel.image_preprocess(image) image = ImageNetModel.image_preprocess(image)
......
...@@ -31,16 +31,16 @@ def get_data(): ...@@ -31,16 +31,16 @@ def get_data():
if __name__ == '__main__': if __name__ == '__main__':
logger.auto_set_dir() logger.auto_set_dir('d')
def model_func(inputs): def model_func(image):
""" """
Keras model has to be created inside this function to be used with tensorpack. Keras model has to be created inside this function to be used with tensorpack.
""" """
M = keras.models.Sequential() M = keras.models.Sequential()
# input_tensor have to be used here for tensorpack trainer to function properly. # input_tensor have to be used here for tensorpack trainer to function properly.
# Just use inputs[1], inputs[2] if you have multiple inputs. # Just use inputs[1], inputs[2] if you have multiple inputs.
M.add(KL.InputLayer(input_tensor=inputs[0])) M.add(KL.InputLayer(input_tensor=image))
M.add(KL.Conv2D(32, 3, activation='relu', padding='same')) M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.MaxPooling2D()) M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, activation='relu', padding='same')) M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
......
...@@ -47,13 +47,15 @@ class KerasModelCaller(object): ...@@ -47,13 +47,15 @@ class KerasModelCaller(object):
def __call__(self, input_tensors): def __call__(self, input_tensors):
""" """
Args:
input_tensors ([tf.Tensor])
Returns: Returns:
output tensors of this tower, evaluated with the input tensors. output tensors of this tower, evaluated with the input tensors.
""" """
reuse = tf.get_variable_scope().reuse reuse = tf.get_variable_scope().reuse
if self.cached_model is None: if self.cached_model is None:
assert not reuse assert not reuse
self.cached_model = self.get_model(input_tensors) self.cached_model = self.get_model(*input_tensors)
return self.cached_model.outputs return self.cached_model.outputs
if reuse: if reuse:
...@@ -63,7 +65,7 @@ class KerasModelCaller(object): ...@@ -63,7 +65,7 @@ class KerasModelCaller(object):
return self.cached_model.call(input_tensors) return self.cached_model.call(input_tensors)
else: else:
# create new Keras model if not reuse # create new Keras model if not reuse
M = self.get_model(input_tensors) M = self.get_model(*input_tensors)
return M.outputs return M.outputs
...@@ -99,7 +101,8 @@ def setup_keras_trainer( ...@@ -99,7 +101,8 @@ def setup_keras_trainer(
""" """
Args: Args:
trainer (SingleCostTrainer): trainer (SingleCostTrainer):
get_model ( -> keras.model.Model): get_model (input1, input2, ... -> keras.model.Model):
Takes tensors and returns a Keras model. Will be part of the tower function.
input (InputSource): input (InputSource):
optimizer (tf.tarin.Optimizer): optimizer (tf.tarin.Optimizer):
loss, metrics: list of strings loss, metrics: list of strings
...@@ -175,7 +178,8 @@ class KerasModel(object): ...@@ -175,7 +178,8 @@ class KerasModel(object):
input, trainer=None): input, trainer=None):
""" """
Args: Args:
get_model ( -> keras.model.Model): get_model (input1, input2, ... -> keras.model.Model):
Takes tensors and returns a Keras model. Will be part of the tower function.
inputs_desc ([InputDesc]): inputs_desc ([InputDesc]):
targets_desc ([InputDesc]): targets_desc ([InputDesc]):
input (InputSource | DataFlow): input (InputSource | DataFlow):
......
# -*- coding: UTF-8 -*- # -*- coding: utf-8 -*-
# File: common.py # File: common.py
from __future__ import division from __future__ import division
import six import six
import numpy as np import numpy as np
......
# -*- coding: UTF-8 -*- # -*- coding: utf-8 -*-
# File: config.py # File: config.py
...@@ -35,7 +35,7 @@ class PredictConfig(object): ...@@ -35,7 +35,7 @@ class PredictConfig(object):
Args: Args:
model (ModelDescBase): to be used to obtain inputs_desc and tower_func. model (ModelDescBase): to be used to obtain inputs_desc and tower_func.
inputs_desc ([InputDesc]): inputs_desc ([InputDesc]):
tower_func: a callable which takes input tensors and construct a tower. tower_func: a callable which takes input tensors (by positional args) and construct a tower.
input_names (list): a list of input tensor names. Defaults to match inputs_desc. input_names (list): a list of input tensor names. Defaults to match inputs_desc.
output_names (list): a list of names of the output tensors to predict, the output_names (list): a list of names of the output tensors to predict, the
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment