Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6306da7e
Commit
6306da7e
authored
Nov 05, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
small refactor in train/base
parent
8bfab811
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
39 additions
and
55 deletions
+39
-55
.gitignore
.gitignore
+1
-0
README.md
README.md
+2
-0
examples/OpenAIGym/train-atari.py
examples/OpenAIGym/train-atari.py
+1
-1
examples/ResNet/cifar10-resnet.py
examples/ResNet/cifar10-resnet.py
+1
-1
scripts/plot-point.py
scripts/plot-point.py
+7
-6
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+1
-8
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+2
-1
tensorpack/dataflow/raw.py
tensorpack/dataflow/raw.py
+2
-0
tensorpack/train/base.py
tensorpack/train/base.py
+16
-32
tensorpack/train/config.py
tensorpack/train/config.py
+4
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+0
-2
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+0
-2
tensorpack/utils/loadcaffe.py
tensorpack/utils/loadcaffe.py
+2
-1
No files found.
.gitignore
View file @
6306da7e
...
@@ -73,3 +73,4 @@ model-*
...
@@ -73,3 +73,4 @@ model-*
checkpoint
checkpoint
*.json
*.json
*.prototxt
*.prototxt
snippet
README.md
View file @
6306da7e
# tensorpack
# tensorpack
Neural Network Toolbox on TensorFlow
Neural Network Toolbox on TensorFlow
Still in development. Underlying design may change.
See some
[
examples
](
examples
)
to learn about the framework.
See some
[
examples
](
examples
)
to learn about the framework.
You can actually train them and reproduce the performance... not just to see how to write code.
You can actually train them and reproduce the performance... not just to see how to write code.
...
...
examples/OpenAIGym/train-atari.py
View file @
6306da7e
...
@@ -204,9 +204,9 @@ def get_config():
...
@@ -204,9 +204,9 @@ def get_config():
HumanHyperParamSetter
(
'entropy_beta'
),
HumanHyperParamSetter
(
'entropy_beta'
),
HumanHyperParamSetter
(
'explore_factor'
),
HumanHyperParamSetter
(
'explore_factor'
),
master
,
master
,
StartProcOrThread
(
master
)
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'logits'
]),
2
),
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'logits'
]),
2
),
]),
]),
extra_threads_procs
=
[
master
],
session_config
=
get_default_sess_config
(
0.5
),
session_config
=
get_default_sess_config
(
0.5
),
model
=
M
,
model
=
M
,
step_per_epoch
=
STEP_PER_EPOCH
,
step_per_epoch
=
STEP_PER_EPOCH
,
...
...
examples/ResNet/cifar10-resnet.py
View file @
6306da7e
...
@@ -22,7 +22,7 @@ Identity Mappings in Deep Residual Networks, arxiv:1603.05027
...
@@ -22,7 +22,7 @@ Identity Mappings in Deep Residual Networks, arxiv:1603.05027
I can reproduce the results on 2 TitanX for
I can reproduce the results on 2 TitanX for
n=5, about 7.1
%
val error after 67k steps (8.6 step/s)
n=5, about 7.1
%
val error after 67k steps (8.6 step/s)
n=18, about 5.
8
%
val error after 80k steps (2.6 step/s)
n=18, about 5.
9
%
val error after 80k steps (2.6 step/s)
n=30: a 182-layer network, about 5.6
%
val error after 51k steps (1.55 step/s)
n=30: a 182-layer network, about 5.6
%
val error after 51k steps (1.55 step/s)
This model uses the whole training set instead of a train-val split.
This model uses the whole training set instead of a train-val split.
"""
"""
...
...
scripts/plot-point.py
View file @
6306da7e
...
@@ -22,6 +22,7 @@ import matplotlib.font_manager as fontm
...
@@ -22,6 +22,7 @@ import matplotlib.font_manager as fontm
import
argparse
,
sys
import
argparse
,
sys
from
collections
import
defaultdict
from
collections
import
defaultdict
from
itertools
import
chain
from
itertools
import
chain
import
six
from
matplotlib
import
rc
from
matplotlib
import
rc
#rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
#rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
...
@@ -52,11 +53,9 @@ def get_args():
...
@@ -52,11 +53,9 @@ def get_args():
help
=
'title of the graph'
,
help
=
'title of the graph'
,
default
=
''
)
default
=
''
)
parser
.
add_argument
(
'--xlabel'
,
parser
.
add_argument
(
'--xlabel'
,
help
=
'x label'
,
help
=
'x label'
,
type
=
six
.
text_type
)
default
=
'x'
)
parser
.
add_argument
(
'--ylabel'
,
parser
.
add_argument
(
'--ylabel'
,
help
=
'y label'
,
help
=
'y label'
,
type
=
six
.
text_type
)
default
=
'y'
)
parser
.
add_argument
(
'-s'
,
'--scale'
,
parser
.
add_argument
(
'-s'
,
'--scale'
,
help
=
'scale of each y, separated by comma'
)
help
=
'scale of each y, separated by comma'
)
parser
.
add_argument
(
'--annotate-maximum'
,
parser
.
add_argument
(
'--annotate-maximum'
,
...
@@ -215,8 +214,10 @@ def do_plot(data_xs, data_ys):
...
@@ -215,8 +214,10 @@ def do_plot(data_xs, data_ys):
if
args
.
annotate_maximum
or
args
.
annotate_minimum
:
if
args
.
annotate_maximum
or
args
.
annotate_minimum
:
annotate_min_max
(
truncate_data_x
,
data_y
,
ax
)
annotate_min_max
(
truncate_data_x
,
data_y
,
ax
)
plt
.
xlabel
(
args
.
xlabel
.
decode
(
'utf-8'
),
fontsize
=
'xx-large'
)
if
args
.
xlabel
:
plt
.
ylabel
(
args
.
ylabel
.
decode
(
'utf-8'
),
fontsize
=
'xx-large'
)
plt
.
xlabel
(
args
.
xlabel
,
fontsize
=
'xx-large'
)
if
args
.
ylabel
:
plt
.
ylabel
(
args
.
ylabel
,
fontsize
=
'xx-large'
)
plt
.
legend
(
loc
=
'best'
,
fontsize
=
'xx-large'
)
plt
.
legend
(
loc
=
'best'
,
fontsize
=
'xx-large'
)
# adjust maxx
# adjust maxx
...
...
tensorpack/callbacks/base.py
View file @
6306da7e
...
@@ -56,13 +56,6 @@ class Callback(object):
...
@@ -56,13 +56,6 @@ class Callback(object):
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
"""
"""
@
property
def
global_step
(
self
):
"""
Access the global step value of this training.
"""
return
self
.
trainer
.
global_step
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
"""
"""
Triggered after every epoch.
Triggered after every epoch.
...
@@ -95,7 +88,7 @@ class ProxyCallback(Callback):
...
@@ -95,7 +88,7 @@ class ProxyCallback(Callback):
self
.
cb
.
trigger_epoch
()
self
.
cb
.
trigger_epoch
()
def
__str__
(
self
):
def
__str__
(
self
):
return
str
(
self
.
cb
)
return
"Proxy-"
+
str
(
self
.
cb
)
class
PeriodicCallback
(
ProxyCallback
):
class
PeriodicCallback
(
ProxyCallback
):
"""
"""
...
...
tensorpack/callbacks/common.py
View file @
6306da7e
...
@@ -9,6 +9,7 @@ import re
...
@@ -9,6 +9,7 @@ import re
from
.base
import
Callback
from
.base
import
Callback
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils.varmanip
import
get_savename_from_varname
from
..tfutils.varmanip
import
get_savename_from_varname
from
..tfutils
import
get_global_step
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
...
@@ -72,7 +73,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
...
@@ -72,7 +73,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
self
.
saver
.
save
(
self
.
saver
.
save
(
tf
.
get_default_session
(),
tf
.
get_default_session
(),
self
.
path
,
self
.
path
,
global_step
=
self
.
global_step
,
global_step
=
get_global_step
()
,
write_meta_graph
=
False
)
write_meta_graph
=
False
)
# create a symbolic link for the latest model
# create a symbolic link for the latest model
...
...
tensorpack/dataflow/raw.py
View file @
6306da7e
...
@@ -22,6 +22,8 @@ class FakeData(RNGDataFlow):
...
@@ -22,6 +22,8 @@ class FakeData(RNGDataFlow):
"""
"""
:param shapes: a list of lists/tuples
:param shapes: a list of lists/tuples
:param size: size of this DataFlow
:param size: size of this DataFlow
:param random: whether to randomly generate data every iteration. note
that only generating the data could be time-consuming!
"""
"""
super
(
FakeData
,
self
)
.
__init__
()
super
(
FakeData
,
self
)
.
__init__
()
self
.
shapes
=
shapes
self
.
shapes
=
shapes
...
...
tensorpack/train/base.py
View file @
6306da7e
...
@@ -13,7 +13,6 @@ import tensorflow as tf
...
@@ -13,7 +13,6 @@ import tensorflow as tf
from
.config
import
TrainConfig
from
.config
import
TrainConfig
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils.timer
import
timed_operation
from
..utils.timer
import
timed_operation
from
..utils.concurrency
import
start_proc_mask_signal
from
..callbacks
import
StatHolder
from
..callbacks
import
StatHolder
from
..tfutils
import
get_global_step
,
get_global_step_var
from
..tfutils
import
get_global_step
,
get_global_step_var
from
..tfutils.summary
import
create_summary
from
..tfutils.summary
import
create_summary
...
@@ -32,7 +31,6 @@ class Trainer(object):
...
@@ -32,7 +31,6 @@ class Trainer(object):
summary_writer: a `tf.SummaryWriter`
summary_writer: a `tf.SummaryWriter`
config: a `TrainConfig`
config: a `TrainConfig`
model: a `ModelDesc`
model: a `ModelDesc`
global_step: a `int`
"""
"""
__metaclass__
=
ABCMeta
__metaclass__
=
ABCMeta
...
@@ -44,7 +42,7 @@ class Trainer(object):
...
@@ -44,7 +42,7 @@ class Trainer(object):
self
.
config
=
config
self
.
config
=
config
self
.
model
=
config
.
model
self
.
model
=
config
.
model
self
.
model
.
get_input_vars
()
# ensure they are present
self
.
model
.
get_input_vars
()
# ensure they are present
self
.
_extra_threads_procs
=
config
.
extra_threads_procs
self
.
init_session_and_coord
()
@
abstractmethod
@
abstractmethod
def
train
(
self
):
def
train
(
self
):
...
@@ -84,15 +82,6 @@ class Trainer(object):
...
@@ -84,15 +82,6 @@ class Trainer(object):
""" This is called right after all steps in an epoch are finished"""
""" This is called right after all steps in an epoch are finished"""
pass
pass
def
_init_summary
(
self
):
if
not
hasattr
(
logger
,
'LOG_DIR'
):
raise
RuntimeError
(
"Please use logger.set_logger_dir at the beginning of your script."
)
self
.
summary_writer
=
tf
.
train
.
SummaryWriter
(
logger
.
LOG_DIR
,
graph
=
self
.
sess
.
graph
)
self
.
summary_op
=
tf
.
merge_all_summaries
()
# create an empty StatHolder
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
def
_process_summary
(
self
,
summary_str
):
def
_process_summary
(
self
,
summary_str
):
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
for
val
in
summary
.
value
:
for
val
in
summary
.
value
:
...
@@ -107,31 +96,39 @@ class Trainer(object):
...
@@ -107,31 +96,39 @@ class Trainer(object):
get_global_step
())
get_global_step
())
self
.
stat_holder
.
add_stat
(
name
,
val
)
self
.
stat_holder
.
add_stat
(
name
,
val
)
def
main_loop
(
self
):
def
finalize_graph
(
self
):
# some final operations that might modify the graph
# some final operations that might modify the graph
get_global_step_var
()
# ensure there is such var, before finalizing the graph
get_global_step_var
()
# ensure there is such var, before finalizing the graph
logger
.
info
(
"Setup callbacks ..."
)
logger
.
info
(
"Setup callbacks ..."
)
callbacks
=
self
.
config
.
callbacks
callbacks
=
self
.
config
.
callbacks
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_init_summary
()
if
not
hasattr
(
logger
,
'LOG_DIR'
):
raise
RuntimeError
(
"logger directory wasn't set!"
)
self
.
summary_writer
=
tf
.
train
.
SummaryWriter
(
logger
.
LOG_DIR
,
graph
=
self
.
sess
.
graph
)
self
.
summary_op
=
tf
.
merge_all_summaries
()
# create an empty StatHolder
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
logger
.
info
(
"Initializing graph variables ..."
)
logger
.
info
(
"Initializing graph variables ..."
)
self
.
sess
.
run
(
tf
.
initialize_all_variables
())
self
.
sess
.
run
(
tf
.
initialize_all_variables
())
self
.
config
.
session_init
.
init
(
self
.
sess
)
self
.
config
.
session_init
.
init
(
self
.
sess
)
tf
.
get_default_graph
()
.
finalize
()
tf
.
get_default_graph
()
.
finalize
()
self
.
_start_concurrency
()
tf
.
train
.
start_queue_runners
(
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
def
main_loop
(
self
):
self
.
finalize_graph
()
callbacks
=
self
.
config
.
callbacks
with
self
.
sess
.
as_default
():
with
self
.
sess
.
as_default
():
try
:
try
:
self
.
global_step
=
get_global_step
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
get_global_step
()))
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
callbacks
.
before_train
()
callbacks
.
before_train
()
for
self
.
epoch_num
in
range
(
for
self
.
epoch_num
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
with
timed_operation
(
with
timed_operation
(
'Epoch {} (global_step {})'
.
format
(
'Epoch {} (global_step {})'
.
format
(
self
.
epoch_num
,
self
.
global_step
+
self
.
config
.
step_per_epoch
)):
self
.
epoch_num
,
get_global_step
()
+
self
.
config
.
step_per_epoch
)):
for
step
in
tqdm
.
trange
(
for
step
in
tqdm
.
trange
(
self
.
config
.
step_per_epoch
,
self
.
config
.
step_per_epoch
,
**
get_tqdm_kwargs
(
leave
=
True
)):
**
get_tqdm_kwargs
(
leave
=
True
)):
...
@@ -139,7 +136,6 @@ class Trainer(object):
...
@@ -139,7 +136,6 @@ class Trainer(object):
return
return
self
.
run_step
()
# implemented by subclass
self
.
run_step
()
# implemented by subclass
#callbacks.trigger_step() # not useful?
#callbacks.trigger_step() # not useful?
self
.
global_step
+=
1
self
.
trigger_epoch
()
self
.
trigger_epoch
()
except
StopTraining
:
except
StopTraining
:
logger
.
info
(
"Training was stopped."
)
logger
.
info
(
"Training was stopped."
)
...
@@ -155,18 +151,6 @@ class Trainer(object):
...
@@ -155,18 +151,6 @@ class Trainer(object):
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
coord
=
tf
.
train
.
Coordinator
()
self
.
coord
=
tf
.
train
.
Coordinator
()
def
_start_concurrency
(
self
):
"""
Run all threads before starting training
"""
logger
.
info
(
"Starting all threads & procs ..."
)
tf
.
train
.
start_queue_runners
(
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
with
self
.
sess
.
as_default
():
# avoid sigint get handled by other processes
start_proc_mask_signal
(
self
.
_extra_threads_procs
)
def
process_grads
(
self
,
grads
):
def
process_grads
(
self
,
grads
):
g
=
[]
g
=
[]
for
grad
,
var
in
grads
:
for
grad
,
var
in
grads
:
...
...
tensorpack/train/config.py
View file @
6306da7e
...
@@ -32,7 +32,6 @@ class TrainConfig(object):
...
@@ -32,7 +32,6 @@ class TrainConfig(object):
:param max_epoch: maximum number of epoch to run training. default to inf
:param max_epoch: maximum number of epoch to run training. default to inf
:param nr_tower: int. number of training towers. default to 1.
:param nr_tower: int. number of training towers. default to 1.
:param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given.
:param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given.
:param extra_threads_procs: list of `Startable` threads or processes
"""
"""
def
assert_type
(
v
,
tp
):
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
assert
isinstance
(
v
,
tp
),
v
.
__class__
...
@@ -72,6 +71,10 @@ class TrainConfig(object):
...
@@ -72,6 +71,10 @@ class TrainConfig(object):
self
.
tower
=
[
0
]
self
.
tower
=
[
0
]
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
if
self
.
extra_threads_procs
:
logger
.
warn
(
"[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs"
)
from
..callbacks.concurrency
import
StartProcOrThread
self
.
callbacks
.
cbs
.
append
(
StartProcOrThread
(
self
.
extra_threads_procs
))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
set_tower
(
self
,
nr_tower
=
None
,
tower
=
None
):
def
set_tower
(
self
,
nr_tower
=
None
,
tower
=
None
):
...
...
tensorpack/train/multigpu.py
View file @
6306da7e
...
@@ -73,7 +73,6 @@ class MultiGPUTrainer(QueueInputTrainer):
...
@@ -73,7 +73,6 @@ class MultiGPUTrainer(QueueInputTrainer):
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
):
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
):
def
train
(
self
):
def
train
(
self
):
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
self
.
_build_enque_thread
()
grad_list
=
self
.
_multi_tower_grads
()
grad_list
=
self
.
_multi_tower_grads
()
...
@@ -92,7 +91,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer):
...
@@ -92,7 +91,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer):
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
):
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
):
def
train
(
self
):
def
train
(
self
):
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
self
.
_build_enque_thread
()
grad_list
=
self
.
_multi_tower_grads
()
grad_list
=
self
.
_multi_tower_grads
()
...
...
tensorpack/train/trainer.py
View file @
6306da7e
...
@@ -76,7 +76,6 @@ class SimpleTrainer(Trainer):
...
@@ -76,7 +76,6 @@ class SimpleTrainer(Trainer):
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
avg_maintain_op
)
avg_maintain_op
)
self
.
init_session_and_coord
()
describe_model
()
describe_model
()
# create an infinte data producer
# create an infinte data producer
self
.
config
.
dataset
.
reset_state
()
self
.
config
.
dataset
.
reset_state
()
...
@@ -196,7 +195,6 @@ class QueueInputTrainer(Trainer):
...
@@ -196,7 +195,6 @@ class QueueInputTrainer(Trainer):
def
train
(
self
):
def
train
(
self
):
assert
len
(
self
.
config
.
tower
)
==
1
,
\
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
self
.
_build_enque_thread
()
grads
=
self
.
_single_tower_grad
()
grads
=
self
.
_single_tower_grad
()
...
...
tensorpack/utils/loadcaffe.py
View file @
6306da7e
...
@@ -113,9 +113,10 @@ def get_caffe_pb():
...
@@ -113,9 +113,10 @@ def get_caffe_pb():
caffe_pb_file
=
os
.
path
.
join
(
dir
,
'caffe_pb2.py'
)
caffe_pb_file
=
os
.
path
.
join
(
dir
,
'caffe_pb2.py'
)
if
not
os
.
path
.
isfile
(
caffe_pb_file
):
if
not
os
.
path
.
isfile
(
caffe_pb_file
):
proto_path
=
download
(
CAFFE_PROTO_URL
,
dir
)
proto_path
=
download
(
CAFFE_PROTO_URL
,
dir
)
assert
os
.
path
.
isfile
(
os
.
path
.
join
(
dir
,
'caffe.proto'
))
ret
=
os
.
system
(
'cd {} && protoc caffe.proto --python_out .'
.
format
(
dir
))
ret
=
os
.
system
(
'cd {} && protoc caffe.proto --python_out .'
.
format
(
dir
))
assert
ret
==
0
,
\
assert
ret
==
0
,
\
"
caffe proto compilation failed! Did you install protoc?
"
"
Command `protoc caffe.proto --python_out .` failed!
"
import
imp
import
imp
return
imp
.
load_source
(
'caffepb'
,
caffe_pb_file
)
return
imp
.
load_source
(
'caffepb'
,
caffe_pb_file
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment