Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8d1ad775
Commit
8d1ad775
authored
Feb 27, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
py3 compatibilty & remove shebang
parent
4916f703
Changes
41
Hide whitespace changes
Inline
Side-by-side
Showing
41 changed files
with
36 additions
and
61 deletions
+36
-61
example_cifar10.py
example_cifar10.py
+2
-1
example_mnist.py
example_mnist.py
+1
-1
tensorpack/callbacks/__init__.py
tensorpack/callbacks/__init__.py
+0
-1
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+0
-1
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+0
-1
tensorpack/callbacks/dump.py
tensorpack/callbacks/dump.py
+0
-1
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+0
-1
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+0
-1
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+2
-2
tensorpack/dataflow/__init__.py
tensorpack/dataflow/__init__.py
+0
-1
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+5
-5
tensorpack/dataflow/dataset/__init__.py
tensorpack/dataflow/dataset/__init__.py
+0
-1
tensorpack/dataflow/dataset/cifar10.py
tensorpack/dataflow/dataset/cifar10.py
+7
-7
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+2
-2
tensorpack/dataflow/dftools.py
tensorpack/dataflow/dftools.py
+0
-1
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+0
-1
tensorpack/dataflow/imgaug/__init__.py
tensorpack/dataflow/imgaug/__init__.py
+0
-1
tensorpack/dataflow/imgaug/base.py
tensorpack/dataflow/imgaug/base.py
+1
-2
tensorpack/dataflow/imgaug/crop.py
tensorpack/dataflow/imgaug/crop.py
+0
-1
tensorpack/dataflow/imgaug/deform.py
tensorpack/dataflow/imgaug/deform.py
+0
-1
tensorpack/dataflow/imgaug/imgproc.py
tensorpack/dataflow/imgaug/imgproc.py
+0
-1
tensorpack/dataflow/imgaug/noname.py
tensorpack/dataflow/imgaug/noname.py
+0
-1
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+0
-1
tensorpack/models/__init__.py
tensorpack/models/__init__.py
+0
-1
tensorpack/models/_common.py
tensorpack/models/_common.py
+2
-2
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+0
-1
tensorpack/predict.py
tensorpack/predict.py
+1
-1
tensorpack/train/__init__.py
tensorpack/train/__init__.py
+0
-1
tensorpack/train/base.py
tensorpack/train/base.py
+2
-2
tensorpack/train/config.py
tensorpack/train/config.py
+0
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-1
tensorpack/utils/__init__.py
tensorpack/utils/__init__.py
+0
-1
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+1
-1
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+5
-4
tensorpack/utils/modelutils.py
tensorpack/utils/modelutils.py
+0
-1
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+0
-1
tensorpack/utils/sessinit.py
tensorpack/utils/sessinit.py
+2
-2
tensorpack/utils/stat.py
tensorpack/utils/stat.py
+0
-1
tensorpack/utils/summary.py
tensorpack/utils/summary.py
+2
-2
tensorpack/utils/symbolic_functions.py
tensorpack/utils/symbolic_functions.py
+0
-1
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+0
-1
No files found.
example_cifar10.py
View file @
8d1ad775
#!/usr/bin/env python
2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: example_cifar10.py
# File: example_cifar10.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
@@ -108,6 +108,7 @@ def get_config():
...
@@ -108,6 +108,7 @@ def get_config():
dataset_train
=
BatchData
(
dataset_train
,
128
)
dataset_train
=
BatchData
(
dataset_train
,
128
)
dataset_train
=
PrefetchData
(
dataset_train
,
3
,
2
)
dataset_train
=
PrefetchData
(
dataset_train
,
3
,
2
)
step_per_epoch
=
dataset_train
.
size
()
/
2
step_per_epoch
=
dataset_train
.
size
()
/
2
step_per_epoch
=
10
augmentors
=
[
augmentors
=
[
imgaug
.
CenterCrop
((
30
,
30
)),
imgaug
.
CenterCrop
((
30
,
30
)),
...
...
example_mnist.py
View file @
8d1ad775
#!/usr/bin/env python
2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File: example_mnist.py
# File: example_mnist.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/callbacks/__init__.py
View file @
8d1ad775
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: __init__.py
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/callbacks/base.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: base.py
# File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/callbacks/common.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: common.py
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/callbacks/dump.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: dump.py
# File: dump.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/callbacks/group.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: group.py
# File: group.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/callbacks/summary.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File: summary.py
# File: summary.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
...
...
tensorpack/callbacks/validation_callback.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: validation_callback.py
# File: validation_callback.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
@@ -7,6 +6,7 @@ import tensorflow as tf
...
@@ -7,6 +6,7 @@ import tensorflow as tf
import
itertools
import
itertools
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
from
six.moves
import
zip
from
..utils
import
*
from
..utils
import
*
from
..utils.stat
import
*
from
..utils.stat
import
*
...
@@ -48,7 +48,7 @@ class ValidationCallback(PeriodicCallback):
...
@@ -48,7 +48,7 @@ class ValidationCallback(PeriodicCallback):
sess
=
tf
.
get_default_session
()
sess
=
tf
.
get_default_session
()
with
tqdm
(
total
=
self
.
ds
.
size
(),
ascii
=
True
)
as
pbar
:
with
tqdm
(
total
=
self
.
ds
.
size
(),
ascii
=
True
)
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
feed
=
dict
(
itertools
.
i
zip
(
self
.
input_vars
,
dp
))
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
outputs
=
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
outputs
=
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
yield
(
dp
,
outputs
)
yield
(
dp
,
outputs
)
...
...
tensorpack/dataflow/__init__.py
View file @
8d1ad775
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: __init__.py
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/dataflow/common.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: common.py
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
as
np
import
numpy
as
np
import
copy
import
copy
from
six.moves
import
range
from
.base
import
DataFlow
,
ProxyDataFlow
from
.base
import
DataFlow
,
ProxyDataFlow
from
..utils
import
*
from
..utils
import
*
...
@@ -47,9 +47,9 @@ class BatchData(ProxyDataFlow):
...
@@ -47,9 +47,9 @@ class BatchData(ProxyDataFlow):
def
aggregate_batch
(
data_holder
):
def
aggregate_batch
(
data_holder
):
size
=
len
(
data_holder
[
0
])
size
=
len
(
data_holder
[
0
])
result
=
[]
result
=
[]
for
k
in
x
range
(
size
):
for
k
in
range
(
size
):
dt
=
data_holder
[
0
][
k
]
dt
=
data_holder
[
0
][
k
]
if
type
(
dt
)
in
[
int
,
bool
,
long
]:
if
type
(
dt
)
in
[
int
,
bool
]:
tp
=
'int32'
tp
=
'int32'
elif
type
(
dt
)
==
float
:
elif
type
(
dt
)
==
float
:
tp
=
'float32'
tp
=
'float32'
...
@@ -104,7 +104,7 @@ class RepeatedData(ProxyDataFlow):
...
@@ -104,7 +104,7 @@ class RepeatedData(ProxyDataFlow):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
yield
dp
yield
dp
else
:
else
:
for
_
in
x
range
(
self
.
nr
):
for
_
in
range
(
self
.
nr
):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
yield
dp
yield
dp
...
@@ -125,7 +125,7 @@ class FakeData(DataFlow):
...
@@ -125,7 +125,7 @@ class FakeData(DataFlow):
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
def
get_data
(
self
):
def
get_data
(
self
):
for
_
in
x
range
(
self
.
_size
):
for
_
in
range
(
self
.
_size
):
yield
[
self
.
rng
.
random_sample
(
k
)
for
k
in
self
.
shapes
]
yield
[
self
.
rng
.
random_sample
(
k
)
for
k
in
self
.
shapes
]
class
MapData
(
ProxyDataFlow
):
class
MapData
(
ProxyDataFlow
):
...
...
tensorpack/dataflow/dataset/__init__.py
View file @
8d1ad775
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: __init__.py
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/dataflow/dataset/cifar10.py
View file @
8d1ad775
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
import
os
,
sys
import
os
,
sys
import
pickle
import
pickle
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
urllib
from
six.moves
import
urllib
,
range
import
copy
import
copy
import
tarfile
import
tarfile
import
logging
import
logging
...
@@ -43,11 +43,11 @@ def read_cifar10(filenames):
...
@@ -43,11 +43,11 @@ def read_cifar10(filenames):
ret
=
[]
ret
=
[]
for
fname
in
filenames
:
for
fname
in
filenames
:
fo
=
open
(
fname
,
'rb'
)
fo
=
open
(
fname
,
'rb'
)
dic
=
pickle
.
load
(
fo
)
dic
=
pickle
.
load
(
fo
,
encoding
=
'bytes'
)
data
=
dic
[
'data'
]
data
=
dic
[
b
'data'
]
label
=
dic
[
'labels'
]
label
=
dic
[
b
'labels'
]
fo
.
close
()
fo
.
close
()
for
k
in
x
range
(
10000
):
for
k
in
range
(
10000
):
img
=
data
[
k
]
.
reshape
(
3
,
32
,
32
)
img
=
data
[
k
]
.
reshape
(
3
,
32
,
32
)
img
=
np
.
transpose
(
img
,
[
1
,
2
,
0
])
img
=
np
.
transpose
(
img
,
[
1
,
2
,
0
])
ret
.
append
([
img
,
label
[
k
]])
ret
.
append
([
img
,
label
[
k
]])
...
@@ -55,7 +55,7 @@ def read_cifar10(filenames):
...
@@ -55,7 +55,7 @@ def read_cifar10(filenames):
def
get_filenames
(
dir
):
def
get_filenames
(
dir
):
filenames
=
[
os
.
path
.
join
(
filenames
=
[
os
.
path
.
join
(
dir
,
'cifar-10-batches-py'
,
'data_batch_
%
d'
%
i
)
for
i
in
x
range
(
1
,
6
)]
dir
,
'cifar-10-batches-py'
,
'data_batch_
%
d'
%
i
)
for
i
in
range
(
1
,
6
)]
filenames
.
append
(
os
.
path
.
join
(
filenames
.
append
(
os
.
path
.
join
(
dir
,
'cifar-10-batches-py'
,
'test_batch'
))
dir
,
'cifar-10-batches-py'
,
'test_batch'
))
return
filenames
return
filenames
...
@@ -115,7 +115,7 @@ if __name__ == '__main__':
...
@@ -115,7 +115,7 @@ if __name__ == '__main__':
ds
=
Cifar10
(
'train'
)
ds
=
Cifar10
(
'train'
)
from
tensorpack.dataflow.dftools
import
dump_dataset_images
from
tensorpack.dataflow.dftools
import
dump_dataset_images
mean
=
ds
.
get_per_channel_mean
()
mean
=
ds
.
get_per_channel_mean
()
print
mean
print
(
mean
)
dump_dataset_images
(
ds
,
'/tmp/cifar'
,
100
)
dump_dataset_images
(
ds
,
'/tmp/cifar'
,
100
)
#for (img, label) in ds.get_data():
#for (img, label) in ds.get_data():
...
...
tensorpack/dataflow/dataset/mnist.py
View file @
8d1ad775
...
@@ -7,7 +7,7 @@ import os
...
@@ -7,7 +7,7 @@ import os
import
gzip
import
gzip
import
numpy
import
numpy
from
six.moves
import
urllib
from
six.moves
import
urllib
,
range
from
...utils
import
logger
from
...utils
import
logger
from
..base
import
DataFlow
from
..base
import
DataFlow
...
@@ -136,7 +136,7 @@ class Mnist(DataFlow):
...
@@ -136,7 +136,7 @@ class Mnist(DataFlow):
def
get_data
(
self
):
def
get_data
(
self
):
ds
=
self
.
train
if
self
.
train_or_test
==
'train'
else
self
.
test
ds
=
self
.
train
if
self
.
train_or_test
==
'train'
else
self
.
test
for
k
in
x
range
(
ds
.
num_examples
):
for
k
in
range
(
ds
.
num_examples
):
img
=
ds
.
images
[
k
]
.
reshape
((
28
,
28
))
img
=
ds
.
images
[
k
]
.
reshape
((
28
,
28
))
label
=
ds
.
labels
[
k
]
label
=
ds
.
labels
[
k
]
yield
[
img
,
label
]
yield
[
img
,
label
]
...
...
tensorpack/dataflow/dftools.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: dftools.py
# File: dftools.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/dataflow/image.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: image.py
# File: image.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/dataflow/imgaug/__init__.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: __init__.py
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/dataflow/imgaug/base.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: base.py
# File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
@@ -25,7 +24,7 @@ class ImageAugmentor(object):
...
@@ -25,7 +24,7 @@ class ImageAugmentor(object):
def
_init
(
self
,
params
=
None
):
def
_init
(
self
,
params
=
None
):
self
.
reset_state
()
self
.
reset_state
()
if
params
:
if
params
:
for
k
,
v
in
params
.
ite
rite
ms
():
for
k
,
v
in
params
.
items
():
if
k
!=
'self'
:
if
k
!=
'self'
:
setattr
(
self
,
k
,
v
)
setattr
(
self
,
k
,
v
)
...
...
tensorpack/dataflow/imgaug/crop.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: crop.py
# File: crop.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/dataflow/imgaug/deform.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: deform.py
# File: deform.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/dataflow/imgaug/imgproc.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: imgproc.py
# File: imgproc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/dataflow/imgaug/noname.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: noname.py
# File: noname.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/dataflow/prefetch.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: prefetch.py
# File: prefetch.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/models/__init__.py
View file @
8d1ad775
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: __init__.py
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/models/_common.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: _common.py
# File: _common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
functools
import
wraps
from
functools
import
wraps
import
six
from
..utils.modelutils
import
*
from
..utils.modelutils
import
*
from
..utils.summary
import
*
from
..utils.summary
import
*
...
@@ -30,7 +30,7 @@ def layer_register(summary_activation=False):
...
@@ -30,7 +30,7 @@ def layer_register(summary_activation=False):
@
wraps
(
func
)
@
wraps
(
func
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
name
=
args
[
0
]
name
=
args
[
0
]
assert
isinstance
(
name
,
basestring
),
\
assert
isinstance
(
name
,
six
.
string_types
),
\
'name must be either the first argument. Args: {}'
.
format
(
str
(
args
))
'name must be either the first argument. Args: {}'
.
format
(
str
(
args
))
args
=
args
[
1
:]
args
=
args
[
1
:]
...
...
tensorpack/models/regularize.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: regularize.py
# File: regularize.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/predict.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: predict.py
# File: predict.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
@@ -9,6 +8,7 @@ import argparse
...
@@ -9,6 +8,7 @@ import argparse
from
collections
import
namedtuple
from
collections
import
namedtuple
import
numpy
as
np
import
numpy
as
np
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
six.moves
import
zip
from
.utils
import
*
from
.utils
import
*
from
.utils.modelutils
import
describe_model
from
.utils.modelutils
import
describe_model
...
...
tensorpack/train/__init__.py
View file @
8d1ad775
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: __init__.py
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/train/base.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File: base.py
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
abc
import
ABCMeta
from
abc
import
ABCMeta
from
six.moves
import
range
import
tqdm
import
tqdm
import
re
import
re
...
@@ -76,7 +76,7 @@ class Trainer(object):
...
@@ -76,7 +76,7 @@ class Trainer(object):
self
.
global_step
=
get_global_step
()
self
.
global_step
=
get_global_step
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
for
epoch
in
x
range
(
1
,
self
.
config
.
max_epoch
):
for
epoch
in
range
(
1
,
self
.
config
.
max_epoch
):
with
timed_operation
(
with
timed_operation
(
'Epoch {}, global_step={}'
.
format
(
'Epoch {}, global_step={}'
.
format
(
epoch
,
self
.
global_step
+
self
.
config
.
step_per_epoch
)):
epoch
,
self
.
global_step
+
self
.
config
.
step_per_epoch
)):
...
...
tensorpack/train/config.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File: config.py
# File: config.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
...
...
tensorpack/train/trainer.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: trainer.py
# File: trainer.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
@@ -6,6 +5,7 @@
...
@@ -6,6 +5,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
import
copy
import
copy
import
re
import
re
from
six.moves
import
zip
from
.base
import
Trainer
from
.base
import
Trainer
from
..dataflow.common
import
RepeatedData
from
..dataflow.common
import
RepeatedData
...
...
tensorpack/utils/__init__.py
View file @
8d1ad775
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: __init__.py
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/utils/concurrency.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: concurrency.py
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
@@ -6,6 +5,7 @@
...
@@ -6,6 +5,7 @@
import
threading
import
threading
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
tensorflow
as
tf
import
tensorflow
as
tf
from
six.moves
import
zip
from
.naming
import
*
from
.naming
import
*
from
.
import
logger
from
.
import
logger
...
...
tensorpack/utils/logger.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: logger.py
# File: logger.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
@@ -8,9 +7,8 @@ import os, shutil
...
@@ -8,9 +7,8 @@ import os, shutil
import
os.path
import
os.path
from
termcolor
import
colored
from
termcolor
import
colored
from
datetime
import
datetime
from
datetime
import
datetime
from
six.moves
import
input
import
sys
import
sys
if
not
sys
.
version_info
>=
(
3
,
0
):
input
=
raw_input
# for compatibility
from
.utils
import
mkdir_p
from
.utils
import
mkdir_p
...
@@ -63,7 +61,10 @@ def set_logger_dir(dirname):
...
@@ -63,7 +61,10 @@ def set_logger_dir(dirname):
Directory {} exists! Please either backup/delete it, or use a new directory
\
Directory {} exists! Please either backup/delete it, or use a new directory
\
unless you're resuming from a previous task."""
.
format
(
dirname
))
unless you're resuming from a previous task."""
.
format
(
dirname
))
logger
.
info
(
"Select Action: k (keep) / b (backup) / d (delete) / n (new):"
)
logger
.
info
(
"Select Action: k (keep) / b (backup) / d (delete) / n (new):"
)
act
=
input
()
.
lower
()
while
True
:
act
=
input
()
.
lower
()
if
act
:
break
timestr
=
datetime
.
now
()
.
strftime
(
'
%
m
%
d-
%
H
%
M
%
S'
)
timestr
=
datetime
.
now
()
.
strftime
(
'
%
m
%
d-
%
H
%
M
%
S'
)
if
act
==
'b'
:
if
act
==
'b'
:
backup_name
=
dirname
+
timestr
backup_name
=
dirname
+
timestr
...
...
tensorpack/utils/modelutils.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: modelutils.py
# File: modelutils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/utils/naming.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: naming.py
# File: naming.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/utils/sessinit.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: sessinit.py
# File: sessinit.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
@@ -7,6 +6,7 @@ import os
...
@@ -7,6 +6,7 @@ import os
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
import
six
from
.
import
logger
from
.
import
logger
class
SessionInit
(
object
):
class
SessionInit
(
object
):
...
@@ -49,7 +49,7 @@ class ParamRestore(SessionInit):
...
@@ -49,7 +49,7 @@ class ParamRestore(SessionInit):
sess
.
run
(
tf
.
initialize_all_variables
())
sess
.
run
(
tf
.
initialize_all_variables
())
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var_dict
=
dict
([
v
.
name
,
v
]
for
v
in
variables
)
var_dict
=
dict
([
v
.
name
,
v
]
for
v
in
variables
)
for
name
,
value
in
s
elf
.
prms
.
iteritems
(
):
for
name
,
value
in
s
ix
.
iteritems
(
self
.
prms
):
try
:
try
:
var
=
var_dict
[
name
]
var
=
var_dict
[
name
]
except
(
ValueError
,
KeyError
):
except
(
ValueError
,
KeyError
):
...
...
tensorpack/utils/stat.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: stat.py
# File: stat.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/utils/summary.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: summary.py
# File: summary.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.
import
logger
,
get_global_step_var
from
.
import
logger
,
get_global_step_var
...
@@ -14,7 +14,7 @@ def create_summary(name, v):
...
@@ -14,7 +14,7 @@ def create_summary(name, v):
Args: v: a value
Args: v: a value
"""
"""
assert
isinstance
(
name
,
basestring
),
type
(
name
)
assert
isinstance
(
name
,
six
.
string_types
),
type
(
name
)
v
=
float
(
v
)
v
=
float
(
v
)
s
=
tf
.
Summary
()
s
=
tf
.
Summary
()
s
.
value
.
add
(
tag
=
name
,
simple_value
=
v
)
s
.
value
.
add
(
tag
=
name
,
simple_value
=
v
)
...
...
tensorpack/utils/symbolic_functions.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: symbolic_functions.py
# File: symbolic_functions.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
tensorpack/utils/utils.py
View file @
8d1ad775
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: utils.py
# File: utils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment