Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
70a64e5f
Commit
70a64e5f
authored
Jul 05, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
auto_set_dir
parent
236c78e0
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
85 additions
and
108 deletions
+85
-108
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+1
-3
examples/DisturbLabel/README.md
examples/DisturbLabel/README.md
+2
-1
examples/DisturbLabel/disturb.py
examples/DisturbLabel/disturb.py
+20
-0
examples/DisturbLabel/mnist-disturb.py
examples/DisturbLabel/mnist-disturb.py
+15
-54
examples/DoReFa-Net/svhn-digit-dorefa.py
examples/DoReFa-Net/svhn-digit-dorefa.py
+1
-3
examples/Inception/inception-bn.py
examples/Inception/inception-bn.py
+2
-7
examples/ResNet/cifar10-resnet.py
examples/ResNet/cifar10-resnet.py
+1
-3
examples/ResNet/svhn-resnet.py
examples/ResNet/svhn-resnet.py
+1
-3
examples/char-rnn/char-rnn.py
examples/char-rnn/char-rnn.py
+1
-3
examples/cifar-convnet.py
examples/cifar-convnet.py
+2
-4
examples/mnist-convnet.py
examples/mnist-convnet.py
+7
-7
examples/svhn-digit-convnet.py
examples/svhn-digit-convnet.py
+9
-8
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+23
-12
No files found.
examples/Atari2600/DQN.py
View file @
70a64e5f
...
@@ -137,9 +137,7 @@ class Model(ModelDesc):
...
@@ -137,9 +137,7 @@ class Model(ModelDesc):
return
self
.
predict_value
.
eval
(
feed_dict
=
{
'state:0'
:
[
state
]})[
0
]
return
self
.
predict_value
.
eval
(
feed_dict
=
{
'state:0'
:
[
state
]})[
0
]
def
get_config
():
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
auto_set_dir
()
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
M
=
Model
()
M
=
Model
()
dataset_train
=
ExpReplay
(
dataset_train
=
ExpReplay
(
...
...
examples/DisturbLabel/README.md
View file @
70a64e5f
...
@@ -14,4 +14,5 @@ Experiements are repeated 15 times for p=0, 10 times for p=0.02 & 0.05, and 5 ti
...
@@ -14,4 +14,5 @@ Experiements are repeated 15 times for p=0, 10 times for p=0.02 & 0.05, and 5 ti
of p. All experiements run for 100 epochs, with lr decay, which are enough for them to converge.
of p. All experiements run for 100 epochs, with lr decay, which are enough for them to converge.
I suppose the disturb method works as a random noise to prevent SGD from getting stuck.
I suppose the disturb method works as a random noise to prevent SGD from getting stuck.
Despite the positive results here, I still doubt whether the method works for ImageNet.
It doesn't work for harder problems such as SVHN (details to follow). And I don't believe
it will work for ImageNet.
examples/DisturbLabel/disturb.py
0 → 100644
View file @
70a64e5f
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: disturb.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
tensorpack
import
ProxyDataFlow
,
get_rng
class
DisturbLabel
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
prob
):
super
(
DisturbLabel
,
self
)
.
__init__
(
ds
)
self
.
prob
=
prob
self
.
rng
=
get_rng
(
self
)
def
get_data
(
self
):
for
dp
in
self
.
ds
.
get_data
():
img
,
l
=
dp
if
self
.
rng
.
rand
()
<
self
.
prob
:
l
=
self
.
rng
.
choice
(
10
)
yield
[
img
,
l
]
examples/DisturbLabel/mnist-disturb.py
View file @
70a64e5f
...
@@ -9,15 +9,22 @@ import os, sys
...
@@ -9,15 +9,22 @@ import os, sys
import
argparse
import
argparse
from
tensorpack
import
*
from
tensorpack
import
*
from
disturb
import
DisturbLabel
BATCH_SIZE
=
128
import
imp
IMAGE_SIZE
=
28
mnist_example
=
imp
.
load_source
(
'mnist_example'
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
,
'mnist-convnet.py'
))
get_config
=
mnist_example
.
get_config
def
get_data
():
dataset_train
=
BatchData
(
DisturbLabel
(
dataset
.
Mnist
(
'train'
),
args
.
prob
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
return
dataset_train
,
dataset_test
mnist_example
.
get_data
=
get_data
class
Model
(
ModelDesc
):
IMAGE_SIZE
=
28
def
_get_input_vars
(
self
):
return
[
InputVar
(
tf
.
float32
,
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
'input'
),
InputVar
(
tf
.
int32
,
(
None
,),
'label'
)
]
class
Model
(
mnist_example
.
Model
):
def
_build_graph
(
self
,
input_vars
,
is_training
):
def
_build_graph
(
self
,
input_vars
,
is_training
):
is_training
=
bool
(
is_training
)
is_training
=
bool
(
is_training
)
keep_prob
=
tf
.
constant
(
0.5
if
is_training
else
1.0
)
keep_prob
=
tf
.
constant
(
0.5
if
is_training
else
1.0
)
...
@@ -54,62 +61,16 @@ class Model(ModelDesc):
...
@@ -54,62 +61,16 @@ class Model(ModelDesc):
self
.
cost
=
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'cost'
)
self
.
cost
=
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'cost'
)
class
DisturbLabel
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
prob
):
super
(
DisturbLabel
,
self
)
.
__init__
(
ds
)
self
.
prob
=
prob
self
.
rng
=
get_rng
(
self
)
def
get_data
(
self
):
for
dp
in
self
.
ds
.
get_data
():
img
,
l
=
dp
if
self
.
rng
.
rand
()
<
self
.
prob
:
l
=
self
.
rng
.
choice
(
10
)
yield
[
img
,
l
]
def
get_config
(
disturb_prob
):
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
# prepare dataset
dataset_train
=
BatchData
(
DisturbLabel
(
dataset
.
Mnist
(
'train'
),
disturb_prob
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
step_per_epoch
=
dataset_train
.
size
()
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-3
,
global_step
=
get_global_step_var
(),
decay_steps
=
dataset_train
.
size
()
*
10
,
decay_rate
=
0.3
,
staircase
=
True
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
return
TrainConfig
(
dataset
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
InferenceRunner
(
dataset_test
,
[
ScalarStats
(
'cost'
),
ClassificationError
()
])
]),
session_config
=
get_default_sess_config
(
0.5
),
model
=
Model
(),
step_per_epoch
=
step_per_epoch
,
max_epoch
=
100
,
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--prob'
,
help
=
'disturb prob'
,
type
=
float
)
parser
.
add_argument
(
'--prob'
,
help
=
'disturb prob'
,
type
=
float
,
required
=
True
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
config
=
get_config
(
args
.
prob
)
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
QueueInputTrainer
(
config
)
.
train
()
QueueInputTrainer
(
config
)
.
train
()
...
...
examples/DoReFa-Net/svhn-digit-dorefa.py
View file @
70a64e5f
...
@@ -164,9 +164,7 @@ class Model(ModelDesc):
...
@@ -164,9 +164,7 @@ class Model(ModelDesc):
self
.
cost
=
tf
.
add_n
([
cost
,
wd_cost
],
name
=
'cost'
)
self
.
cost
=
tf
.
add_n
([
cost
,
wd_cost
],
name
=
'cost'
)
def
get_config
():
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
auto_set_dir
()
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
# prepare dataset
# prepare dataset
d1
=
dataset
.
SVHNDigit
(
'train'
)
d1
=
dataset
.
SVHNDigit
(
'train'
)
...
...
examples/Inception/inception-bn.py
View file @
70a64e5f
...
@@ -149,6 +149,7 @@ def get_data(train_or_test):
...
@@ -149,6 +149,7 @@ def get_data(train_or_test):
def
get_config
():
def
get_config
():
logger
.
auto_set_dir
()
# prepare dataset
# prepare dataset
dataset_train
=
get_data
(
'train'
)
dataset_train
=
get_data
(
'train'
)
step_per_epoch
=
5000
step_per_epoch
=
5000
...
@@ -184,15 +185,9 @@ if __name__ == '__main__':
...
@@ -184,15 +185,9 @@ if __name__ == '__main__':
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--data'
,
help
=
'ImageNet data root directory'
,
parser
.
add_argument
(
'--data'
,
help
=
'ImageNet data root directory'
,
required
=
True
)
required
=
True
)
global
args
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
...
...
examples/ResNet/cifar10-resnet.py
View file @
70a64e5f
...
@@ -141,9 +141,7 @@ def get_data(train_or_test):
...
@@ -141,9 +141,7 @@ def get_data(train_or_test):
return
ds
return
ds
def
get_config
():
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
auto_set_dir
()
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
# prepare dataset
# prepare dataset
dataset_train
=
get_data
(
'train'
)
dataset_train
=
get_data
(
'train'
)
...
...
examples/ResNet/svhn-resnet.py
View file @
70a64e5f
...
@@ -144,9 +144,7 @@ def get_data(train_or_test):
...
@@ -144,9 +144,7 @@ def get_data(train_or_test):
return
ds
return
ds
def
get_config
():
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
auto_set_dir
()
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
# prepare dataset
# prepare dataset
dataset_train
=
get_data
(
'train'
)
dataset_train
=
get_data
(
'train'
)
...
...
examples/char-rnn/char-rnn.py
View file @
70a64e5f
...
@@ -106,9 +106,7 @@ class Model(ModelDesc):
...
@@ -106,9 +106,7 @@ class Model(ModelDesc):
[
grad
],
param
.
grad_clip
)[
0
][
0
])]
[
grad
],
param
.
grad_clip
)[
0
][
0
])]
def
get_config
():
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
auto_set_dir
()
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
ds
=
CharRNNData
(
param
.
corpus
,
100000
)
ds
=
CharRNNData
(
param
.
corpus
,
100000
)
ds
=
BatchData
(
ds
,
param
.
batch_size
)
ds
=
BatchData
(
ds
,
param
.
batch_size
)
...
...
examples/cifar-convnet.py
View file @
70a64e5f
...
@@ -107,6 +107,8 @@ def get_data(train_or_test, cifar_classnum):
...
@@ -107,6 +107,8 @@ def get_data(train_or_test, cifar_classnum):
return
ds
return
ds
def
get_config
(
cifar_classnum
):
def
get_config
(
cifar_classnum
):
logger
.
auto_set_dir
()
# prepare dataset
# prepare dataset
dataset_train
=
get_data
(
'train'
,
cifar_classnum
)
dataset_train
=
get_data
(
'train'
,
cifar_classnum
)
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
dataset_train
.
size
()
...
@@ -144,10 +146,6 @@ if __name__ == '__main__':
...
@@ -144,10 +146,6 @@ if __name__ == '__main__':
type
=
int
,
default
=
10
)
type
=
int
,
default
=
10
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
else
:
else
:
...
...
examples/mnist-convnet.py
View file @
70a64e5f
...
@@ -15,7 +15,6 @@ MNIST ConvNet example.
...
@@ -15,7 +15,6 @@ MNIST ConvNet example.
about 0.6
%
validation error after 30 epochs.
about 0.6
%
validation error after 30 epochs.
"""
"""
BATCH_SIZE
=
128
IMAGE_SIZE
=
28
IMAGE_SIZE
=
28
class
Model
(
ModelDesc
):
class
Model
(
ModelDesc
):
...
@@ -65,14 +64,15 @@ class Model(ModelDesc):
...
@@ -65,14 +64,15 @@ class Model(ModelDesc):
summary
.
add_param_summary
([(
'.*/W'
,
[
'histogram'
])])
# monitor histogram of all W
summary
.
add_param_summary
([(
'.*/W'
,
[
'histogram'
])])
# monitor histogram of all W
self
.
cost
=
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'cost'
)
self
.
cost
=
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'cost'
)
def
get_data
():
train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
return
train
,
test
def
get_config
():
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
auto_set_dir
()
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
# prepare dataset
dataset_train
,
dataset_test
=
get_data
()
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
dataset_train
.
size
()
lr
=
tf
.
train
.
exponential_decay
(
lr
=
tf
.
train
.
exponential_decay
(
...
...
examples/svhn-digit-convnet.py
View file @
70a64e5f
...
@@ -56,11 +56,10 @@ class Model(ModelDesc):
...
@@ -56,11 +56,10 @@ class Model(ModelDesc):
wd_cost
=
regularize_cost
(
'fc.*/W'
,
l2_regularizer
(
0.00001
))
wd_cost
=
regularize_cost
(
'fc.*/W'
,
l2_regularizer
(
0.00001
))
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
wd_cost
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
wd_cost
)
add_param_summary
([(
'.*/W'
,
[
'histogram'
,
'
sparsity
'
])])
# monitor W
add_param_summary
([(
'.*/W'
,
[
'histogram'
,
'
rms
'
])])
# monitor W
self
.
cost
=
tf
.
add_n
([
cost
,
wd_cost
],
name
=
'cost'
)
self
.
cost
=
tf
.
add_n
([
cost
,
wd_cost
],
name
=
'cost'
)
def
get_config
():
def
get_data
():
# prepare dataset
d1
=
dataset
.
SVHNDigit
(
'train'
)
d1
=
dataset
.
SVHNDigit
(
'train'
)
d2
=
dataset
.
SVHNDigit
(
'extra'
)
d2
=
dataset
.
SVHNDigit
(
'extra'
)
data_train
=
RandomMixData
([
d1
,
d2
])
data_train
=
RandomMixData
([
d1
,
d2
])
...
@@ -77,11 +76,17 @@ def get_config():
...
@@ -77,11 +76,17 @@ def get_config():
data_train
=
AugmentImageComponent
(
data_train
,
augmentors
)
data_train
=
AugmentImageComponent
(
data_train
,
augmentors
)
data_train
=
BatchData
(
data_train
,
128
)
data_train
=
BatchData
(
data_train
,
128
)
data_train
=
PrefetchData
(
data_train
,
5
,
5
)
data_train
=
PrefetchData
(
data_train
,
5
,
5
)
step_per_epoch
=
data_train
.
size
()
augmentors
=
[
imgaug
.
Resize
((
40
,
40
))
]
augmentors
=
[
imgaug
.
Resize
((
40
,
40
))
]
data_test
=
AugmentImageComponent
(
data_test
,
augmentors
)
data_test
=
AugmentImageComponent
(
data_test
,
augmentors
)
data_test
=
BatchData
(
data_test
,
128
,
remainder
=
True
)
data_test
=
BatchData
(
data_test
,
128
,
remainder
=
True
)
return
data_train
,
data_test
def
get_config
():
logger
.
auto_set_dir
()
data_train
,
data_test
=
get_data
()
step_per_epoch
=
data_train
.
size
()
lr
=
tf
.
train
.
exponential_decay
(
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-3
,
learning_rate
=
1e-3
,
...
@@ -110,10 +115,6 @@ if __name__ == '__main__':
...
@@ -110,10 +115,6 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
else
:
else
:
...
...
tensorpack/utils/logger.py
View file @
70a64e5f
...
@@ -14,7 +14,7 @@ from .fs import mkdir_p
...
@@ -14,7 +14,7 @@ from .fs import mkdir_p
__all__
=
[]
__all__
=
[]
class
MyFormatter
(
logging
.
Formatter
):
class
_
MyFormatter
(
logging
.
Formatter
):
def
format
(
self
,
record
):
def
format
(
self
,
record
):
date
=
colored
(
'[
%(asctime)
s @
%(filename)
s:
%(lineno)
d]'
,
'green'
)
date
=
colored
(
'[
%(asctime)
s @
%(filename)
s:
%(lineno)
d]'
,
'green'
)
msg
=
'
%(message)
s'
msg
=
'
%(message)
s'
...
@@ -28,17 +28,17 @@ class MyFormatter(logging.Formatter):
...
@@ -28,17 +28,17 @@ class MyFormatter(logging.Formatter):
# Python3 compatibilty
# Python3 compatibilty
self
.
_style
.
_fmt
=
fmt
self
.
_style
.
_fmt
=
fmt
self
.
_fmt
=
fmt
self
.
_fmt
=
fmt
return
super
(
MyFormatter
,
self
)
.
format
(
record
)
return
super
(
_
MyFormatter
,
self
)
.
format
(
record
)
def
getlogger
():
def
_
getlogger
():
logger
=
logging
.
getLogger
(
'tensorpack'
)
logger
=
logging
.
getLogger
(
'tensorpack'
)
logger
.
propagate
=
False
logger
.
propagate
=
False
logger
.
setLevel
(
logging
.
INFO
)
logger
.
setLevel
(
logging
.
INFO
)
handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
handler
.
setFormatter
(
MyFormatter
(
datefmt
=
'
%
m
%
d
%
H:
%
M:
%
S'
))
handler
.
setFormatter
(
_
MyFormatter
(
datefmt
=
'
%
m
%
d
%
H:
%
M:
%
S'
))
logger
.
addHandler
(
handler
)
logger
.
addHandler
(
handler
)
return
logger
return
logger
logger
=
getlogger
()
_logger
=
_
getlogger
()
def
get_time_str
():
def
get_time_str
():
...
@@ -52,8 +52,8 @@ def _set_file(path):
...
@@ -52,8 +52,8 @@ def _set_file(path):
info
(
"Log file '{}' backuped to '{}'"
.
format
(
path
,
backup_name
))
info
(
"Log file '{}' backuped to '{}'"
.
format
(
path
,
backup_name
))
hdl
=
logging
.
FileHandler
(
hdl
=
logging
.
FileHandler
(
filename
=
path
,
encoding
=
'utf-8'
,
mode
=
'w'
)
filename
=
path
,
encoding
=
'utf-8'
,
mode
=
'w'
)
hdl
.
setFormatter
(
MyFormatter
(
datefmt
=
'
%
m
%
d
%
H:
%
M:
%
S'
))
hdl
.
setFormatter
(
_
MyFormatter
(
datefmt
=
'
%
m
%
d
%
H:
%
M:
%
S'
))
logger
.
addHandler
(
hdl
)
_
logger
.
addHandler
(
hdl
)
def
set_logger_dir
(
dirname
,
action
=
None
):
def
set_logger_dir
(
dirname
,
action
=
None
):
"""
"""
...
@@ -63,10 +63,10 @@ def set_logger_dir(dirname, action=None):
...
@@ -63,10 +63,10 @@ def set_logger_dir(dirname, action=None):
"""
"""
global
LOG_FILE
,
LOG_DIR
global
LOG_FILE
,
LOG_DIR
if
os
.
path
.
isdir
(
dirname
):
if
os
.
path
.
isdir
(
dirname
):
logger
.
warn
(
"""
\
_
logger
.
warn
(
"""
\
Directory {} exists! Please either backup/delete it, or use a new directory
\
Directory {} exists! Please either backup/delete it, or use a new directory
\
unless you're resuming from a previous task."""
.
format
(
dirname
))
unless you're resuming from a previous task."""
.
format
(
dirname
))
logger
.
info
(
"Select Action: k (keep) / b (backup) / d (delete) / n (new):"
)
_
logger
.
info
(
"Select Action: k (keep) / b (backup) / d (delete) / n (new):"
)
if
not
action
:
if
not
action
:
while
True
:
while
True
:
act
=
input
()
.
lower
()
.
strip
()
act
=
input
()
.
lower
()
.
strip
()
...
@@ -92,10 +92,21 @@ unless you're resuming from a previous task.""".format(dirname))
...
@@ -92,10 +92,21 @@ unless you're resuming from a previous task.""".format(dirname))
LOG_FILE
=
os
.
path
.
join
(
dirname
,
'log.log'
)
LOG_FILE
=
os
.
path
.
join
(
dirname
,
'log.log'
)
_set_file
(
LOG_FILE
)
_set_file
(
LOG_FILE
)
# export logger functions
for
func
in
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'warn'
,
'exception'
,
'debug'
]:
locals
()[
func
]
=
getattr
(
_logger
,
func
)
def
disable_logger
():
def
disable_logger
():
""" disable all logging ability from this moment"""
for
func
in
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'warn'
,
'exception'
,
'debug'
]:
for
func
in
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'warn'
,
'exception'
,
'debug'
]:
globals
()[
func
]
=
lambda
x
:
None
globals
()[
func
]
=
lambda
x
:
None
# export logger functions
def
auto_set_dir
(
action
=
None
):
for
func
in
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'warn'
,
'exception'
,
'debug'
]:
""" set log directory to a subdir inside 'train_log', with the name being
locals
()[
func
]
=
getattr
(
logger
,
func
)
the main python file currently running"""
mod
=
sys
.
modules
[
'__main__'
]
basename
=
os
.
path
.
basename
(
mod
.
__file__
)
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]),
action
=
action
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment