Commit 3e876599 authored by Yuxin Wu's avatar Yuxin Wu

bring some sense to import

parent bd0ca738
......@@ -21,7 +21,7 @@ from tensorpack.dataflow import imgaug
"""
ResNet-110 for SVHN Digit Classification.
Reach 1.9% validation error after 90 epochs, with 2 TitanX xxhr, 2it/s.
Reach 1.8% validation error after 70 epochs, with 2 TitanX. 2it/s.
You might need to adjust the learning rate schedule when running with 1 GPU.
"""
......
......@@ -8,14 +8,10 @@ import numpy as np
import os, sys
import argparse
from tensorpack.train import *
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.tfutils import *
import tensorpack as tp
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.callbacks import *
from tensorpack.dataflow import *
"""
MNIST ConvNet example.
......@@ -60,7 +56,7 @@ class Model(ModelDesc):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ClassificationError to use at test time
wrong = prediction_incorrect(logits, label)
wrong = tp.symbolic_functions.prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
......@@ -72,7 +68,7 @@ class Model(ModelDesc):
name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W
tp.summary.add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W
return tf.add_n([wd_cost, cost], name='cost')
def get_config():
......@@ -81,22 +77,22 @@ def get_config():
os.path.join('train_log', basename[:basename.rfind('.')]))
# prepare dataset
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
dataset_train = tp.BatchData(tp.dataset.Mnist('train'), 128)
dataset_test = tp.BatchData(tp.dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
# prepare session
sess_config = get_default_sess_config()
sess_config = tp.get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
lr = tf.train.exponential_decay(
learning_rate=1e-3,
global_step=get_global_step_var(),
global_step=tp.get_global_step_var(),
decay_steps=dataset_train.size() * 10,
decay_rate=0.3, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
return TrainConfig(
return tp.TrainConfig(
dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([
......@@ -125,5 +121,5 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train()
tp.SimpleTrainer(config).train()
# -*- coding: utf-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import models
import train
import utils
import tfutils
import callbacks
import dataflow
from .train import *
from .models import *
from .utils import *
from .tfutils import *
from .callbacks import *
from .dataflow import *
......@@ -8,6 +8,7 @@ import os
def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
......
......@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import walk_packages
import importlib
import os
import os.path
......@@ -12,10 +13,11 @@ from . import imgaug
def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
__SKIP = ['dftools', 'dataset']
__SKIP = ['dftools', 'dataset', 'imgaug']
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_') and \
......
......@@ -10,6 +10,7 @@ __all__ = []
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
......
......@@ -63,8 +63,8 @@ class CenterPaste(ImageAugmentor):
background = self.background_filler.fill(
self.background_shape, img.arr)
h0 = (self.background_shape[0] - img_shape[0]) * 0.5
w0 = (self.background_shape[1] - img_shape[1]) * 0.5
h0 = int((self.background_shape[0] - img_shape[0]) * 0.5)
w0 = int((self.background_shape[1] - img_shape[1]) * 0.5)
background[h0:h0+img_shape[0], w0:w0+img_shape[1]] = img.arr
img.arr = background
if img.coords:
......
......@@ -8,6 +8,7 @@ import os
def _global_import(name):
p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
......
......@@ -8,6 +8,9 @@ import tensorflow as tf
from ..utils import *
from . import get_global_step_var
__all__ = ['create_summary', 'add_param_summary', 'add_activation_summary',
'summary_moving_average']
def create_summary(name, v):
"""
Return a tf.Summary object with name and simple scalar value v
......
......@@ -11,6 +11,7 @@ def global_import(name):
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
del globals()[name]
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
......
......@@ -13,6 +13,7 @@ These utils should be irrelevant to tensorflow.
def _global_import(name):
p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
_global_import('naming')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment