Commit 59d099cd authored by Yuxin Wu's avatar Yuxin Wu

use trainerv2 for mnist-keras

parent bc4c6044
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: mnist-keras-functional.py # File: mnist-keras.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np import numpy as np
...@@ -20,6 +20,7 @@ This is an mnist example demonstrating how to use Keras symbolic function inside ...@@ -20,6 +20,7 @@ This is an mnist example demonstrating how to use Keras symbolic function inside
This way you can define models in Keras-style, and benefit from the more efficeint trainers in tensorpack. This way you can define models in Keras-style, and benefit from the more efficeint trainers in tensorpack.
""" """
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import memoized
...@@ -84,11 +85,11 @@ def get_data(): ...@@ -84,11 +85,11 @@ def get_data():
return train, test return train, test
def get_config(): if __name__ == '__main__':
logger.auto_set_dir() logger.auto_set_dir()
dataset_train, dataset_test = get_data() dataset_train, dataset_test = get_data()
return TrainConfig( cfg = TrainConfig(
model=Model(), model=Model(),
dataflow=dataset_train, dataflow=dataset_train,
callbacks=[ callbacks=[
...@@ -101,10 +102,4 @@ def get_config(): ...@@ -101,10 +102,4 @@ def get_config():
max_epoch=100, max_epoch=100,
) )
launch_train_with_config(cfg, QueueInputTrainer())
if __name__ == '__main__':
config = get_config()
QueueInputTrainer(config).train()
# for multigpu training:
# config.nr_tower = 2
# SyncMultiGPUTrainer(config).train()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment