Commit 59d099cd authored by Yuxin Wu's avatar Yuxin Wu

use trainerv2 for mnist-keras

parent bc4c6044
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: mnist-keras-functional.py
# File: mnist-keras.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
......@@ -20,6 +20,7 @@ This is an mnist example demonstrating how to use Keras symbolic function inside
This way you can define models in Keras-style, and benefit from the more efficeint trainers in tensorpack.
"""
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.utils.argtools import memoized
......@@ -84,11 +85,11 @@ def get_data():
return train, test
def get_config():
if __name__ == '__main__':
logger.auto_set_dir()
dataset_train, dataset_test = get_data()
return TrainConfig(
cfg = TrainConfig(
model=Model(),
dataflow=dataset_train,
callbacks=[
......@@ -101,10 +102,4 @@ def get_config():
max_epoch=100,
)
if __name__ == '__main__':
config = get_config()
QueueInputTrainer(config).train()
# for multigpu training:
# config.nr_tower = 2
# SyncMultiGPUTrainer(config).train()
launch_train_with_config(cfg, QueueInputTrainer())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment