Commit 8d1ad775 authored by Yuxin Wu's avatar Yuxin Wu

py3 compatibilty & remove shebang

parent 4916f703
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: example_cifar10.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......@@ -108,6 +108,7 @@ def get_config():
dataset_train = BatchData(dataset_train, 128)
dataset_train = PrefetchData(dataset_train, 3, 2)
step_per_epoch = dataset_train.size() / 2
step_per_epoch = 10
augmentors = [
imgaug.CenterCrop((30, 30)),
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: example_mnist.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: dump.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: group.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: summary.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: validation_callback.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......@@ -7,6 +6,7 @@ import tensorflow as tf
import itertools
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
from six.moves import zip
from ..utils import *
from ..utils.stat import *
......@@ -48,7 +48,7 @@ class ValidationCallback(PeriodicCallback):
sess = tf.get_default_session()
with tqdm(total=self.ds.size(), ascii=True) as pbar:
for dp in self.ds.get_data():
feed = dict(itertools.izip(self.input_vars, dp))
feed = dict(zip(self.input_vars, dp))
batch_size = dp[0].shape[0] # assume batched input
outputs = sess.run(output_vars, feed_dict=feed)
yield (dp, outputs)
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import copy
from six.moves import range
from .base import DataFlow, ProxyDataFlow
from ..utils import *
......@@ -47,9 +47,9 @@ class BatchData(ProxyDataFlow):
def aggregate_batch(data_holder):
size = len(data_holder[0])
result = []
for k in xrange(size):
for k in range(size):
dt = data_holder[0][k]
if type(dt) in [int, bool, long]:
if type(dt) in [int, bool]:
tp = 'int32'
elif type(dt) == float:
tp = 'float32'
......@@ -104,7 +104,7 @@ class RepeatedData(ProxyDataFlow):
for dp in self.ds.get_data():
yield dp
else:
for _ in xrange(self.nr):
for _ in range(self.nr):
for dp in self.ds.get_data():
yield dp
......@@ -125,7 +125,7 @@ class FakeData(DataFlow):
self.rng = get_rng(self)
def get_data(self):
for _ in xrange(self._size):
for _ in range(self._size):
yield [self.rng.random_sample(k) for k in self.shapes]
class MapData(ProxyDataFlow):
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
......@@ -5,7 +5,7 @@
import os, sys
import pickle
import numpy as np
from six.moves import urllib
from six.moves import urllib, range
import copy
import tarfile
import logging
......@@ -43,11 +43,11 @@ def read_cifar10(filenames):
ret = []
for fname in filenames:
fo = open(fname, 'rb')
dic = pickle.load(fo)
data = dic['data']
label = dic['labels']
dic = pickle.load(fo, encoding='bytes')
data = dic[b'data']
label = dic[b'labels']
fo.close()
for k in xrange(10000):
for k in range(10000):
img = data[k].reshape(3, 32, 32)
img = np.transpose(img, [1, 2, 0])
ret.append([img, label[k]])
......@@ -55,7 +55,7 @@ def read_cifar10(filenames):
def get_filenames(dir):
filenames = [os.path.join(
dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in xrange(1, 6)]
dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in range(1, 6)]
filenames.append(os.path.join(
dir, 'cifar-10-batches-py', 'test_batch'))
return filenames
......@@ -115,7 +115,7 @@ if __name__ == '__main__':
ds = Cifar10('train')
from tensorpack.dataflow.dftools import dump_dataset_images
mean = ds.get_per_channel_mean()
print mean
print(mean)
dump_dataset_images(ds, '/tmp/cifar', 100)
#for (img, label) in ds.get_data():
......
......@@ -7,7 +7,7 @@ import os
import gzip
import numpy
from six.moves import urllib
from six.moves import urllib, range
from ...utils import logger
from ..base import DataFlow
......@@ -136,7 +136,7 @@ class Mnist(DataFlow):
def get_data(self):
ds = self.train if self.train_or_test == 'train' else self.test
for k in xrange(ds.num_examples):
for k in range(ds.num_examples):
img = ds.images[k].reshape((28, 28))
label = ds.labels[k]
yield [img, label]
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: dftools.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: image.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......@@ -25,7 +24,7 @@ class ImageAugmentor(object):
def _init(self, params=None):
self.reset_state()
if params:
for k, v in params.iteritems():
for k, v in params.items():
if k != 'self':
setattr(self, k, v)
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: crop.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: deform.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: imgproc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: noname.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: prefetch.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: _common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from functools import wraps
import six
from ..utils.modelutils import *
from ..utils.summary import *
......@@ -30,7 +30,7 @@ def layer_register(summary_activation=False):
@wraps(func)
def __call__(self, *args, **kwargs):
name = args[0]
assert isinstance(name, basestring), \
assert isinstance(name, six.string_types), \
'name must be either the first argument. Args: {}'.format(str(args))
args = args[1:]
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: regularize.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: predict.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......@@ -9,6 +8,7 @@ import argparse
from collections import namedtuple
import numpy as np
from tqdm import tqdm
from six.moves import zip
from .utils import *
from .utils.modelutils import describe_model
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from abc import ABCMeta
from six.moves import range
import tqdm
import re
......@@ -76,7 +76,7 @@ class Trainer(object):
self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step))
for epoch in xrange(1, self.config.max_epoch):
for epoch in range(1, self.config.max_epoch):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, self.global_step + self.config.step_per_epoch)):
......
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: config.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: trainer.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......@@ -6,6 +5,7 @@
import tensorflow as tf
import copy
import re
from six.moves import zip
from .base import Trainer
from ..dataflow.common import RepeatedData
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......@@ -6,6 +5,7 @@
import threading
from contextlib import contextmanager
import tensorflow as tf
from six.moves import zip
from .naming import *
from . import logger
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: logger.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......@@ -8,9 +7,8 @@ import os, shutil
import os.path
from termcolor import colored
from datetime import datetime
from six.moves import input
import sys
if not sys.version_info >= (3, 0):
input = raw_input # for compatibility
from .utils import mkdir_p
......@@ -63,7 +61,10 @@ def set_logger_dir(dirname):
Directory {} exists! Please either backup/delete it, or use a new directory \
unless you're resuming from a previous task.""".format(dirname))
logger.info("Select Action: k (keep) / b (backup) / d (delete) / n (new):")
while True:
act = input().lower()
if act:
break
timestr = datetime.now().strftime('%m%d-%H%M%S')
if act == 'b':
backup_name = dirname + timestr
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: modelutils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: naming.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: sessinit.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......@@ -7,6 +6,7 @@ import os
from abc import abstractmethod, ABCMeta
import numpy as np
import tensorflow as tf
import six
from . import logger
class SessionInit(object):
......@@ -49,7 +49,7 @@ class ParamRestore(SessionInit):
sess.run(tf.initialize_all_variables())
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var_dict = dict([v.name, v] for v in variables)
for name, value in self.prms.iteritems():
for name, value in six.iteritems(self.prms):
try:
var = var_dict[name]
except (ValueError, KeyError):
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: stat.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: summary.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import six
import tensorflow as tf
from . import logger, get_global_step_var
......@@ -14,7 +14,7 @@ def create_summary(name, v):
Args: v: a value
"""
assert isinstance(name, basestring), type(name)
assert isinstance(name, six.string_types), type(name)
v = float(v)
s = tf.Summary()
s.value.add(tag=name, simple_value=v)
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: symbolic_functions.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: utils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment