Commit c98f2351 authored by Yuxin Wu's avatar Yuxin Wu

organize imports

parent 06e19907
...@@ -5,15 +5,22 @@ ...@@ -5,15 +5,22 @@
from pkgutil import walk_packages from pkgutil import walk_packages
import os import os
__all__ = []
def _global_import(name): def _global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name] del globals()[name]
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
__all__.append(k)
_CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in walk_packages( for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]): [_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if not module_name.startswith('_'): if not module_name.startswith('_'):
_global_import(module_name) _global_import(module_name)
...@@ -10,11 +10,15 @@ import os ...@@ -10,11 +10,15 @@ import os
import os.path import os.path
from ..utils import logger from ..utils import logger
__all__ = ['LinearWrap']
def _global_import(name): def _global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
__all__.append(k)
for _, module_name, _ in walk_packages( for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]): [os.path.dirname(__file__)]):
...@@ -85,4 +89,3 @@ class LinearWrap(object): ...@@ -85,4 +89,3 @@ class LinearWrap(object):
print(self._t) print(self._t)
return self return self
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from . import *
import unittest import unittest
class TestModel(unittest.TestCase): class TestModel(unittest.TestCase):
...@@ -30,6 +29,7 @@ def run_test_case(case): ...@@ -30,6 +29,7 @@ def run_test_case(case):
if __name__ == '__main__': if __name__ == '__main__':
import tensorpack import tensorpack
from tensorpack.utils import logger from tensorpack.utils import logger
from . import *
logger.disable_logger() logger.disable_logger()
subs = tensorpack.models._test.TestModel.__subclasses__() subs = tensorpack.models._test.TestModel.__subclasses__()
for cls in subs: for cls in subs:
......
...@@ -6,15 +6,23 @@ from pkgutil import walk_packages ...@@ -6,15 +6,23 @@ from pkgutil import walk_packages
import os import os
import os.path import os.path
__all__ = []
def global_import(name): def global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name] del globals()[name]
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
__all__.append(k)
_CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in walk_packages( for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]): [_CURR_DIR]):
if not module_name.startswith('_'): srcpath = os.path.join(_CURR_DIR, module_name + '.py')
global_import(module_name) if not os.path.isfile(srcpath):
continue
if module_name.startswith('_'):
continue
global_import(module_name)
...@@ -5,17 +5,33 @@ ...@@ -5,17 +5,33 @@
from pkgutil import walk_packages from pkgutil import walk_packages
import os import os
__all__ = []
def _global_import(name): def _global_import(name):
p = __import__(name, globals(), None, level=1) p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
if name in ['common', 'argscope']:
del globals()[name]
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
__all__.append(k)
_TO_IMPORT = set([
'sessinit',
'common',
'gradproc',
'argscope',
'tower'
])
_global_import('sessinit') _CURR_DIR = os.path.dirname(__file__)
_global_import('common') for _, module_name, _ in walk_packages(
_global_import('gradproc') [_CURR_DIR]):
_global_import('argscope') srcpath = os.path.join(_CURR_DIR, module_name + '.py')
_global_import('tower') if not os.path.isfile(srcpath):
continue
if module_name.startswith('_'):
continue
if module_name in _TO_IMPORT:
_global_import(module_name)
if module_name != 'common':
__all__.append(module_name)
...@@ -6,15 +6,23 @@ from pkgutil import walk_packages ...@@ -6,15 +6,23 @@ from pkgutil import walk_packages
import os import os
import os.path import os.path
__all__ = []
def global_import(name): def global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else [] lst = p.__all__ if '__all__' in dir(p) else []
del globals()[name]
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
del globals()[name] __all__.append(k)
_CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in walk_packages( for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]): [_CURR_DIR]):
if not module_name.startswith('_'): srcpath = os.path.join(_CURR_DIR, module_name + '.py')
global_import(module_name) if not os.path.isfile(srcpath):
continue
if module_name.startswith('_'):
continue
global_import(module_name)
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: queue.py # File: feedfree.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
...@@ -9,12 +9,56 @@ from ..utils import logger ...@@ -9,12 +9,56 @@ from ..utils import logger
from ..tfutils import get_global_step_var from ..tfutils import get_global_step_var
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from ..tfutils.summary import summary_moving_average from ..tfutils.summary import summary_moving_average, add_moving_summary
from .input_data import QueueInput, FeedfreeInput from .input_data import QueueInput, FeedfreeInput
from .trainer import (MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer) from .base import Trainer
from .trainer import MultiPredictorTowerTrainer
__all__ = ['SimpleFeedfreeTrainer', 'QueueInputTrainer'] __all__ = ['FeedfreeTrainer', 'SingleCostFeedfreeTrainer', 'SimpleFeedfreeTrainer', 'QueueInputTrainer']
class FeedfreeTrainer(Trainer):
""" A trainer which runs iteration without feed_dict (therefore faster) """
def _trigger_epoch(self):
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self._process_summary(summary_str)
def _get_input_tensors(self):
return self._input_method.get_input_tensors()
def _setup(self):
assert isinstance(self._input_method, FeedfreeInput), type(self._input_method)
self._input_method._setup(self)
class SingleCostFeedfreeTrainer(FeedfreeTrainer):
def _get_cost_and_grad(self):
""" get the cost and gradient on a new tower"""
actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs)
cost_var = self.model.get_cost()
# GATE_NONE faster?
grads = self.config.optimizer.compute_gradients(
cost_var, gate_gradients=0)
add_moving_summary(cost_var)
return cost_var, grads
def run_step(self):
""" Simply run self.train_op"""
self.sess.run(self.train_op)
# debug-benchmark code:
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
class SimpleFeedfreeTrainer( class SimpleFeedfreeTrainer(
MultiPredictorTowerTrainer, MultiPredictorTowerTrainer,
...@@ -50,8 +94,7 @@ class QueueInputTrainer(SimpleFeedfreeTrainer): ...@@ -50,8 +94,7 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
Single tower Trainer, takes input from a queue Single tower Trainer, takes input from a queue
:param config: a `TrainConfig` instance. config.dataset must exist :param config: a `TrainConfig` instance. config.dataset must exist
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints. :param input_queue: a `tf.QueueBase` instance
Defaults to a FIFO queue of size 100.
:param predict_tower: list of gpu relative idx to run prediction. default to be [0]. :param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu. Use -1 for cpu.
""" """
...@@ -59,6 +102,3 @@ class QueueInputTrainer(SimpleFeedfreeTrainer): ...@@ -59,6 +102,3 @@ class QueueInputTrainer(SimpleFeedfreeTrainer):
assert len(config.tower) == 1, \ assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead." "QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super(QueueInputTrainer, self).__init__(config, predict_tower) super(QueueInputTrainer, self).__init__(config, predict_tower)
def _setup(self):
super(QueueInputTrainer, self)._setup()
...@@ -8,7 +8,7 @@ import threading ...@@ -8,7 +8,7 @@ import threading
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import six import six
from ..dataflow.common import RepeatedData from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..utils import logger from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
...@@ -21,6 +21,7 @@ class InputData(object): ...@@ -21,6 +21,7 @@ class InputData(object):
class FeedInput(InputData): class FeedInput(InputData):
def __init__(self, ds): def __init__(self, ds):
assert isinstance(ds, DataFlow), ds
self.ds = ds self.ds = ds
def size(self): def size(self):
...@@ -91,6 +92,12 @@ class EnqueueThread(threading.Thread): ...@@ -91,6 +92,12 @@ class EnqueueThread(threading.Thread):
class QueueInput(FeedfreeInput): class QueueInput(FeedfreeInput):
def __init__(self, ds, queue=None): def __init__(self, ds, queue=None):
"""
:param ds: a `DataFlow` instance
:param queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 50.
"""
assert isinstance(ds, DataFlow), ds
self.queue = queue self.queue = queue
self.ds = ds self.ds = ds
......
...@@ -15,12 +15,14 @@ from ..tfutils import (backup_collection, restore_collection, ...@@ -15,12 +15,14 @@ from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .trainer import FeedfreeTrainer, SingleCostFeedfreeTrainer, MultiPredictorTowerTrainer from .base import Trainer
from .trainer import MultiPredictorTowerTrainer
from .feedfree import SingleCostFeedfreeTrainer
from .input_data import QueueInput from .input_data import QueueInput
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer'] __all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
class MultiGPUTrainer(FeedfreeTrainer): class MultiGPUTrainer(Trainer):
""" Base class for multi-gpu training""" """ Base class for multi-gpu training"""
@staticmethod @staticmethod
def _multi_tower_grads(towers, get_tower_grad_func): def _multi_tower_grads(towers, get_tower_grad_func):
......
...@@ -16,8 +16,7 @@ from ..predict import OnlinePredictor, build_multi_tower_prediction_graph ...@@ -16,8 +16,7 @@ from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from .input_data import FeedInput, FeedfreeInput from .input_data import FeedInput, FeedfreeInput
__all__ = ['SimpleTrainer', 'FeedfreeTrainer', 'MultiPredictorTowerTrainer', __all__ = ['SimpleTrainer','MultiPredictorTowerTrainer']
'SingleCostFeedfreeTrainer']
class PredictorFactory(object): class PredictorFactory(object):
""" Make predictors for a trainer""" """ Make predictors for a trainer"""
...@@ -110,46 +109,3 @@ class MultiPredictorTowerTrainer(Trainer): ...@@ -110,46 +109,3 @@ class MultiPredictorTowerTrainer(Trainer):
def get_predict_funcs(self, input_names, output_names, n): def get_predict_funcs(self, input_names, output_names, n):
return [self.get_predict_func(input_names, output_names, k) for k in range(n)] return [self.get_predict_func(input_names, output_names, k) for k in range(n)]
class FeedfreeTrainer(Trainer):
""" A trainer which runs iteration without feed_dict (therefore faster) """
def _trigger_epoch(self):
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self._process_summary(summary_str)
def _get_input_tensors(self):
return self._input_method.get_input_tensors()
def _setup(self):
assert isinstance(self._input_method, FeedfreeInput), type(self._input_method)
self._input_method._setup(self)
class SingleCostFeedfreeTrainer(FeedfreeTrainer):
def _get_cost_and_grad(self):
""" get the cost and gradient on a new tower"""
actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs)
cost_var = self.model.get_cost()
# GATE_NONE faster?
grads = self.config.optimizer.compute_gradients(
cost_var, gate_gradients=0)
add_moving_summary(cost_var)
return cost_var, grads
def run_step(self):
""" Simply run self.train_op"""
self.sess.run(self.train_op)
# debug-benchmark code:
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
...@@ -10,12 +10,31 @@ Common utils. ...@@ -10,12 +10,31 @@ Common utils.
These utils should be irrelevant to tensorflow. These utils should be irrelevant to tensorflow.
""" """
__all__ = []
def _global_import(name): def _global_import(name):
p = __import__(name, globals(), None, level=1) p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
_global_import('naming') __all__.append(k)
_global_import('utils')
_global_import('gpu') _TO_IMPORT = set([
'naming',
'utils',
'gpu'
])
_CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in walk_packages(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if module_name.startswith('_'):
continue
if module_name in _TO_IMPORT:
_global_import(module_name)
__all__.append(module_name)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment