Commit 0c519bdc authored by Yuxin Wu's avatar Yuxin Wu

Let Keras Model accept a real "tower func" with positional args (#739)

parent 3ef33a34
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# -*- coding: utf-8 -*-
# File: imagenet-resnet.py
import argparse
......
......@@ -85,8 +85,8 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
return x
def resnet50(inputs):
input = tf.layers.Input(tensor=inputs[0])
def resnet50(image):
input = Input(tensor=image)
def image_preprocess(image):
image = ImageNetModel.image_preprocess(image)
......
......@@ -31,16 +31,16 @@ def get_data():
if __name__ == '__main__':
logger.auto_set_dir()
logger.auto_set_dir('d')
def model_func(inputs):
def model_func(image):
"""
Keras model has to be created inside this function to be used with tensorpack.
"""
M = keras.models.Sequential()
# input_tensor have to be used here for tensorpack trainer to function properly.
# Just use inputs[1], inputs[2] if you have multiple inputs.
M.add(KL.InputLayer(input_tensor=inputs[0]))
M.add(KL.InputLayer(input_tensor=image))
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
......
......@@ -47,13 +47,15 @@ class KerasModelCaller(object):
def __call__(self, input_tensors):
"""
Args:
input_tensors ([tf.Tensor])
Returns:
output tensors of this tower, evaluated with the input tensors.
"""
reuse = tf.get_variable_scope().reuse
if self.cached_model is None:
assert not reuse
self.cached_model = self.get_model(input_tensors)
self.cached_model = self.get_model(*input_tensors)
return self.cached_model.outputs
if reuse:
......@@ -63,7 +65,7 @@ class KerasModelCaller(object):
return self.cached_model.call(input_tensors)
else:
# create new Keras model if not reuse
M = self.get_model(input_tensors)
M = self.get_model(*input_tensors)
return M.outputs
......@@ -99,7 +101,8 @@ def setup_keras_trainer(
"""
Args:
trainer (SingleCostTrainer):
get_model ( -> keras.model.Model):
get_model (input1, input2, ... -> keras.model.Model):
Takes tensors and returns a Keras model. Will be part of the tower function.
input (InputSource):
optimizer (tf.tarin.Optimizer):
loss, metrics: list of strings
......@@ -175,7 +178,8 @@ class KerasModel(object):
input, trainer=None):
"""
Args:
get_model ( -> keras.model.Model):
get_model (input1, input2, ... -> keras.model.Model):
Takes tensors and returns a Keras model. Will be part of the tower function.
inputs_desc ([InputDesc]):
targets_desc ([InputDesc]):
input (InputSource | DataFlow):
......
# -*- coding: UTF-8 -*-
# -*- coding: utf-8 -*-
# File: common.py
from __future__ import division
import six
import numpy as np
......
# -*- coding: UTF-8 -*-
# -*- coding: utf-8 -*-
# File: config.py
......@@ -35,7 +35,7 @@ class PredictConfig(object):
Args:
model (ModelDescBase): to be used to obtain inputs_desc and tower_func.
inputs_desc ([InputDesc]):
tower_func: a callable which takes input tensors and construct a tower.
tower_func: a callable which takes input tensors (by positional args) and construct a tower.
input_names (list): a list of input tensor names. Defaults to match inputs_desc.
output_names (list): a list of names of the output tensors to predict, the
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment