Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
09e99778
Commit
09e99778
authored
Dec 30, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
sessinit paraemter
parent
8c674cc4
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
54 additions
and
17 deletions
+54
-17
example_cifar10.py
example_cifar10.py
+3
-7
example_mnist.py
example_mnist.py
+5
-3
tensorpack/dataflow/__init__.py
tensorpack/dataflow/__init__.py
+1
-0
tensorpack/models/_common.py
tensorpack/models/_common.py
+3
-0
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+7
-3
tensorpack/train.py
tensorpack/train.py
+0
-1
tensorpack/utils/__init__.py
tensorpack/utils/__init__.py
+33
-1
tensorpack/utils/sessinit.py
tensorpack/utils/sessinit.py
+1
-1
tensorpack/utils/summary.py
tensorpack/utils/summary.py
+1
-1
No files found.
example_cifar10.py
View file @
09e99778
...
...
@@ -13,9 +13,7 @@ from tensorpack.utils import *
from
tensorpack.utils.symbolic_functions
import
*
from
tensorpack.utils.summary
import
*
from
tensorpack.utils.callback
import
*
from
tensorpack.utils.sessinit
import
*
from
tensorpack.utils.validation_callback
import
*
from
tensorpack.dataflow.dataset
import
Cifar10
from
tensorpack.dataflow
import
*
BATCH_SIZE
=
128
...
...
@@ -81,10 +79,10 @@ def get_config():
logger
.
set_logger_dir
(
log_dir
)
import
cv2
dataset_train
=
Cifar10
(
'train'
)
dataset_train
=
dataset
.
Cifar10
(
'train'
)
dataset_train
=
MapData
(
dataset_train
,
lambda
img
:
cv2
.
resize
(
img
,
(
24
,
24
)))
dataset_train
=
BatchData
(
dataset_train
,
128
)
dataset_test
=
Cifar10
(
'test'
)
dataset_test
=
dataset
.
Cifar10
(
'test'
)
dataset_test
=
MapData
(
dataset_test
,
lambda
img
:
cv2
.
resize
(
img
,
(
24
,
24
)))
dataset_test
=
BatchData
(
dataset_test
,
128
)
step_per_epoch
=
dataset_train
.
size
()
...
...
@@ -133,7 +131,6 @@ if __name__ == '__main__':
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
global
args
args
=
parser
.
parse_args
()
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
...
...
@@ -141,8 +138,7 @@ if __name__ == '__main__':
with
tf
.
Graph
()
.
as_default
():
train
.
prepare
()
config
=
get_config
()
if
args
.
load
:
config
[
'session_init'
]
=
SaverRestore
(
args
.
load
)
sess_init
=
NewSession
()
train
.
start_train
(
config
)
example_mnist.py
View file @
09e99778
...
...
@@ -16,7 +16,6 @@ from tensorpack.utils.symbolic_functions import *
from
tensorpack.utils.summary
import
*
from
tensorpack.utils.callback
import
*
from
tensorpack.utils.validation_callback
import
*
from
tensorpack.dataflow.dataset
import
Mnist
from
tensorpack.dataflow
import
*
BATCH_SIZE
=
128
...
...
@@ -95,8 +94,8 @@ def get_config():
IMAGE_SIZE
=
28
dataset_train
=
BatchData
(
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
Mnist
(
'test'
),
256
,
remainder
=
True
)
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
step_per_epoch
=
dataset_train
.
size
()
#step_per_epoch = 20
#dataset_test = FixedSizeData(dataset_test, 20)
...
...
@@ -144,6 +143,7 @@ if __name__ == '__main__':
from
tensorpack
import
train
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
args
=
parser
.
parse_args
()
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
...
...
@@ -151,4 +151,6 @@ if __name__ == '__main__':
with
tf
.
Graph
()
.
as_default
():
train
.
prepare
()
config
=
get_config
()
if
args
.
load
:
config
[
'session_init'
]
=
SaverRestore
(
args
.
load
)
train
.
start_train
(
config
)
tensorpack/dataflow/__init__.py
View file @
09e99778
...
...
@@ -6,6 +6,7 @@
from
pkgutil
import
walk_packages
import
os
import
os.path
import
dataset
__SKIP
=
[
'dftools'
,
'dataset'
]
def
global_import
(
name
):
...
...
tensorpack/models/_common.py
View file @
09e99778
...
...
@@ -4,6 +4,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
functools
import
wraps
from
..utils.modelutils
import
*
from
..utils.summary
import
*
from
..utils
import
logger
...
...
@@ -21,6 +23,7 @@ def layer_register(summary_activation=False):
Can be overriden when creating the layer.
"""
def
wrapper
(
func
):
@
wraps
(
func
)
def
inner
(
*
args
,
**
kwargs
):
name
=
args
[
0
]
assert
isinstance
(
name
,
basestring
)
...
...
tensorpack/models/regularize.py
View file @
09e99778
...
...
@@ -7,9 +7,14 @@ import tensorflow as tf
import
re
from
..utils
import
logger
from
..utils
import
*
__all__
=
[
'regularize_cost'
]
@
memoized
def
_log_regularizer
(
name
):
logger
.
info
(
"Apply regularizer for {}"
.
format
(
name
))
def
regularize_cost
(
regex
,
func
):
"""
Apply a regularizer on every trainable variable matching the regex
...
...
@@ -20,8 +25,7 @@ def regularize_cost(regex, func):
costs
=
[]
for
p
in
params
:
name
=
p
.
name
if
re
.
search
(
regex
,
name
):
logger
.
info
(
"Apply regularizer for {}"
.
format
(
name
))
costs
.
append
(
func
(
p
))
costs
.
append
(
func
(
p
))
_log_regularizer
(
name
)
return
tf
.
add_n
(
costs
)
tensorpack/train.py
View file @
09e99778
...
...
@@ -11,7 +11,6 @@ from utils import *
from
utils.concurrency
import
EnqueueThread
,
coordinator_guard
from
utils.summary
import
summary_moving_average
from
utils.modelutils
import
describe_model
from
utils.sessinit
import
NewSession
from
utils
import
logger
from
dataflow
import
DataFlow
...
...
tensorpack/utils/__init__.py
View file @
09e99778
...
...
@@ -8,8 +8,11 @@ import os
import
time
import
sys
from
contextlib
import
contextmanager
import
logger
import
tensorflow
as
tf
import
collections
import
logger
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
())
...
...
@@ -17,6 +20,7 @@ def global_import(name):
for
k
in
lst
:
globals
()[
k
]
=
p
.
__dict__
[
k
]
global_import
(
'naming'
)
global_import
(
'sessinit'
)
@
contextmanager
def
timed_operation
(
msg
,
log_start
=
False
):
...
...
@@ -66,3 +70,31 @@ def get_default_sess_config():
conf
.
allow_soft_placement
=
True
return
conf
class
memoized
(
object
):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def
__init__
(
self
,
func
):
self
.
func
=
func
self
.
cache
=
{}
def
__call__
(
self
,
*
args
):
if
not
isinstance
(
args
,
collections
.
Hashable
):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return
self
.
func
(
*
args
)
if
args
in
self
.
cache
:
return
self
.
cache
[
args
]
else
:
value
=
self
.
func
(
*
args
)
self
.
cache
[
args
]
=
value
return
value
def
__repr__
(
self
):
'''Return the function's docstring.'''
return
self
.
func
.
__doc__
def
__get__
(
self
,
obj
,
objtype
):
'''Support instance methods.'''
return
functools
.
partial
(
self
.
__call__
,
obj
)
tensorpack/utils/sessinit.py
View file @
09e99778
...
...
@@ -7,7 +7,7 @@ from abc import abstractmethod
import
numpy
as
np
import
tensorflow
as
tf
from
.
import
logger
import
logger
class
SessionInit
(
object
):
@
abstractmethod
def
init
(
self
,
sess
):
...
...
tensorpack/utils/summary.py
View file @
09e99778
...
...
@@ -5,7 +5,7 @@
import
tensorflow
as
tf
import
logger
from
.
naming
import
*
from
naming
import
*
def
create_summary
(
name
,
v
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment