Commit 8d1ad775 authored by Yuxin Wu's avatar Yuxin Wu

py3 compatibilty & remove shebang

parent 4916f703
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: example_cifar10.py # File: example_cifar10.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
...@@ -108,6 +108,7 @@ def get_config(): ...@@ -108,6 +108,7 @@ def get_config():
dataset_train = BatchData(dataset_train, 128) dataset_train = BatchData(dataset_train, 128)
dataset_train = PrefetchData(dataset_train, 3, 2) dataset_train = PrefetchData(dataset_train, 3, 2)
step_per_epoch = dataset_train.size() / 2 step_per_epoch = dataset_train.size() / 2
step_per_epoch = 10
augmentors = [ augmentors = [
imgaug.CenterCrop((30, 30)), imgaug.CenterCrop((30, 30)),
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: example_mnist.py # File: example_mnist.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: __init__.py # File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: base.py # File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: common.py # File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: dump.py # File: dump.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: group.py # File: group.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: summary.py # File: summary.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: validation_callback.py # File: validation_callback.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
...@@ -7,6 +6,7 @@ import tensorflow as tf ...@@ -7,6 +6,7 @@ import tensorflow as tf
import itertools import itertools
from tqdm import tqdm from tqdm import tqdm
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from six.moves import zip
from ..utils import * from ..utils import *
from ..utils.stat import * from ..utils.stat import *
...@@ -48,7 +48,7 @@ class ValidationCallback(PeriodicCallback): ...@@ -48,7 +48,7 @@ class ValidationCallback(PeriodicCallback):
sess = tf.get_default_session() sess = tf.get_default_session()
with tqdm(total=self.ds.size(), ascii=True) as pbar: with tqdm(total=self.ds.size(), ascii=True) as pbar:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
feed = dict(itertools.izip(self.input_vars, dp)) feed = dict(zip(self.input_vars, dp))
batch_size = dp[0].shape[0] # assume batched input batch_size = dp[0].shape[0] # assume batched input
outputs = sess.run(output_vars, feed_dict=feed) outputs = sess.run(output_vars, feed_dict=feed)
yield (dp, outputs) yield (dp, outputs)
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: __init__.py # File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: common.py # File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np import numpy as np
import copy import copy
from six.moves import range
from .base import DataFlow, ProxyDataFlow from .base import DataFlow, ProxyDataFlow
from ..utils import * from ..utils import *
...@@ -47,9 +47,9 @@ class BatchData(ProxyDataFlow): ...@@ -47,9 +47,9 @@ class BatchData(ProxyDataFlow):
def aggregate_batch(data_holder): def aggregate_batch(data_holder):
size = len(data_holder[0]) size = len(data_holder[0])
result = [] result = []
for k in xrange(size): for k in range(size):
dt = data_holder[0][k] dt = data_holder[0][k]
if type(dt) in [int, bool, long]: if type(dt) in [int, bool]:
tp = 'int32' tp = 'int32'
elif type(dt) == float: elif type(dt) == float:
tp = 'float32' tp = 'float32'
...@@ -104,7 +104,7 @@ class RepeatedData(ProxyDataFlow): ...@@ -104,7 +104,7 @@ class RepeatedData(ProxyDataFlow):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
yield dp yield dp
else: else:
for _ in xrange(self.nr): for _ in range(self.nr):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
yield dp yield dp
...@@ -125,7 +125,7 @@ class FakeData(DataFlow): ...@@ -125,7 +125,7 @@ class FakeData(DataFlow):
self.rng = get_rng(self) self.rng = get_rng(self)
def get_data(self): def get_data(self):
for _ in xrange(self._size): for _ in range(self._size):
yield [self.rng.random_sample(k) for k in self.shapes] yield [self.rng.random_sample(k) for k in self.shapes]
class MapData(ProxyDataFlow): class MapData(ProxyDataFlow):
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: __init__.py # File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import os, sys import os, sys
import pickle import pickle
import numpy as np import numpy as np
from six.moves import urllib from six.moves import urllib, range
import copy import copy
import tarfile import tarfile
import logging import logging
...@@ -43,11 +43,11 @@ def read_cifar10(filenames): ...@@ -43,11 +43,11 @@ def read_cifar10(filenames):
ret = [] ret = []
for fname in filenames: for fname in filenames:
fo = open(fname, 'rb') fo = open(fname, 'rb')
dic = pickle.load(fo) dic = pickle.load(fo, encoding='bytes')
data = dic['data'] data = dic[b'data']
label = dic['labels'] label = dic[b'labels']
fo.close() fo.close()
for k in xrange(10000): for k in range(10000):
img = data[k].reshape(3, 32, 32) img = data[k].reshape(3, 32, 32)
img = np.transpose(img, [1, 2, 0]) img = np.transpose(img, [1, 2, 0])
ret.append([img, label[k]]) ret.append([img, label[k]])
...@@ -55,7 +55,7 @@ def read_cifar10(filenames): ...@@ -55,7 +55,7 @@ def read_cifar10(filenames):
def get_filenames(dir): def get_filenames(dir):
filenames = [os.path.join( filenames = [os.path.join(
dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in xrange(1, 6)] dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in range(1, 6)]
filenames.append(os.path.join( filenames.append(os.path.join(
dir, 'cifar-10-batches-py', 'test_batch')) dir, 'cifar-10-batches-py', 'test_batch'))
return filenames return filenames
...@@ -115,7 +115,7 @@ if __name__ == '__main__': ...@@ -115,7 +115,7 @@ if __name__ == '__main__':
ds = Cifar10('train') ds = Cifar10('train')
from tensorpack.dataflow.dftools import dump_dataset_images from tensorpack.dataflow.dftools import dump_dataset_images
mean = ds.get_per_channel_mean() mean = ds.get_per_channel_mean()
print mean print(mean)
dump_dataset_images(ds, '/tmp/cifar', 100) dump_dataset_images(ds, '/tmp/cifar', 100)
#for (img, label) in ds.get_data(): #for (img, label) in ds.get_data():
......
...@@ -7,7 +7,7 @@ import os ...@@ -7,7 +7,7 @@ import os
import gzip import gzip
import numpy import numpy
from six.moves import urllib from six.moves import urllib, range
from ...utils import logger from ...utils import logger
from ..base import DataFlow from ..base import DataFlow
...@@ -136,7 +136,7 @@ class Mnist(DataFlow): ...@@ -136,7 +136,7 @@ class Mnist(DataFlow):
def get_data(self): def get_data(self):
ds = self.train if self.train_or_test == 'train' else self.test ds = self.train if self.train_or_test == 'train' else self.test
for k in xrange(ds.num_examples): for k in range(ds.num_examples):
img = ds.images[k].reshape((28, 28)) img = ds.images[k].reshape((28, 28))
label = ds.labels[k] label = ds.labels[k]
yield [img, label] yield [img, label]
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: dftools.py # File: dftools.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: image.py # File: image.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: __init__.py # File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: base.py # File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
...@@ -25,7 +24,7 @@ class ImageAugmentor(object): ...@@ -25,7 +24,7 @@ class ImageAugmentor(object):
def _init(self, params=None): def _init(self, params=None):
self.reset_state() self.reset_state()
if params: if params:
for k, v in params.iteritems(): for k, v in params.items():
if k != 'self': if k != 'self':
setattr(self, k, v) setattr(self, k, v)
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: crop.py # File: crop.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: deform.py # File: deform.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: imgproc.py # File: imgproc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: noname.py # File: noname.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: prefetch.py # File: prefetch.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: __init__.py # File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: _common.py # File: _common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from functools import wraps from functools import wraps
import six
from ..utils.modelutils import * from ..utils.modelutils import *
from ..utils.summary import * from ..utils.summary import *
...@@ -30,7 +30,7 @@ def layer_register(summary_activation=False): ...@@ -30,7 +30,7 @@ def layer_register(summary_activation=False):
@wraps(func) @wraps(func)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
name = args[0] name = args[0]
assert isinstance(name, basestring), \ assert isinstance(name, six.string_types), \
'name must be either the first argument. Args: {}'.format(str(args)) 'name must be either the first argument. Args: {}'.format(str(args))
args = args[1:] args = args[1:]
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: regularize.py # File: regularize.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: predict.py # File: predict.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
...@@ -9,6 +8,7 @@ import argparse ...@@ -9,6 +8,7 @@ import argparse
from collections import namedtuple from collections import namedtuple
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from six.moves import zip
from .utils import * from .utils import *
from .utils.modelutils import describe_model from .utils.modelutils import describe_model
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: __init__.py # File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: base.py # File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
from abc import ABCMeta from abc import ABCMeta
from six.moves import range
import tqdm import tqdm
import re import re
...@@ -76,7 +76,7 @@ class Trainer(object): ...@@ -76,7 +76,7 @@ class Trainer(object):
self.global_step = get_global_step() self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step)) logger.info("Start training with global_step={}".format(self.global_step))
for epoch in xrange(1, self.config.max_epoch): for epoch in range(1, self.config.max_epoch):
with timed_operation( with timed_operation(
'Epoch {}, global_step={}'.format( 'Epoch {}, global_step={}'.format(
epoch, self.global_step + self.config.step_per_epoch)): epoch, self.global_step + self.config.step_per_epoch)):
......
#!/usr/bin/env python2
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: config.py # File: config.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: trainer.py # File: trainer.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
...@@ -6,6 +5,7 @@ ...@@ -6,6 +5,7 @@
import tensorflow as tf import tensorflow as tf
import copy import copy
import re import re
from six.moves import zip
from .base import Trainer from .base import Trainer
from ..dataflow.common import RepeatedData from ..dataflow.common import RepeatedData
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: __init__.py # File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: concurrency.py # File: concurrency.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
...@@ -6,6 +5,7 @@ ...@@ -6,6 +5,7 @@
import threading import threading
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
from six.moves import zip
from .naming import * from .naming import *
from . import logger from . import logger
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: logger.py # File: logger.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
...@@ -8,9 +7,8 @@ import os, shutil ...@@ -8,9 +7,8 @@ import os, shutil
import os.path import os.path
from termcolor import colored from termcolor import colored
from datetime import datetime from datetime import datetime
from six.moves import input
import sys import sys
if not sys.version_info >= (3, 0):
input = raw_input # for compatibility
from .utils import mkdir_p from .utils import mkdir_p
...@@ -63,7 +61,10 @@ def set_logger_dir(dirname): ...@@ -63,7 +61,10 @@ def set_logger_dir(dirname):
Directory {} exists! Please either backup/delete it, or use a new directory \ Directory {} exists! Please either backup/delete it, or use a new directory \
unless you're resuming from a previous task.""".format(dirname)) unless you're resuming from a previous task.""".format(dirname))
logger.info("Select Action: k (keep) / b (backup) / d (delete) / n (new):") logger.info("Select Action: k (keep) / b (backup) / d (delete) / n (new):")
act = input().lower() while True:
act = input().lower()
if act:
break
timestr = datetime.now().strftime('%m%d-%H%M%S') timestr = datetime.now().strftime('%m%d-%H%M%S')
if act == 'b': if act == 'b':
backup_name = dirname + timestr backup_name = dirname + timestr
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: modelutils.py # File: modelutils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: naming.py # File: naming.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: sessinit.py # File: sessinit.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
...@@ -7,6 +6,7 @@ import os ...@@ -7,6 +6,7 @@ import os
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import six
from . import logger from . import logger
class SessionInit(object): class SessionInit(object):
...@@ -49,7 +49,7 @@ class ParamRestore(SessionInit): ...@@ -49,7 +49,7 @@ class ParamRestore(SessionInit):
sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_all_variables())
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var_dict = dict([v.name, v] for v in variables) var_dict = dict([v.name, v] for v in variables)
for name, value in self.prms.iteritems(): for name, value in six.iteritems(self.prms):
try: try:
var = var_dict[name] var = var_dict[name]
except (ValueError, KeyError): except (ValueError, KeyError):
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: stat.py # File: stat.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: summary.py # File: summary.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import six
import tensorflow as tf import tensorflow as tf
from . import logger, get_global_step_var from . import logger, get_global_step_var
...@@ -14,7 +14,7 @@ def create_summary(name, v): ...@@ -14,7 +14,7 @@ def create_summary(name, v):
Args: v: a value Args: v: a value
""" """
assert isinstance(name, basestring), type(name) assert isinstance(name, six.string_types), type(name)
v = float(v) v = float(v)
s = tf.Summary() s = tf.Summary()
s.value.add(tag=name, simple_value=v) s.value.add(tag=name, simple_value=v)
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: symbolic_functions.py # File: symbolic_functions.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: utils.py # File: utils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment