Commit ac9ac2a4 authored by Yuxin Wu's avatar Yuxin Wu

isort -y -sp tox.ini

parent 9c2be2ad
......@@ -32,7 +32,7 @@ It's Yet Another TF high-level API, with __speed__, and __flexibility__ built to
See [tutorials and documentations](http://tensorpack.readthedocs.io/tutorial/index.html#user-tutorials) to know more about these features.
## [Examples](examples):
## Examples:
We refuse toy examples.
Instead of showing you 10 arbitrary networks trained on toy datasets,
......
......@@ -4,20 +4,18 @@
# Author: Yuxin Wu
import multiprocessing as mp
import time
import os
import threading
from abc import abstractmethod, ABCMeta
import time
from abc import ABCMeta, abstractmethod
from collections import defaultdict
import six
from six.moves import queue
import zmq
from six.moves import queue
from tensorpack.utils import logger
from tensorpack.utils.serialize import loads, dumps
from tensorpack.utils.concurrency import (
LoopThread, ensure_proc_terminate, enable_death_signal)
from tensorpack.utils.concurrency import LoopThread, enable_death_signal, ensure_proc_terminate
from tensorpack.utils.serialize import dumps, loads
__all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange',
......
......@@ -3,29 +3,26 @@
# File: train-atari.py
# Author: Yuxin Wu
import argparse
import numpy as np
import sys
import os
import sys
import uuid
import argparse
import cv2
import tensorflow as tf
import gym
import six
import tensorflow as tf
from six.moves import queue
from tensorpack import *
from tensorpack.utils.concurrency import ensure_proc_terminate, start_proc_mask_signal
from tensorpack.utils.serialize import dumps
from tensorpack.tfutils.gradproc import MapGradient, SummaryGradient
from tensorpack.utils.concurrency import ensure_proc_terminate, start_proc_mask_signal
from tensorpack.utils.gpu import get_num_gpu
from tensorpack.utils.serialize import dumps
import gym
from simulator import SimulatorProcess, SimulatorMaster, TransitionExperience
from atari_wrapper import FireResetEnv, FrameStack, LimitLength, MapState
from common import Evaluator, eval_model_multithread, play_n_episodes
from atari_wrapper import MapState, FrameStack, FireResetEnv, LimitLength
from simulator import SimulatorMaster, SimulatorProcess, TransitionExperience
if six.PY3:
from concurrent import futures
......
......@@ -2,17 +2,17 @@
# -*- coding: utf-8 -*-
# File: create-lmdb.py
# Author: Yuxin Wu
import argparse
import numpy as np
import os
import scipy.io.wavfile as wavfile
import string
import numpy as np
import argparse
import bob.ap
import scipy.io.wavfile as wavfile
from tensorpack.dataflow import DataFlow, LMDBSerializer
from tensorpack.utils import fs, logger, serialize
from tensorpack.utils.argtools import memoized
from tensorpack.utils.stats import OnlineMoments
from tensorpack.utils import serialize, fs, logger
from tensorpack.utils.utils import get_tqdm
CHARSET = set(string.ascii_lowercase + ' ')
......
......@@ -2,10 +2,11 @@
# File: timitdata.py
# Author: Yuxin Wu
from tensorpack import ProxyDataFlow
import numpy as np
from six.moves import range
from tensorpack import ProxyDataFlow
__all__ = ['TIMITBatch']
......
......@@ -3,17 +3,17 @@
# File: train-timit.py
# Author: Yuxin Wu
import os
import argparse
import os
import tensorflow as tf
from six.moves import range
from tensorpack import *
from tensorpack.tfutils.gradproc import SummaryGradient, GlobalNormClip
from tensorpack.tfutils.gradproc import GlobalNormClip, SummaryGradient
from tensorpack.utils import serialize
import tensorflow as tf
from timitdata import TIMITBatch
rnn = tf.contrib.rnn
......
......@@ -4,16 +4,16 @@
# Author: Yuxin Wu
from __future__ import print_function
import argparse
import numpy as np
import os
import cv2
import argparse
import tensorflow as tf
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow.dataset import ILSVRCMeta
import tensorflow as tf
from tensorpack.tfutils.summary import *
from tensorpack.tfutils.symbolic_functions import *
def tower_func(image):
......
......@@ -3,15 +3,16 @@
# File: load-cpm.py
# Author: Yuxin Wu
import argparse
import numpy as np
import cv2
import tensorflow as tf
import numpy as np
import argparse
from tensorpack import *
from tensorpack.utils import viz
from tensorpack.utils.argtools import memoized
"""
15 channels:
0-1 head, neck
......
......@@ -3,12 +3,12 @@
# File: load-vgg16.py
from __future__ import print_function
import cv2
import tensorflow as tf
import argparse
import numpy as np
import os
import cv2
import six
import argparse
import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow.dataset import ILSVRCMeta
......
......@@ -3,12 +3,12 @@
# File: load-vgg19.py
from __future__ import print_function
import cv2
import tensorflow as tf
import argparse
import numpy as np
import os
import cv2
import six
import argparse
import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow.dataset import ILSVRCMeta
......
......@@ -3,21 +3,20 @@
# File: char-rnn.py
# Author: Yuxin Wu
import argparse
import numpy as np
import operator
import os
import sys
import argparse
from collections import Counter
import operator
import six
import tensorflow as tf
from six.moves import range
from tensorpack import *
from tensorpack.tfutils import summary, optimizer
from tensorpack.tfutils import optimizer, summary
from tensorpack.tfutils.gradproc import GlobalNormClip
import tensorflow as tf
rnn = tf.contrib.rnn
class _NS: pass # noqa
......
......@@ -3,20 +3,20 @@
# File: DQN.py
# Author: Yuxin Wu
import os
import argparse
import cv2
import numpy as np
import tensorflow as tf
import os
import cv2
import gym
import tensorflow as tf
from tensorpack import *
from DQNModel import Model as DQNModel
from atari import AtariPlayer
from atari_wrapper import FireResetEnv, FrameStack, LimitLength, MapState
from common import Evaluator, eval_model_multithread, play_n_episodes
from atari_wrapper import FrameStack, MapState, FireResetEnv, LimitLength
from DQNModel import Model as DQNModel
from expreplay import ExpReplay
from atari import AtariPlayer
BATCH_SIZE = 64
IMAGE_SIZE = (84, 84)
......
......@@ -4,11 +4,11 @@
import abc
import tensorflow as tf
from tensorpack import ModelDesc
from tensorpack.utils import logger
from tensorpack.tfutils import (
varreplace, summary, get_current_tower_context, optimizer, gradproc)
from tensorpack.tfutils import get_current_tower_context, gradproc, optimizer, summary, varreplace
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.utils import logger
class Model(ModelDesc):
......
......@@ -4,19 +4,18 @@
import numpy as np
import os
import cv2
import threading
import six
from six.moves import range
from tensorpack.utils import logger
from tensorpack.utils.utils import get_rng, execute_only_once
from tensorpack.utils.fs import get_dataset_path
import cv2
import gym
import six
from ale_python_interface import ALEInterface
from gym import spaces
from gym.envs.atari.atari_env import ACTION_MEANING
from six.moves import range
from ale_python_interface import ALEInterface
from tensorpack.utils import logger
from tensorpack.utils.fs import get_dataset_path
from tensorpack.utils.utils import execute_only_once, get_rng
__all__ = ['AtariPlayer']
......
......@@ -3,7 +3,6 @@
import numpy as np
from collections import deque
import gym
from gym import spaces
......
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu
import multiprocessing
import random
import time
import multiprocessing
from tqdm import tqdm
from six.moves import queue
from tqdm import tqdm
from tensorpack.utils.concurrency import StoppableThread, ShareSessionThread
from tensorpack.callbacks import Callback
from tensorpack.utils import logger
from tensorpack.utils.concurrency import ShareSessionThread, StoppableThread
from tensorpack.utils.stats import StatCounter
from tensorpack.utils.utils import get_tqdm_kwargs
......
......@@ -2,18 +2,18 @@
# File: expreplay.py
# Author: Yuxin Wu
import numpy as np
import copy
from collections import deque, namedtuple
import numpy as np
import threading
from collections import deque, namedtuple
from six.moves import queue, range
from tensorpack.callbacks.base import Callback
from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger
from tensorpack.utils.utils import get_tqdm, get_rng
from tensorpack.utils.stats import StatCounter
from tensorpack.utils.concurrency import LoopThread, ShareSessionThread
from tensorpack.callbacks.base import Callback
from tensorpack.utils.stats import StatCounter
from tensorpack.utils.utils import get_rng, get_tqdm
__all__ = ['ExpReplay']
......
......@@ -2,17 +2,17 @@
# -*- coding: utf-8 -*-
# File: mnist-disturb.py
import os
import argparse
import imp
import os
import tensorflow as tf
from tensorpack import *
from tensorpack.utils import logger
from tensorpack.dataflow import dataset
import tensorflow as tf
from tensorpack.utils import logger
from disturb import DisturbLabel
import imp
mnist_example = imp.load_source('mnist_example',
os.path.join(os.path.dirname(__file__), '..', 'basics', 'mnist-convnet.py'))
get_config = mnist_example.get_config
......
......@@ -3,13 +3,12 @@
# File: svhn-disturb.py
import argparse
import os
import imp
import os
from tensorpack import *
from tensorpack.utils import logger
from tensorpack.dataflow import dataset
from tensorpack.utils import logger
from disturb import DisturbLabel
......
......@@ -3,24 +3,22 @@
# File: alexnet-dorefa.py
# Author: Yuxin Wu, Yuheng Zou ({wyx,zyh}@megvii.com)
import cv2
import tensorflow as tf
import argparse
import numpy as np
import os
import sys
import cv2
import tensorflow as tf
from tensorpack import *
from tensorpack.tfutils.summary import add_param_summary
from tensorpack.dataflow import dataset
from tensorpack.tfutils.sessinit import get_model_loader
from tensorpack.tfutils.summary import add_param_summary
from tensorpack.tfutils.varreplace import remap_variables
from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import (
get_imagenet_dataflow, fbresnet_augmentor, ImageNetModel, eval_on_ILSVRC12)
from dorefa import get_dorefa, ternarize
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, fbresnet_augmentor, get_imagenet_dataflow
"""
This is a tensorpack script for the ImageNet results in paper:
......
......@@ -3,6 +3,7 @@
# Author: Yuxin Wu
import tensorflow as tf
from tensorpack.utils.argtools import graph_memoized
......
......@@ -2,18 +2,18 @@
# -*- coding: utf-8 -*-
# File: resnet-dorefa.py
import cv2
import tensorflow as tf
import argparse
import numpy as np
import os
import cv2
import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils.varreplace import remap_variables
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, fbresnet_augmentor
from dorefa import get_dorefa
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, fbresnet_augmentor
"""
This script loads the pre-trained ResNet-18 model with (W,A,G) = (1,4,32)
......
......@@ -3,13 +3,13 @@
# File: svhn-digit-dorefa.py
# Author: Yuxin Wu
import os
import argparse
import os
import tensorflow as tf
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary, add_param_summary
from tensorpack.dataflow import dataset
from tensorpack.tfutils.summary import add_moving_summary, add_param_summary
from tensorpack.tfutils.varreplace import remap_variables
from dorefa import get_dorefa
......
......@@ -3,19 +3,18 @@
# File: steering-filter.py
import argparse
import multiprocessing
import numpy as np
import tensorflow as tf
import cv2
import tensorflow as tf
from scipy.signal import convolve2d
from six.moves import range, zip
import multiprocessing
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.utils import logger
from tensorpack.utils.viz import *
from tensorpack.utils.argtools import shape2d, shape4d
from tensorpack.dataflow import dataset
from tensorpack.utils.viz import *
BATCH = 32
SHAPE = 64
......
# -*- coding: utf-8 -*-
# File: basemodel.py
from contextlib import contextmanager, ExitStack
import numpy as np
from contextlib import ExitStack, contextmanager
import tensorflow as tf
from tensorpack.models import BatchNorm, Conv2D, MaxPooling, layer_register
from tensorpack.tfutils import argscope
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils.varreplace import custom_getter_scope, freeze_variables
from tensorpack.models import (
Conv2D, MaxPooling, BatchNorm, layer_register)
from config import config as cfg
......
......@@ -3,17 +3,16 @@
import numpy as np
import os
from termcolor import colored
from tabulate import tabulate
import tqdm
from tabulate import tabulate
from termcolor import colored
from tensorpack.utils import logger
from tensorpack.utils.timer import timed_operation
from tensorpack.utils.argtools import log_once
from tensorpack.utils.timer import timed_operation
from config import config as cfg
__all__ = ['COCODetection', 'COCOMeta']
......
......@@ -4,6 +4,7 @@
import numpy as np
import os
import pprint
from tensorpack.utils import logger
from tensorpack.utils.gpu import get_num_gpu
......
# -*- coding: utf-8 -*-
# File: data.py
import cv2
import numpy as np
import copy
import numpy as np
import cv2
from tensorpack.utils.argtools import memoized, log_once
from tensorpack.dataflow import (
imgaug, TestDataSpeed,
MultiProcessMapDataZMQ, MultiThreadMapData,
MapDataComponent, DataFromList)
DataFromList, MapDataComponent, MultiProcessMapDataZMQ, MultiThreadMapData, TestDataSpeed, imgaug)
from tensorpack.utils import logger
# import tensorpack.utils.viz as tpviz
from tensorpack.utils.argtools import log_once, memoized
from coco import COCODetection
from common import (
CustomResize, DataFromListOfDict, box_to_point8, filter_boxes_inside_shape, point8_to_box, segmentation_to_mask)
from config import config as cfg
from utils.generate_anchors import generate_anchors
from utils.np_box_ops import area as np_area
from utils.np_box_ops import ioa as np_ioa
from common import (
DataFromListOfDict, CustomResize, filter_boxes_inside_shape,
box_to_point8, point8_to_box, segmentation_to_mask)
from config import config as cfg
# import tensorpack.utils.viz as tpviz
try:
import pycocotools.mask as cocomask
......
# -*- coding: utf-8 -*-
# File: eval.py
import tqdm
import itertools
import numpy as np
import os
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
import itertools
import numpy as np
import cv2
from concurrent.futures import ThreadPoolExecutor
from tensorpack.utils.utils import get_tqdm_kwargs
import pycocotools.mask as cocomask
import tqdm
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
import pycocotools.mask as cocomask
from tensorpack.utils.utils import get_tqdm_kwargs
from coco import COCOMeta
from common import CustomResize, clip_boxes
......
......@@ -2,10 +2,10 @@ import tensorflow as tf
from tensorpack.tfutils import get_current_tower_context
from utils.box_ops import pairwise_iou
from model_box import clip_boxes
from model_frcnn import FastRCNNHead, BoxProposals, fastrcnn_outputs
from config import config as cfg
from model_box import clip_boxes
from model_frcnn import BoxProposals, FastRCNNHead, fastrcnn_outputs
from utils.box_ops import pairwise_iou
class CascadeRCNNHead(object):
......
# -*- coding: utf-8 -*-
import itertools
import numpy as np
import tensorflow as tf
import itertools
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.models import Conv2D, FixedUnPooling, MaxPooling, layer_register
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.tower import get_current_tower_context
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.models import (
Conv2D, layer_register, FixedUnPooling, MaxPooling)
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.tower import get_current_tower_context
from model_rpn import rpn_losses, generate_rpn_proposals
from basemodel import GroupNorm
from config import config as cfg
from model_box import roi_align
from model_rpn import generate_rpn_proposals, rpn_losses
from utils.box_ops import area as tf_area
from config import config as cfg
from basemodel import GroupNorm
@layer_register(log_shape=True)
......
......@@ -3,18 +3,17 @@
import tensorflow as tf
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.models import Conv2D, FullyConnected, layer_register
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.common import get_tf_version_tuple
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.models import (
Conv2D, FullyConnected, layer_register)
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized_method
from basemodel import GroupNorm
from utils.box_ops import pairwise_iou
from model_box import encode_bbox_target, decode_bbox_target
from config import config as cfg
from model_box import decode_bbox_target, encode_bbox_target
from utils.box_ops import pairwise_iou
@under_name_scope()
......
......@@ -2,12 +2,11 @@
import tensorflow as tf
from tensorpack.models import (
Conv2D, layer_register, Conv2DTranspose)
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.models import Conv2D, Conv2DTranspose, layer_register
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.common import get_tf_version_tuple
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils.summary import add_moving_summary
from basemodel import GroupNorm
from config import config as cfg
......
......@@ -2,14 +2,13 @@
import tensorflow as tf
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import under_name_scope, auto_reuse_variable_scope
from tensorpack.models import Conv2D, layer_register
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope, under_name_scope
from tensorpack.tfutils.summary import add_moving_summary
from model_box import clip_boxes
from config import config as cfg
from model_box import clip_boxes
@layer_register(log_shape=True)
......
......@@ -2,58 +2,45 @@
# -*- coding: utf-8 -*-
# File: train.py
import os
import argparse
import cv2
import shutil
import itertools
import tqdm
import numpy as np
import json
import numpy as np
import os
import shutil
import cv2
import six
import tensorflow as tf
try:
import horovod.tensorflow as hvd
except ImportError:
pass
assert six.PY3, "FasterRCNN requires Python 3!"
import tqdm
import tensorpack.utils.viz as tpviz
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils import optimizer
from tensorpack.tfutils.common import get_tf_version_tuple
import tensorpack.utils.viz as tpviz
from coco import COCODetection
from basemodel import (
image_preprocess, resnet_c4_backbone, resnet_conv5,
resnet_fpn_backbone)
from tensorpack.tfutils.summary import add_moving_summary
import model_frcnn
import model_mrcnn
from model_frcnn import (
sample_fast_rcnn_targets, fastrcnn_outputs,
fastrcnn_predictions, BoxProposals, FastRCNNHead)
from model_mrcnn import maskrcnn_upXconv_head, maskrcnn_loss
from model_rpn import rpn_head, rpn_losses, generate_rpn_proposals
from model_fpn import (
fpn_model, multilevel_roi_align,
multilevel_rpn_losses, generate_fpn_proposals)
from basemodel import image_preprocess, resnet_c4_backbone, resnet_conv5, resnet_fpn_backbone
from coco import COCODetection
from config import config as cfg
from config import finalize_configs
from data import get_all_anchors, get_all_anchors_fpn, get_eval_dataflow, get_train_dataflow
from eval import DetectionResult, detect_one_image, eval_coco, multithread_eval_coco, print_coco_metrics
from model_box import RPNAnchors, clip_boxes, crop_and_resize, roi_align
from model_cascade import CascadeRCNNHead
from model_box import (
clip_boxes, crop_and_resize, roi_align, RPNAnchors)
from data import (
get_train_dataflow, get_eval_dataflow,
get_all_anchors, get_all_anchors_fpn)
from viz import (
draw_annotation, draw_proposal_recall,
draw_predictions, draw_final_outputs)
from eval import (
eval_coco, multithread_eval_coco,
detect_one_image, print_coco_metrics, DetectionResult)
from config import finalize_configs, config as cfg
from model_fpn import fpn_model, generate_fpn_proposals, multilevel_roi_align, multilevel_rpn_losses
from model_frcnn import BoxProposals, FastRCNNHead, fastrcnn_outputs, fastrcnn_predictions, sample_fast_rcnn_targets
from model_mrcnn import maskrcnn_loss, maskrcnn_upXconv_head
from model_rpn import generate_rpn_proposals, rpn_head, rpn_losses
from viz import draw_annotation, draw_final_outputs, draw_predictions, draw_proposal_recall
try:
import horovod.tensorflow as hvd
except ImportError:
pass
assert six.PY3, "FasterRCNN requires Python 3!"
class DetectionModel(ModelDesc):
......
......@@ -2,8 +2,10 @@
# File: box_ops.py
import tensorflow as tf
from tensorpack.tfutils.scope_utils import under_name_scope
"""
This file is modified from
https://github.com/tensorflow/models/blob/master/object_detection/core/box_list_ops.py
......
......@@ -7,8 +7,8 @@
# Written by Ross Girshick and Sean Bell
# --------------------------------------------------------
from six.moves import range
import numpy as np
from six.moves import range
# Verify that we compute the same anchors as Shaoqing's matlab implementation:
#
......@@ -27,7 +27,7 @@ import numpy as np
# -79 -167 96 184
# -167 -343 184 360
#array([[ -83., -39., 100., 56.],
# array([[ -83., -39., 100., 56.],
# [-175., -87., 192., 104.],
# [-359., -183., 376., 200.],
# [ -55., -55., 72., 72.],
......@@ -37,6 +37,7 @@ import numpy as np
# [ -79., -167., 96., 184.],
# [-167., -343., 184., 360.]])
def generate_anchors(base_size=16, ratios=[0.5, 1, 2],
scales=2**np.arange(3, 6)):
"""
......@@ -50,6 +51,7 @@ def generate_anchors(base_size=16, ratios=[0.5, 1, 2],
for i in range(ratio_anchors.shape[0])])
return anchors
def _whctrs(anchor):
"""
Return width, height, x center, and y center for an anchor (window).
......@@ -61,6 +63,7 @@ def _whctrs(anchor):
y_ctr = anchor[1] + 0.5 * (h - 1)
return w, h, x_ctr, y_ctr
def _mkanchors(ws, hs, x_ctr, y_ctr):
"""
Given a vector of widths (ws) and heights (hs) around a center
......@@ -75,6 +78,7 @@ def _mkanchors(ws, hs, x_ctr, y_ctr):
y_ctr + 0.5 * (hs - 1)))
return anchors
def _ratio_enum(anchor, ratios):
"""
Enumerate a set of anchors for each aspect ratio wrt an anchor.
......@@ -88,6 +92,7 @@ def _ratio_enum(anchor, ratios):
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
def _scale_enum(anchor, scales):
"""
Enumerate a set of anchors for each scale wrt an anchor.
......@@ -98,17 +103,3 @@ def _scale_enum(anchor, scales):
hs = h * scales
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
if __name__ == '__main__':
#import time
#t = time.time()
#a = generate_anchors()
#print(time.time() - t)
#print(a)
#from IPython import embed; embed()
anchors = generate_anchors(
16, scales=np.asarray((2, 4, 8, 16, 32), 'float32'),
ratios=[0.5,1,2])
print(anchors)
import IPython as IP; IP.embed()
# -*- coding: utf-8 -*-
# File: viz.py
from six.moves import zip
import numpy as np
from six.moves import zip
from tensorpack.utils import viz
from tensorpack.utils.palette import PALETTE_RGB
from utils.np_box_ops import iou as np_iou
from config import config as cfg
from utils.np_box_ops import iou as np_iou
def draw_annotation(img, boxes, klass, is_crowd=None):
......
......@@ -3,12 +3,14 @@
# File: BEGAN.py
# Author: Yuxin Wu
import tensorflow as tf
from tensorpack import *
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.gpu import get_num_gpu
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf
import DCGAN
from GAN import GANModelDesc, GANTrainer, MultiGPUGANTrainer
"""
......@@ -19,7 +21,6 @@ A pretrained model on CelebA is at http://models.tensorpack.com/GAN/
"""
import DCGAN
NH = 64
NF = 64
GAMMA = 0.5
......
......@@ -3,18 +3,18 @@
# File: ConditionalGAN-mnist.py
# Author: Yuxin Wu
import argparse
import numpy as np
import tensorflow as tf
import os
import cv2
import argparse
import tensorflow as tf
from tensorpack import *
from tensorpack.utils.viz import interactive_imshow, stack_patches
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.dataflow import dataset
from GAN import GANTrainer, RandomZData, GANModelDesc
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.utils.viz import interactive_imshow, stack_patches
from GAN import GANModelDesc, GANTrainer, RandomZData
"""
To train:
......
......@@ -3,17 +3,17 @@
# File: CycleGAN.py
# Author: Yuxin Wu
import os
import argparse
import glob
import os
import tensorflow as tf
from six.moves import range
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf
from GAN import GANTrainer, GANModelDesc
from tensorpack.tfutils.summary import add_moving_summary
from GAN import GANModelDesc, GANTrainer
"""
1. Download the dataset following the original project: https://github.com/junyanz/CycleGAN#train
......
......@@ -3,18 +3,18 @@
# File: DCGAN.py
# Author: Yuxin Wu
import argparse
import glob
import numpy as np
import os
import argparse
import tensorflow as tf
from tensorpack import *
from tensorpack.utils.viz import stack_patches
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf
from tensorpack.utils.viz import stack_patches
from GAN import GANModelDesc, GANTrainer, RandomZData
from GAN import GANTrainer, RandomZData, GANModelDesc
"""
1. Download the 'aligned&cropped' version of CelebA dataset
......
......@@ -3,17 +3,17 @@
# File: DiscoGAN-CelebA.py
# Author: Yuxin Wu
import os
import argparse
from six.moves import map, zip
import numpy as np
import os
import tensorflow as tf
from six.moves import map, zip
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf
from GAN import SeparateGANTrainer, GANModelDesc
from tensorpack.tfutils.summary import add_moving_summary
from GAN import GANModelDesc, SeparateGANTrainer
"""
1. Download "aligned&cropped" version of celebA to /path/to/img_align_celeba.
......
......@@ -2,13 +2,13 @@
# File: GAN.py
# Author: Yuxin Wu
import tensorflow as tf
import numpy as np
from tensorpack import (TowerTrainer, StagingInput,
ModelDescBase, DataFlow, argscope, BatchNorm)
from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper
import tensorflow as tf
from tensorpack import BatchNorm, DataFlow, ModelDescBase, StagingInput, TowerTrainer, argscope
from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper
from tensorpack.utils import logger
from tensorpack.utils.argtools import memoized_method
from tensorpack.utils.develop import deprecated
......
......@@ -3,20 +3,20 @@
# File: Image2Image.py
# Author: Yuxin Wu
import cv2
import numpy as np
import tensorflow as tf
import argparse
import glob
import numpy as np
import os
import argparse
import cv2
import tensorflow as tf
from tensorpack import *
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.gpu import get_num_gpu
from tensorpack.utils.viz import stack_patches
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from GAN import GANTrainer, GANModelDesc
from GAN import GANModelDesc, GANTrainer
"""
To train Image-to-Image translation model with image pairs:
......
......@@ -3,12 +3,14 @@
# File: Improved-WGAN.py
# Author: Yuxin Wu
import tensorflow as tf
from tensorpack import *
from tensorpack.tfutils import get_tf_version_tuple
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf
from tensorpack.tfutils.summary import add_moving_summary
import DCGAN
from GAN import SeparateGANTrainer
"""
......@@ -18,7 +20,6 @@ See the docstring in DCGAN.py for usage.
# Don't want to mix two examples together, but want to reuse the code.
# So here just import stuff from DCGAN.
import DCGAN
class Model(DCGAN.Model):
......
......@@ -3,19 +3,19 @@
# File: InfoGAN-mnist.py
# Author: Yuxin Wu
import cv2
import argparse
import numpy as np
import tensorflow as tf
import os
import argparse
import cv2
import tensorflow as tf
from tensorpack import *
from tensorpack.utils import viz
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope, under_name_scope
from tensorpack.tfutils import optimizer, summary, gradproc
from tensorpack.dataflow import dataset
from GAN import GANTrainer, GANModelDesc
from tensorpack.tfutils import gradproc, optimizer, summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope, under_name_scope
from tensorpack.utils import viz
from GAN import GANModelDesc, GANTrainer
"""
To train:
......
......@@ -3,9 +3,12 @@
# File: WGAN.py
# Author: Yuxin Wu
import tensorflow as tf
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
import tensorflow as tf
import DCGAN
from GAN import SeparateGANTrainer
"""
......@@ -15,7 +18,6 @@ See the docstring in DCGAN.py for usage.
# Don't want to mix two examples together, but want to reuse the code.
# So here just import stuff from DCGAN
import DCGAN
class Model(DCGAN.Model):
......
......@@ -3,19 +3,18 @@
# File: hed.py
# Author: Yuxin Wu
import argparse
import numpy as np
import os
import cv2
import tensorflow as tf
import numpy as np
import argparse
from six.moves import zip
import os
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_num_gpu
from tensorpack.tfutils import optimizer, gradproc
from tensorpack.tfutils import gradproc, optimizer
from tensorpack.tfutils.summary import add_moving_summary, add_param_summary
from tensorpack.utils.gpu import get_num_gpu
def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss'):
......
......@@ -3,10 +3,9 @@
# File: alexnet.py
import argparse
import numpy as np
import os
import cv2
import numpy as np
import tensorflow as tf
from tensorpack import *
......
......@@ -2,24 +2,22 @@
# File: imagenet_utils.py
import cv2
import os
import numpy as np
import tqdm
import multiprocessing
import tensorflow as tf
import numpy as np
import os
from abc import abstractmethod
import cv2
import tensorflow as tf
import tqdm
from tensorpack import ModelDesc
from tensorpack.dataflow import AugmentImageComponent, BatchData, MultiThreadMapData, PrefetchDataZMQ, dataset, imgaug
from tensorpack.input_source import QueueInput, StagingInput
from tensorpack.dataflow import (
imgaug, dataset, AugmentImageComponent, PrefetchDataZMQ,
BatchData, MultiThreadMapData)
from tensorpack.predict import PredictConfig, FeedfreePredictor
from tensorpack.utils.stats import RatioCounter
from tensorpack.models import regularize_cost
from tensorpack.predict import FeedfreePredictor, PredictConfig
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger
from tensorpack.utils.stats import RatioCounter
"""
......
......@@ -7,10 +7,9 @@ import argparse
import os
import tensorflow as tf
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.dataflow import dataset
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import fbresnet_augmentor, get_imagenet_dataflow
......
......@@ -3,24 +3,20 @@
# File: shufflenet.py
import argparse
import numpy as np
import math
import numpy as np
import os
import cv2
import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import imgaug
from tensorpack.tfutils import argscope, get_model_loader, model_utils
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.utils.gpu import get_num_gpu
from tensorpack.utils import logger
from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import (
get_imagenet_dataflow,
ImageNetModel, GoogleNetResize, eval_on_ILSVRC12)
from imagenet_utils import GoogleNetResize, ImageNetModel, eval_on_ILSVRC12, get_imagenet_dataflow
@layer_register(log_shape=True)
......
......@@ -4,7 +4,6 @@
import argparse
import os
import tensorflow as tf
from tensorpack import *
......@@ -12,8 +11,7 @@ from tensorpack.tfutils import argscope
from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import (
ImageNetModel, get_imagenet_dataflow, fbresnet_augmentor)
from imagenet_utils import ImageNetModel, fbresnet_augmentor, get_imagenet_dataflow
def GroupNorm(x, group, gamma_initializer=tf.constant_initializer(1.)):
......
......@@ -2,16 +2,16 @@
# -*- coding: utf-8 -*-
# Author: Patrick Wieschollek <mail@patwie.com>
import argparse
import glob
import os
import cv2
import glob
from helper import Flow
import argparse
from tensorpack import *
from tensorpack.utils import viz
import flownet_models as models
from helper import Flow
def apply(model, model_path, left, right, ground_truth=None):
......
......@@ -4,6 +4,7 @@
import tensorflow as tf
from tensorpack import ModelDesc, argscope, enable_argscope_for_module
enable_argscope_for_module(tf.layers)
......
......@@ -3,21 +3,20 @@
# File: PTB-LSTM.py
# Author: Yuxin Wu
import argparse
import numpy as np
import os
import argparse
import tensorflow as tf
from tensorpack import *
from tensorpack.tfutils import optimizer, summary, gradproc
from tensorpack.tfutils import gradproc, optimizer, summary
from tensorpack.utils import logger
from tensorpack.utils.fs import download, get_dataset_path
from tensorpack.utils.argtools import memoized_ignoreargs
from tensorpack.utils.fs import download, get_dataset_path
import reader as tfreader
from reader import ptb_producer
import tensorflow as tf
rnn = tf.contrib.rnn
SEQ_LEN = 35
......
......@@ -16,13 +16,9 @@
"""Utilities for parsing PTB text files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function
import collections
import os
import tensorflow as tf
......
......@@ -3,14 +3,14 @@
# File: cifar10-preact18-mixup.py
# Author: Tao Hu <taohu620@gmail.com>, Yauheni Selivonchyk <y.selivonchyk@gmail.com>
import numpy as np
import argparse
import numpy as np
import os
import tensorflow as tf
from tensorpack import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils.summary import *
BATCH_SIZE = 128
CLASS_NUM = 10
......
......@@ -5,14 +5,12 @@
import argparse
import os
import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils.summary import add_moving_summary, add_param_summary
from tensorpack.utils.gpu import get_num_gpu
from tensorpack.dataflow import dataset
import tensorflow as tf
"""
CIFAR10 ResNet example. See:
......
......@@ -5,22 +5,18 @@
import argparse
import os
from tensorpack import logger, QueueInput, TFDatasetInput
from tensorpack.models import *
from tensorpack import QueueInput, TFDatasetInput, logger
from tensorpack.callbacks import *
from tensorpack.train import (
TrainConfig, SyncMultiGPUTrainerReplicated, launch_train_with_config)
from tensorpack.dataflow import FakeData
from tensorpack.models import *
from tensorpack.tfutils import argscope, get_model_loader
from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import (
get_imagenet_dataflow, get_imagenet_tfdata,
ImageNetModel, eval_on_ILSVRC12)
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, get_imagenet_dataflow, get_imagenet_tfdata
from resnet_model import (
preresnet_group, preresnet_basicblock, preresnet_bottleneck,
resnet_group, resnet_basicblock, resnet_bottleneck, se_resnet_bottleneck,
resnet_backbone)
preresnet_basicblock, preresnet_bottleneck, preresnet_group, resnet_backbone, resnet_basicblock, resnet_bottleneck,
resnet_group, se_resnet_bottleneck)
class Model(ImageNetModel):
......
......@@ -4,20 +4,20 @@
# Author: Eric Yujia Huang <yujiah1@andrew.cmu.edu>
# Yuxin Wu
import cv2
import functools
import tensorflow as tf
import argparse
import re
import functools
import numpy as np
import re
import cv2
import six
import tensorflow as tf
from tensorpack import *
from tensorpack.utils import logger
from tensorpack.dataflow.dataset import ILSVRCMeta
from tensorpack.utils import logger
from imagenet_utils import eval_on_ILSVRC12, get_imagenet_dataflow, ImageNetModel
from resnet_model import resnet_group, resnet_bottleneck
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, get_imagenet_dataflow
from resnet_model import resnet_bottleneck, resnet_group
DEPTH = None
CFG = {
......
......@@ -3,9 +3,8 @@
import tensorflow as tf
from tensorpack.models import BatchNorm, BNReLU, Conv2D, FullyConnected, GlobalAvgPooling, MaxPooling
from tensorpack.tfutils.argscope import argscope, get_arg_scope
from tensorpack.models import (
Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm, BNReLU, FullyConnected)
def resnet_shortcut(l, n_out, stride, activation=tf.identity):
......
......@@ -2,28 +2,24 @@
# -*- coding: utf-8 -*-
# File: CAM-resnet.py
import cv2
import sys
import argparse
import multiprocessing
import numpy as np
import os
import multiprocessing
import sys
import cv2
import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils import optimizer, gradproc
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils import gradproc, optimizer
from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_num_gpu
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.utils import viz
from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import (
fbresnet_augmentor, ImageNetModel)
from resnet_model import (
preresnet_basicblock, preresnet_group)
from imagenet_utils import ImageNetModel, fbresnet_augmentor
from resnet_model import preresnet_basicblock, preresnet_group
TOTAL_BATCH_SIZE = 256
DEPTH = None
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import cv2
import numpy as np
import sys
from contextlib import contextmanager
import numpy as np
import cv2
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.nets import resnet_v1
......
......@@ -3,7 +3,8 @@
# Author: tensorpack contributors
import numpy as np
from tensorpack.dataflow import dataset, BatchData
from tensorpack.dataflow import BatchData, dataset
def get_test_data(batch=128):
......
......@@ -2,17 +2,16 @@
# -*- coding: utf-8 -*-
# File: mnist-embeddings.py
import numpy as np
import argparse
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.gpu import change_gpu
from embedding_data import get_test_data, MnistPairs, MnistTriplets
from embedding_data import MnistPairs, MnistTriplets, get_test_data
MATPLOTLIB_AVAIBLABLE = False
try:
......
......@@ -3,16 +3,15 @@
# File: mnist-addition.py
# Author: Yuxin Wu
import cv2
import argparse
import numpy as np
import tensorflow as tf
import os
import argparse
import cv2
import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils import optimizer, summary, gradproc
from tensorpack.tfutils import gradproc, optimizer, summary
IMAGE_SIZE = 42
WARP_TARGET_SIZE = 28
......
import cv2
import os
import argparse
import numpy as np
import os
import zipfile
from tensorpack import RNGDataFlow, MapDataComponent, LMDBSerializer
import cv2
from tensorpack import LMDBSerializer, MapDataComponent, RNGDataFlow
class ImageDataFromZIPFile(RNGDataFlow):
......
......@@ -2,11 +2,11 @@
# -*- coding: utf-8 -*-
# Author: Patrick Wieschollek <mail@patwie.com>
import os
import argparse
import numpy as np
import os
import cv2
import six
import numpy as np
import tensorflow as tf
from tensorpack import *
......@@ -14,10 +14,10 @@ from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger
from tensorpack.utils.gpu import get_num_gpu
from data_sampler import (
ImageDecode, ImageDataFromZIPFile,
RejectTooSmallImages, CenterSquareResize)
from GAN import SeparateGANTrainer, GANModelDesc
from data_sampler import CenterSquareResize, ImageDataFromZIPFile, ImageDecode, RejectTooSmallImages
from GAN import GANModelDesc, SeparateGANTrainer
Reduction = tf.losses.Reduction
BATCH_SIZE = 16
......
......@@ -2,15 +2,16 @@
# -*- coding: utf-8 -*-
# File: cifar-convnet.py
# Author: Yuxin Wu
import tensorflow as tf
import argparse
import os
import tensorflow as tf
from tensorpack import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_num_gpu
"""
A small convnet model for Cifar10 or Cifar100 dataset.
......
......@@ -4,6 +4,7 @@
import argparse
import cv2
import tensorflow as tf
from tensorpack import *
from tensorpack.tfutils.export import ModelExporter
......
......@@ -3,17 +3,16 @@
# File: mnist-convnet.py
import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils import summary
"""
MNIST ConvNet example.
about 0.6% validation error after 30 epochs.
"""
# Just import everything into current namespace
from tensorpack import *
from tensorpack.tfutils import summary
from tensorpack.dataflow import dataset
IMAGE_SIZE = 28
......
......@@ -3,6 +3,11 @@
# File: mnist-tflayers.py
import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils import get_current_tower_context, summary
"""
MNIST ConvNet example using tf.layers
Mostly the same as 'mnist-convnet.py',
......@@ -11,12 +16,6 @@ the only differences are:
2. use tf.layers variable names to summarize weights
"""
# Just import everything into current namespace
from tensorpack import *
from tensorpack.tfutils import summary, get_current_tower_context
from tensorpack.dataflow import dataset
IMAGE_SIZE = 28
# Monkey-patch tf.layers to support argscope.
enable_argscope_for_module(tf.layers)
......
......@@ -11,11 +11,12 @@ the only differences are:
"""
from tensorpack import *
from tensorpack.dataflow import dataset
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorpack import *
from tensorpack.dataflow import dataset
IMAGE_SIZE = 28
......
......@@ -7,6 +7,7 @@ The same MNIST ConvNet example, but with weights/activations visualization.
"""
import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import dataset
......
......@@ -5,12 +5,12 @@
import argparse
import os
import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils.summary import *
import tensorflow as tf
"""
A very small SVHN convnet model (only 0.8m parameters).
......
# -*- coding: utf-8 -*-
# Author: Your Name <your@email.com>
import os
import argparse
import os
import tensorflow as tf
from tensorpack import *
......
......@@ -3,22 +3,21 @@
# File: imagenet-resnet-keras.py
# Author: Yuxin Wu
import argparse
import numpy as np
import os
import tensorflow as tf
import argparse
from tensorflow.python.keras.layers import *
from tensorpack import InputDesc, SyncMultiGPUTrainerReplicated
from tensorpack.callbacks import *
from tensorpack.contrib.keras import KerasModel
from tensorpack.dataflow import FakeData, MapDataComponent
from tensorpack.tfutils.common import get_tf_version_tuple
from tensorpack.utils import logger
from tensorpack.utils.gpu import get_num_gpu
from tensorpack.contrib.keras import KerasModel
from tensorpack.callbacks import *
from tensorflow.python.keras.layers import *
from tensorpack.tfutils.common import get_tf_version_tuple
from imagenet_utils import get_imagenet_dataflow, fbresnet_augmentor
from imagenet_utils import fbresnet_augmentor, get_imagenet_dataflow
TOTAL_BATCH_SIZE = 512
BASE_LR = 0.1 * (TOTAL_BATCH_SIZE // 256)
......
......@@ -5,17 +5,15 @@
import numpy as np
import tensorflow as tf
from tensorflow import keras
KL = keras.layers
from tensorpack import InputDesc, QueueInput
from tensorpack.dataflow import dataset, BatchData, MapData
from tensorpack.utils import logger
from tensorpack.contrib.keras import KerasModel
from tensorpack.callbacks import ModelSaver
from tensorpack.contrib.keras import KerasModel
from tensorpack.dataflow import BatchData, MapData, dataset
from tensorpack.utils import logger
KL = keras.layers
IMAGE_SIZE = 28
......
......@@ -5,6 +5,12 @@
import tensorflow as tf
from tensorflow import keras
from tensorpack import *
from tensorpack.contrib.keras import KerasPhaseCallback
from tensorpack.dataflow import dataset
from tensorpack.utils.argtools import memoized
KL = keras.layers
"""
......@@ -14,12 +20,6 @@ This way you can define models in Keras-style, and benefit from the more efficei
Note: this example does not work for replicated-style data-parallel trainers.
"""
from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.utils.argtools import memoized
from tensorpack.contrib.keras import KerasPhaseCallback
IMAGE_SIZE = 28
......
......@@ -3,11 +3,11 @@
# File: checkpoint-manipulate.py
import argparse
import numpy as np
from tensorpack.tfutils.varmanip import load_chkpt_vars
from tensorpack.utils import logger
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
......
......@@ -2,12 +2,13 @@
# -*- coding: utf-8 -*-
# File: checkpoint-prof.py
import tensorflow as tf
import argparse
import numpy as np
import tensorflow as tf
from tensorpack import get_default_sess_config, get_op_tensor_name
from tensorpack.utils import logger
from tensorpack.tfutils.sessinit import get_model_loader
import argparse
from tensorpack.utils import logger
if __name__ == '__main__':
parser = argparse.ArgumentParser()
......
......@@ -2,10 +2,10 @@
# -*- coding: utf-8 -*-
# File: dump-model-params.py
import numpy as np
import six
import argparse
import numpy as np
import os
import six
import tensorflow as tf
from tensorpack.tfutils import varmanip
......
......@@ -2,11 +2,11 @@
# -*- coding: utf-8 -*-
# File: ls-checkpoint.py
import tensorflow as tf
import numpy as np
import six
import sys
import pprint
import sys
import six
import tensorflow as tf
from tensorpack.tfutils.varmanip import get_checkpoint_path
......
import platform
from os import path
import setuptools
from setuptools import setup
from os import path
import platform
version = int(setuptools.__version__.split('.')[0])
assert version > 30, "Tensorpack installation requires setuptools > 30"
......@@ -24,7 +24,7 @@ def add_git_version():
from subprocess import check_output
try:
return check_output("git describe --tags --long --dirty".split()).decode('utf-8').strip()
except:
except Exception:
return __version__ # noqa
newlibinfo_content = [l for l in libinfo_content if not l.startswith('__git_version__')]
......
......@@ -2,9 +2,10 @@
# File: base.py
import tensorflow as tf
from abc import ABCMeta
import six
import tensorflow as tf
from ..tfutils.common import get_op_or_tensor_by_name
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory']
......
......@@ -3,9 +3,10 @@
import multiprocessing as mp
from .base import Callback
from ..utils.concurrency import start_proc_mask_signal, StoppableThread
from ..utils import logger
from ..utils.concurrency import StoppableThread, start_proc_mask_signal
from .base import Callback
__all__ = ['StartProcOrThread']
......
......@@ -4,14 +4,14 @@
""" Graph related callbacks"""
import tensorflow as tf
import os
import numpy as np
import os
import tensorflow as tf
from six.moves import zip
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
from .base import Callback
from ..tfutils.common import get_op_tensor_name
__all__ = ['RunOp', 'RunUpdateOps', 'ProcessTensors', 'DumpTensors',
'DumpTensor', 'DumpTensorAsImage', 'DumpParamAsImage']
......
......@@ -2,16 +2,16 @@
# File: group.py
import tensorflow as tf
import traceback
from contextlib import contextmanager
from time import time as timer
import traceback
import six
import tensorflow as tf
from .base import Callback
from .hooks import CallbackToHook
from ..utils import logger
from ..utils.utils import humanize_time_delta
from .base import Callback
from .hooks import CallbackToHook
if six.PY3:
from time import perf_counter as timer # noqa
......
......@@ -5,6 +5,7 @@
""" Compatible layers between tf.train.SessionRunHook and Callback"""
import tensorflow as tf
from .base import Callback
__all__ = ['CallbackToHook', 'HookToCallback']
......
......@@ -7,10 +7,10 @@ from abc import ABCMeta
import six
from six.moves import zip
from .base import Callback
from ..utils import logger
from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
from ..utils.stats import BinaryStatistics, RatioCounter
from .base import Callback
__all__ = ['ScalarStats', 'Inferencer',
'ClassificationError', 'BinaryClassificationStats']
......
......@@ -2,24 +2,19 @@
# File: inference_runner.py
import sys
import tensorflow as tf
from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
import itertools
import sys
from contextlib import contextmanager
import tensorflow as tf
import tqdm
from six.moves import range
from tensorflow.python.training.monitored_session import _HookedSession as HookedSession
from ..utils import logger
from ..utils.utils import get_tqdm_kwargs
from ..dataflow.base import DataFlow
from ..input_source import FeedInput, InputSource, QueueInput, StagingInput
from ..tfutils.tower import PredictTowerContext
from ..input_source import (
InputSource, FeedInput, QueueInput, StagingInput)
from ..utils import logger
from ..utils.utils import get_tqdm_kwargs
from .base import Callback
from .group import Callbacks
from .inference import Inferencer
......
......@@ -2,14 +2,14 @@
# File: misc.py
import numpy as np
import os
import time
from collections import deque
import numpy as np
from .base import Callback
from ..utils.utils import humanize_time_delta
from ..utils import logger
from ..utils.utils import humanize_time_delta
from .base import Callback
__all__ = ['SendStat', 'InjectShell', 'EstimatedTimeLeft']
......
......@@ -2,20 +2,20 @@
# File: monitor.py
import os
import json
import numpy as np
import operator
import os
import re
import shutil
import time
from datetime import datetime
import operator
from collections import defaultdict
from datetime import datetime
import six
import json
import re
import tensorflow as tf
from ..tfutils.summary import create_image_summary, create_scalar_summary
from ..utils import logger
from ..tfutils.summary import create_scalar_summary, create_image_summary
from ..utils.develop import HIDE_DOC
from .base import Callback
......
......@@ -2,16 +2,16 @@
# File: param.py
import tensorflow as tf
from collections import deque
from abc import abstractmethod, ABCMeta
import operator
import six
import os
from abc import ABCMeta, abstractmethod
from collections import deque
import six
import tensorflow as tf
from .base import Callback
from ..utils import logger
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
from .base import Callback
__all__ = ['HyperParam', 'GraphVarParam', 'ObjAttrParam',
'HyperParamSetter', 'HumanHyperParamSetter',
......
......@@ -2,20 +2,20 @@
# File: prof.py
import os
import numpy as np
import multiprocessing as mp
import numpy as np
import os
import time
from six.moves import map
import tensorflow as tf
from six.moves import map
from tensorflow.python.client import timeline
from .base import Callback
from ..tfutils.common import gpu_available_in_session
from ..utils import logger
from ..utils.concurrency import ensure_proc_terminate, start_proc_mask_signal
from ..utils.gpu import get_num_gpu
from ..utils.nvml import NVMLContext
from ..tfutils.common import gpu_available_in_session
from .base import Callback
__all__ = ['GPUUtilizationTracker', 'GraphProfiler', 'PeakMemoryTracker']
......
......@@ -2,12 +2,12 @@
# File: saver.py
import tensorflow as tf
from datetime import datetime
import os
from datetime import datetime
import tensorflow as tf
from .base import Callback
from ..utils import logger
from .base import Callback
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
......
# -*- coding: utf-8 -*-
# File: stats.py
from .graph import DumpParamAsImage # noqa
# for compatibility only
from .misc import InjectShell, SendStat # noqa
from .graph import DumpParamAsImage # noqa
__all__ = []
......@@ -4,14 +4,13 @@
""" Some common step callbacks. """
import tensorflow as tf
from six.moves import zip
import tqdm
from six.moves import zip
from ..tfutils.common import get_global_step_var, get_op_tensor_name
from ..utils import logger
from ..utils.utils import get_tqdm_kwargs
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
from ..tfutils.common import (
get_op_tensor_name, get_global_step_var)
from ..utils.utils import get_tqdm_kwargs
from .base import Callback
__all__ = ['TensorPrinter', 'ProgressBar', 'SessionRunTimeout']
......
......@@ -2,9 +2,9 @@
# File: summary.py
import tensorflow as tf
import numpy as np
from collections import deque
import tensorflow as tf
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
......
......@@ -2,8 +2,8 @@
# File: trigger.py
from .base import ProxyCallback, Callback
from ..utils.develop import log_deprecated
from .base import Callback, ProxyCallback
__all__ = ['PeriodicTrigger', 'PeriodicCallback', 'EnableCallbackIf']
......
# -*- coding: utf-8 -*-
# File: keras.py
import tensorflow as tf
from contextlib import contextmanager
import six
from tensorflow import keras
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow import keras
from tensorflow.python.keras import metrics as metrics_module
from contextlib import contextmanager
from ..callbacks import Callback, CallbackToHook, InferenceRunner, InferenceRunnerBase, ScalarStats
from ..models.regularize import regularize_cost_from_collection
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
from ..train.trainers import DistributedTrainerBase
from ..train.interface import apply_default_prefetch
from ..callbacks import (
Callback, InferenceRunnerBase, InferenceRunner, CallbackToHook,
ScalarStats)
from ..tfutils.common import get_op_tensor_name
from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_op_tensor_name
from ..tfutils.scope_utils import cached_name_scope
from ..tfutils.summary import add_moving_summary
from ..utils.gpu import get_nr_gpu
from ..tfutils.tower import get_current_tower_context
from ..train import SimpleTrainer, SyncMultiGPUTrainerParameterServer, Trainer
from ..train.interface import apply_default_prefetch
from ..train.trainers import DistributedTrainerBase
from ..utils import logger
from ..utils.gpu import get_nr_gpu
__all__ = ['KerasPhaseCallback', 'setup_keras_trainer', 'KerasModel']
......
......@@ -3,8 +3,9 @@
import threading
from abc import abstractmethod, ABCMeta
from abc import ABCMeta, abstractmethod
import six
from ..utils.utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated']
......
......@@ -2,20 +2,20 @@
# File: common.py
from __future__ import division
import six
import itertools
import numpy as np
from copy import copy
import pprint
import itertools
from termcolor import colored
from collections import deque, defaultdict
from six.moves import range, map
from collections import defaultdict, deque
from copy import copy
import six
import tqdm
from six.moves import map, range
from termcolor import colored
from .base import DataFlow, ProxyDataFlow, RNGDataFlow, DataFlowReentrantGuard
from ..utils import logger
from ..utils.utils import get_tqdm, get_rng, get_tqdm_kwargs
from ..utils.develop import log_deprecated
from ..utils.utils import get_rng, get_tqdm, get_tqdm_kwargs
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow, RNGDataFlow
__all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData',
'MapDataComponent', 'RepeatedData', 'RepeatedDataPoint', 'RandomChooseData',
......
......@@ -2,9 +2,9 @@
# File: bsds500.py
import os
import glob
import numpy as np
import os
from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow
......
......@@ -3,9 +3,9 @@
# Yukun Chen <cykustc@gmail.com>
import numpy as np
import os
import pickle
import numpy as np
import tarfile
import six
from six.moves import range
......
# -*- coding: utf-8 -*-
# File: ilsvrc.py
import numpy as np
import os
import tarfile
import numpy as np
import tqdm
from ...utils import logger
from ...utils.fs import download, get_dataset_path, mkdir_p
from ...utils.loadcaffe import get_caffe_pb
from ...utils.fs import mkdir_p, download, get_dataset_path
from ...utils.timer import timed_operation
from ..base import RNGDataFlow
......
......@@ -2,9 +2,9 @@
# File: mnist.py
import os
import gzip
import numpy
import os
from six.moves import range
from ...utils import logger
......
......@@ -2,11 +2,11 @@
# File: svhn.py
import os
import numpy as np
import os
from ...utils import logger
from ...utils.fs import get_dataset_path, download
from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow
__all__ = ['SVHNDigit']
......
......@@ -3,15 +3,13 @@
from ..utils.develop import deprecated
from .remote import dump_dataflow_to_process_queue
from .serialize import LMDBSerializer, TFRecordSerializer
__all__ = ['dump_dataflow_to_process_queue',
'dump_dataflow_to_lmdb', 'dump_dataflow_to_tfrecord']
from .remote import dump_dataflow_to_process_queue
@deprecated("Use LMDBSerializer.save instead!", "2019-01-31")
def dump_dataflow_to_lmdb(df, lmdb_path, write_frequency=5000):
LMDBSerializer.save(df, lmdb_path, write_frequency)
......
......@@ -3,18 +3,19 @@
import numpy as np
import os
import six
from six.moves import range
import os
from ..utils import logger
from ..utils.utils import get_tqdm
from ..utils.timer import timed_operation
from ..utils.loadcaffe import get_caffe_pb
from ..utils.compatible_serialize import loads
from ..utils.argtools import log_once
from ..utils.compatible_serialize import loads
from ..utils.develop import create_dummy_class # noqa
from ..utils.develop import log_deprecated
from .base import RNGDataFlow, DataFlow, DataFlowReentrantGuard
from ..utils.loadcaffe import get_caffe_pb
from ..utils.timer import timed_operation
from ..utils.utils import get_tqdm
from .base import DataFlow, DataFlowReentrantGuard, RNGDataFlow
from .common import MapData
__all__ = ['HDF5Data', 'LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint',
......@@ -258,7 +259,7 @@ class TFRecordData(DataFlow):
for dp in gen:
yield loads(dp)
from ..utils.develop import create_dummy_class # noqa
try:
import h5py
except ImportError:
......
......@@ -2,13 +2,14 @@
# File: image.py
import numpy as np
import copy as copy_mod
import numpy as np
from contextlib import contextmanager
from .base import RNGDataFlow
from .common import MapDataComponent, MapData
from ..utils import logger
from ..utils.argtools import shape2d
from .base import RNGDataFlow
from .common import MapData, MapDataComponent
__all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates', 'AugmentImageComponents']
......
......@@ -4,13 +4,13 @@
import sys
import cv2
from . import AugmentorList
from .crop import *
from .imgproc import *
from .noname import *
from .deform import *
from .imgproc import *
from .noise import SaltPepperNoise
from .noname import *
anchors = [(0.2, 0.2), (0.7, 0.2), (0.8, 0.8), (0.5, 0.5), (0.2, 0.5)]
augmentors = AugmentorList([
......
......@@ -4,12 +4,12 @@
import inspect
import pprint
from abc import abstractmethod, ABCMeta
from abc import ABCMeta, abstractmethod
import six
from six.moves import zip
from ...utils.utils import get_rng
from ...utils.argtools import log_once
from ...utils.utils import get_rng
from ..image import check_dtype
__all__ = ['Augmentor', 'ImageAugmentor', 'AugmentorList']
......
# -*- coding: utf-8 -*-
# File: convert.py
from .base import ImageAugmentor
from .meta import MapImage
import numpy as np
import cv2
from .base import ImageAugmentor
from .meta import MapImage
__all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32']
......
......@@ -3,8 +3,7 @@
from ...utils.argtools import shape2d
from .transform import TransformAugmentorBase, CropTransform
from .transform import CropTransform, TransformAugmentorBase
__all__ = ['RandomCrop', 'CenterCrop', 'RandomCropRandomShape']
......
......@@ -2,10 +2,11 @@
# File: deform.py
from .base import ImageAugmentor
from ...utils import logger
import numpy as np
from ...utils import logger
from .base import ImageAugmentor
__all__ = []
# Code was temporarily kept here for a future reference in case someone needs it
......
......@@ -4,7 +4,6 @@ import numpy as np
from .base import ImageAugmentor
__all__ = ['IAAugmentor', 'Albumentations']
......
......@@ -3,8 +3,8 @@
import math
import cv2
import numpy as np
import cv2
from .base import ImageAugmentor
from .transform import TransformAugmentorBase, WarpAffineTransform
......
......@@ -2,10 +2,11 @@
# File: imgproc.py
from .base import ImageAugmentor
import numpy as np
import cv2
from .base import ImageAugmentor
__all__ = ['Hue', 'Brightness', 'BrightnessScale', 'Contrast', 'MeanVarianceNormalize',
'GaussianBlur', 'Gamma', 'Clip', 'Saturation', 'Lighting', 'MinMaxNormalize']
......
......@@ -5,9 +5,9 @@
import numpy as np
import cv2
from .base import ImageAugmentor
from ...utils import logger
from ...utils.argtools import shape2d
from .base import ImageAugmentor
from .transform import ResizeTransform, TransformAugmentorBase
__all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge', 'Transpose']
......
......@@ -2,10 +2,11 @@
# File: noise.py
from .base import ImageAugmentor
import numpy as np
import cv2
from .base import ImageAugmentor
__all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise']
......
......@@ -2,10 +2,10 @@
# File: paste.py
from .base import ImageAugmentor
from abc import abstractmethod
import numpy as np
from abc import abstractmethod
from .base import ImageAugmentor
__all__ = ['CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller',
'RandomPaste']
......
# -*- coding: utf-8 -*-
# File: transform.py
from abc import abstractmethod, ABCMeta
import six
import cv2
import numpy as np
from abc import ABCMeta, abstractmethod
import cv2
import six
from .base import ImageAugmentor
......
# -*- coding: utf-8 -*-
# File: parallel.py
import atexit
import errno
import itertools
import multiprocessing as mp
import os
import sys
import uuid
import weakref
from contextlib import contextmanager
import multiprocessing as mp
import itertools
from six.moves import range, zip, queue
import errno
import uuid
import os
import zmq
import atexit
from six.moves import queue, range, zip
from .base import DataFlow, ProxyDataFlow, DataFlowTerminated, DataFlowReentrantGuard
from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal,
enable_death_signal,
StoppableThread)
from ..utils.serialize import loads, dumps
from ..utils import logger
from ..utils.gpu import change_gpu
from ..utils.concurrency import (
StoppableThread, enable_death_signal, ensure_proc_terminate, mask_sigint, start_proc_mask_signal)
from ..utils.develop import log_deprecated
from ..utils.gpu import change_gpu
from ..utils.serialize import dumps, loads
from .base import DataFlow, DataFlowReentrantGuard, DataFlowTerminated, ProxyDataFlow
__all__ = ['PrefetchData', 'MultiProcessPrefetchData',
'PrefetchDataZMQ', 'PrefetchOnGPUs', 'MultiThreadPrefetchData']
......
# -*- coding: utf-8 -*-
# File: parallel_map.py
import numpy as np
import ctypes
import copy
import threading
import ctypes
import multiprocessing as mp
from six.moves import queue
import numpy as np
import threading
import zmq
from six.moves import queue
from .base import DataFlow, ProxyDataFlow, DataFlowReentrantGuard
from .common import RepeatedData
from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils.serialize import loads, dumps
from .parallel import (
_MultiProcessZMQDataFlow, _repeat_iter, _get_pipe_name,
_bind_guard, _zmq_catch_error)
from ..utils.serialize import dumps, loads
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow
from .common import RepeatedData
from .parallel import _bind_guard, _get_pipe_name, _MultiProcessZMQDataFlow, _repeat_iter, _zmq_catch_error
__all__ = ['ThreadedMapData', 'MultiThreadMapData',
'MultiProcessMapData', 'MultiProcessMapDataZMQ']
......
......@@ -2,10 +2,11 @@
# File: raw.py
import numpy as np
import copy
import numpy as np
import six
from six.moves import range
from .base import DataFlow, RNGDataFlow
__all__ = ['FakeData', 'DataFromQueue', 'DataFromList', 'DataFromGenerator', 'DataFromIterable']
......
......@@ -2,17 +2,17 @@
# File: remote.py
import multiprocessing as mp
import time
from collections import deque
import tqdm
import multiprocessing as mp
from six.moves import range
from collections import deque
from .base import DataFlow, DataFlowReentrantGuard
from ..utils import logger
from ..utils.utils import get_tqdm_kwargs
from ..utils.concurrency import DIE
from ..utils.serialize import dumps, loads
from ..utils.utils import get_tqdm_kwargs
from .base import DataFlow, DataFlowReentrantGuard
try:
import zmq
......
# -*- coding: utf-8 -*-
# File: serialize.py
import os
import numpy as np
import os
from collections import defaultdict
from ..utils.utils import get_tqdm
from ..utils import logger
from ..utils.compatible_serialize import dumps, loads
from ..utils.develop import create_dummy_class # noqa
from ..utils.utils import get_tqdm
from .base import DataFlow
from .format import LMDBData, HDF5Data
from .common import MapData, FixedSizeData
from .raw import DataFromList, DataFromGenerator
from .common import FixedSizeData, MapData
from .format import HDF5Data, LMDBData
from .raw import DataFromGenerator, DataFromList
__all__ = ['LMDBSerializer', 'NumpySerializer', 'TFRecordSerializer', 'HDF5Serializer']
......@@ -195,7 +195,6 @@ class HDF5Serializer():
return HDF5Data(path, data_paths, shuffle)
from ..utils.develop import create_dummy_class # noqa
try:
import lmdb
except ImportError:
......
# -*- coding: utf-8 -*-
# File: distributed.py
import tensorflow as tf
import re
import tensorflow as tf
from six.moves import range
from ..tfutils.common import get_global_step_var, get_op_tensor_name
from ..utils import logger
from ..utils.argtools import memoized
from ..tfutils.common import get_op_tensor_name, get_global_step_var
from .training import GraphBuilder, DataParallelBuilder
from .utils import (
override_to_local_variable, aggregate_grads,
OverrideCachingDevice)
from .training import DataParallelBuilder, GraphBuilder
from .utils import OverrideCachingDevice, aggregate_grads, override_to_local_variable
__all__ = ['DistributedParameterServerBuilder', 'DistributedReplicatedBuilder']
......
......@@ -5,11 +5,11 @@
from collections import namedtuple
import tensorflow as tf
from ..models.regularize import regularize_cost_from_collection
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.argtools import memoized_method
from ..utils.develop import log_deprecated
from ..tfutils.tower import get_current_tower_context
from ..models.regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
......
......@@ -3,9 +3,9 @@
import tensorflow as tf
from ..tfutils.tower import PredictTowerContext
from ..utils import logger
from ..utils.develop import deprecated
from ..tfutils.tower import PredictTowerContext
from .training import GraphBuilder
__all__ = ['SimplePredictBuilder']
......
# -*- coding: utf-8 -*-
# File: training.py
from abc import ABCMeta, abstractmethod
import tensorflow as tf
import copy
import six
import re
import pprint
from six.moves import zip, range
import re
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
import six
import tensorflow as tf
from six.moves import range, zip
from ..utils import logger
from ..tfutils.tower import TrainTowerContext
from ..tfutils.gradproc import ScaleGradient
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.gradproc import ScaleGradient
from ..tfutils.tower import TrainTowerContext
from ..utils import logger
from .utils import (
LeastLoadedDeviceSetter, override_to_local_variable,
allreduce_grads, aggregate_grads, allreduce_grads_hierarchical,
split_grad_list, merge_grad_list, GradientPacker)
GradientPacker, LeastLoadedDeviceSetter, aggregate_grads, allreduce_grads, allreduce_grads_hierarchical,
merge_grad_list, override_to_local_variable, split_grad_list)
__all__ = ['GraphBuilder',
'SyncMultiGPUParameterServerBuilder', 'DataParallelBuilder',
......
......@@ -2,16 +2,15 @@
# File: utils.py
from contextlib import contextmanager
import operator
from contextlib import contextmanager
import tensorflow as tf
from ..tfutils.varreplace import custom_getter_scope
from ..tfutils.scope_utils import under_name_scope, cached_name_scope
from ..tfutils.common import get_tf_version_tuple
from ..utils.argtools import call_only_once
from ..tfutils.scope_utils import cached_name_scope, under_name_scope
from ..tfutils.varreplace import custom_getter_scope
from ..utils import logger
from ..utils.argtools import call_only_once
__all__ = ['LeastLoadedDeviceSetter',
'OverrideCachingDevice',
......
......@@ -2,27 +2,28 @@
# File: input_source.py
import tensorflow as tf
try:
from tensorflow.python.ops.data_flow_ops import StagingArea
except ImportError:
pass
import threading
from contextlib import contextmanager
from itertools import chain
import tensorflow as tf
from six.moves import range, zip
import threading
from .input_source_base import InputSource
from ..callbacks.base import Callback, CallbackFactory
from ..callbacks.graph import RunOp
from ..dataflow import DataFlow, MapData, RepeatedData
from ..tfutils.summary import add_moving_summary
from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
from ..tfutils.dependency import dependency_of_fetches
from ..tfutils.summary import add_moving_summary
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.concurrency import ShareSessionThread
from ..callbacks.base import Callback, CallbackFactory
from ..callbacks.graph import RunOp
from .input_source_base import InputSource
try:
from tensorflow.python.ops.data_flow_ops import StagingArea
except ImportError:
pass
__all__ = ['PlaceholderInput', 'FeedInput', 'FeedfreeInput',
'QueueInput', 'BatchQueueInput',
......
# -*- coding: utf-8 -*-
# File: input_source_base.py
from abc import ABCMeta, abstractmethod
import copy
import six
from six.moves import zip
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
import six
import tensorflow as tf
from six.moves import zip
from ..utils.argtools import memoized_method, call_only_once
from ..callbacks.base import CallbackFactory
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
from ..utils.argtools import call_only_once, memoized_method
__all__ = ['InputSource', 'remap_input_source']
......
......@@ -4,13 +4,15 @@
import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.argtools import get_data_format
from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_tuple
from .common import layer_register, VariableHolder
from .common import VariableHolder, layer_register
from .tflayer import convert_to_tflayer_args
"""
Old Custom BN Implementation, Kept Here For Future Reference
"""
......
......@@ -3,8 +3,8 @@
import logging
import tensorflow as tf
import unittest
import tensorflow as tf
class TestModel(unittest.TestCase):
......
......@@ -2,17 +2,17 @@
# File: batch_norm.py
import tensorflow as tf
from tensorflow.python.training import moving_averages
import re
import six
import tensorflow as tf
from tensorflow.python.training import moving_averages
from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.argtools import get_data_format
from ..tfutils.tower import get_current_tower_context
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.collection import backup_collection, restore_collection
from .common import layer_register, VariableHolder
from .common import VariableHolder, layer_register
from .tflayer import convert_to_tflayer_args, rename_get_variable
__all__ = ['BatchNorm', 'BatchRenorm']
......
......@@ -2,7 +2,7 @@
# File: common.py
from .registry import layer_register # noqa
from .utils import VariableHolder # noqa
from .tflayer import rename_tflayer_get_variable
from .utils import VariableHolder # noqa
__all__ = ['layer_register', 'VariableHolder', 'rename_tflayer_get_variable']
......@@ -3,10 +3,11 @@
import tensorflow as tf
from .common import layer_register, VariableHolder
from ..tfutils.common import get_tf_version_tuple
from ..utils.argtools import shape2d, shape4d, get_data_format
from .tflayer import rename_get_variable, convert_to_tflayer_args
from ..utils.argtools import get_data_format, shape2d, shape4d
from .common import VariableHolder, layer_register
from .tflayer import convert_to_tflayer_args, rename_get_variable
__all__ = ['Conv2D', 'Deconv2D', 'Conv2DTranspose']
......@@ -50,7 +51,7 @@ def Conv2D(
"""
if kernel_initializer is None:
if get_tf_version_tuple() <= (1, 12):
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0),
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0)
else:
kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
if split == 1:
......@@ -158,7 +159,7 @@ def Conv2DTranspose(
"""
if kernel_initializer is None:
if get_tf_version_tuple() <= (1, 12):
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0),
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0)
else:
kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
......
......@@ -2,11 +2,11 @@
# File: fc.py
import tensorflow as tf
import numpy as np
import tensorflow as tf
from ..tfutils.common import get_tf_version_tuple
from .common import layer_register, VariableHolder
from .common import VariableHolder, layer_register
from .tflayer import convert_to_tflayer_args, rename_get_variable
__all__ = ['FullyConnected']
......@@ -48,7 +48,7 @@ def FullyConnected(
"""
if kernel_initializer is None:
if get_tf_version_tuple() <= (1, 12):
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0),
kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0)
else:
kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal')
......
......@@ -3,8 +3,9 @@
import tensorflow as tf
from .common import layer_register, VariableHolder
from ..utils.argtools import get_data_format
from .common import VariableHolder, layer_register
__all__ = ['LayerNorm', 'InstanceNorm']
......
......@@ -2,8 +2,9 @@
# File: linearwrap.py
import six
from types import ModuleType
import six
from .registry import get_registered_layer
__all__ = ['LinearWrap']
......
......@@ -4,8 +4,8 @@
import tensorflow as tf
from .common import layer_register, VariableHolder
from .batch_norm import BatchNorm
from .common import VariableHolder, layer_register
__all__ = ['Maxout', 'PReLU', 'BNReLU']
......
# -*- coding: utf-8 -*-
# File: pool.py
import tensorflow as tf
import numpy as np
import tensorflow as tf
from .shape_utils import StaticDynamicShape
from .common import layer_register
from ..utils.argtools import shape2d, get_data_format
from ..utils.argtools import get_data_format, shape2d
from ..utils.develop import log_deprecated
from ._test import TestModel
from .common import layer_register
from .shape_utils import StaticDynamicShape
from .tflayer import convert_to_tflayer_args
__all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
'BilinearUpSample']
......
......@@ -2,11 +2,11 @@
# File: registry.py
import tensorflow as tf
import copy
import re
from functools import wraps
import six
import re
import copy
import tensorflow as tf
from ..tfutils.argscope import get_arg_scope
from ..tfutils.model_utils import get_shape_str
......
......@@ -2,13 +2,13 @@
# File: regularize.py
import tensorflow as tf
import re
import tensorflow as tf
from ..utils import logger
from ..utils.argtools import graph_memoized
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.argtools import graph_memoized
from .common import layer_register
__all__ = ['regularize_cost', 'regularize_cost_from_collection',
......
......@@ -3,6 +3,7 @@
import tensorflow as tf
from .common import layer_register
__all__ = ['ConcatWith']
......
# -*- coding: utf-8 -*-
# File: tflayer.py
import tensorflow as tf
import six
import functools
import six
import tensorflow as tf
from ..utils.argtools import get_data_format
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.varreplace import custom_getter_scope
from ..utils.argtools import get_data_format
__all__ = []
......
......@@ -2,13 +2,13 @@
# File: base.py
from abc import abstractmethod, ABCMeta
import tensorflow as tf
from abc import ABCMeta, abstractmethod
import six
import tensorflow as tf
from ..input_source import PlaceholderInput
from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import PredictTowerContext
from ..input_source import PlaceholderInput
__all__ = ['PredictorBase', 'AsyncPredictorBase',
'OnlinePredictor', 'OfflinePredictor',
......
......@@ -2,16 +2,16 @@
# File: concurrency.py
import numpy as np
import multiprocessing
import numpy as np
import six
from six.moves import queue, range
import tensorflow as tf
from six.moves import queue, range
from ..utils import logger
from ..utils.concurrency import DIE, StoppableThread, ShareSessionThread
from ..tfutils.model_utils import describe_trainable_vars
from .base import OnlinePredictor, OfflinePredictor, AsyncPredictorBase
from ..utils import logger
from ..utils.concurrency import DIE, ShareSessionThread, StoppableThread
from .base import AsyncPredictorBase, OfflinePredictor, OnlinePredictor
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
'MultiThreadAsyncPredictor']
......
......@@ -2,13 +2,13 @@
# File: config.py
import tensorflow as tf
import six
import tensorflow as tf
from ..graph_builder import ModelDescBase
from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import JustCurrentSession, SessionInit
from ..tfutils.tower import TowerFuncWrapper
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..utils import logger
__all__ = ['PredictConfig']
......
......@@ -2,22 +2,21 @@
# File: dataset.py
from six.moves import range, zip
from abc import ABCMeta, abstractmethod
import multiprocessing
import os
from abc import ABCMeta, abstractmethod
import six
from six.moves import range, zip
from ..dataflow import DataFlow
from ..dataflow.remote import dump_dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from ..utils import logger
from ..utils.utils import get_tqdm
from ..utils.concurrency import DIE, OrderedResultGatherProc, ensure_proc_terminate
from ..utils.gpu import change_gpu, get_num_gpu
from ..utils.utils import get_tqdm
from .base import OfflinePredictor
from .concurrency import MultiProcessQueuePredictWorker
from .config import PredictConfig
from .base import OfflinePredictor
__all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
'MultiProcessDatasetPredictor']
......
#!/usr/bin/env python
from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
from tensorflow.python.training.monitored_session import _HookedSession as HookedSession
from .base import PredictorBase
from ..tfutils.tower import PredictTowerContext
from ..callbacks import Callbacks
from ..tfutils.tower import PredictTowerContext
from .base import PredictorBase
__all__ = ['FeedfreePredictor']
......
......@@ -3,10 +3,11 @@
import tensorflow as tf
from ..utils import logger
from ..graph_builder.model_desc import InputDesc
from ..input_source import PlaceholderInput
from ..tfutils.tower import PredictTowerContext
from ..utils import logger
from .base import OnlinePredictor
__all__ = ['MultiTowerOfflinePredictor',
......
# -*- coding: utf-8 -*-
# File: argscope.py
from contextlib import contextmanager
from collections import defaultdict
import copy
from collections import defaultdict
from contextlib import contextmanager
from functools import wraps
from inspect import isfunction, getmembers
from inspect import getmembers, isfunction
from .tower import get_current_tower_context
from ..utils import logger
from .tower import get_current_tower_context
__all__ = ['argscope', 'get_arg_scope', 'enable_argscope_for_module']
......
......@@ -2,10 +2,10 @@
# File: collection.py
import tensorflow as tf
from contextlib import contextmanager
from copy import copy
import six
from contextlib import contextmanager
import tensorflow as tf
from ..utils import logger
from ..utils.argtools import memoized
......
......@@ -4,6 +4,7 @@
import tensorflow as tf
from six.moves import map
from ..utils.argtools import graph_memoized
from ..utils.develop import deprecated
......
import tensorflow as tf
from tensorflow.contrib.graph_editor import get_backward_walk_ops
from ..utils.argtools import graph_memoized
"""
......
......@@ -12,10 +12,10 @@ from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
from tensorflow.python.tools import optimize_for_inference_lib
from ..utils import logger
from ..input_source import PlaceholderInput
from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import PredictTowerContext
from ..input_source import PlaceholderInput
from ..utils import logger
__all__ = ['ModelExporter']
......
......@@ -2,14 +2,15 @@
# File: gradproc.py
import tensorflow as tf
from abc import ABCMeta, abstractmethod
import inspect
import re
from abc import ABCMeta, abstractmethod
import six
import inspect
import tensorflow as tf
from ..utils import logger
from .symbolic_functions import rms, print_stat
from .summary import add_moving_summary
from .symbolic_functions import print_stat, rms
__all__ = ['GradientProcessor',
'FilterNoneGrad', 'GlobalNormClip', 'MapGradient', 'SummaryGradient',
......
......@@ -3,8 +3,8 @@
# Author: tensorpack contributors
import tensorflow as tf
from termcolor import colored
from tabulate import tabulate
from termcolor import colored
from ..utils import logger
......
......@@ -2,11 +2,11 @@
# File: optimizer.py
import tensorflow as tf
from contextlib import contextmanager
import tensorflow as tf
from ..utils.develop import HIDE_DOC
from ..tfutils.common import get_tf_version_tuple
from ..utils.develop import HIDE_DOC
from .gradproc import FilterNoneGrad, GradientProcessor
__all__ = ['apply_grad_processors', 'ProxyOptimizer',
......
......@@ -2,9 +2,9 @@
# File: scope_utils.py
import tensorflow as tf
import functools
from contextlib import contextmanager
import tensorflow as tf
from ..utils.argtools import graph_memoized
from .common import get_tf_version_tuple
......
......@@ -3,8 +3,9 @@
import tensorflow as tf
from .common import get_default_sess_config
from ..utils import logger
from .common import get_default_sess_config
__all__ = ['NewSessionCreator', 'ReuseSessionCreator', 'SessionCreatorAdapter']
......
......@@ -3,13 +3,12 @@
import numpy as np
import tensorflow as tf
import six
import tensorflow as tf
from ..utils import logger
from .common import get_op_tensor_name
from .varmanip import (SessionUpdate, get_savename_from_varname,
is_training_name, get_checkpoint_path)
from .varmanip import SessionUpdate, get_checkpoint_path, get_savename_from_varname, is_training_name
__all__ = ['SessionInit', 'ChainInit',
'SaverRestore', 'SaverRestoreRelaxed', 'DictRestore',
......
......@@ -2,20 +2,19 @@
# File: summary.py
import re
from contextlib import contextmanager
import six
import tensorflow as tf
import re
from six.moves import range
from contextlib import contextmanager
from tensorflow.python.training import moving_averages
from ..utils import logger
from ..utils.argtools import graph_memoized
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .tower import get_current_tower_context
from .symbolic_functions import rms
from .scope_utils import cached_name_scope
from .symbolic_functions import rms
from .tower import get_current_tower_context
__all__ = ['add_tensor_summary', 'add_param_summary',
'add_activation_summary', 'add_moving_summary',
......
......@@ -2,15 +2,15 @@
# File: tower.py
import tensorflow as tf
from abc import ABCMeta, abstractmethod, abstractproperty
import six
import tensorflow as tf
from six.moves import zip
from abc import abstractproperty, abstractmethod, ABCMeta
from ..utils import logger
from ..utils.argtools import call_only_once
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from ..utils.develop import HIDE_DOC
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .collection import CollectionGuard
from .common import get_op_or_tensor_by_name, get_op_tensor_name
......
# -*- coding: utf-8 -*-
# File: varmanip.py
import six
import numpy as np
import os
import pprint
import six
import tensorflow as tf
import numpy as np
from ..utils import logger
from .common import get_op_tensor_name
......
......@@ -2,8 +2,8 @@
# File: varreplace.py
# Credit: Qinyao He
import tensorflow as tf
from contextlib import contextmanager
import tensorflow as tf
from .common import get_tf_version_tuple
......
# -*- coding: utf-8 -*-
# File: base.py
import tensorflow as tf
import weakref
import copy
import time
from six.moves import range
import weakref
import six
import copy
import tensorflow as tf
from six.moves import range
from ..callbacks import (
Callback, Callbacks, Monitors, TrainingMonitor)
from ..utils import logger
from ..utils.utils import humanize_time_delta
from ..utils.argtools import call_only_once
from ..callbacks import Callback, Callbacks, Monitors, TrainingMonitor
from ..callbacks.steps import MaintainStepCounter
from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..tfutils.sesscreate import ReuseSessionCreator, NewSessionCreator
from ..callbacks.steps import MaintainStepCounter
from .config import TrainConfig, DEFAULT_MONITORS, DEFAULT_CALLBACKS
from ..tfutils.sesscreate import NewSessionCreator, ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession, SessionInit
from ..utils import logger
from ..utils.argtools import call_only_once
from ..utils.utils import humanize_time_delta
from .config import DEFAULT_CALLBACKS, DEFAULT_MONITORS, TrainConfig
__all__ = ['StopTraining', 'Trainer']
......
......@@ -5,15 +5,13 @@ import os
import tensorflow as tf
from ..callbacks import (
MovingAverageSummary,
ProgressBar, MergeAllSummaries,
TFEventWriter, JSONWriter, ScalarPrinter, RunUpdateOps)
JSONWriter, MergeAllSummaries, MovingAverageSummary, ProgressBar, RunUpdateOps, ScalarPrinter, TFEventWriter)
from ..dataflow.base import DataFlow
from ..graph_builder.model_desc import ModelDescBase
from ..utils import logger
from ..tfutils.sessinit import SessionInit, SaverRestore
from ..tfutils.sesscreate import NewSessionCreator
from ..input_source import InputSource
from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.sessinit import SaverRestore, SessionInit
from ..utils import logger
__all__ = ['TrainConfig', 'AutoResumeTrainConfig', 'DEFAULT_CALLBACKS', 'DEFAULT_MONITORS']
......
......@@ -3,11 +3,8 @@
import tensorflow as tf
from ..input_source import (
InputSource, FeedInput, FeedfreeInput,
QueueInput, StagingInput, DummyConstantInput)
from ..input_source import DummyConstantInput, FeedfreeInput, FeedInput, InputSource, QueueInput, StagingInput
from ..utils import logger
from .config import TrainConfig
from .tower import SingleCostTrainer
from .trainers import SimpleTrainer
......
# -*- coding: utf-8 -*-
# File: tower.py
import tensorflow as tf
from abc import ABCMeta, abstractmethod
import six
from abc import abstractmethod, ABCMeta
import tensorflow as tf
from ..utils.argtools import call_only_once, memoized
from ..utils.develop import HIDE_DOC
from ..utils import logger
from ..input_source import PlaceholderInput
from ..predict.base import OnlinePredictor
from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context, PredictTowerContext
from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import PredictTowerContext, TowerFuncWrapper, get_current_tower_context
from ..utils import logger
from ..utils.argtools import call_only_once, memoized
from ..utils.develop import HIDE_DOC
from .base import Trainer
__all__ = ['SingleCostTrainer', 'TowerTrainer']
......
# -*- coding: utf-8 -*-
# File: trainers.py
import sys
import multiprocessing as mp
import os
import sys
import tensorflow as tf
import multiprocessing as mp
from ..callbacks import RunOp, CallbackFactory
from ..callbacks import CallbackFactory, RunOp
from ..graph_builder.distributed import DistributedParameterServerBuilder, DistributedReplicatedBuilder
from ..graph_builder.training import (
AsyncMultiGPUBuilder, SyncMultiGPUParameterServerBuilder, SyncMultiGPUReplicatedBuilder)
from ..graph_builder.utils import override_to_local_variable
from ..input_source import FeedfreeInput, QueueInput
from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.tower import TrainTowerContext
from ..utils import logger
from ..utils.argtools import map_arg
from ..utils.develop import HIDE_DOC, log_deprecated
from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..tfutils.tower import TrainTowerContext
from ..input_source import QueueInput, FeedfreeInput
from ..graph_builder.training import (
SyncMultiGPUParameterServerBuilder,
SyncMultiGPUReplicatedBuilder,
AsyncMultiGPUBuilder)
from ..graph_builder.distributed import DistributedReplicatedBuilder, DistributedParameterServerBuilder
from ..graph_builder.utils import override_to_local_variable
from .tower import SingleCostTrainer
__all__ = ['NoOpTrainer', 'SimpleTrainer',
......
......@@ -2,6 +2,4 @@
# File: utility.py
# for backwards-compatibility
from ..graph_builder.utils import ( # noqa
OverrideToLocalVariable,
override_to_local_variable, LeastLoadedDeviceSetter)
from ..graph_builder.utils import LeastLoadedDeviceSetter, OverrideToLocalVariable, override_to_local_variable # noqa
......@@ -4,7 +4,9 @@
import inspect
import six
from . import logger
if six.PY2:
import functools32 as functools
else:
......
import os
from .serialize import loads_msgpack, loads_pyarrow, dumps_msgpack, dumps_pyarrow
from .serialize import dumps_msgpack, dumps_pyarrow, loads_msgpack, loads_pyarrow
"""
Serialization that has compatibility guarantee (therefore is safe to store to disk).
......
......@@ -3,14 +3,14 @@
# Some code taken from zxytim
import threading
import platform
import multiprocessing
import atexit
import bisect
from contextlib import contextmanager
import multiprocessing
import platform
import signal
import threading
import weakref
from contextlib import contextmanager
import six
from six.moves import queue
......
......@@ -6,11 +6,11 @@
""" Utilities for developers only.
These are not visible to users (not automatically imported). And should not
appeared in docs."""
import os
import functools
from datetime import datetime
import importlib
import os
import types
from datetime import datetime
import six
from . import logger
......
......@@ -2,10 +2,11 @@
# File: fs.py
import os
from six.moves import urllib
import errno
import os
import tqdm
from six.moves import urllib
from . import logger
from .utils import execute_only_once
......
......@@ -3,10 +3,11 @@
import os
from .utils import change_env
from . import logger
from .nvml import NVMLContext
from .concurrency import subproc_call
from .nvml import NVMLContext
from .utils import change_env
__all__ = ['change_gpu', 'get_nr_gpu', 'get_num_gpu']
......
......@@ -2,14 +2,14 @@
# File: loadcaffe.py
import sys
import numpy as np
import os
import sys
from .utils import change_env
from .fs import download, get_dataset_path
from .concurrency import subproc_call
from . import logger
from .concurrency import subproc_call
from .fs import download, get_dataset_path
from .utils import change_env
__all__ = ['load_caffe', 'get_caffe_pb']
......
......@@ -16,12 +16,12 @@ The logger module itself has the common logging functions of Python's
import logging
import os
import shutil
import os.path
from termcolor import colored
import shutil
import sys
from datetime import datetime
from six.moves import input
import sys
from termcolor import colored
__all__ = ['set_logger_dir', 'auto_set_dir', 'get_logger_dir']
......
# -*- coding: utf-8 -*-
# File: nvml.py
from ctypes import (byref, c_uint, c_ulonglong,
CDLL, POINTER, Structure)
import threading
from ctypes import CDLL, POINTER, Structure, byref, c_uint, c_ulonglong
__all__ = ['NVMLContext']
......
......@@ -3,6 +3,7 @@
import numpy as np
from .develop import log_deprecated
__all__ = ['IntBox', 'FloatBox']
......
# -*- coding: utf-8 -*-
# File: serialize.py
import sys
import os
from .develop import create_dummy_func
import sys
from . import logger
from .develop import create_dummy_func
__all__ = ['loads', 'dumps']
......
......@@ -2,14 +2,14 @@
# File: timer.py
from contextlib import contextmanager
from collections import defaultdict
import six
import atexit
from collections import defaultdict
from contextlib import contextmanager
from time import time as timer
import six
from .stats import StatCounter
from . import logger
from .stats import StatCounter
if six.PY3:
from time import perf_counter as timer # noqa
......
......@@ -2,17 +2,16 @@
# File: utils.py
import inspect
import numpy as np
import os
import sys
from contextlib import contextmanager
import inspect
from datetime import datetime, timedelta
from tqdm import tqdm
import numpy as np
from . import logger
__all__ = ['change_env',
'get_rng',
'fix_rng_seed',
......
......@@ -5,8 +5,10 @@
import numpy as np
import os
import sys
from .fs import mkdir_p
from ..utils.develop import create_dummy_func # noqa
from .argtools import shape2d
from .fs import mkdir_p
from .palette import PALETTE_RGB
try:
......@@ -411,7 +413,6 @@ def draw_boxes(im, boxes, labels=None, color=None):
return im
from ..utils.develop import create_dummy_func # noqa
try:
import matplotlib.pyplot as plt
except (ImportError, RuntimeError):
......
......@@ -8,16 +8,11 @@ export TF_CPP_MIN_LOG_LEVEL=2
# test import (#471)
python -c 'from tensorpack.dataflow.imgaug import transform'
# python -m unittest discover -v
python -m unittest discover -v
# python -m tensorpack.models._test
# segfault for no reason (https://travis-ci.org/ppwwyyxx/tensorpack/jobs/217702985)
# python ../tensorpack/user_ops/test-recv-op.py
python test_char_rnn.py
python test_infogan.py
python test_mnist.py
python test_mnist_similarity.py
TENSORPACK_SERIALIZE=pyarrow python test_serializer.py
TENSORPACK_SERIALIZE=msgpack python test_serializer.py
from case_script import TestPythonScript
import os
from case_script import TestPythonScript
def random_content():
return ('Lorem ipsum dolor sit amet\n'
......
......@@ -10,6 +10,7 @@ class InfoGANTest(TestPythonScript):
return '../examples/GAN/InfoGAN-mnist.py'
def test(self):
return True # https://github.com/tensorflow/tensorflow/issues/24517
if get_tf_version_tuple() < (1, 4):
return True # requires leaky_relu
self.assertSurvive(self.script, args=None)
#! /usr/bin/env python
# -*- coding: utf-8 -*-
from tensorpack.dataflow.base import DataFlow
from tensorpack.dataflow import LMDBSerializer, TFRecordSerializer, NumpySerializer, HDF5Serializer
import unittest
import os
import numpy as np
import os
import unittest
from tensorpack.dataflow import HDF5Serializer, LMDBSerializer, NumpySerializer, TFRecordSerializer
from tensorpack.dataflow.base import DataFlow
def delete_file_if_exists(fn):
......
......@@ -12,3 +12,13 @@ exclude = .git,
snippet,
examples-old,
_test.py,
[isort]
line_length=100
skip=docs/conf.py
multi_line_output=4
known_tensorpack=tensorpack
known_standard_library=numpy
known_third_party=bob,gym,matplotlib
no_lines_before=STDLIB,THIRDPARTY
sections=FUTURE,STDLIB,THIRDPARTY,tensorpack,FIRSTPARTY,LOCALFOLDER
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment