Commit d1cfdd4d authored by Yuxin Wu's avatar Yuxin Wu

move to tfutils/

parent f8b54d8e
...@@ -12,8 +12,9 @@ from tensorpack.train import TrainConfig, QueueInputTrainer ...@@ -12,8 +12,9 @@ from tensorpack.train import TrainConfig, QueueInputTrainer
from tensorpack.models import * from tensorpack.models import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import * from tensorpack.tfutils import *
from tensorpack.utils.summary import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug from tensorpack.dataflow import imgaug
......
...@@ -12,8 +12,9 @@ from tensorpack.train import TrainConfig, QueueInputTrainer ...@@ -12,8 +12,9 @@ from tensorpack.train import TrainConfig, QueueInputTrainer
from tensorpack.models import * from tensorpack.models import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import * from tensorpack.tfutils import *
from tensorpack.utils.summary import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug from tensorpack.dataflow import imgaug
......
#!/usr/bin/env python2 #!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: example_alexnet.py # File: load_alexnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
...@@ -13,8 +13,9 @@ from tensorpack.train import TrainConfig, start_train ...@@ -13,8 +13,9 @@ from tensorpack.train import TrainConfig, start_train
from tensorpack.predict import PredictConfig, get_predict_func from tensorpack.predict import PredictConfig, get_predict_func
from tensorpack.models import * from tensorpack.models import *
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import * from tensorpack.tfutils import *
from tensorpack.utils.summary import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
......
...@@ -13,9 +13,9 @@ import argparse ...@@ -13,9 +13,9 @@ import argparse
from tensorpack.train import * from tensorpack.train import *
from tensorpack.models import * from tensorpack.models import *
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.utils.gradproc import * from tensorpack.tfutils.summary import *
from tensorpack.utils.summary import * from tensorpack.tfutils import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
......
...@@ -12,8 +12,9 @@ from tensorpack.train import TrainConfig, QueueInputTrainer ...@@ -12,8 +12,9 @@ from tensorpack.train import TrainConfig, QueueInputTrainer
from tensorpack.models import * from tensorpack.models import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import * from tensorpack.tfutils import *
from tensorpack.utils.summary import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug from tensorpack.dataflow import imgaug
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import tensorflow as tf import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
import time
from .base import Callback, TrainCallback, TestCallback from .base import Callback, TrainCallback, TestCallback
from .summary import * from .summary import *
......
...@@ -9,7 +9,7 @@ from six.moves import zip ...@@ -9,7 +9,7 @@ from six.moves import zip
from ..utils import * from ..utils import *
from ..utils.stat import * from ..utils.stat import *
from ..utils.summary import * from ..tfutils.summary import *
from .base import PeriodicCallback, Callback, TestCallback from .base import PeriodicCallback, Callback, TestCallback
__all__ = ['ValidationError', 'ValidationCallback', 'ValidationStatPrinter'] __all__ = ['ValidationError', 'ValidationCallback', 'ValidationStatPrinter']
......
...@@ -6,8 +6,8 @@ import tensorflow as tf ...@@ -6,8 +6,8 @@ import tensorflow as tf
from functools import wraps from functools import wraps
import six import six
from ..utils.modelutils import * from ..tfutils.modelutils import *
from ..utils.summary import * from ..tfutils.summary import *
from ..utils import logger from ..utils import logger
# make sure each layer is only logged once # make sure each layer is only logged once
......
...@@ -7,7 +7,7 @@ import tensorflow as tf ...@@ -7,7 +7,7 @@ import tensorflow as tf
import math import math
from ._common import layer_register from ._common import layer_register
from ..utils.symbolic_functions import * from ..tfutils.symbolic_functions import *
__all__ = ['FullyConnected'] __all__ = ['FullyConnected']
......
...@@ -7,7 +7,7 @@ from abc import ABCMeta, abstractmethod ...@@ -7,7 +7,7 @@ from abc import ABCMeta, abstractmethod
import tensorflow as tf import tensorflow as tf
from collections import namedtuple from collections import namedtuple
from ..utils.gradproc import * from ..tfutils import *
__all__ = ['ModelDesc', 'InputVar'] __all__ = ['ModelDesc', 'InputVar']
......
...@@ -6,7 +6,7 @@ import tensorflow as tf ...@@ -6,7 +6,7 @@ import tensorflow as tf
import numpy import numpy
from ._common import * from ._common import *
from ..utils.symbolic_functions import * from ..tfutils.symbolic_functions import *
__all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling'] __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling']
......
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import walk_packages
import os
def global_import(name):
p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
global_import('sessinit')
global_import('common')
global_import('gradproc')
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from ..utils.naming import *
import tensorflow as tf
def get_default_sess_config(mem_fraction=0.5):
"""
Return a better config to use as default.
Tensorflow default session config consume too much resources
"""
conf = tf.ConfigProto()
conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction
conf.gpu_options.allocator_type = 'BFC'
conf.allow_soft_placement = True
return conf
def get_global_step_var():
""" get global_step variable in the current graph"""
try:
return tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
except KeyError:
var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
return var
def get_global_step():
""" get global_step value with current graph and session"""
return tf.train.global_step(
tf.get_default_session(),
get_global_step_var())
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import tensorflow as tf import tensorflow as tf
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import re import re
from . import logger from ..utils import logger
__all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient', __all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
'ScaleGradient'] 'ScaleGradient']
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import tensorflow as tf import tensorflow as tf
from . import logger from ..utils import logger
def describe_model(): def describe_model():
""" describe the current model parameters""" """ describe the current model parameters"""
......
...@@ -8,7 +8,11 @@ import numpy as np ...@@ -8,7 +8,11 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
import six import six
from . import logger from ..utils import logger
__all__ = ['SessionInit', 'NewSession', 'SaverRestore', 'ParamRestore',
'dump_session_params']
class SessionInit(object): class SessionInit(object):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
import six import six
import tensorflow as tf import tensorflow as tf
from . import logger, get_global_step_var from ..utils import *
from .naming import * from . import get_global_step_var
def create_summary(name, v): def create_summary(name, v):
""" """
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
from abc import ABCMeta from abc import ABCMeta, abstractmethod
from six.moves import range from six.moves import range
import tqdm import tqdm
import re import re
...@@ -11,7 +11,8 @@ import re ...@@ -11,7 +11,8 @@ import re
from .config import TrainConfig from .config import TrainConfig
from ..utils import * from ..utils import *
from ..callbacks import StatHolder from ..callbacks import StatHolder
from ..utils.modelutils import describe_model from ..tfutils import *
from ..tfutils.modelutils import describe_model
__all__ = ['Trainer'] __all__ = ['Trainer']
......
...@@ -7,6 +7,7 @@ import tensorflow as tf ...@@ -7,6 +7,7 @@ import tensorflow as tf
from ..callbacks import Callbacks from ..callbacks import Callbacks
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import * from ..utils import *
from ..tfutils import *
from ..dataflow import DataFlow from ..dataflow import DataFlow
__all__ = ['TrainConfig'] __all__ = ['TrainConfig']
......
...@@ -11,7 +11,8 @@ from six.moves import zip ...@@ -11,7 +11,8 @@ from six.moves import zip
from .base import Trainer from .base import Trainer
from ..dataflow.common import RepeatedData from ..dataflow.common import RepeatedData
from ..utils import * from ..utils import *
from ..utils.summary import summary_moving_average from ..tfutils.summary import summary_moving_average
from ..tfutils import *
__all__ = ['SimpleTrainer', 'QueueInputTrainer', 'start_train'] __all__ = ['SimpleTrainer', 'QueueInputTrainer', 'start_train']
......
...@@ -13,32 +13,32 @@ def global_import(name): ...@@ -13,32 +13,32 @@ def global_import(name):
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
global_import('naming') global_import('naming')
global_import('sessinit') #global_import('sessinit')
global_import('utils') global_import('utils')
# TODO move this utils to another file # TODO move this utils to another file
def get_default_sess_config(mem_fraction=0.5): #def get_default_sess_config(mem_fraction=0.5):
""" #"""
Return a better config to use as default. #Return a better config to use as default.
Tensorflow default session config consume too much resources #Tensorflow default session config consume too much resources
""" #"""
conf = tf.ConfigProto() #conf = tf.ConfigProto()
conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction #conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction
conf.gpu_options.allocator_type = 'BFC' #conf.gpu_options.allocator_type = 'BFC'
conf.allow_soft_placement = True #conf.allow_soft_placement = True
return conf #return conf
def get_global_step_var(): #def get_global_step_var():
""" get global_step variable in the current graph""" #""" get global_step variable in the current graph"""
try: #try:
return tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME) #return tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
except KeyError: #except KeyError:
var = tf.Variable( #var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME) #0, trainable=False, name=GLOBAL_STEP_OP_NAME)
return var #return var
def get_global_step(): #def get_global_step():
""" get global_step value with current graph and session""" #""" get global_step value with current graph and session"""
return tf.train.global_step( #return tf.train.global_step(
tf.get_default_session(), #tf.get_default_session(),
get_global_step_var()) #get_global_step_var())
...@@ -6,10 +6,12 @@ import os, sys ...@@ -6,10 +6,12 @@ import os, sys
from contextlib import contextmanager from contextlib import contextmanager
import time import time
import collections import collections
import numpy as np
from . import logger from . import logger
__all__ = ['timed_operation', 'change_env', 'get_rng', 'memoized'] __all__ = ['timed_operation', 'change_env', 'get_rng', 'memoized']
#def expand_dim_if_necessary(var, dp): #def expand_dim_if_necessary(var, dp):
# """ # """
# Args: # Args:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment