Commit ac9ac2a4 authored by Yuxin Wu's avatar Yuxin Wu

isort -y -sp tox.ini

parent 9c2be2ad
...@@ -32,7 +32,7 @@ It's Yet Another TF high-level API, with __speed__, and __flexibility__ built to ...@@ -32,7 +32,7 @@ It's Yet Another TF high-level API, with __speed__, and __flexibility__ built to
See [tutorials and documentations](http://tensorpack.readthedocs.io/tutorial/index.html#user-tutorials) to know more about these features. See [tutorials and documentations](http://tensorpack.readthedocs.io/tutorial/index.html#user-tutorials) to know more about these features.
## [Examples](examples): ## Examples:
We refuse toy examples. We refuse toy examples.
Instead of showing you 10 arbitrary networks trained on toy datasets, Instead of showing you 10 arbitrary networks trained on toy datasets,
......
...@@ -4,20 +4,18 @@ ...@@ -4,20 +4,18 @@
# Author: Yuxin Wu # Author: Yuxin Wu
import multiprocessing as mp import multiprocessing as mp
import time
import os import os
import threading import threading
from abc import abstractmethod, ABCMeta import time
from abc import ABCMeta, abstractmethod
from collections import defaultdict from collections import defaultdict
import six import six
from six.moves import queue
import zmq import zmq
from six.moves import queue
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.serialize import loads, dumps from tensorpack.utils.concurrency import LoopThread, enable_death_signal, ensure_proc_terminate
from tensorpack.utils.concurrency import ( from tensorpack.utils.serialize import dumps, loads
LoopThread, ensure_proc_terminate, enable_death_signal)
__all__ = ['SimulatorProcess', 'SimulatorMaster', __all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange', 'SimulatorProcessStateExchange',
......
...@@ -3,29 +3,26 @@ ...@@ -3,29 +3,26 @@
# File: train-atari.py # File: train-atari.py
# Author: Yuxin Wu # Author: Yuxin Wu
import argparse
import numpy as np import numpy as np
import sys
import os import os
import sys
import uuid import uuid
import argparse
import cv2 import cv2
import tensorflow as tf import gym
import six import six
import tensorflow as tf
from six.moves import queue from six.moves import queue
from tensorpack import * from tensorpack import *
from tensorpack.utils.concurrency import ensure_proc_terminate, start_proc_mask_signal
from tensorpack.utils.serialize import dumps
from tensorpack.tfutils.gradproc import MapGradient, SummaryGradient from tensorpack.tfutils.gradproc import MapGradient, SummaryGradient
from tensorpack.utils.concurrency import ensure_proc_terminate, start_proc_mask_signal
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
from tensorpack.utils.serialize import dumps
from atari_wrapper import FireResetEnv, FrameStack, LimitLength, MapState
import gym
from simulator import SimulatorProcess, SimulatorMaster, TransitionExperience
from common import Evaluator, eval_model_multithread, play_n_episodes from common import Evaluator, eval_model_multithread, play_n_episodes
from atari_wrapper import MapState, FrameStack, FireResetEnv, LimitLength from simulator import SimulatorMaster, SimulatorProcess, TransitionExperience
if six.PY3: if six.PY3:
from concurrent import futures from concurrent import futures
......
...@@ -2,17 +2,17 @@ ...@@ -2,17 +2,17 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: create-lmdb.py # File: create-lmdb.py
# Author: Yuxin Wu # Author: Yuxin Wu
import argparse
import numpy as np
import os import os
import scipy.io.wavfile as wavfile
import string import string
import numpy as np
import argparse
import bob.ap import bob.ap
import scipy.io.wavfile as wavfile
from tensorpack.dataflow import DataFlow, LMDBSerializer from tensorpack.dataflow import DataFlow, LMDBSerializer
from tensorpack.utils import fs, logger, serialize
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import memoized
from tensorpack.utils.stats import OnlineMoments from tensorpack.utils.stats import OnlineMoments
from tensorpack.utils import serialize, fs, logger
from tensorpack.utils.utils import get_tqdm from tensorpack.utils.utils import get_tqdm
CHARSET = set(string.ascii_lowercase + ' ') CHARSET = set(string.ascii_lowercase + ' ')
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
# File: timitdata.py # File: timitdata.py
# Author: Yuxin Wu # Author: Yuxin Wu
from tensorpack import ProxyDataFlow
import numpy as np import numpy as np
from six.moves import range from six.moves import range
from tensorpack import ProxyDataFlow
__all__ = ['TIMITBatch'] __all__ = ['TIMITBatch']
......
...@@ -3,17 +3,17 @@ ...@@ -3,17 +3,17 @@
# File: train-timit.py # File: train-timit.py
# Author: Yuxin Wu # Author: Yuxin Wu
import os
import argparse import argparse
import os
import tensorflow as tf
from six.moves import range from six.moves import range
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.gradproc import SummaryGradient, GlobalNormClip from tensorpack.tfutils.gradproc import GlobalNormClip, SummaryGradient
from tensorpack.utils import serialize from tensorpack.utils import serialize
import tensorflow as tf
from timitdata import TIMITBatch from timitdata import TIMITBatch
rnn = tf.contrib.rnn rnn = tf.contrib.rnn
......
...@@ -4,16 +4,16 @@ ...@@ -4,16 +4,16 @@
# Author: Yuxin Wu # Author: Yuxin Wu
from __future__ import print_function from __future__ import print_function
import argparse
import numpy as np import numpy as np
import os import os
import cv2 import cv2
import argparse import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow.dataset import ILSVRCMeta from tensorpack.dataflow.dataset import ILSVRCMeta
import tensorflow as tf from tensorpack.tfutils.summary import *
from tensorpack.tfutils.symbolic_functions import *
def tower_func(image): def tower_func(image):
......
...@@ -3,15 +3,16 @@ ...@@ -3,15 +3,16 @@
# File: load-cpm.py # File: load-cpm.py
# Author: Yuxin Wu # Author: Yuxin Wu
import argparse
import numpy as np
import cv2 import cv2
import tensorflow as tf import tensorflow as tf
import numpy as np
import argparse
from tensorpack import * from tensorpack import *
from tensorpack.utils import viz from tensorpack.utils import viz
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import memoized
""" """
15 channels: 15 channels:
0-1 head, neck 0-1 head, neck
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
# File: load-vgg16.py # File: load-vgg16.py
from __future__ import print_function from __future__ import print_function
import cv2 import argparse
import tensorflow as tf
import numpy as np import numpy as np
import os import os
import cv2
import six import six
import argparse import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.dataflow.dataset import ILSVRCMeta from tensorpack.dataflow.dataset import ILSVRCMeta
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
# File: load-vgg19.py # File: load-vgg19.py
from __future__ import print_function from __future__ import print_function
import cv2 import argparse
import tensorflow as tf
import numpy as np import numpy as np
import os import os
import cv2
import six import six
import argparse import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.dataflow.dataset import ILSVRCMeta from tensorpack.dataflow.dataset import ILSVRCMeta
......
...@@ -3,21 +3,20 @@ ...@@ -3,21 +3,20 @@
# File: char-rnn.py # File: char-rnn.py
# Author: Yuxin Wu # Author: Yuxin Wu
import argparse
import numpy as np import numpy as np
import operator
import os import os
import sys import sys
import argparse
from collections import Counter from collections import Counter
import operator
import six import six
import tensorflow as tf
from six.moves import range from six.moves import range
from tensorpack import * from tensorpack import *
from tensorpack.tfutils import summary, optimizer from tensorpack.tfutils import optimizer, summary
from tensorpack.tfutils.gradproc import GlobalNormClip from tensorpack.tfutils.gradproc import GlobalNormClip
import tensorflow as tf
rnn = tf.contrib.rnn rnn = tf.contrib.rnn
class _NS: pass # noqa class _NS: pass # noqa
......
...@@ -3,20 +3,20 @@ ...@@ -3,20 +3,20 @@
# File: DQN.py # File: DQN.py
# Author: Yuxin Wu # Author: Yuxin Wu
import os
import argparse import argparse
import cv2
import numpy as np import numpy as np
import tensorflow as tf import os
import cv2
import gym import gym
import tensorflow as tf
from tensorpack import * from tensorpack import *
from DQNModel import Model as DQNModel from atari import AtariPlayer
from atari_wrapper import FireResetEnv, FrameStack, LimitLength, MapState
from common import Evaluator, eval_model_multithread, play_n_episodes from common import Evaluator, eval_model_multithread, play_n_episodes
from atari_wrapper import FrameStack, MapState, FireResetEnv, LimitLength from DQNModel import Model as DQNModel
from expreplay import ExpReplay from expreplay import ExpReplay
from atari import AtariPlayer
BATCH_SIZE = 64 BATCH_SIZE = 64
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
......
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
import abc import abc
import tensorflow as tf import tensorflow as tf
from tensorpack import ModelDesc from tensorpack import ModelDesc
from tensorpack.utils import logger from tensorpack.tfutils import get_current_tower_context, gradproc, optimizer, summary, varreplace
from tensorpack.tfutils import (
varreplace, summary, get_current_tower_context, optimizer, gradproc)
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.utils import logger
class Model(ModelDesc): class Model(ModelDesc):
......
...@@ -4,19 +4,18 @@ ...@@ -4,19 +4,18 @@
import numpy as np import numpy as np
import os import os
import cv2
import threading import threading
import six import cv2
from six.moves import range
from tensorpack.utils import logger
from tensorpack.utils.utils import get_rng, execute_only_once
from tensorpack.utils.fs import get_dataset_path
import gym import gym
import six
from ale_python_interface import ALEInterface
from gym import spaces from gym import spaces
from gym.envs.atari.atari_env import ACTION_MEANING from gym.envs.atari.atari_env import ACTION_MEANING
from six.moves import range
from ale_python_interface import ALEInterface from tensorpack.utils import logger
from tensorpack.utils.fs import get_dataset_path
from tensorpack.utils.utils import execute_only_once, get_rng
__all__ = ['AtariPlayer'] __all__ = ['AtariPlayer']
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import numpy as np import numpy as np
from collections import deque from collections import deque
import gym import gym
from gym import spaces from gym import spaces
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: common.py # File: common.py
# Author: Yuxin Wu # Author: Yuxin Wu
import multiprocessing
import random import random
import time import time
import multiprocessing
from tqdm import tqdm
from six.moves import queue from six.moves import queue
from tqdm import tqdm
from tensorpack.utils.concurrency import StoppableThread, ShareSessionThread
from tensorpack.callbacks import Callback from tensorpack.callbacks import Callback
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.concurrency import ShareSessionThread, StoppableThread
from tensorpack.utils.stats import StatCounter from tensorpack.utils.stats import StatCounter
from tensorpack.utils.utils import get_tqdm_kwargs from tensorpack.utils.utils import get_tqdm_kwargs
......
...@@ -2,18 +2,18 @@ ...@@ -2,18 +2,18 @@
# File: expreplay.py # File: expreplay.py
# Author: Yuxin Wu # Author: Yuxin Wu
import numpy as np
import copy import copy
from collections import deque, namedtuple import numpy as np
import threading import threading
from collections import deque, namedtuple
from six.moves import queue, range from six.moves import queue, range
from tensorpack.callbacks.base import Callback
from tensorpack.dataflow import DataFlow from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.utils import get_tqdm, get_rng
from tensorpack.utils.stats import StatCounter
from tensorpack.utils.concurrency import LoopThread, ShareSessionThread from tensorpack.utils.concurrency import LoopThread, ShareSessionThread
from tensorpack.callbacks.base import Callback from tensorpack.utils.stats import StatCounter
from tensorpack.utils.utils import get_rng, get_tqdm
__all__ = ['ExpReplay'] __all__ = ['ExpReplay']
......
...@@ -2,17 +2,17 @@ ...@@ -2,17 +2,17 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: mnist-disturb.py # File: mnist-disturb.py
import os
import argparse import argparse
import imp
import os
import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.utils import logger
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
import tensorflow as tf from tensorpack.utils import logger
from disturb import DisturbLabel from disturb import DisturbLabel
import imp
mnist_example = imp.load_source('mnist_example', mnist_example = imp.load_source('mnist_example',
os.path.join(os.path.dirname(__file__), '..', 'basics', 'mnist-convnet.py')) os.path.join(os.path.dirname(__file__), '..', 'basics', 'mnist-convnet.py'))
get_config = mnist_example.get_config get_config = mnist_example.get_config
......
...@@ -3,13 +3,12 @@ ...@@ -3,13 +3,12 @@
# File: svhn-disturb.py # File: svhn-disturb.py
import argparse import argparse
import os
import imp import imp
import os
from tensorpack import * from tensorpack import *
from tensorpack.utils import logger
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.utils import logger
from disturb import DisturbLabel from disturb import DisturbLabel
......
...@@ -3,24 +3,22 @@ ...@@ -3,24 +3,22 @@
# File: alexnet-dorefa.py # File: alexnet-dorefa.py
# Author: Yuxin Wu, Yuheng Zou ({wyx,zyh}@megvii.com) # Author: Yuxin Wu, Yuheng Zou ({wyx,zyh}@megvii.com)
import cv2
import tensorflow as tf
import argparse import argparse
import numpy as np import numpy as np
import os import os
import sys import sys
import cv2
import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_param_summary from tensorpack.dataflow import dataset
from tensorpack.tfutils.sessinit import get_model_loader from tensorpack.tfutils.sessinit import get_model_loader
from tensorpack.tfutils.summary import add_param_summary
from tensorpack.tfutils.varreplace import remap_variables from tensorpack.tfutils.varreplace import remap_variables
from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import (
get_imagenet_dataflow, fbresnet_augmentor, ImageNetModel, eval_on_ILSVRC12)
from dorefa import get_dorefa, ternarize from dorefa import get_dorefa, ternarize
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, fbresnet_augmentor, get_imagenet_dataflow
""" """
This is a tensorpack script for the ImageNet results in paper: This is a tensorpack script for the ImageNet results in paper:
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# Author: Yuxin Wu # Author: Yuxin Wu
import tensorflow as tf import tensorflow as tf
from tensorpack.utils.argtools import graph_memoized from tensorpack.utils.argtools import graph_memoized
......
...@@ -2,18 +2,18 @@ ...@@ -2,18 +2,18 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: resnet-dorefa.py # File: resnet-dorefa.py
import cv2
import tensorflow as tf
import argparse import argparse
import numpy as np import numpy as np
import os import os
import cv2
import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.tfutils.varreplace import remap_variables from tensorpack.tfutils.varreplace import remap_variables
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, fbresnet_augmentor
from dorefa import get_dorefa from dorefa import get_dorefa
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, fbresnet_augmentor
""" """
This script loads the pre-trained ResNet-18 model with (W,A,G) = (1,4,32) This script loads the pre-trained ResNet-18 model with (W,A,G) = (1,4,32)
......
...@@ -3,13 +3,13 @@ ...@@ -3,13 +3,13 @@
# File: svhn-digit-dorefa.py # File: svhn-digit-dorefa.py
# Author: Yuxin Wu # Author: Yuxin Wu
import os
import argparse import argparse
import os
import tensorflow as tf import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary, add_param_summary
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.tfutils.summary import add_moving_summary, add_param_summary
from tensorpack.tfutils.varreplace import remap_variables from tensorpack.tfutils.varreplace import remap_variables
from dorefa import get_dorefa from dorefa import get_dorefa
......
...@@ -3,19 +3,18 @@ ...@@ -3,19 +3,18 @@
# File: steering-filter.py # File: steering-filter.py
import argparse import argparse
import multiprocessing
import numpy as np import numpy as np
import tensorflow as tf
import cv2 import cv2
import tensorflow as tf
from scipy.signal import convolve2d from scipy.signal import convolve2d
from six.moves import range, zip from six.moves import range, zip
import multiprocessing
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.viz import *
from tensorpack.utils.argtools import shape2d, shape4d from tensorpack.utils.argtools import shape2d, shape4d
from tensorpack.dataflow import dataset from tensorpack.utils.viz import *
BATCH = 32 BATCH = 32
SHAPE = 64 SHAPE = 64
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: basemodel.py # File: basemodel.py
from contextlib import contextmanager, ExitStack
import numpy as np import numpy as np
from contextlib import ExitStack, contextmanager
import tensorflow as tf import tensorflow as tf
from tensorpack.models import BatchNorm, Conv2D, MaxPooling, layer_register
from tensorpack.tfutils import argscope from tensorpack.tfutils import argscope
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils.varreplace import custom_getter_scope, freeze_variables from tensorpack.tfutils.varreplace import custom_getter_scope, freeze_variables
from tensorpack.models import (
Conv2D, MaxPooling, BatchNorm, layer_register)
from config import config as cfg from config import config as cfg
......
...@@ -3,17 +3,16 @@ ...@@ -3,17 +3,16 @@
import numpy as np import numpy as np
import os import os
from termcolor import colored
from tabulate import tabulate
import tqdm import tqdm
from tabulate import tabulate
from termcolor import colored
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.timer import timed_operation
from tensorpack.utils.argtools import log_once from tensorpack.utils.argtools import log_once
from tensorpack.utils.timer import timed_operation
from config import config as cfg from config import config as cfg
__all__ = ['COCODetection', 'COCOMeta'] __all__ = ['COCODetection', 'COCOMeta']
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import numpy as np import numpy as np
import os import os
import pprint import pprint
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: data.py # File: data.py
import cv2
import numpy as np
import copy import copy
import numpy as np
import cv2
from tensorpack.utils.argtools import memoized, log_once
from tensorpack.dataflow import ( from tensorpack.dataflow import (
imgaug, TestDataSpeed, DataFromList, MapDataComponent, MultiProcessMapDataZMQ, MultiThreadMapData, TestDataSpeed, imgaug)
MultiProcessMapDataZMQ, MultiThreadMapData,
MapDataComponent, DataFromList)
from tensorpack.utils import logger from tensorpack.utils import logger
# import tensorpack.utils.viz as tpviz from tensorpack.utils.argtools import log_once, memoized
from coco import COCODetection from coco import COCODetection
from common import (
CustomResize, DataFromListOfDict, box_to_point8, filter_boxes_inside_shape, point8_to_box, segmentation_to_mask)
from config import config as cfg
from utils.generate_anchors import generate_anchors from utils.generate_anchors import generate_anchors
from utils.np_box_ops import area as np_area from utils.np_box_ops import area as np_area
from utils.np_box_ops import ioa as np_ioa from utils.np_box_ops import ioa as np_ioa
from common import (
DataFromListOfDict, CustomResize, filter_boxes_inside_shape, # import tensorpack.utils.viz as tpviz
box_to_point8, point8_to_box, segmentation_to_mask)
from config import config as cfg
try: try:
import pycocotools.mask as cocomask import pycocotools.mask as cocomask
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: eval.py # File: eval.py
import tqdm import itertools
import numpy as np
import os import os
from collections import namedtuple from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack from contextlib import ExitStack
import itertools
import numpy as np
import cv2 import cv2
from concurrent.futures import ThreadPoolExecutor import pycocotools.mask as cocomask
import tqdm
from tensorpack.utils.utils import get_tqdm_kwargs
from pycocotools.coco import COCO from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval
import pycocotools.mask as cocomask
from tensorpack.utils.utils import get_tqdm_kwargs
from coco import COCOMeta from coco import COCOMeta
from common import CustomResize, clip_boxes from common import CustomResize, clip_boxes
......
...@@ -2,10 +2,10 @@ import tensorflow as tf ...@@ -2,10 +2,10 @@ import tensorflow as tf
from tensorpack.tfutils import get_current_tower_context from tensorpack.tfutils import get_current_tower_context
from utils.box_ops import pairwise_iou
from model_box import clip_boxes
from model_frcnn import FastRCNNHead, BoxProposals, fastrcnn_outputs
from config import config as cfg from config import config as cfg
from model_box import clip_boxes
from model_frcnn import BoxProposals, FastRCNNHead, fastrcnn_outputs
from utils.box_ops import pairwise_iou
class CascadeRCNNHead(object): class CascadeRCNNHead(object):
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import itertools
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import itertools
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.models import Conv2D, FixedUnPooling, MaxPooling, layer_register
from tensorpack.tfutils.argscope import argscope from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.tower import get_current_tower_context
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.models import ( from tensorpack.tfutils.summary import add_moving_summary
Conv2D, layer_register, FixedUnPooling, MaxPooling) from tensorpack.tfutils.tower import get_current_tower_context
from model_rpn import rpn_losses, generate_rpn_proposals from basemodel import GroupNorm
from config import config as cfg
from model_box import roi_align from model_box import roi_align
from model_rpn import generate_rpn_proposals, rpn_losses
from utils.box_ops import area as tf_area from utils.box_ops import area as tf_area
from config import config as cfg
from basemodel import GroupNorm
@layer_register(log_shape=True) @layer_register(log_shape=True)
......
...@@ -3,18 +3,17 @@ ...@@ -3,18 +3,17 @@
import tensorflow as tf import tensorflow as tf
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.models import Conv2D, FullyConnected, layer_register
from tensorpack.tfutils.argscope import argscope from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.common import get_tf_version_tuple from tensorpack.tfutils.common import get_tf_version_tuple
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.models import ( from tensorpack.tfutils.summary import add_moving_summary
Conv2D, FullyConnected, layer_register)
from tensorpack.utils.argtools import memoized_method from tensorpack.utils.argtools import memoized_method
from basemodel import GroupNorm from basemodel import GroupNorm
from utils.box_ops import pairwise_iou
from model_box import encode_bbox_target, decode_bbox_target
from config import config as cfg from config import config as cfg
from model_box import decode_bbox_target, encode_bbox_target
from utils.box_ops import pairwise_iou
@under_name_scope() @under_name_scope()
......
...@@ -2,12 +2,11 @@ ...@@ -2,12 +2,11 @@
import tensorflow as tf import tensorflow as tf
from tensorpack.models import ( from tensorpack.models import Conv2D, Conv2DTranspose, layer_register
Conv2D, layer_register, Conv2DTranspose)
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils.argscope import argscope from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.common import get_tf_version_tuple from tensorpack.tfutils.common import get_tf_version_tuple
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils.summary import add_moving_summary
from basemodel import GroupNorm from basemodel import GroupNorm
from config import config as cfg from config import config as cfg
......
...@@ -2,14 +2,13 @@ ...@@ -2,14 +2,13 @@
import tensorflow as tf import tensorflow as tf
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import under_name_scope, auto_reuse_variable_scope
from tensorpack.models import Conv2D, layer_register from tensorpack.models import Conv2D, layer_register
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope, under_name_scope
from tensorpack.tfutils.summary import add_moving_summary
from model_box import clip_boxes
from config import config as cfg from config import config as cfg
from model_box import clip_boxes
@layer_register(log_shape=True) @layer_register(log_shape=True)
......
...@@ -2,58 +2,45 @@ ...@@ -2,58 +2,45 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: train.py # File: train.py
import os
import argparse import argparse
import cv2
import shutil
import itertools import itertools
import tqdm
import numpy as np
import json import json
import numpy as np
import os
import shutil
import cv2
import six import six
import tensorflow as tf import tensorflow as tf
try: import tqdm
import horovod.tensorflow as hvd
except ImportError:
pass
assert six.PY3, "FasterRCNN requires Python 3!"
import tensorpack.utils.viz as tpviz
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils import optimizer from tensorpack.tfutils import optimizer
from tensorpack.tfutils.common import get_tf_version_tuple from tensorpack.tfutils.common import get_tf_version_tuple
import tensorpack.utils.viz as tpviz from tensorpack.tfutils.summary import add_moving_summary
from coco import COCODetection
from basemodel import (
image_preprocess, resnet_c4_backbone, resnet_conv5,
resnet_fpn_backbone)
import model_frcnn import model_frcnn
import model_mrcnn import model_mrcnn
from model_frcnn import ( from basemodel import image_preprocess, resnet_c4_backbone, resnet_conv5, resnet_fpn_backbone
sample_fast_rcnn_targets, fastrcnn_outputs, from coco import COCODetection
fastrcnn_predictions, BoxProposals, FastRCNNHead) from config import config as cfg
from model_mrcnn import maskrcnn_upXconv_head, maskrcnn_loss from config import finalize_configs
from model_rpn import rpn_head, rpn_losses, generate_rpn_proposals from data import get_all_anchors, get_all_anchors_fpn, get_eval_dataflow, get_train_dataflow
from model_fpn import ( from eval import DetectionResult, detect_one_image, eval_coco, multithread_eval_coco, print_coco_metrics
fpn_model, multilevel_roi_align, from model_box import RPNAnchors, clip_boxes, crop_and_resize, roi_align
multilevel_rpn_losses, generate_fpn_proposals)
from model_cascade import CascadeRCNNHead from model_cascade import CascadeRCNNHead
from model_box import ( from model_fpn import fpn_model, generate_fpn_proposals, multilevel_roi_align, multilevel_rpn_losses
clip_boxes, crop_and_resize, roi_align, RPNAnchors) from model_frcnn import BoxProposals, FastRCNNHead, fastrcnn_outputs, fastrcnn_predictions, sample_fast_rcnn_targets
from model_mrcnn import maskrcnn_loss, maskrcnn_upXconv_head
from data import ( from model_rpn import generate_rpn_proposals, rpn_head, rpn_losses
get_train_dataflow, get_eval_dataflow, from viz import draw_annotation, draw_final_outputs, draw_predictions, draw_proposal_recall
get_all_anchors, get_all_anchors_fpn)
from viz import ( try:
draw_annotation, draw_proposal_recall, import horovod.tensorflow as hvd
draw_predictions, draw_final_outputs) except ImportError:
from eval import ( pass
eval_coco, multithread_eval_coco,
detect_one_image, print_coco_metrics, DetectionResult) assert six.PY3, "FasterRCNN requires Python 3!"
from config import finalize_configs, config as cfg
class DetectionModel(ModelDesc): class DetectionModel(ModelDesc):
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
# File: box_ops.py # File: box_ops.py
import tensorflow as tf import tensorflow as tf
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
""" """
This file is modified from This file is modified from
https://github.com/tensorflow/models/blob/master/object_detection/core/box_list_ops.py https://github.com/tensorflow/models/blob/master/object_detection/core/box_list_ops.py
......
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
# Written by Ross Girshick and Sean Bell # Written by Ross Girshick and Sean Bell
# -------------------------------------------------------- # --------------------------------------------------------
from six.moves import range
import numpy as np import numpy as np
from six.moves import range
# Verify that we compute the same anchors as Shaoqing's matlab implementation: # Verify that we compute the same anchors as Shaoqing's matlab implementation:
# #
...@@ -27,7 +27,7 @@ import numpy as np ...@@ -27,7 +27,7 @@ import numpy as np
# -79 -167 96 184 # -79 -167 96 184
# -167 -343 184 360 # -167 -343 184 360
#array([[ -83., -39., 100., 56.], # array([[ -83., -39., 100., 56.],
# [-175., -87., 192., 104.], # [-175., -87., 192., 104.],
# [-359., -183., 376., 200.], # [-359., -183., 376., 200.],
# [ -55., -55., 72., 72.], # [ -55., -55., 72., 72.],
...@@ -37,6 +37,7 @@ import numpy as np ...@@ -37,6 +37,7 @@ import numpy as np
# [ -79., -167., 96., 184.], # [ -79., -167., 96., 184.],
# [-167., -343., 184., 360.]]) # [-167., -343., 184., 360.]])
def generate_anchors(base_size=16, ratios=[0.5, 1, 2], def generate_anchors(base_size=16, ratios=[0.5, 1, 2],
scales=2**np.arange(3, 6)): scales=2**np.arange(3, 6)):
""" """
...@@ -50,6 +51,7 @@ def generate_anchors(base_size=16, ratios=[0.5, 1, 2], ...@@ -50,6 +51,7 @@ def generate_anchors(base_size=16, ratios=[0.5, 1, 2],
for i in range(ratio_anchors.shape[0])]) for i in range(ratio_anchors.shape[0])])
return anchors return anchors
def _whctrs(anchor): def _whctrs(anchor):
""" """
Return width, height, x center, and y center for an anchor (window). Return width, height, x center, and y center for an anchor (window).
...@@ -61,6 +63,7 @@ def _whctrs(anchor): ...@@ -61,6 +63,7 @@ def _whctrs(anchor):
y_ctr = anchor[1] + 0.5 * (h - 1) y_ctr = anchor[1] + 0.5 * (h - 1)
return w, h, x_ctr, y_ctr return w, h, x_ctr, y_ctr
def _mkanchors(ws, hs, x_ctr, y_ctr): def _mkanchors(ws, hs, x_ctr, y_ctr):
""" """
Given a vector of widths (ws) and heights (hs) around a center Given a vector of widths (ws) and heights (hs) around a center
...@@ -75,6 +78,7 @@ def _mkanchors(ws, hs, x_ctr, y_ctr): ...@@ -75,6 +78,7 @@ def _mkanchors(ws, hs, x_ctr, y_ctr):
y_ctr + 0.5 * (hs - 1))) y_ctr + 0.5 * (hs - 1)))
return anchors return anchors
def _ratio_enum(anchor, ratios): def _ratio_enum(anchor, ratios):
""" """
Enumerate a set of anchors for each aspect ratio wrt an anchor. Enumerate a set of anchors for each aspect ratio wrt an anchor.
...@@ -88,6 +92,7 @@ def _ratio_enum(anchor, ratios): ...@@ -88,6 +92,7 @@ def _ratio_enum(anchor, ratios):
anchors = _mkanchors(ws, hs, x_ctr, y_ctr) anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors return anchors
def _scale_enum(anchor, scales): def _scale_enum(anchor, scales):
""" """
Enumerate a set of anchors for each scale wrt an anchor. Enumerate a set of anchors for each scale wrt an anchor.
...@@ -98,17 +103,3 @@ def _scale_enum(anchor, scales): ...@@ -98,17 +103,3 @@ def _scale_enum(anchor, scales):
hs = h * scales hs = h * scales
anchors = _mkanchors(ws, hs, x_ctr, y_ctr) anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors return anchors
if __name__ == '__main__':
#import time
#t = time.time()
#a = generate_anchors()
#print(time.time() - t)
#print(a)
#from IPython import embed; embed()
anchors = generate_anchors(
16, scales=np.asarray((2, 4, 8, 16, 32), 'float32'),
ratios=[0.5,1,2])
print(anchors)
import IPython as IP; IP.embed()
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: viz.py # File: viz.py
from six.moves import zip
import numpy as np import numpy as np
from six.moves import zip
from tensorpack.utils import viz from tensorpack.utils import viz
from tensorpack.utils.palette import PALETTE_RGB from tensorpack.utils.palette import PALETTE_RGB
from utils.np_box_ops import iou as np_iou
from config import config as cfg from config import config as cfg
from utils.np_box_ops import iou as np_iou
def draw_annotation(img, boxes, klass, is_crowd=None): def draw_annotation(img, boxes, klass, is_crowd=None):
......
...@@ -3,12 +3,14 @@ ...@@ -3,12 +3,14 @@
# File: BEGAN.py # File: BEGAN.py
# Author: Yuxin Wu # Author: Yuxin Wu
import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf
import DCGAN
from GAN import GANModelDesc, GANTrainer, MultiGPUGANTrainer from GAN import GANModelDesc, GANTrainer, MultiGPUGANTrainer
""" """
...@@ -19,7 +21,6 @@ A pretrained model on CelebA is at http://models.tensorpack.com/GAN/ ...@@ -19,7 +21,6 @@ A pretrained model on CelebA is at http://models.tensorpack.com/GAN/
""" """
import DCGAN
NH = 64 NH = 64
NF = 64 NF = 64
GAMMA = 0.5 GAMMA = 0.5
......
...@@ -3,18 +3,18 @@ ...@@ -3,18 +3,18 @@
# File: ConditionalGAN-mnist.py # File: ConditionalGAN-mnist.py
# Author: Yuxin Wu # Author: Yuxin Wu
import argparse
import numpy as np import numpy as np
import tensorflow as tf
import os import os
import cv2 import cv2
import argparse import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import interactive_imshow, stack_patches
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from GAN import GANTrainer, RandomZData, GANModelDesc from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.utils.viz import interactive_imshow, stack_patches
from GAN import GANModelDesc, GANTrainer, RandomZData
""" """
To train: To train:
......
...@@ -3,17 +3,17 @@ ...@@ -3,17 +3,17 @@
# File: CycleGAN.py # File: CycleGAN.py
# Author: Yuxin Wu # Author: Yuxin Wu
import os
import argparse import argparse
import glob import glob
import os
import tensorflow as tf
from six.moves import range from six.moves import range
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf from tensorpack.tfutils.summary import add_moving_summary
from GAN import GANTrainer, GANModelDesc
from GAN import GANModelDesc, GANTrainer
""" """
1. Download the dataset following the original project: https://github.com/junyanz/CycleGAN#train 1. Download the dataset following the original project: https://github.com/junyanz/CycleGAN#train
......
...@@ -3,18 +3,18 @@ ...@@ -3,18 +3,18 @@
# File: DCGAN.py # File: DCGAN.py
# Author: Yuxin Wu # Author: Yuxin Wu
import argparse
import glob import glob
import numpy as np import numpy as np
import os import os
import argparse import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import stack_patches
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf from tensorpack.utils.viz import stack_patches
from GAN import GANModelDesc, GANTrainer, RandomZData
from GAN import GANTrainer, RandomZData, GANModelDesc
""" """
1. Download the 'aligned&cropped' version of CelebA dataset 1. Download the 'aligned&cropped' version of CelebA dataset
......
...@@ -3,17 +3,17 @@ ...@@ -3,17 +3,17 @@
# File: DiscoGAN-CelebA.py # File: DiscoGAN-CelebA.py
# Author: Yuxin Wu # Author: Yuxin Wu
import os
import argparse import argparse
from six.moves import map, zip
import numpy as np import numpy as np
import os
import tensorflow as tf
from six.moves import map, zip
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf from tensorpack.tfutils.summary import add_moving_summary
from GAN import SeparateGANTrainer, GANModelDesc
from GAN import GANModelDesc, SeparateGANTrainer
""" """
1. Download "aligned&cropped" version of celebA to /path/to/img_align_celeba. 1. Download "aligned&cropped" version of celebA to /path/to/img_align_celeba.
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# File: GAN.py # File: GAN.py
# Author: Yuxin Wu # Author: Yuxin Wu
import tensorflow as tf
import numpy as np import numpy as np
from tensorpack import (TowerTrainer, StagingInput, import tensorflow as tf
ModelDescBase, DataFlow, argscope, BatchNorm)
from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper from tensorpack import BatchNorm, DataFlow, ModelDescBase, StagingInput, TowerTrainer, argscope
from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.argtools import memoized_method from tensorpack.utils.argtools import memoized_method
from tensorpack.utils.develop import deprecated from tensorpack.utils.develop import deprecated
......
...@@ -3,20 +3,20 @@ ...@@ -3,20 +3,20 @@
# File: Image2Image.py # File: Image2Image.py
# Author: Yuxin Wu # Author: Yuxin Wu
import cv2 import argparse
import numpy as np
import tensorflow as tf
import glob import glob
import numpy as np
import os import os
import argparse import cv2
import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
from tensorpack.utils.viz import stack_patches from tensorpack.utils.viz import stack_patches
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from GAN import GANModelDesc, GANTrainer
from GAN import GANTrainer, GANModelDesc
""" """
To train Image-to-Image translation model with image pairs: To train Image-to-Image translation model with image pairs:
......
...@@ -3,12 +3,14 @@ ...@@ -3,12 +3,14 @@
# File: Improved-WGAN.py # File: Improved-WGAN.py
# Author: Yuxin Wu # Author: Yuxin Wu
import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.tfutils import get_tf_version_tuple from tensorpack.tfutils import get_tf_version_tuple
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf from tensorpack.tfutils.summary import add_moving_summary
import DCGAN
from GAN import SeparateGANTrainer from GAN import SeparateGANTrainer
""" """
...@@ -18,7 +20,6 @@ See the docstring in DCGAN.py for usage. ...@@ -18,7 +20,6 @@ See the docstring in DCGAN.py for usage.
# Don't want to mix two examples together, but want to reuse the code. # Don't want to mix two examples together, but want to reuse the code.
# So here just import stuff from DCGAN. # So here just import stuff from DCGAN.
import DCGAN
class Model(DCGAN.Model): class Model(DCGAN.Model):
......
...@@ -3,19 +3,19 @@ ...@@ -3,19 +3,19 @@
# File: InfoGAN-mnist.py # File: InfoGAN-mnist.py
# Author: Yuxin Wu # Author: Yuxin Wu
import cv2 import argparse
import numpy as np import numpy as np
import tensorflow as tf
import os import os
import argparse import cv2
import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.utils import viz
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope, under_name_scope
from tensorpack.tfutils import optimizer, summary, gradproc
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from GAN import GANTrainer, GANModelDesc from tensorpack.tfutils import gradproc, optimizer, summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope, under_name_scope
from tensorpack.utils import viz
from GAN import GANModelDesc, GANTrainer
""" """
To train: To train:
......
...@@ -3,9 +3,12 @@ ...@@ -3,9 +3,12 @@
# File: WGAN.py # File: WGAN.py
# Author: Yuxin Wu # Author: Yuxin Wu
import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
import tensorflow as tf
import DCGAN
from GAN import SeparateGANTrainer from GAN import SeparateGANTrainer
""" """
...@@ -15,7 +18,6 @@ See the docstring in DCGAN.py for usage. ...@@ -15,7 +18,6 @@ See the docstring in DCGAN.py for usage.
# Don't want to mix two examples together, but want to reuse the code. # Don't want to mix two examples together, but want to reuse the code.
# So here just import stuff from DCGAN # So here just import stuff from DCGAN
import DCGAN
class Model(DCGAN.Model): class Model(DCGAN.Model):
......
...@@ -3,19 +3,18 @@ ...@@ -3,19 +3,18 @@
# File: hed.py # File: hed.py
# Author: Yuxin Wu # Author: Yuxin Wu
import argparse
import numpy as np
import os
import cv2 import cv2
import tensorflow as tf import tensorflow as tf
import numpy as np
import argparse
from six.moves import zip from six.moves import zip
import os
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_num_gpu from tensorpack.tfutils import gradproc, optimizer
from tensorpack.tfutils import optimizer, gradproc
from tensorpack.tfutils.summary import add_moving_summary, add_param_summary from tensorpack.tfutils.summary import add_moving_summary, add_param_summary
from tensorpack.utils.gpu import get_num_gpu
def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss'): def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss'):
......
...@@ -3,10 +3,9 @@ ...@@ -3,10 +3,9 @@
# File: alexnet.py # File: alexnet.py
import argparse import argparse
import numpy as np
import os import os
import cv2 import cv2
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorpack import * from tensorpack import *
......
...@@ -2,24 +2,22 @@ ...@@ -2,24 +2,22 @@
# File: imagenet_utils.py # File: imagenet_utils.py
import cv2
import os
import numpy as np
import tqdm
import multiprocessing import multiprocessing
import tensorflow as tf import numpy as np
import os
from abc import abstractmethod from abc import abstractmethod
import cv2
import tensorflow as tf
import tqdm
from tensorpack import ModelDesc from tensorpack import ModelDesc
from tensorpack.dataflow import AugmentImageComponent, BatchData, MultiThreadMapData, PrefetchDataZMQ, dataset, imgaug
from tensorpack.input_source import QueueInput, StagingInput from tensorpack.input_source import QueueInput, StagingInput
from tensorpack.dataflow import (
imgaug, dataset, AugmentImageComponent, PrefetchDataZMQ,
BatchData, MultiThreadMapData)
from tensorpack.predict import PredictConfig, FeedfreePredictor
from tensorpack.utils.stats import RatioCounter
from tensorpack.models import regularize_cost from tensorpack.models import regularize_cost
from tensorpack.predict import FeedfreePredictor, PredictConfig
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.stats import RatioCounter
""" """
......
...@@ -7,10 +7,9 @@ import argparse ...@@ -7,10 +7,9 @@ import argparse
import os import os
import tensorflow as tf import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import fbresnet_augmentor, get_imagenet_dataflow from imagenet_utils import fbresnet_augmentor, get_imagenet_dataflow
......
...@@ -3,24 +3,20 @@ ...@@ -3,24 +3,20 @@
# File: shufflenet.py # File: shufflenet.py
import argparse import argparse
import numpy as np
import math import math
import numpy as np
import os import os
import cv2 import cv2
import tensorflow as tf import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import imgaug from tensorpack.dataflow import imgaug
from tensorpack.tfutils import argscope, get_model_loader, model_utils from tensorpack.tfutils import argscope, get_model_loader, model_utils
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.utils.gpu import get_num_gpu
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import ( from imagenet_utils import GoogleNetResize, ImageNetModel, eval_on_ILSVRC12, get_imagenet_dataflow
get_imagenet_dataflow,
ImageNetModel, GoogleNetResize, eval_on_ILSVRC12)
@layer_register(log_shape=True) @layer_register(log_shape=True)
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import argparse import argparse
import os import os
import tensorflow as tf import tensorflow as tf
from tensorpack import * from tensorpack import *
...@@ -12,8 +11,7 @@ from tensorpack.tfutils import argscope ...@@ -12,8 +11,7 @@ from tensorpack.tfutils import argscope
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import ( from imagenet_utils import ImageNetModel, fbresnet_augmentor, get_imagenet_dataflow
ImageNetModel, get_imagenet_dataflow, fbresnet_augmentor)
def GroupNorm(x, group, gamma_initializer=tf.constant_initializer(1.)): def GroupNorm(x, group, gamma_initializer=tf.constant_initializer(1.)):
......
...@@ -2,16 +2,16 @@ ...@@ -2,16 +2,16 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Author: Patrick Wieschollek <mail@patwie.com> # Author: Patrick Wieschollek <mail@patwie.com>
import argparse
import glob
import os import os
import cv2 import cv2
import glob
from helper import Flow
import argparse
from tensorpack import * from tensorpack import *
from tensorpack.utils import viz from tensorpack.utils import viz
import flownet_models as models import flownet_models as models
from helper import Flow
def apply(model, model_path, left, right, ground_truth=None): def apply(model, model_path, left, right, ground_truth=None):
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import tensorflow as tf import tensorflow as tf
from tensorpack import ModelDesc, argscope, enable_argscope_for_module from tensorpack import ModelDesc, argscope, enable_argscope_for_module
enable_argscope_for_module(tf.layers) enable_argscope_for_module(tf.layers)
......
...@@ -3,21 +3,20 @@ ...@@ -3,21 +3,20 @@
# File: PTB-LSTM.py # File: PTB-LSTM.py
# Author: Yuxin Wu # Author: Yuxin Wu
import argparse
import numpy as np import numpy as np
import os import os
import argparse import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.tfutils import optimizer, summary, gradproc from tensorpack.tfutils import gradproc, optimizer, summary
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.fs import download, get_dataset_path
from tensorpack.utils.argtools import memoized_ignoreargs from tensorpack.utils.argtools import memoized_ignoreargs
from tensorpack.utils.fs import download, get_dataset_path
import reader as tfreader import reader as tfreader
from reader import ptb_producer from reader import ptb_producer
import tensorflow as tf
rnn = tf.contrib.rnn rnn = tf.contrib.rnn
SEQ_LEN = 35 SEQ_LEN = 35
......
...@@ -16,13 +16,9 @@ ...@@ -16,13 +16,9 @@
"""Utilities for parsing PTB text files.""" """Utilities for parsing PTB text files."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import collections import collections
import os import os
import tensorflow as tf import tensorflow as tf
......
...@@ -3,14 +3,14 @@ ...@@ -3,14 +3,14 @@
# File: cifar10-preact18-mixup.py # File: cifar10-preact18-mixup.py
# Author: Tao Hu <taohu620@gmail.com>, Yauheni Selivonchyk <y.selivonchyk@gmail.com> # Author: Tao Hu <taohu620@gmail.com>, Yauheni Selivonchyk <y.selivonchyk@gmail.com>
import numpy as np
import argparse import argparse
import numpy as np
import os import os
import tensorflow as tf import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.tfutils.summary import *
BATCH_SIZE = 128 BATCH_SIZE = 128
CLASS_NUM = 10 CLASS_NUM = 10
......
...@@ -5,14 +5,12 @@ ...@@ -5,14 +5,12 @@
import argparse import argparse
import os import os
import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils.summary import add_moving_summary, add_param_summary from tensorpack.tfutils.summary import add_moving_summary, add_param_summary
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
from tensorpack.dataflow import dataset
import tensorflow as tf
""" """
CIFAR10 ResNet example. See: CIFAR10 ResNet example. See:
......
...@@ -5,22 +5,18 @@ ...@@ -5,22 +5,18 @@
import argparse import argparse
import os import os
from tensorpack import logger, QueueInput, TFDatasetInput from tensorpack import QueueInput, TFDatasetInput, logger
from tensorpack.models import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.train import (
TrainConfig, SyncMultiGPUTrainerReplicated, launch_train_with_config)
from tensorpack.dataflow import FakeData from tensorpack.dataflow import FakeData
from tensorpack.models import *
from tensorpack.tfutils import argscope, get_model_loader from tensorpack.tfutils import argscope, get_model_loader
from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import ( from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, get_imagenet_dataflow, get_imagenet_tfdata
get_imagenet_dataflow, get_imagenet_tfdata,
ImageNetModel, eval_on_ILSVRC12)
from resnet_model import ( from resnet_model import (
preresnet_group, preresnet_basicblock, preresnet_bottleneck, preresnet_basicblock, preresnet_bottleneck, preresnet_group, resnet_backbone, resnet_basicblock, resnet_bottleneck,
resnet_group, resnet_basicblock, resnet_bottleneck, se_resnet_bottleneck, resnet_group, se_resnet_bottleneck)
resnet_backbone)
class Model(ImageNetModel): class Model(ImageNetModel):
......
...@@ -4,20 +4,20 @@ ...@@ -4,20 +4,20 @@
# Author: Eric Yujia Huang <yujiah1@andrew.cmu.edu> # Author: Eric Yujia Huang <yujiah1@andrew.cmu.edu>
# Yuxin Wu # Yuxin Wu
import cv2
import functools
import tensorflow as tf
import argparse import argparse
import re import functools
import numpy as np import numpy as np
import re
import cv2
import six import six
import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.utils import logger
from tensorpack.dataflow.dataset import ILSVRCMeta from tensorpack.dataflow.dataset import ILSVRCMeta
from tensorpack.utils import logger
from imagenet_utils import eval_on_ILSVRC12, get_imagenet_dataflow, ImageNetModel from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, get_imagenet_dataflow
from resnet_model import resnet_group, resnet_bottleneck from resnet_model import resnet_bottleneck, resnet_group
DEPTH = None DEPTH = None
CFG = { CFG = {
......
...@@ -3,9 +3,8 @@ ...@@ -3,9 +3,8 @@
import tensorflow as tf import tensorflow as tf
from tensorpack.models import BatchNorm, BNReLU, Conv2D, FullyConnected, GlobalAvgPooling, MaxPooling
from tensorpack.tfutils.argscope import argscope, get_arg_scope from tensorpack.tfutils.argscope import argscope, get_arg_scope
from tensorpack.models import (
Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm, BNReLU, FullyConnected)
def resnet_shortcut(l, n_out, stride, activation=tf.identity): def resnet_shortcut(l, n_out, stride, activation=tf.identity):
......
...@@ -2,28 +2,24 @@ ...@@ -2,28 +2,24 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: CAM-resnet.py # File: CAM-resnet.py
import cv2
import sys
import argparse import argparse
import multiprocessing
import numpy as np import numpy as np
import os import os
import multiprocessing import sys
import cv2
import tensorflow as tf import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.tfutils import optimizer, gradproc from tensorpack.tfutils import gradproc, optimizer
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_num_gpu from tensorpack.tfutils.symbolic_functions import *
from tensorpack.utils import viz from tensorpack.utils import viz
from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import ( from imagenet_utils import ImageNetModel, fbresnet_augmentor
fbresnet_augmentor, ImageNetModel) from resnet_model import preresnet_basicblock, preresnet_group
from resnet_model import (
preresnet_basicblock, preresnet_group)
TOTAL_BATCH_SIZE = 256 TOTAL_BATCH_SIZE = 256
DEPTH = None DEPTH = None
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import cv2 import numpy as np
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
import numpy as np import cv2
import tensorflow as tf import tensorflow as tf
import tensorflow.contrib.slim as slim import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.nets import resnet_v1 from tensorflow.contrib.slim.nets import resnet_v1
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
# Author: tensorpack contributors # Author: tensorpack contributors
import numpy as np import numpy as np
from tensorpack.dataflow import dataset, BatchData
from tensorpack.dataflow import BatchData, dataset
def get_test_data(batch=128): def get_test_data(batch=128):
......
...@@ -2,17 +2,16 @@ ...@@ -2,17 +2,16 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: mnist-embeddings.py # File: mnist-embeddings.py
import numpy as np
import argparse import argparse
import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow.contrib.slim as slim import tensorflow.contrib.slim as slim
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.gpu import change_gpu from tensorpack.utils.gpu import change_gpu
from embedding_data import get_test_data, MnistPairs, MnistTriplets from embedding_data import MnistPairs, MnistTriplets, get_test_data
MATPLOTLIB_AVAIBLABLE = False MATPLOTLIB_AVAIBLABLE = False
try: try:
......
...@@ -3,16 +3,15 @@ ...@@ -3,16 +3,15 @@
# File: mnist-addition.py # File: mnist-addition.py
# Author: Yuxin Wu # Author: Yuxin Wu
import cv2 import argparse
import numpy as np import numpy as np
import tensorflow as tf
import os import os
import argparse import cv2
import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.tfutils import optimizer, summary, gradproc from tensorpack.tfutils import gradproc, optimizer, summary
IMAGE_SIZE = 42 IMAGE_SIZE = 42
WARP_TARGET_SIZE = 28 WARP_TARGET_SIZE = 28
......
import cv2
import os
import argparse import argparse
import numpy as np import numpy as np
import os
import zipfile import zipfile
from tensorpack import RNGDataFlow, MapDataComponent, LMDBSerializer import cv2
from tensorpack import LMDBSerializer, MapDataComponent, RNGDataFlow
class ImageDataFromZIPFile(RNGDataFlow): class ImageDataFromZIPFile(RNGDataFlow):
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Author: Patrick Wieschollek <mail@patwie.com> # Author: Patrick Wieschollek <mail@patwie.com>
import os
import argparse import argparse
import numpy as np
import os
import cv2 import cv2
import six import six
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorpack import * from tensorpack import *
...@@ -14,10 +14,10 @@ from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope ...@@ -14,10 +14,10 @@ from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
from data_sampler import (
ImageDecode, ImageDataFromZIPFile, from data_sampler import CenterSquareResize, ImageDataFromZIPFile, ImageDecode, RejectTooSmallImages
RejectTooSmallImages, CenterSquareResize) from GAN import GANModelDesc, SeparateGANTrainer
from GAN import SeparateGANTrainer, GANModelDesc
Reduction = tf.losses.Reduction Reduction = tf.losses.Reduction
BATCH_SIZE = 16 BATCH_SIZE = 16
......
...@@ -2,15 +2,16 @@ ...@@ -2,15 +2,16 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: cifar-convnet.py # File: cifar-convnet.py
# Author: Yuxin Wu # Author: Yuxin Wu
import tensorflow as tf
import argparse import argparse
import os import os
import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
""" """
A small convnet model for Cifar10 or Cifar100 dataset. A small convnet model for Cifar10 or Cifar100 dataset.
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import argparse import argparse
import cv2 import cv2
import tensorflow as tf import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.export import ModelExporter from tensorpack.tfutils.export import ModelExporter
......
...@@ -3,17 +3,16 @@ ...@@ -3,17 +3,16 @@
# File: mnist-convnet.py # File: mnist-convnet.py
import tensorflow as tf import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils import summary
""" """
MNIST ConvNet example. MNIST ConvNet example.
about 0.6% validation error after 30 epochs. about 0.6% validation error after 30 epochs.
""" """
# Just import everything into current namespace
from tensorpack import *
from tensorpack.tfutils import summary
from tensorpack.dataflow import dataset
IMAGE_SIZE = 28 IMAGE_SIZE = 28
......
...@@ -3,6 +3,11 @@ ...@@ -3,6 +3,11 @@
# File: mnist-tflayers.py # File: mnist-tflayers.py
import tensorflow as tf import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils import get_current_tower_context, summary
""" """
MNIST ConvNet example using tf.layers MNIST ConvNet example using tf.layers
Mostly the same as 'mnist-convnet.py', Mostly the same as 'mnist-convnet.py',
...@@ -11,12 +16,6 @@ the only differences are: ...@@ -11,12 +16,6 @@ the only differences are:
2. use tf.layers variable names to summarize weights 2. use tf.layers variable names to summarize weights
""" """
# Just import everything into current namespace
from tensorpack import *
from tensorpack.tfutils import summary, get_current_tower_context
from tensorpack.dataflow import dataset
IMAGE_SIZE = 28 IMAGE_SIZE = 28
# Monkey-patch tf.layers to support argscope. # Monkey-patch tf.layers to support argscope.
enable_argscope_for_module(tf.layers) enable_argscope_for_module(tf.layers)
......
...@@ -11,11 +11,12 @@ the only differences are: ...@@ -11,11 +11,12 @@ the only differences are:
""" """
from tensorpack import *
from tensorpack.dataflow import dataset
import tensorflow as tf import tensorflow as tf
import tensorflow.contrib.slim as slim import tensorflow.contrib.slim as slim
from tensorpack import *
from tensorpack.dataflow import dataset
IMAGE_SIZE = 28 IMAGE_SIZE = 28
......
...@@ -7,6 +7,7 @@ The same MNIST ConvNet example, but with weights/activations visualization. ...@@ -7,6 +7,7 @@ The same MNIST ConvNet example, but with weights/activations visualization.
""" """
import tensorflow as tf import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
......
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
import argparse import argparse
import os import os
import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
import tensorflow as tf
""" """
A very small SVHN convnet model (only 0.8m parameters). A very small SVHN convnet model (only 0.8m parameters).
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Author: Your Name <your@email.com> # Author: Your Name <your@email.com>
import os
import argparse import argparse
import os
import tensorflow as tf import tensorflow as tf
from tensorpack import * from tensorpack import *
......
...@@ -3,22 +3,21 @@ ...@@ -3,22 +3,21 @@
# File: imagenet-resnet-keras.py # File: imagenet-resnet-keras.py
# Author: Yuxin Wu # Author: Yuxin Wu
import argparse
import numpy as np import numpy as np
import os import os
import tensorflow as tf import tensorflow as tf
import argparse from tensorflow.python.keras.layers import *
from tensorpack import InputDesc, SyncMultiGPUTrainerReplicated from tensorpack import InputDesc, SyncMultiGPUTrainerReplicated
from tensorpack.callbacks import *
from tensorpack.contrib.keras import KerasModel
from tensorpack.dataflow import FakeData, MapDataComponent from tensorpack.dataflow import FakeData, MapDataComponent
from tensorpack.tfutils.common import get_tf_version_tuple
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.gpu import get_num_gpu from tensorpack.utils.gpu import get_num_gpu
from tensorpack.contrib.keras import KerasModel
from tensorpack.callbacks import *
from tensorflow.python.keras.layers import *
from tensorpack.tfutils.common import get_tf_version_tuple
from imagenet_utils import get_imagenet_dataflow, fbresnet_augmentor
from imagenet_utils import fbresnet_augmentor, get_imagenet_dataflow
TOTAL_BATCH_SIZE = 512 TOTAL_BATCH_SIZE = 512
BASE_LR = 0.1 * (TOTAL_BATCH_SIZE // 256) BASE_LR = 0.1 * (TOTAL_BATCH_SIZE // 256)
......
...@@ -5,17 +5,15 @@ ...@@ -5,17 +5,15 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow import keras from tensorflow import keras
KL = keras.layers
from tensorpack import InputDesc, QueueInput from tensorpack import InputDesc, QueueInput
from tensorpack.dataflow import dataset, BatchData, MapData
from tensorpack.utils import logger
from tensorpack.contrib.keras import KerasModel
from tensorpack.callbacks import ModelSaver from tensorpack.callbacks import ModelSaver
from tensorpack.contrib.keras import KerasModel
from tensorpack.dataflow import BatchData, MapData, dataset
from tensorpack.utils import logger
KL = keras.layers
IMAGE_SIZE = 28 IMAGE_SIZE = 28
......
...@@ -5,6 +5,12 @@ ...@@ -5,6 +5,12 @@
import tensorflow as tf import tensorflow as tf
from tensorflow import keras from tensorflow import keras
from tensorpack import *
from tensorpack.contrib.keras import KerasPhaseCallback
from tensorpack.dataflow import dataset
from tensorpack.utils.argtools import memoized
KL = keras.layers KL = keras.layers
""" """
...@@ -14,12 +20,6 @@ This way you can define models in Keras-style, and benefit from the more efficei ...@@ -14,12 +20,6 @@ This way you can define models in Keras-style, and benefit from the more efficei
Note: this example does not work for replicated-style data-parallel trainers. Note: this example does not work for replicated-style data-parallel trainers.
""" """
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.utils.argtools import memoized
from tensorpack.contrib.keras import KerasPhaseCallback
IMAGE_SIZE = 28 IMAGE_SIZE = 28
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
# File: checkpoint-manipulate.py # File: checkpoint-manipulate.py
import argparse
import numpy as np import numpy as np
from tensorpack.tfutils.varmanip import load_chkpt_vars from tensorpack.tfutils.varmanip import load_chkpt_vars
from tensorpack.utils import logger from tensorpack.utils import logger
import argparse
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -2,12 +2,13 @@ ...@@ -2,12 +2,13 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: checkpoint-prof.py # File: checkpoint-prof.py
import tensorflow as tf import argparse
import numpy as np import numpy as np
import tensorflow as tf
from tensorpack import get_default_sess_config, get_op_tensor_name from tensorpack import get_default_sess_config, get_op_tensor_name
from tensorpack.utils import logger
from tensorpack.tfutils.sessinit import get_model_loader from tensorpack.tfutils.sessinit import get_model_loader
import argparse from tensorpack.utils import logger
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: dump-model-params.py # File: dump-model-params.py
import numpy as np
import six
import argparse import argparse
import numpy as np
import os import os
import six
import tensorflow as tf import tensorflow as tf
from tensorpack.tfutils import varmanip from tensorpack.tfutils import varmanip
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: ls-checkpoint.py # File: ls-checkpoint.py
import tensorflow as tf
import numpy as np import numpy as np
import six
import sys
import pprint import pprint
import sys
import six
import tensorflow as tf
from tensorpack.tfutils.varmanip import get_checkpoint_path from tensorpack.tfutils.varmanip import get_checkpoint_path
......
import platform
from os import path
import setuptools import setuptools
from setuptools import setup from setuptools import setup
from os import path
import platform
version = int(setuptools.__version__.split('.')[0]) version = int(setuptools.__version__.split('.')[0])
assert version > 30, "Tensorpack installation requires setuptools > 30" assert version > 30, "Tensorpack installation requires setuptools > 30"
...@@ -24,7 +24,7 @@ def add_git_version(): ...@@ -24,7 +24,7 @@ def add_git_version():
from subprocess import check_output from subprocess import check_output
try: try:
return check_output("git describe --tags --long --dirty".split()).decode('utf-8').strip() return check_output("git describe --tags --long --dirty".split()).decode('utf-8').strip()
except: except Exception:
return __version__ # noqa return __version__ # noqa
newlibinfo_content = [l for l in libinfo_content if not l.startswith('__git_version__')] newlibinfo_content = [l for l in libinfo_content if not l.startswith('__git_version__')]
......
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
# File: base.py # File: base.py
import tensorflow as tf
from abc import ABCMeta from abc import ABCMeta
import six import six
import tensorflow as tf
from ..tfutils.common import get_op_or_tensor_by_name from ..tfutils.common import get_op_or_tensor_by_name
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory'] __all__ = ['Callback', 'ProxyCallback', 'CallbackFactory']
......
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
import multiprocessing as mp import multiprocessing as mp
from .base import Callback
from ..utils.concurrency import start_proc_mask_signal, StoppableThread
from ..utils import logger from ..utils import logger
from ..utils.concurrency import StoppableThread, start_proc_mask_signal
from .base import Callback
__all__ = ['StartProcOrThread'] __all__ = ['StartProcOrThread']
......
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
""" Graph related callbacks""" """ Graph related callbacks"""
import tensorflow as tf
import os
import numpy as np import numpy as np
import os
import tensorflow as tf
from six.moves import zip from six.moves import zip
from ..tfutils.common import get_op_tensor_name
from ..utils import logger from ..utils import logger
from .base import Callback from .base import Callback
from ..tfutils.common import get_op_tensor_name
__all__ = ['RunOp', 'RunUpdateOps', 'ProcessTensors', 'DumpTensors', __all__ = ['RunOp', 'RunUpdateOps', 'ProcessTensors', 'DumpTensors',
'DumpTensor', 'DumpTensorAsImage', 'DumpParamAsImage'] 'DumpTensor', 'DumpTensorAsImage', 'DumpParamAsImage']
......
...@@ -2,16 +2,16 @@ ...@@ -2,16 +2,16 @@
# File: group.py # File: group.py
import tensorflow as tf import traceback
from contextlib import contextmanager from contextlib import contextmanager
from time import time as timer from time import time as timer
import traceback
import six import six
import tensorflow as tf
from .base import Callback
from .hooks import CallbackToHook
from ..utils import logger from ..utils import logger
from ..utils.utils import humanize_time_delta from ..utils.utils import humanize_time_delta
from .base import Callback
from .hooks import CallbackToHook
if six.PY3: if six.PY3:
from time import perf_counter as timer # noqa from time import perf_counter as timer # noqa
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
""" Compatible layers between tf.train.SessionRunHook and Callback""" """ Compatible layers between tf.train.SessionRunHook and Callback"""
import tensorflow as tf import tensorflow as tf
from .base import Callback from .base import Callback
__all__ = ['CallbackToHook', 'HookToCallback'] __all__ = ['CallbackToHook', 'HookToCallback']
......
...@@ -7,10 +7,10 @@ from abc import ABCMeta ...@@ -7,10 +7,10 @@ from abc import ABCMeta
import six import six
from six.moves import zip from six.moves import zip
from .base import Callback
from ..utils import logger
from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..utils import logger
from ..utils.stats import BinaryStatistics, RatioCounter
from .base import Callback
__all__ = ['ScalarStats', 'Inferencer', __all__ = ['ScalarStats', 'Inferencer',
'ClassificationError', 'BinaryClassificationStats'] 'ClassificationError', 'BinaryClassificationStats']
......
...@@ -2,24 +2,19 @@ ...@@ -2,24 +2,19 @@
# File: inference_runner.py # File: inference_runner.py
import sys
import tensorflow as tf
from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
import itertools import itertools
import sys
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf
import tqdm import tqdm
from six.moves import range from six.moves import range
from tensorflow.python.training.monitored_session import _HookedSession as HookedSession
from ..utils import logger
from ..utils.utils import get_tqdm_kwargs
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..input_source import FeedInput, InputSource, QueueInput, StagingInput
from ..tfutils.tower import PredictTowerContext from ..tfutils.tower import PredictTowerContext
from ..utils import logger
from ..input_source import ( from ..utils.utils import get_tqdm_kwargs
InputSource, FeedInput, QueueInput, StagingInput)
from .base import Callback from .base import Callback
from .group import Callbacks from .group import Callbacks
from .inference import Inferencer from .inference import Inferencer
......
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
# File: misc.py # File: misc.py
import numpy as np
import os import os
import time import time
from collections import deque from collections import deque
import numpy as np
from .base import Callback
from ..utils.utils import humanize_time_delta
from ..utils import logger from ..utils import logger
from ..utils.utils import humanize_time_delta
from .base import Callback
__all__ = ['SendStat', 'InjectShell', 'EstimatedTimeLeft'] __all__ = ['SendStat', 'InjectShell', 'EstimatedTimeLeft']
......
...@@ -2,20 +2,20 @@ ...@@ -2,20 +2,20 @@
# File: monitor.py # File: monitor.py
import os import json
import numpy as np import numpy as np
import operator
import os
import re
import shutil import shutil
import time import time
from datetime import datetime
import operator
from collections import defaultdict from collections import defaultdict
from datetime import datetime
import six import six
import json
import re
import tensorflow as tf import tensorflow as tf
from ..tfutils.summary import create_image_summary, create_scalar_summary
from ..utils import logger from ..utils import logger
from ..tfutils.summary import create_scalar_summary, create_image_summary
from ..utils.develop import HIDE_DOC from ..utils.develop import HIDE_DOC
from .base import Callback from .base import Callback
......
...@@ -2,16 +2,16 @@ ...@@ -2,16 +2,16 @@
# File: param.py # File: param.py
import tensorflow as tf
from collections import deque
from abc import abstractmethod, ABCMeta
import operator import operator
import six
import os import os
from abc import ABCMeta, abstractmethod
from collections import deque
import six
import tensorflow as tf
from .base import Callback
from ..utils import logger
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..utils import logger
from .base import Callback
__all__ = ['HyperParam', 'GraphVarParam', 'ObjAttrParam', __all__ = ['HyperParam', 'GraphVarParam', 'ObjAttrParam',
'HyperParamSetter', 'HumanHyperParamSetter', 'HyperParamSetter', 'HumanHyperParamSetter',
......
...@@ -2,20 +2,20 @@ ...@@ -2,20 +2,20 @@
# File: prof.py # File: prof.py
import os
import numpy as np
import multiprocessing as mp import multiprocessing as mp
import numpy as np
import os
import time import time
from six.moves import map
import tensorflow as tf import tensorflow as tf
from six.moves import map
from tensorflow.python.client import timeline from tensorflow.python.client import timeline
from .base import Callback from ..tfutils.common import gpu_available_in_session
from ..utils import logger from ..utils import logger
from ..utils.concurrency import ensure_proc_terminate, start_proc_mask_signal from ..utils.concurrency import ensure_proc_terminate, start_proc_mask_signal
from ..utils.gpu import get_num_gpu from ..utils.gpu import get_num_gpu
from ..utils.nvml import NVMLContext from ..utils.nvml import NVMLContext
from ..tfutils.common import gpu_available_in_session from .base import Callback
__all__ = ['GPUUtilizationTracker', 'GraphProfiler', 'PeakMemoryTracker'] __all__ = ['GPUUtilizationTracker', 'GraphProfiler', 'PeakMemoryTracker']
......
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
# File: saver.py # File: saver.py
import tensorflow as tf
from datetime import datetime
import os import os
from datetime import datetime
import tensorflow as tf
from .base import Callback
from ..utils import logger from ..utils import logger
from .base import Callback
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver'] __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: stats.py # File: stats.py
from .graph import DumpParamAsImage # noqa
# for compatibility only # for compatibility only
from .misc import InjectShell, SendStat # noqa from .misc import InjectShell, SendStat # noqa
from .graph import DumpParamAsImage # noqa
__all__ = [] __all__ = []
...@@ -4,14 +4,13 @@ ...@@ -4,14 +4,13 @@
""" Some common step callbacks. """ """ Some common step callbacks. """
import tensorflow as tf import tensorflow as tf
from six.moves import zip
import tqdm import tqdm
from six.moves import zip
from ..tfutils.common import get_global_step_var, get_op_tensor_name
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm_kwargs
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
from ..tfutils.common import ( from ..utils.utils import get_tqdm_kwargs
get_op_tensor_name, get_global_step_var)
from .base import Callback from .base import Callback
__all__ = ['TensorPrinter', 'ProgressBar', 'SessionRunTimeout'] __all__ = ['TensorPrinter', 'ProgressBar', 'SessionRunTimeout']
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
# File: summary.py # File: summary.py
import tensorflow as tf
import numpy as np import numpy as np
from collections import deque from collections import deque
import tensorflow as tf
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..utils import logger from ..utils import logger
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
# File: trigger.py # File: trigger.py
from .base import ProxyCallback, Callback
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from .base import Callback, ProxyCallback
__all__ = ['PeriodicTrigger', 'PeriodicCallback', 'EnableCallbackIf'] __all__ = ['PeriodicTrigger', 'PeriodicCallback', 'EnableCallbackIf']
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: keras.py # File: keras.py
import tensorflow as tf from contextlib import contextmanager
import six import six
from tensorflow import keras import tensorflow as tf
import tensorflow.keras.backend as K import tensorflow.keras.backend as K
from tensorflow import keras
from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import metrics as metrics_module
from contextlib import contextmanager
from ..callbacks import Callback, CallbackToHook, InferenceRunner, InferenceRunnerBase, ScalarStats
from ..models.regularize import regularize_cost_from_collection from ..models.regularize import regularize_cost_from_collection
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
from ..train.trainers import DistributedTrainerBase
from ..train.interface import apply_default_prefetch
from ..callbacks import (
Callback, InferenceRunnerBase, InferenceRunner, CallbackToHook,
ScalarStats)
from ..tfutils.common import get_op_tensor_name
from ..tfutils.collection import backup_collection, restore_collection from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.tower import get_current_tower_context from ..tfutils.common import get_op_tensor_name
from ..tfutils.scope_utils import cached_name_scope from ..tfutils.scope_utils import cached_name_scope
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..utils.gpu import get_nr_gpu from ..tfutils.tower import get_current_tower_context
from ..train import SimpleTrainer, SyncMultiGPUTrainerParameterServer, Trainer
from ..train.interface import apply_default_prefetch
from ..train.trainers import DistributedTrainerBase
from ..utils import logger from ..utils import logger
from ..utils.gpu import get_nr_gpu
__all__ = ['KerasPhaseCallback', 'setup_keras_trainer', 'KerasModel'] __all__ = ['KerasPhaseCallback', 'setup_keras_trainer', 'KerasModel']
......
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
import threading import threading
from abc import abstractmethod, ABCMeta from abc import ABCMeta, abstractmethod
import six import six
from ..utils.utils import get_rng from ..utils.utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated'] __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated']
......
...@@ -2,20 +2,20 @@ ...@@ -2,20 +2,20 @@
# File: common.py # File: common.py
from __future__ import division from __future__ import division
import six import itertools
import numpy as np import numpy as np
from copy import copy
import pprint import pprint
import itertools from collections import defaultdict, deque
from termcolor import colored from copy import copy
from collections import deque, defaultdict import six
from six.moves import range, map
import tqdm import tqdm
from six.moves import map, range
from termcolor import colored
from .base import DataFlow, ProxyDataFlow, RNGDataFlow, DataFlowReentrantGuard
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm, get_rng, get_tqdm_kwargs
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..utils.utils import get_rng, get_tqdm, get_tqdm_kwargs
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow, RNGDataFlow
__all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData', __all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData',
'MapDataComponent', 'RepeatedData', 'RepeatedDataPoint', 'RandomChooseData', 'MapDataComponent', 'RepeatedData', 'RepeatedDataPoint', 'RandomChooseData',
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
# File: bsds500.py # File: bsds500.py
import os
import glob import glob
import numpy as np import numpy as np
import os
from ...utils.fs import download, get_dataset_path from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow from ..base import RNGDataFlow
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
# Yukun Chen <cykustc@gmail.com> # Yukun Chen <cykustc@gmail.com>
import numpy as np
import os import os
import pickle import pickle
import numpy as np
import tarfile import tarfile
import six import six
from six.moves import range from six.moves import range
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: ilsvrc.py # File: ilsvrc.py
import numpy as np
import os import os
import tarfile import tarfile
import numpy as np
import tqdm import tqdm
from ...utils import logger from ...utils import logger
from ...utils.fs import download, get_dataset_path, mkdir_p
from ...utils.loadcaffe import get_caffe_pb from ...utils.loadcaffe import get_caffe_pb
from ...utils.fs import mkdir_p, download, get_dataset_path
from ...utils.timer import timed_operation from ...utils.timer import timed_operation
from ..base import RNGDataFlow from ..base import RNGDataFlow
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
# File: mnist.py # File: mnist.py
import os
import gzip import gzip
import numpy import numpy
import os
from six.moves import range from six.moves import range
from ...utils import logger from ...utils import logger
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
# File: svhn.py # File: svhn.py
import os
import numpy as np import numpy as np
import os
from ...utils import logger from ...utils import logger
from ...utils.fs import get_dataset_path, download from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow from ..base import RNGDataFlow
__all__ = ['SVHNDigit'] __all__ = ['SVHNDigit']
......
...@@ -3,15 +3,13 @@ ...@@ -3,15 +3,13 @@
from ..utils.develop import deprecated from ..utils.develop import deprecated
from .remote import dump_dataflow_to_process_queue
from .serialize import LMDBSerializer, TFRecordSerializer from .serialize import LMDBSerializer, TFRecordSerializer
__all__ = ['dump_dataflow_to_process_queue', __all__ = ['dump_dataflow_to_process_queue',
'dump_dataflow_to_lmdb', 'dump_dataflow_to_tfrecord'] 'dump_dataflow_to_lmdb', 'dump_dataflow_to_tfrecord']
from .remote import dump_dataflow_to_process_queue
@deprecated("Use LMDBSerializer.save instead!", "2019-01-31") @deprecated("Use LMDBSerializer.save instead!", "2019-01-31")
def dump_dataflow_to_lmdb(df, lmdb_path, write_frequency=5000): def dump_dataflow_to_lmdb(df, lmdb_path, write_frequency=5000):
LMDBSerializer.save(df, lmdb_path, write_frequency) LMDBSerializer.save(df, lmdb_path, write_frequency)
......
...@@ -3,18 +3,19 @@ ...@@ -3,18 +3,19 @@
import numpy as np import numpy as np
import os
import six import six
from six.moves import range from six.moves import range
import os
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm
from ..utils.timer import timed_operation
from ..utils.loadcaffe import get_caffe_pb
from ..utils.compatible_serialize import loads
from ..utils.argtools import log_once from ..utils.argtools import log_once
from ..utils.compatible_serialize import loads
from ..utils.develop import create_dummy_class # noqa
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from .base import RNGDataFlow, DataFlow, DataFlowReentrantGuard from ..utils.loadcaffe import get_caffe_pb
from ..utils.timer import timed_operation
from ..utils.utils import get_tqdm
from .base import DataFlow, DataFlowReentrantGuard, RNGDataFlow
from .common import MapData from .common import MapData
__all__ = ['HDF5Data', 'LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint', __all__ = ['HDF5Data', 'LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint',
...@@ -258,7 +259,7 @@ class TFRecordData(DataFlow): ...@@ -258,7 +259,7 @@ class TFRecordData(DataFlow):
for dp in gen: for dp in gen:
yield loads(dp) yield loads(dp)
from ..utils.develop import create_dummy_class # noqa
try: try:
import h5py import h5py
except ImportError: except ImportError:
......
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
# File: image.py # File: image.py
import numpy as np
import copy as copy_mod import copy as copy_mod
import numpy as np
from contextlib import contextmanager from contextlib import contextmanager
from .base import RNGDataFlow
from .common import MapDataComponent, MapData
from ..utils import logger from ..utils import logger
from ..utils.argtools import shape2d from ..utils.argtools import shape2d
from .base import RNGDataFlow
from .common import MapData, MapDataComponent
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', 'AugmentImageComponents'] __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', 'AugmentImageComponents']
......
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
import sys import sys
import cv2 import cv2
from . import AugmentorList from . import AugmentorList
from .crop import * from .crop import *
from .imgproc import *
from .noname import *
from .deform import * from .deform import *
from .imgproc import *
from .noise import SaltPepperNoise from .noise import SaltPepperNoise
from .noname import *
anchors = [(0.2, 0.2), (0.7, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)] anchors = [(0.2, 0.2), (0.7, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)]
augmentors = AugmentorList([ augmentors = AugmentorList([
......
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
import inspect import inspect
import pprint import pprint
from abc import abstractmethod, ABCMeta from abc import ABCMeta, abstractmethod
import six import six
from six.moves import zip from six.moves import zip
from ...utils.utils import get_rng
from ...utils.argtools import log_once from ...utils.argtools import log_once
from ...utils.utils import get_rng
from ..image import check_dtype from ..image import check_dtype
__all__ = ['Augmentor', 'ImageAugmentor', 'AugmentorList'] __all__ = ['Augmentor', 'ImageAugmentor', 'AugmentorList']
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: convert.py # File: convert.py
from .base import ImageAugmentor
from .meta import MapImage
import numpy as np import numpy as np
import cv2 import cv2
from .base import ImageAugmentor
from .meta import MapImage
__all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32'] __all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32']
......
...@@ -3,8 +3,7 @@ ...@@ -3,8 +3,7 @@
from ...utils.argtools import shape2d from ...utils.argtools import shape2d
from .transform import TransformAugmentorBase, CropTransform from .transform import CropTransform, TransformAugmentorBase
__all__ = ['RandomCrop', 'CenterCrop', 'RandomCropRandomShape'] __all__ = ['RandomCrop', 'CenterCrop', 'RandomCropRandomShape']
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
# File: deform.py # File: deform.py
from .base import ImageAugmentor
from ...utils import logger
import numpy as np import numpy as np
from ...utils import logger
from .base import ImageAugmentor
__all__ = [] __all__ = []
# Code was temporarily kept here for a future reference in case someone needs it # Code was temporarily kept here for a future reference in case someone needs it
......
...@@ -4,7 +4,6 @@ import numpy as np ...@@ -4,7 +4,6 @@ import numpy as np
from .base import ImageAugmentor from .base import ImageAugmentor
__all__ = ['IAAugmentor', 'Albumentations'] __all__ = ['IAAugmentor', 'Albumentations']
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
import math import math
import cv2
import numpy as np import numpy as np
import cv2
from .base import ImageAugmentor from .base import ImageAugmentor
from .transform import TransformAugmentorBase, WarpAffineTransform from .transform import TransformAugmentorBase, WarpAffineTransform
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
# File: imgproc.py # File: imgproc.py
from .base import ImageAugmentor
import numpy as np import numpy as np
import cv2 import cv2
from .base import ImageAugmentor
__all__ = ['Hue', 'Brightness', 'BrightnessScale', 'Contrast', 'MeanVarianceNormalize', __all__ = ['Hue', 'Brightness', 'BrightnessScale', 'Contrast', 'MeanVarianceNormalize',
'GaussianBlur', 'Gamma', 'Clip', 'Saturation', 'Lighting', 'MinMaxNormalize'] 'GaussianBlur', 'Gamma', 'Clip', 'Saturation', 'Lighting', 'MinMaxNormalize']
......
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
import numpy as np import numpy as np
import cv2 import cv2
from .base import ImageAugmentor
from ...utils import logger from ...utils import logger
from ...utils.argtools import shape2d from ...utils.argtools import shape2d
from .base import ImageAugmentor
from .transform import ResizeTransform, TransformAugmentorBase from .transform import ResizeTransform, TransformAugmentorBase
__all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge', 'Transpose'] __all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge', 'Transpose']
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
# File: noise.py # File: noise.py
from .base import ImageAugmentor
import numpy as np import numpy as np
import cv2 import cv2
from .base import ImageAugmentor
__all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise'] __all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise']
......
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
# File: paste.py # File: paste.py
from .base import ImageAugmentor
from abc import abstractmethod
import numpy as np import numpy as np
from abc import abstractmethod
from .base import ImageAugmentor
__all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller', __all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller',
'RandomPaste'] 'RandomPaste']
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: transform.py # File: transform.py
from abc import abstractmethod, ABCMeta
import six
import cv2
import numpy as np import numpy as np
from abc import ABCMeta, abstractmethod
import cv2
import six
from .base import ImageAugmentor from .base import ImageAugmentor
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: parallel.py # File: parallel.py
import atexit
import errno
import itertools
import multiprocessing as mp
import os
import sys import sys
import uuid
import weakref import weakref
from contextlib import contextmanager from contextlib import contextmanager
import multiprocessing as mp
import itertools
from six.moves import range, zip, queue
import errno
import uuid
import os
import zmq import zmq
import atexit from six.moves import queue, range, zip
from .base import DataFlow, ProxyDataFlow, DataFlowTerminated, DataFlowReentrantGuard
from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal,
enable_death_signal,
StoppableThread)
from ..utils.serialize import loads, dumps
from ..utils import logger from ..utils import logger
from ..utils.gpu import change_gpu from ..utils.concurrency import (
StoppableThread, enable_death_signal, ensure_proc_terminate, mask_sigint, start_proc_mask_signal)
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..utils.gpu import change_gpu
from ..utils.serialize import dumps, loads
from .base import DataFlow, DataFlowReentrantGuard, DataFlowTerminated, ProxyDataFlow
__all__ = ['PrefetchData', 'MultiProcessPrefetchData', __all__ = ['PrefetchData', 'MultiProcessPrefetchData',
'PrefetchDataZMQ', 'PrefetchOnGPUs', 'MultiThreadPrefetchData'] 'PrefetchDataZMQ', 'PrefetchOnGPUs', 'MultiThreadPrefetchData']
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: parallel_map.py # File: parallel_map.py
import numpy as np
import ctypes
import copy import copy
import threading import ctypes
import multiprocessing as mp import multiprocessing as mp
from six.moves import queue import numpy as np
import threading
import zmq import zmq
from six.moves import queue
from .base import DataFlow, ProxyDataFlow, DataFlowReentrantGuard
from .common import RepeatedData
from ..utils.concurrency import StoppableThread, enable_death_signal from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils.serialize import loads, dumps from ..utils.serialize import dumps, loads
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow
from .parallel import ( from .common import RepeatedData
_MultiProcessZMQDataFlow, _repeat_iter, _get_pipe_name, from .parallel import _bind_guard, _get_pipe_name, _MultiProcessZMQDataFlow, _repeat_iter, _zmq_catch_error
_bind_guard, _zmq_catch_error)
__all__ = ['ThreadedMapData', 'MultiThreadMapData', __all__ = ['ThreadedMapData', 'MultiThreadMapData',
'MultiProcessMapData', 'MultiProcessMapDataZMQ'] 'MultiProcessMapData', 'MultiProcessMapDataZMQ']
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
# File: raw.py # File: raw.py
import numpy as np
import copy import copy
import numpy as np
import six import six
from six.moves import range from six.moves import range
from .base import DataFlow, RNGDataFlow from .base import DataFlow, RNGDataFlow
__all__ = ['FakeData', 'DataFromQueue', 'DataFromList', 'DataFromGenerator', 'DataFromIterable'] __all__ = ['FakeData', 'DataFromQueue', 'DataFromList', 'DataFromGenerator', 'DataFromIterable']
......
...@@ -2,17 +2,17 @@ ...@@ -2,17 +2,17 @@
# File: remote.py # File: remote.py
import multiprocessing as mp
import time import time
from collections import deque
import tqdm import tqdm
import multiprocessing as mp
from six.moves import range from six.moves import range
from collections import deque
from .base import DataFlow, DataFlowReentrantGuard
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm_kwargs
from ..utils.concurrency import DIE from ..utils.concurrency import DIE
from ..utils.serialize import dumps, loads from ..utils.serialize import dumps, loads
from ..utils.utils import get_tqdm_kwargs
from .base import DataFlow, DataFlowReentrantGuard
try: try:
import zmq import zmq
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: serialize.py # File: serialize.py
import os
import numpy as np import numpy as np
import os
from collections import defaultdict from collections import defaultdict
from ..utils.utils import get_tqdm
from ..utils import logger from ..utils import logger
from ..utils.compatible_serialize import dumps, loads from ..utils.compatible_serialize import dumps, loads
from ..utils.develop import create_dummy_class # noqa
from ..utils.utils import get_tqdm
from .base import DataFlow from .base import DataFlow
from .format import LMDBData, HDF5Data from .common import FixedSizeData, MapData
from .common import MapData, FixedSizeData from .format import HDF5Data, LMDBData
from .raw import DataFromList, DataFromGenerator from .raw import DataFromGenerator, DataFromList
__all__ = ['LMDBSerializer', 'NumpySerializer', 'TFRecordSerializer', 'HDF5Serializer'] __all__ = ['LMDBSerializer', 'NumpySerializer', 'TFRecordSerializer', 'HDF5Serializer']
...@@ -195,7 +195,6 @@ class HDF5Serializer(): ...@@ -195,7 +195,6 @@ class HDF5Serializer():
return HDF5Data(path, data_paths, shuffle) return HDF5Data(path, data_paths, shuffle)
from ..utils.develop import create_dummy_class # noqa
try: try:
import lmdb import lmdb
except ImportError: except ImportError:
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: distributed.py # File: distributed.py
import tensorflow as tf
import re import re
import tensorflow as tf
from six.moves import range from six.moves import range
from ..tfutils.common import get_global_step_var, get_op_tensor_name
from ..utils import logger from ..utils import logger
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..tfutils.common import get_op_tensor_name, get_global_step_var from .training import DataParallelBuilder, GraphBuilder
from .utils import OverrideCachingDevice, aggregate_grads, override_to_local_variable
from .training import GraphBuilder, DataParallelBuilder
from .utils import (
override_to_local_variable, aggregate_grads,
OverrideCachingDevice)
__all__ = ['DistributedParameterServerBuilder', 'DistributedReplicatedBuilder'] __all__ = ['DistributedParameterServerBuilder', 'DistributedReplicatedBuilder']
......
...@@ -5,11 +5,11 @@ ...@@ -5,11 +5,11 @@
from collections import namedtuple from collections import namedtuple
import tensorflow as tf import tensorflow as tf
from ..models.regularize import regularize_cost_from_collection
from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.argtools import memoized_method from ..utils.argtools import memoized_method
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..tfutils.tower import get_current_tower_context
from ..models.regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase'] __all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
import tensorflow as tf import tensorflow as tf
from ..tfutils.tower import PredictTowerContext
from ..utils import logger from ..utils import logger
from ..utils.develop import deprecated from ..utils.develop import deprecated
from ..tfutils.tower import PredictTowerContext
from .training import GraphBuilder from .training import GraphBuilder
__all__ = ['SimplePredictBuilder'] __all__ = ['SimplePredictBuilder']
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: training.py # File: training.py
from abc import ABCMeta, abstractmethod
import tensorflow as tf
import copy import copy
import six
import re
import pprint import pprint
from six.moves import zip, range import re
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
import six
import tensorflow as tf
from six.moves import range, zip
from ..utils import logger
from ..tfutils.tower import TrainTowerContext
from ..tfutils.gradproc import ScaleGradient
from ..tfutils.common import get_tf_version_tuple from ..tfutils.common import get_tf_version_tuple
from ..tfutils.gradproc import ScaleGradient
from ..tfutils.tower import TrainTowerContext
from ..utils import logger
from .utils import ( from .utils import (
LeastLoadedDeviceSetter, override_to_local_variable, GradientPacker, LeastLoadedDeviceSetter, aggregate_grads, allreduce_grads, allreduce_grads_hierarchical,
allreduce_grads, aggregate_grads, allreduce_grads_hierarchical, merge_grad_list, override_to_local_variable, split_grad_list)
split_grad_list, merge_grad_list, GradientPacker)
__all__ = ['GraphBuilder', __all__ = ['GraphBuilder',
'SyncMultiGPUParameterServerBuilder', 'DataParallelBuilder', 'SyncMultiGPUParameterServerBuilder', 'DataParallelBuilder',
......
...@@ -2,16 +2,15 @@ ...@@ -2,16 +2,15 @@
# File: utils.py # File: utils.py
from contextlib import contextmanager
import operator import operator
from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
from ..tfutils.varreplace import custom_getter_scope
from ..tfutils.scope_utils import under_name_scope, cached_name_scope
from ..tfutils.common import get_tf_version_tuple from ..tfutils.common import get_tf_version_tuple
from ..utils.argtools import call_only_once from ..tfutils.scope_utils import cached_name_scope, under_name_scope
from ..tfutils.varreplace import custom_getter_scope
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once
__all__ = ['LeastLoadedDeviceSetter', __all__ = ['LeastLoadedDeviceSetter',
'OverrideCachingDevice', 'OverrideCachingDevice',
......
...@@ -2,27 +2,28 @@ ...@@ -2,27 +2,28 @@
# File: input_source.py # File: input_source.py
import tensorflow as tf import threading
try:
from tensorflow.python.ops.data_flow_ops import StagingArea
except ImportError:
pass
from contextlib import contextmanager from contextlib import contextmanager
from itertools import chain from itertools import chain
import tensorflow as tf
from six.moves import range, zip from six.moves import range, zip
import threading
from .input_source_base import InputSource from ..callbacks.base import Callback, CallbackFactory
from ..callbacks.graph import RunOp
from ..dataflow import DataFlow, MapData, RepeatedData from ..dataflow import DataFlow, MapData, RepeatedData
from ..tfutils.summary import add_moving_summary
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
from ..tfutils.dependency import dependency_of_fetches from ..tfutils.dependency import dependency_of_fetches
from ..tfutils.summary import add_moving_summary
from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.concurrency import ShareSessionThread from ..utils.concurrency import ShareSessionThread
from ..callbacks.base import Callback, CallbackFactory from .input_source_base import InputSource
from ..callbacks.graph import RunOp
try:
from tensorflow.python.ops.data_flow_ops import StagingArea
except ImportError:
pass
__all__ = ['PlaceholderInput', 'FeedInput', 'FeedfreeInput', __all__ = ['PlaceholderInput', 'FeedInput', 'FeedfreeInput',
'QueueInput', 'BatchQueueInput', 'QueueInput', 'BatchQueueInput',
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: input_source_base.py # File: input_source_base.py
from abc import ABCMeta, abstractmethod
import copy import copy
import six from abc import ABCMeta, abstractmethod
from six.moves import zip
from contextlib import contextmanager from contextlib import contextmanager
import six
import tensorflow as tf import tensorflow as tf
from six.moves import zip
from ..utils.argtools import memoized_method, call_only_once
from ..callbacks.base import CallbackFactory from ..callbacks.base import CallbackFactory
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once, memoized_method
__all__ = ['InputSource', 'remap_input_source'] __all__ = ['InputSource', 'remap_input_source']
......
...@@ -4,13 +4,15 @@ ...@@ -4,13 +4,15 @@
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages from tensorflow.python.training import moving_averages
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.argtools import get_data_format from ..utils.argtools import get_data_format
from ..tfutils.tower import get_current_tower_context from .common import VariableHolder, layer_register
from ..tfutils.common import get_tf_version_tuple
from .common import layer_register, VariableHolder
from .tflayer import convert_to_tflayer_args from .tflayer import convert_to_tflayer_args
""" """
Old Custom BN Implementation, Kept Here For Future Reference Old Custom BN Implementation, Kept Here For Future Reference
""" """
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
import logging import logging
import tensorflow as tf
import unittest import unittest
import tensorflow as tf
class TestModel(unittest.TestCase): class TestModel(unittest.TestCase):
......
...@@ -2,17 +2,17 @@ ...@@ -2,17 +2,17 @@
# File: batch_norm.py # File: batch_norm.py
import tensorflow as tf
from tensorflow.python.training import moving_averages
import re import re
import six import six
import tensorflow as tf
from tensorflow.python.training import moving_averages
from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from ..utils.argtools import get_data_format from ..utils.argtools import get_data_format
from ..tfutils.tower import get_current_tower_context from .common import VariableHolder, layer_register
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.collection import backup_collection, restore_collection
from .common import layer_register, VariableHolder
from .tflayer import convert_to_tflayer_args, rename_get_variable from .tflayer import convert_to_tflayer_args, rename_get_variable
__all__ = ['BatchNorm', 'BatchRenorm'] __all__ = ['BatchNorm', 'BatchRenorm']
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# File: common.py # File: common.py
from .registry import layer_register # noqa from .registry import layer_register # noqa
from .utils import VariableHolder # noqa
from .tflayer import rename_tflayer_get_variable from .tflayer import rename_tflayer_get_variable
from .utils import VariableHolder # noqa
__all__ = ['layer_register', 'VariableHolder', 'rename_tflayer_get_variable'] __all__ = ['layer_register', 'VariableHolder', 'rename_tflayer_get_variable']
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
import tensorflow as tf import tensorflow as tf
from .common import layer_register, VariableHolder
from ..tfutils.common import get_tf_version_tuple from ..tfutils.common import get_tf_version_tuple
from ..utils.argtools import shape2d, shape4d, get_data_format from ..utils.argtools import get_data_format, shape2d, shape4d
from .tflayer import rename_get_variable, convert_to_tflayer_args from .common import VariableHolder, layer_register
from .tflayer import convert_to_tflayer_args, rename_get_variable
__all__ = ['Conv2D', 'Deconv2D', 'Conv2DTranspose'] __all__ = ['Conv2D', 'Deconv2D', 'Conv2DTranspose']
...@@ -50,7 +51,7 @@ def Conv2D( ...@@ -50,7 +51,7 @@ def Conv2D(
""" """
if kernel_initializer is None: if kernel_initializer is None:
if get_tf_version_tuple() <= (1, 12): if get_tf_version_tuple() <= (1, 12):
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0), kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0)
else: else:
kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal') kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
if split == 1: if split == 1:
...@@ -158,7 +159,7 @@ def Conv2DTranspose( ...@@ -158,7 +159,7 @@ def Conv2DTranspose(
""" """
if kernel_initializer is None: if kernel_initializer is None:
if get_tf_version_tuple() <= (1, 12): if get_tf_version_tuple() <= (1, 12):
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0), kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0)
else: else:
kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal') kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
# File: fc.py # File: fc.py
import tensorflow as tf
import numpy as np import numpy as np
import tensorflow as tf
from ..tfutils.common import get_tf_version_tuple from ..tfutils.common import get_tf_version_tuple
from .common import layer_register, VariableHolder from .common import VariableHolder, layer_register
from .tflayer import convert_to_tflayer_args, rename_get_variable from .tflayer import convert_to_tflayer_args, rename_get_variable
__all__ = ['FullyConnected'] __all__ = ['FullyConnected']
...@@ -48,7 +48,7 @@ def FullyConnected( ...@@ -48,7 +48,7 @@ def FullyConnected(
""" """
if kernel_initializer is None: if kernel_initializer is None:
if get_tf_version_tuple() <= (1, 12): if get_tf_version_tuple() <= (1, 12):
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0), kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0)
else: else:
kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal') kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
......
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
import tensorflow as tf import tensorflow as tf
from .common import layer_register, VariableHolder
from ..utils.argtools import get_data_format from ..utils.argtools import get_data_format
from .common import VariableHolder, layer_register
__all__ = ['LayerNorm', 'InstanceNorm'] __all__ = ['LayerNorm', 'InstanceNorm']
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
# File: linearwrap.py # File: linearwrap.py
import six
from types import ModuleType from types import ModuleType
import six
from .registry import get_registered_layer from .registry import get_registered_layer
__all__ = ['LinearWrap'] __all__ = ['LinearWrap']
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import tensorflow as tf import tensorflow as tf
from .common import layer_register, VariableHolder
from .batch_norm import BatchNorm from .batch_norm import BatchNorm
from .common import VariableHolder, layer_register
__all__ = ['Maxout', 'PReLU', 'BNReLU'] __all__ = ['Maxout', 'PReLU', 'BNReLU']
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: pool.py # File: pool.py
import tensorflow as tf
import numpy as np import numpy as np
import tensorflow as tf
from .shape_utils import StaticDynamicShape from ..utils.argtools import get_data_format, shape2d
from .common import layer_register
from ..utils.argtools import shape2d, get_data_format
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ._test import TestModel from ._test import TestModel
from .common import layer_register
from .shape_utils import StaticDynamicShape
from .tflayer import convert_to_tflayer_args from .tflayer import convert_to_tflayer_args
__all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling', __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
'BilinearUpSample'] 'BilinearUpSample']
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
# File: registry.py # File: registry.py
import tensorflow as tf import copy
import re
from functools import wraps from functools import wraps
import six import six
import re import tensorflow as tf
import copy
from ..tfutils.argscope import get_arg_scope from ..tfutils.argscope import get_arg_scope
from ..tfutils.model_utils import get_shape_str from ..tfutils.model_utils import get_shape_str
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# File: regularize.py # File: regularize.py
import tensorflow as tf
import re import re
import tensorflow as tf
from ..utils import logger
from ..utils.argtools import graph_memoized
from ..tfutils.common import get_tf_version_tuple from ..tfutils.common import get_tf_version_tuple
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.argtools import graph_memoized
from .common import layer_register from .common import layer_register
__all__ = ['regularize_cost', 'regularize_cost_from_collection', __all__ = ['regularize_cost', 'regularize_cost_from_collection',
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import tensorflow as tf import tensorflow as tf
from .common import layer_register from .common import layer_register
__all__ = ['ConcatWith'] __all__ = ['ConcatWith']
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: tflayer.py # File: tflayer.py
import tensorflow as tf
import six
import functools import functools
import six
import tensorflow as tf
from ..utils.argtools import get_data_format
from ..tfutils.common import get_tf_version_tuple from ..tfutils.common import get_tf_version_tuple
from ..tfutils.varreplace import custom_getter_scope from ..tfutils.varreplace import custom_getter_scope
from ..utils.argtools import get_data_format
__all__ = [] __all__ = []
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# File: base.py # File: base.py
from abc import abstractmethod, ABCMeta from abc import ABCMeta, abstractmethod
import tensorflow as tf
import six import six
import tensorflow as tf
from ..input_source import PlaceholderInput
from ..tfutils.common import get_tensors_by_names from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import PredictTowerContext from ..tfutils.tower import PredictTowerContext
from ..input_source import PlaceholderInput
__all__ = ['PredictorBase', 'AsyncPredictorBase', __all__ = ['PredictorBase', 'AsyncPredictorBase',
'OnlinePredictor', 'OfflinePredictor', 'OnlinePredictor', 'OfflinePredictor',
......
...@@ -2,16 +2,16 @@ ...@@ -2,16 +2,16 @@
# File: concurrency.py # File: concurrency.py
import numpy as np
import multiprocessing import multiprocessing
import numpy as np
import six import six
from six.moves import queue, range
import tensorflow as tf import tensorflow as tf
from six.moves import queue, range
from ..utils import logger
from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread
from ..tfutils.model_utils import describe_trainable_vars from ..tfutils.model_utils import describe_trainable_vars
from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase from ..utils import logger
from ..utils.concurrency import DIE, ShareSessionThread, StoppableThread
from .base import AsyncPredictorBase, OfflinePredictor, OnlinePredictor
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker', __all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
'MultiThreadAsyncPredictor'] 'MultiThreadAsyncPredictor']
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# File: config.py # File: config.py
import tensorflow as tf
import six import six
import tensorflow as tf
from ..graph_builder import ModelDescBase from ..graph_builder import ModelDescBase
from ..tfutils import get_default_sess_config from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import JustCurrentSession, SessionInit
from ..tfutils.tower import TowerFuncWrapper from ..tfutils.tower import TowerFuncWrapper
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..utils import logger from ..utils import logger
__all__ = ['PredictConfig'] __all__ = ['PredictConfig']
......
...@@ -2,22 +2,21 @@ ...@@ -2,22 +2,21 @@
# File: dataset.py # File: dataset.py
from six.moves import range, zip
from abc import ABCMeta, abstractmethod
import multiprocessing import multiprocessing
import os import os
from abc import ABCMeta, abstractmethod
import six import six
from six.moves import range, zip
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..dataflow.remote import dump_dataflow_to_process_queue from ..dataflow.remote import dump_dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm from ..utils.concurrency import DIE, OrderedResultGatherProc, ensure_proc_terminate
from ..utils.gpu import change_gpu, get_num_gpu from ..utils.gpu import change_gpu, get_num_gpu
from ..utils.utils import get_tqdm
from .base import OfflinePredictor
from .concurrency import MultiProcessQueuePredictWorker from .concurrency import MultiProcessQueuePredictWorker
from .config import PredictConfig from .config import PredictConfig
from .base import OfflinePredictor
__all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor', __all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
'MultiProcessDatasetPredictor'] 'MultiProcessDatasetPredictor']
......
#!/usr/bin/env python #!/usr/bin/env python
from tensorflow.python.training.monitored_session \ from tensorflow.python.training.monitored_session import _HookedSession as HookedSession
import _HookedSession as HookedSession
from .base import PredictorBase
from ..tfutils.tower import PredictTowerContext
from ..callbacks import Callbacks from ..callbacks import Callbacks
from ..tfutils.tower import PredictTowerContext
from .base import PredictorBase
__all__ = ['FeedfreePredictor'] __all__ = ['FeedfreePredictor']
......
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
import tensorflow as tf import tensorflow as tf
from ..utils import logger
from ..graph_builder.model_desc import InputDesc from ..graph_builder.model_desc import InputDesc
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..tfutils.tower import PredictTowerContext from ..tfutils.tower import PredictTowerContext
from ..utils import logger
from .base import OnlinePredictor from .base import OnlinePredictor
__all__ = ['MultiTowerOfflinePredictor', __all__ = ['MultiTowerOfflinePredictor',
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: argscope.py # File: argscope.py
from contextlib import contextmanager
from collections import defaultdict
import copy import copy
from collections import defaultdict
from contextlib import contextmanager
from functools import wraps from functools import wraps
from inspect import isfunction, getmembers from inspect import getmembers, isfunction
from .tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from .tower import get_current_tower_context
__all__ = ['argscope', 'get_arg_scope', 'enable_argscope_for_module'] __all__ = ['argscope', 'get_arg_scope', 'enable_argscope_for_module']
......
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
# File: collection.py # File: collection.py
import tensorflow as tf from contextlib import contextmanager
from copy import copy from copy import copy
import six import six
from contextlib import contextmanager import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..utils.argtools import memoized from ..utils.argtools import memoized
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import tensorflow as tf import tensorflow as tf
from six.moves import map from six.moves import map
from ..utils.argtools import graph_memoized from ..utils.argtools import graph_memoized
from ..utils.develop import deprecated from ..utils.develop import deprecated
......
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.graph_editor import get_backward_walk_ops from tensorflow.contrib.graph_editor import get_backward_walk_ops
from ..utils.argtools import graph_memoized from ..utils.argtools import graph_memoized
""" """
......
...@@ -12,10 +12,10 @@ from tensorflow.python.framework import graph_util ...@@ -12,10 +12,10 @@ from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from tensorflow.python.tools import optimize_for_inference_lib from tensorflow.python.tools import optimize_for_inference_lib
from ..utils import logger from ..input_source import PlaceholderInput
from ..tfutils.common import get_tensors_by_names from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import PredictTowerContext from ..tfutils.tower import PredictTowerContext
from ..input_source import PlaceholderInput from ..utils import logger
__all__ = ['ModelExporter'] __all__ = ['ModelExporter']
......
...@@ -2,14 +2,15 @@ ...@@ -2,14 +2,15 @@
# File: gradproc.py # File: gradproc.py
import tensorflow as tf import inspect
from abc import ABCMeta, abstractmethod
import re import re
from abc import ABCMeta, abstractmethod
import six import six
import inspect import tensorflow as tf
from ..utils import logger from ..utils import logger
from .symbolic_functions import rms, print_stat
from .summary import add_moving_summary from .summary import add_moving_summary
from .symbolic_functions import print_stat, rms
__all__ = ['GradientProcessor', __all__ = ['GradientProcessor',
'FilterNoneGrad', 'GlobalNormClip', 'MapGradient', 'SummaryGradient', 'FilterNoneGrad', 'GlobalNormClip', 'MapGradient', 'SummaryGradient',
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
# Author: tensorpack contributors # Author: tensorpack contributors
import tensorflow as tf import tensorflow as tf
from termcolor import colored
from tabulate import tabulate from tabulate import tabulate
from termcolor import colored
from ..utils import logger from ..utils import logger
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
# File: optimizer.py # File: optimizer.py
import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf
from ..utils.develop import HIDE_DOC
from ..tfutils.common import get_tf_version_tuple from ..tfutils.common import get_tf_version_tuple
from ..utils.develop import HIDE_DOC
from .gradproc import FilterNoneGrad, GradientProcessor from .gradproc import FilterNoneGrad, GradientProcessor
__all__ = ['apply_grad_processors', 'ProxyOptimizer', __all__ = ['apply_grad_processors', 'ProxyOptimizer',
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
# File: scope_utils.py # File: scope_utils.py
import tensorflow as tf
import functools import functools
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf
from ..utils.argtools import graph_memoized from ..utils.argtools import graph_memoized
from .common import get_tf_version_tuple from .common import get_tf_version_tuple
......
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
import tensorflow as tf import tensorflow as tf
from .common import get_default_sess_config
from ..utils import logger from ..utils import logger
from .common import get_default_sess_config
__all__ = ['NewSessionCreator', 'ReuseSessionCreator', 'SessionCreatorAdapter'] __all__ = ['NewSessionCreator', 'ReuseSessionCreator', 'SessionCreatorAdapter']
......
...@@ -3,13 +3,12 @@ ...@@ -3,13 +3,12 @@
import numpy as np import numpy as np
import tensorflow as tf
import six import six
import tensorflow as tf
from ..utils import logger from ..utils import logger
from .common import get_op_tensor_name from .common import get_op_tensor_name
from .varmanip import (SessionUpdate, get_savename_from_varname, from .varmanip import SessionUpdate, get_checkpoint_path, get_savename_from_varname, is_training_name
is_training_name, get_checkpoint_path)
__all__ = ['SessionInit', 'ChainInit', __all__ = ['SessionInit', 'ChainInit',
'SaverRestore', 'SaverRestoreRelaxed', 'DictRestore', 'SaverRestore', 'SaverRestoreRelaxed', 'DictRestore',
......
...@@ -2,20 +2,19 @@ ...@@ -2,20 +2,19 @@
# File: summary.py # File: summary.py
import re
from contextlib import contextmanager
import six import six
import tensorflow as tf import tensorflow as tf
import re
from six.moves import range from six.moves import range
from contextlib import contextmanager
from tensorflow.python.training import moving_averages from tensorflow.python.training import moving_averages
from ..utils import logger from ..utils import logger
from ..utils.argtools import graph_memoized from ..utils.argtools import graph_memoized
from ..utils.naming import MOVING_SUMMARY_OPS_KEY from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .tower import get_current_tower_context
from .symbolic_functions import rms
from .scope_utils import cached_name_scope from .scope_utils import cached_name_scope
from .symbolic_functions import rms
from .tower import get_current_tower_context
__all__ = ['add_tensor_summary', 'add_param_summary', __all__ = ['add_tensor_summary', 'add_param_summary',
'add_activation_summary', 'add_moving_summary', 'add_activation_summary', 'add_moving_summary',
......
...@@ -2,15 +2,15 @@ ...@@ -2,15 +2,15 @@
# File: tower.py # File: tower.py
import tensorflow as tf from abc import ABCMeta, abstractmethod, abstractproperty
import six import six
import tensorflow as tf
from six.moves import zip from six.moves import zip
from abc import abstractproperty, abstractmethod, ABCMeta
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once from ..utils.argtools import call_only_once
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from ..utils.develop import HIDE_DOC from ..utils.develop import HIDE_DOC
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .collection import CollectionGuard from .collection import CollectionGuard
from .common import get_op_or_tensor_by_name, get_op_tensor_name from .common import get_op_or_tensor_by_name, get_op_tensor_name
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: varmanip.py # File: varmanip.py
import six import numpy as np
import os import os
import pprint import pprint
import six
import tensorflow as tf import tensorflow as tf
import numpy as np
from ..utils import logger from ..utils import logger
from .common import get_op_tensor_name from .common import get_op_tensor_name
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
# File: varreplace.py # File: varreplace.py
# Credit: Qinyao He # Credit: Qinyao He
import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf
from .common import get_tf_version_tuple from .common import get_tf_version_tuple
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: base.py # File: base.py
import tensorflow as tf import copy
import weakref
import time import time
from six.moves import range import weakref
import six import six
import copy import tensorflow as tf
from six.moves import range
from ..callbacks import ( from ..callbacks import Callback, Callbacks, Monitors, TrainingMonitor
Callback, Callbacks, Monitors, TrainingMonitor) from ..callbacks.steps import MaintainStepCounter
from ..utils import logger
from ..utils.utils import humanize_time_delta
from ..utils.argtools import call_only_once
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_trainable_vars from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sessinit import SessionInit, JustCurrentSession from ..tfutils.sesscreate import NewSessionCreator, ReuseSessionCreator
from ..tfutils.sesscreate import ReuseSessionCreator, NewSessionCreator from ..tfutils.sessinit import JustCurrentSession, SessionInit
from ..callbacks.steps import MaintainStepCounter from ..utils import logger
from ..utils.argtools import call_only_once
from .config import TrainConfig, DEFAULT_MONITORS, DEFAULT_CALLBACKS from ..utils.utils import humanize_time_delta
from .config import DEFAULT_CALLBACKS, DEFAULT_MONITORS, TrainConfig
__all__ = ['StopTraining', 'Trainer'] __all__ = ['StopTraining', 'Trainer']
......
...@@ -5,15 +5,13 @@ import os ...@@ -5,15 +5,13 @@ import os
import tensorflow as tf import tensorflow as tf
from ..callbacks import ( from ..callbacks import (
MovingAverageSummary, JSONWriter, MergeAllSummaries, MovingAverageSummary, ProgressBar, RunUpdateOps, ScalarPrinter, TFEventWriter)
ProgressBar, MergeAllSummaries,
TFEventWriter, JSONWriter, ScalarPrinter, RunUpdateOps)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..graph_builder.model_desc import ModelDescBase from ..graph_builder.model_desc import ModelDescBase
from ..utils import logger
from ..tfutils.sessinit import SessionInit, SaverRestore
from ..tfutils.sesscreate import NewSessionCreator
from ..input_source import InputSource from ..input_source import InputSource
from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.sessinit import SaverRestore, SessionInit
from ..utils import logger
__all__ = ['TrainConfig', 'AutoResumeTrainConfig', 'DEFAULT_CALLBACKS', 'DEFAULT_MONITORS'] __all__ = ['TrainConfig', 'AutoResumeTrainConfig', 'DEFAULT_CALLBACKS', 'DEFAULT_MONITORS']
......
...@@ -3,11 +3,8 @@ ...@@ -3,11 +3,8 @@
import tensorflow as tf import tensorflow as tf
from ..input_source import ( from ..input_source import DummyConstantInput, FeedfreeInput, FeedInput, InputSource, QueueInput, StagingInput
InputSource, FeedInput, FeedfreeInput,
QueueInput, StagingInput, DummyConstantInput)
from ..utils import logger from ..utils import logger
from .config import TrainConfig from .config import TrainConfig
from .tower import SingleCostTrainer from .tower import SingleCostTrainer
from .trainers import SimpleTrainer from .trainers import SimpleTrainer
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: tower.py # File: tower.py
import tensorflow as tf from abc import ABCMeta, abstractmethod
import six import six
from abc import abstractmethod, ABCMeta import tensorflow as tf
from ..utils.argtools import call_only_once, memoized
from ..utils.develop import HIDE_DOC
from ..utils import logger
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..predict.base import OnlinePredictor from ..predict.base import OnlinePredictor
from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context, PredictTowerContext
from ..tfutils.gradproc import FilterNoneGrad from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import PredictTowerContext, TowerFuncWrapper, get_current_tower_context
from ..utils import logger
from ..utils.argtools import call_only_once, memoized
from ..utils.develop import HIDE_DOC
from .base import Trainer from .base import Trainer
__all__ = ['SingleCostTrainer', 'TowerTrainer'] __all__ = ['SingleCostTrainer', 'TowerTrainer']
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: trainers.py # File: trainers.py
import sys import multiprocessing as mp
import os import os
import sys
import tensorflow as tf import tensorflow as tf
import multiprocessing as mp
from ..callbacks import RunOp, CallbackFactory from ..callbacks import CallbackFactory, RunOp
from ..graph_builder.distributed import DistributedParameterServerBuilder, DistributedReplicatedBuilder
from ..graph_builder.training import (
AsyncMultiGPUBuilder, SyncMultiGPUParameterServerBuilder, SyncMultiGPUReplicatedBuilder)
from ..graph_builder.utils import override_to_local_variable
from ..input_source import FeedfreeInput, QueueInput
from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.tower import TrainTowerContext
from ..utils import logger from ..utils import logger
from ..utils.argtools import map_arg from ..utils.argtools import map_arg
from ..utils.develop import HIDE_DOC, log_deprecated from ..utils.develop import HIDE_DOC, log_deprecated
from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..tfutils.tower import TrainTowerContext
from ..input_source import QueueInput, FeedfreeInput
from ..graph_builder.training import (
SyncMultiGPUParameterServerBuilder,
SyncMultiGPUReplicatedBuilder,
AsyncMultiGPUBuilder)
from ..graph_builder.distributed import DistributedReplicatedBuilder, DistributedParameterServerBuilder
from ..graph_builder.utils import override_to_local_variable
from .tower import SingleCostTrainer from .tower import SingleCostTrainer
__all__ = ['NoOpTrainer', 'SimpleTrainer', __all__ = ['NoOpTrainer', 'SimpleTrainer',
......
...@@ -2,6 +2,4 @@ ...@@ -2,6 +2,4 @@
# File: utility.py # File: utility.py
# for backwards-compatibility # for backwards-compatibility
from ..graph_builder.utils import ( # noqa from ..graph_builder.utils import LeastLoadedDeviceSetter, OverrideToLocalVariable, override_to_local_variable # noqa
OverrideToLocalVariable,
override_to_local_variable, LeastLoadedDeviceSetter)
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
import inspect import inspect
import six import six
from . import logger from . import logger
if six.PY2: if six.PY2:
import functools32 as functools import functools32 as functools
else: else:
......
import os import os
from .serialize import loads_msgpack, loads_pyarrow, dumps_msgpack, dumps_pyarrow
from .serialize import dumps_msgpack, dumps_pyarrow, loads_msgpack, loads_pyarrow
""" """
Serialization that has compatibility guarantee (therefore is safe to store to disk). Serialization that has compatibility guarantee (therefore is safe to store to disk).
......
...@@ -3,14 +3,14 @@ ...@@ -3,14 +3,14 @@
# Some code taken from zxytim # Some code taken from zxytim
import threading
import platform
import multiprocessing
import atexit import atexit
import bisect import bisect
from contextlib import contextmanager import multiprocessing
import platform
import signal import signal
import threading
import weakref import weakref
from contextlib import contextmanager
import six import six
from six.moves import queue from six.moves import queue
......
...@@ -6,11 +6,11 @@ ...@@ -6,11 +6,11 @@
""" Utilities for developers only. """ Utilities for developers only.
These are not visible to users (not automatically imported). And should not These are not visible to users (not automatically imported). And should not
appeared in docs.""" appeared in docs."""
import os
import functools import functools
from datetime import datetime
import importlib import importlib
import os
import types import types
from datetime import datetime
import six import six
from . import logger from . import logger
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
# File: fs.py # File: fs.py
import os
from six.moves import urllib
import errno import errno
import os
import tqdm import tqdm
from six.moves import urllib
from . import logger from . import logger
from .utils import execute_only_once from .utils import execute_only_once
......
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
import os import os
from .utils import change_env
from . import logger from . import logger
from .nvml import NVMLContext
from .concurrency import subproc_call from .concurrency import subproc_call
from .nvml import NVMLContext
from .utils import change_env
__all__ = ['change_gpu', 'get_nr_gpu', 'get_num_gpu'] __all__ = ['change_gpu', 'get_nr_gpu', 'get_num_gpu']
......
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
# File: loadcaffe.py # File: loadcaffe.py
import sys
import numpy as np import numpy as np
import os import os
import sys
from .utils import change_env
from .fs import download, get_dataset_path
from .concurrency import subproc_call
from . import logger from . import logger
from .concurrency import subproc_call
from .fs import download, get_dataset_path
from .utils import change_env
__all__ = ['load_caffe', 'get_caffe_pb'] __all__ = ['load_caffe', 'get_caffe_pb']
......
...@@ -16,12 +16,12 @@ The logger module itself has the common logging functions of Python's ...@@ -16,12 +16,12 @@ The logger module itself has the common logging functions of Python's
import logging import logging
import os import os
import shutil
import os.path import os.path
from termcolor import colored import shutil
import sys
from datetime import datetime from datetime import datetime
from six.moves import input from six.moves import input
import sys from termcolor import colored
__all__ = ['set_logger_dir', 'auto_set_dir', 'get_logger_dir'] __all__ = ['set_logger_dir', 'auto_set_dir', 'get_logger_dir']
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: nvml.py # File: nvml.py
from ctypes import (byref, c_uint, c_ulonglong,
CDLL, POINTER, Structure)
import threading import threading
from ctypes import CDLL, POINTER, Structure, byref, c_uint, c_ulonglong
__all__ = ['NVMLContext'] __all__ = ['NVMLContext']
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import numpy as np import numpy as np
from .develop import log_deprecated from .develop import log_deprecated
__all__ = ['IntBox', 'FloatBox'] __all__ = ['IntBox', 'FloatBox']
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: serialize.py # File: serialize.py
import sys
import os import os
from .develop import create_dummy_func import sys
from . import logger from . import logger
from .develop import create_dummy_func
__all__ = ['loads', 'dumps'] __all__ = ['loads', 'dumps']
......
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
# File: timer.py # File: timer.py
from contextlib import contextmanager
from collections import defaultdict
import six
import atexit import atexit
from collections import defaultdict
from contextlib import contextmanager
from time import time as timer from time import time as timer
import six
from .stats import StatCounter
from . import logger from . import logger
from .stats import StatCounter
if six.PY3: if six.PY3:
from time import perf_counter as timer # noqa from time import perf_counter as timer # noqa
......
...@@ -2,17 +2,16 @@ ...@@ -2,17 +2,16 @@
# File: utils.py # File: utils.py
import inspect
import numpy as np
import os import os
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
import inspect
from datetime import datetime, timedelta from datetime import datetime, timedelta
from tqdm import tqdm from tqdm import tqdm
import numpy as np
from . import logger from . import logger
__all__ = ['change_env', __all__ = ['change_env',
'get_rng', 'get_rng',
'fix_rng_seed', 'fix_rng_seed',
......
...@@ -5,8 +5,10 @@ ...@@ -5,8 +5,10 @@
import numpy as np import numpy as np
import os import os
import sys import sys
from .fs import mkdir_p
from ..utils.develop import create_dummy_func # noqa
from .argtools import shape2d from .argtools import shape2d
from .fs import mkdir_p
from .palette import PALETTE_RGB from .palette import PALETTE_RGB
try: try:
...@@ -411,7 +413,6 @@ def draw_boxes(im, boxes, labels=None, color=None): ...@@ -411,7 +413,6 @@ def draw_boxes(im, boxes, labels=None, color=None):
return im return im
from ..utils.develop import create_dummy_func # noqa
try: try:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
except (ImportError, RuntimeError): except (ImportError, RuntimeError):
......
...@@ -8,16 +8,11 @@ export TF_CPP_MIN_LOG_LEVEL=2 ...@@ -8,16 +8,11 @@ export TF_CPP_MIN_LOG_LEVEL=2
# test import (#471) # test import (#471)
python -c 'from tensorpack.dataflow.imgaug import transform' python -c 'from tensorpack.dataflow.imgaug import transform'
# python -m unittest discover -v python -m unittest discover -v
# python -m tensorpack.models._test # python -m tensorpack.models._test
# segfault for no reason (https://travis-ci.org/ppwwyyxx/tensorpack/jobs/217702985) # segfault for no reason (https://travis-ci.org/ppwwyyxx/tensorpack/jobs/217702985)
# python ../tensorpack/user_ops/test-recv-op.py # python ../tensorpack/user_ops/test-recv-op.py
python test_char_rnn.py
python test_infogan.py
python test_mnist.py
python test_mnist_similarity.py
TENSORPACK_SERIALIZE=pyarrow python test_serializer.py TENSORPACK_SERIALIZE=pyarrow python test_serializer.py
TENSORPACK_SERIALIZE=msgpack python test_serializer.py TENSORPACK_SERIALIZE=msgpack python test_serializer.py
from case_script import TestPythonScript
import os import os
from case_script import TestPythonScript
def random_content(): def random_content():
return ('Lorem ipsum dolor sit amet\n' return ('Lorem ipsum dolor sit amet\n'
......
...@@ -10,6 +10,7 @@ class InfoGANTest(TestPythonScript): ...@@ -10,6 +10,7 @@ class InfoGANTest(TestPythonScript):
return '../examples/GAN/InfoGAN-mnist.py' return '../examples/GAN/InfoGAN-mnist.py'
def test(self): def test(self):
return True # https://github.com/tensorflow/tensorflow/issues/24517
if get_tf_version_tuple() < (1, 4): if get_tf_version_tuple() < (1, 4):
return True # requires leaky_relu return True # requires leaky_relu
self.assertSurvive(self.script, args=None) self.assertSurvive(self.script, args=None)
#! /usr/bin/env python #! /usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from tensorpack.dataflow.base import DataFlow
from tensorpack.dataflow import LMDBSerializer, TFRecordSerializer, NumpySerializer, HDF5Serializer
import unittest
import os
import numpy as np import numpy as np
import os
import unittest
from tensorpack.dataflow import HDF5Serializer, LMDBSerializer, NumpySerializer, TFRecordSerializer
from tensorpack.dataflow.base import DataFlow
def delete_file_if_exists(fn): def delete_file_if_exists(fn):
......
...@@ -12,3 +12,13 @@ exclude = .git, ...@@ -12,3 +12,13 @@ exclude = .git,
snippet, snippet,
examples-old, examples-old,
_test.py, _test.py,
[isort]
line_length=100
skip=docs/conf.py
multi_line_output=4
known_tensorpack=tensorpack
known_standard_library=numpy
known_third_party=bob,gym,matplotlib
no_lines_before=STDLIB,THIRDPARTY
sections=FUTURE,STDLIB,THIRDPARTY,tensorpack,FIRSTPARTY,LOCALFOLDER
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment