Commit c98f2351 authored by Yuxin Wu's avatar Yuxin Wu

organize imports

parent 06e19907
......@@ -5,15 +5,22 @@
from pkgutil import walk_packages
import os
__all__ = []
def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
__all__.append(k)
_CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if not module_name.startswith('_'):
_global_import(module_name)
......@@ -10,11 +10,15 @@ import os
import os.path
from ..utils import logger
__all__ = ['LinearWrap']
def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
__all__.append(k)
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
......@@ -85,4 +89,3 @@ class LinearWrap(object):
print(self._t)
return self
......@@ -5,7 +5,6 @@
import tensorflow as tf
import numpy as np
from . import *
import unittest
class TestModel(unittest.TestCase):
......@@ -30,6 +29,7 @@ def run_test_case(case):
if __name__ == '__main__':
import tensorpack
from tensorpack.utils import logger
from . import *
logger.disable_logger()
subs = tensorpack.models._test.TestModel.__subclasses__()
for cls in subs:
......
......@@ -6,15 +6,23 @@ from pkgutil import walk_packages
import os
import os.path
__all__ = []
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
__all__.append(k)
_CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if module_name.startswith('_'):
continue
global_import(module_name)
......@@ -5,17 +5,33 @@
from pkgutil import walk_packages
import os
__all__ = []
def _global_import(name):
p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
if name in ['common', 'argscope']:
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
__all__.append(k)
_TO_IMPORT = set([
'sessinit',
'common',
'gradproc',
'argscope',
'tower'
])
_global_import('sessinit')
_global_import('common')
_global_import('gradproc')
_global_import('argscope')
_global_import('tower')
_CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in walk_packages(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if module_name.startswith('_'):
continue
if module_name in _TO_IMPORT:
_global_import(module_name)
if module_name != 'common':
__all__.append(module_name)
......@@ -6,15 +6,23 @@ from pkgutil import walk_packages
import os
import os.path
__all__ = []
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else []
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
del globals()[name]
__all__.append(k)
_CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if module_name.startswith('_'):
continue
global_import(module_name)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: queue.py
# File: feedfree.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
......@@ -9,12 +9,56 @@ from ..utils import logger
from ..tfutils import get_global_step_var
from ..tfutils.tower import TowerContext
from ..tfutils.gradproc import apply_grad_processors
from ..tfutils.summary import summary_moving_average
from ..tfutils.summary import summary_moving_average, add_moving_summary
from .input_data import QueueInput, FeedfreeInput
from .trainer import (MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer)
from .base import Trainer
from .trainer import MultiPredictorTowerTrainer
__all__ = ['SimpleFeedfreeTrainer', 'QueueInputTrainer']
__all__ = ['FeedfreeTrainer', 'SingleCostFeedfreeTrainer', 'SimpleFeedfreeTrainer', 'QueueInputTrainer']
class FeedfreeTrainer(Trainer):
""" A trainer which runs iteration without feed_dict (therefore faster) """
def _trigger_epoch(self):
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self._process_summary(summary_str)
def _get_input_tensors(self):
return self._input_method.get_input_tensors()
def _setup(self):
assert isinstance(self._input_method, FeedfreeInput), type(self._input_method)
self._input_method._setup(self)
class SingleCostFeedfreeTrainer(FeedfreeTrainer):
def _get_cost_and_grad(self):
""" get the cost and gradient on a new tower"""
actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs)
cost_var = self.model.get_cost()
# GATE_NONE faster?
grads = self.config.optimizer.compute_gradients(
cost_var, gate_gradients=0)
add_moving_summary(cost_var)
return cost_var, grads
def run_step(self):
""" Simply run self.train_op"""
self.sess.run(self.train_op)
# debug-benchmark code:
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
class SimpleFeedfreeTrainer(
MultiPredictorTowerTrainer,
......@@ -50,8 +94,7 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
Single tower Trainer, takes input from a queue
:param config: a `TrainConfig` instance. config.dataset must exist
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 100.
:param input_queue: a `tf.QueueBase` instance
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu.
"""
......@@ -59,6 +102,3 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super(QueueInputTrainer, self).__init__(config, predict_tower)
def _setup(self):
super(QueueInputTrainer, self)._setup()
......@@ -8,7 +8,7 @@ import threading
from abc import ABCMeta, abstractmethod
import six
from ..dataflow.common import RepeatedData
from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary
from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread
......@@ -21,6 +21,7 @@ class InputData(object):
class FeedInput(InputData):
def __init__(self, ds):
assert isinstance(ds, DataFlow), ds
self.ds = ds
def size(self):
......@@ -91,6 +92,12 @@ class EnqueueThread(threading.Thread):
class QueueInput(FeedfreeInput):
def __init__(self, ds, queue=None):
"""
:param ds: a `DataFlow` instance
:param queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 50.
"""
assert isinstance(ds, DataFlow), ds
self.queue = queue
self.ds = ds
......
......@@ -15,12 +15,14 @@ from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .trainer import FeedfreeTrainer, SingleCostFeedfreeTrainer, MultiPredictorTowerTrainer
from .base import Trainer
from .trainer import MultiPredictorTowerTrainer
from .feedfree import SingleCostFeedfreeTrainer
from .input_data import QueueInput
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
class MultiGPUTrainer(FeedfreeTrainer):
class MultiGPUTrainer(Trainer):
""" Base class for multi-gpu training"""
@staticmethod
def _multi_tower_grads(towers, get_tower_grad_func):
......
......@@ -16,8 +16,7 @@ from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..tfutils.gradproc import apply_grad_processors
from .input_data import FeedInput, FeedfreeInput
__all__ = ['SimpleTrainer', 'FeedfreeTrainer', 'MultiPredictorTowerTrainer',
'SingleCostFeedfreeTrainer']
__all__ = ['SimpleTrainer','MultiPredictorTowerTrainer']
class PredictorFactory(object):
""" Make predictors for a trainer"""
......@@ -110,46 +109,3 @@ class MultiPredictorTowerTrainer(Trainer):
def get_predict_funcs(self, input_names, output_names, n):
return [self.get_predict_func(input_names, output_names, k) for k in range(n)]
class FeedfreeTrainer(Trainer):
""" A trainer which runs iteration without feed_dict (therefore faster) """
def _trigger_epoch(self):
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self._process_summary(summary_str)
def _get_input_tensors(self):
return self._input_method.get_input_tensors()
def _setup(self):
assert isinstance(self._input_method, FeedfreeInput), type(self._input_method)
self._input_method._setup(self)
class SingleCostFeedfreeTrainer(FeedfreeTrainer):
def _get_cost_and_grad(self):
""" get the cost and gradient on a new tower"""
actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs)
cost_var = self.model.get_cost()
# GATE_NONE faster?
grads = self.config.optimizer.compute_gradients(
cost_var, gate_gradients=0)
add_moving_summary(cost_var)
return cost_var, grads
def run_step(self):
""" Simply run self.train_op"""
self.sess.run(self.train_op)
# debug-benchmark code:
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
......@@ -10,12 +10,31 @@ Common utils.
These utils should be irrelevant to tensorflow.
"""
__all__ = []
def _global_import(name):
p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
_global_import('naming')
_global_import('utils')
_global_import('gpu')
__all__.append(k)
_TO_IMPORT = set([
'naming',
'utils',
'gpu'
])
_CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in walk_packages(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if module_name.startswith('_'):
continue
if module_name in _TO_IMPORT:
_global_import(module_name)
__all__.append(module_name)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment