Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e94abf66
Commit
e94abf66
authored
Jan 04, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix TF deprecation of concat/split/pack/unpack
parent
e04bb2d5
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
58 additions
and
56 deletions
+58
-56
examples/ConvolutionalPoseMachines/load-cpm.py
examples/ConvolutionalPoseMachines/load-cpm.py
+2
-1
examples/GAN/Image2Image.py
examples/GAN/Image2Image.py
+2
-2
examples/GAN/InfoGAN-mnist.py
examples/GAN/InfoGAN-mnist.py
+2
-2
examples/HED/hed.py
examples/HED/hed.py
+1
-1
examples/Inception/inception-bn.py
examples/Inception/inception-bn.py
+1
-1
examples/Inception/inceptionv3.py
examples/Inception/inceptionv3.py
+22
-22
examples/SpatialTransformer/mnist-addition.py
examples/SpatialTransformer/mnist-addition.py
+5
-5
examples/char-rnn/char-rnn.py
examples/char-rnn/char-rnn.py
+1
-1
tensorpack/dataflow/dataset/ptb.py
tensorpack/dataflow/dataset/ptb.py
+1
-1
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+5
-4
tensorpack/models/image_sample.py
tensorpack/models/image_sample.py
+6
-6
tensorpack/models/pool.py
tensorpack/models/pool.py
+6
-6
tensorpack/models/shapes.py
tensorpack/models/shapes.py
+3
-3
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+1
-1
No files found.
examples/ConvolutionalPoseMachines/load-cpm.py
View file @
e94abf66
...
...
@@ -76,7 +76,8 @@ class Model(ModelDesc):
.
Conv2D
(
'conv4_7_CPM'
,
128
)())
def
add_stage
(
stage
,
l
):
l
=
tf
.
concat
(
3
,
[
l
,
shared
,
pool_center
],
name
=
'concat_stage{}'
.
format
(
stage
))
l
=
tf
.
concat_v2
([
l
,
shared
,
pool_center
],
3
,
name
=
'concat_stage{}'
.
format
(
stage
))
for
i
in
range
(
1
,
6
):
l
=
Conv2D
(
'Mconv{}_stage{}'
.
format
(
i
,
stage
),
l
,
128
)
l
=
Conv2D
(
'Mconv6_stage{}'
.
format
(
stage
),
l
,
128
,
kernel_shape
=
1
)
...
...
examples/GAN/Image2Image.py
View file @
e94abf66
...
...
@@ -87,7 +87,7 @@ class Model(ModelDesc):
def
discriminator
(
self
,
inputs
,
outputs
):
""" return a (b, 1) logits"""
l
=
tf
.
concat
(
3
,
[
inputs
,
outputs
]
)
l
=
tf
.
concat
_v2
([
inputs
,
outputs
],
3
)
with
argscope
(
Conv2D
,
nl
=
tf
.
identity
,
kernel_shape
=
4
,
stride
=
2
):
l
=
(
LinearWrap
(
l
)
.
Conv2D
(
'conv0'
,
NF
,
nl
=
LeakyReLU
)
...
...
@@ -125,7 +125,7 @@ class Model(ModelDesc):
if
OUT_CH
==
1
:
output
=
tf
.
image
.
grayscale_to_rgb
(
output
)
fake_output
=
tf
.
image
.
grayscale_to_rgb
(
fake_output
)
viz
=
(
tf
.
concat
(
2
,
[
input
,
output
,
fake_output
]
)
+
1.0
)
*
128.0
viz
=
(
tf
.
concat
_v2
([
input
,
output
,
fake_output
],
2
)
+
1.0
)
*
128.0
viz
=
tf
.
cast
(
tf
.
clip_by_value
(
viz
,
0
,
255
),
tf
.
uint8
,
name
=
'viz'
)
tf
.
summary
.
image
(
'input,output,fake'
,
viz
,
max_outputs
=
max
(
30
,
BATCH
))
...
...
examples/GAN/InfoGAN-mnist.py
View file @
e94abf66
...
...
@@ -61,9 +61,9 @@ class Model(ModelDesc):
zc
=
tf
.
one_hot
(
ids
,
10
,
name
=
'zc_train'
)
zc
=
tf
.
placeholder_with_default
(
zc
,
[
None
,
10
],
name
=
'zc'
)
z
=
tf
.
random_uniform
(
tf
.
p
ack
([
tf
.
shape
(
zc
)[
0
],
90
]),
-
1
,
1
,
name
=
'z_train'
)
z
=
tf
.
random_uniform
(
tf
.
st
ack
([
tf
.
shape
(
zc
)[
0
],
90
]),
-
1
,
1
,
name
=
'z_train'
)
z
=
tf
.
placeholder_with_default
(
z
,
[
None
,
90
],
name
=
'z'
)
z
=
tf
.
concat
(
1
,
[
zc
,
z
]
,
name
=
'fullz'
)
z
=
tf
.
concat
_v2
([
zc
,
z
],
1
,
name
=
'fullz'
)
with
argscope
([
Conv2D
,
Deconv2D
,
FullyConnected
],
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
0.02
)):
...
...
examples/HED/hed.py
View file @
e94abf66
...
...
@@ -67,7 +67,7 @@ class Model(ModelDesc):
b5
=
branch
(
'branch5'
,
l
,
16
)
final_map
=
Conv2D
(
'convfcweight'
,
tf
.
concat
(
3
,
[
b1
,
b2
,
b3
,
b4
,
b5
]
),
1
,
1
,
tf
.
concat
_v2
([
b1
,
b2
,
b3
,
b4
,
b5
],
3
),
1
,
1
,
W_init
=
tf
.
constant_initializer
(
0.2
),
use_bias
=
False
,
nl
=
tf
.
identity
)
costs
=
[]
...
...
examples/Inception/inception-bn.py
View file @
e94abf66
...
...
@@ -59,7 +59,7 @@ class Model(ModelDesc):
if
nrpool
!=
0
:
# pool + passthrough if nrpool == 0
x4
=
Conv2D
(
'poolproj'
,
x4
,
nrpool
,
1
)
outs
.
append
(
x4
)
return
tf
.
concat
(
3
,
outs
,
name
=
'concat'
)
return
tf
.
concat
_v2
(
outs
,
3
,
name
=
'concat'
)
with
argscope
(
Conv2D
,
nl
=
BNReLU
,
use_bias
=
False
):
l
=
Conv2D
(
'conv0'
,
image
,
64
,
7
,
stride
=
2
)
...
...
examples/Inception/inceptionv3.py
View file @
e94abf66
...
...
@@ -88,55 +88,55 @@ class Model(ModelDesc):
.
MaxPooling
(
'pool4'
,
3
,
2
)())
# 35
with
tf
.
variable_scope
(
'incep-35-256a'
):
l
=
tf
.
concat
(
3
,
[
l
=
tf
.
concat
_v2
(
[
Conv2D
(
'conv11'
,
l
,
64
,
1
),
proj_kk
(
l
,
5
,
48
,
64
),
proj_233
(
l
,
64
,
96
),
pool_proj
(
l
,
32
,
'avg'
)
],
name
=
'concat'
)
],
3
,
name
=
'concat'
)
with
tf
.
variable_scope
(
'incep-35-288a'
):
l
=
tf
.
concat
(
3
,
[
l
=
tf
.
concat
_v2
(
[
Conv2D
(
'conv11'
,
l
,
64
,
1
),
proj_kk
(
l
,
5
,
48
,
64
),
proj_233
(
l
,
64
,
96
),
pool_proj
(
l
,
64
,
'avg'
)
],
name
=
'concat'
)
],
3
,
name
=
'concat'
)
with
tf
.
variable_scope
(
'incep-35-288b'
):
l
=
tf
.
concat
(
3
,
[
l
=
tf
.
concat
_v2
(
[
Conv2D
(
'conv11'
,
l
,
64
,
1
),
proj_kk
(
l
,
5
,
48
,
64
),
proj_233
(
l
,
64
,
96
),
pool_proj
(
l
,
64
,
'avg'
)
],
name
=
'concat'
)
],
3
,
name
=
'concat'
)
# 35x35x288
with
tf
.
variable_scope
(
'incep-17-768a'
):
l
=
tf
.
concat
(
3
,
[
l
=
tf
.
concat
_v2
(
[
Conv2D
(
'conv3x3'
,
l
,
384
,
3
,
stride
=
2
,
padding
=
'VALID'
),
proj_233
(
l
,
64
,
96
,
stride
=
2
),
MaxPooling
(
'maxpool'
,
l
,
3
,
2
)
],
name
=
'concat'
)
],
3
,
name
=
'concat'
)
with
tf
.
variable_scope
(
'incep-17-768b'
):
l
=
tf
.
concat
(
3
,
[
l
=
tf
.
concat
_v2
(
[
Conv2D
(
'conv11'
,
l
,
192
,
1
),
proj_77
(
l
,
128
,
192
),
proj_277
(
l
,
128
,
192
),
pool_proj
(
l
,
192
,
'avg'
)
],
name
=
'concat'
)
],
3
,
name
=
'concat'
)
for
x
in
[
'c'
,
'd'
]:
with
tf
.
variable_scope
(
'incep-17-768{}'
.
format
(
x
)):
l
=
tf
.
concat
(
3
,
[
l
=
tf
.
concat
_v2
(
[
Conv2D
(
'conv11'
,
l
,
192
,
1
),
proj_77
(
l
,
160
,
192
),
proj_277
(
l
,
160
,
192
),
pool_proj
(
l
,
192
,
'avg'
)
],
name
=
'concat'
)
],
3
,
name
=
'concat'
)
with
tf
.
variable_scope
(
'incep-17-768e'
):
l
=
tf
.
concat
(
3
,
[
l
=
tf
.
concat
_v2
(
[
Conv2D
(
'conv11'
,
l
,
192
,
1
),
proj_77
(
l
,
192
,
192
),
proj_277
(
l
,
192
,
192
),
pool_proj
(
l
,
192
,
'avg'
)
],
name
=
'concat'
)
],
3
,
name
=
'concat'
)
# 17x17x768
with
tf
.
variable_scope
(
'br1'
):
...
...
@@ -147,30 +147,30 @@ class Model(ModelDesc):
br1
=
FullyConnected
(
'fc'
,
br1
,
1000
,
nl
=
tf
.
identity
)
with
tf
.
variable_scope
(
'incep-17-1280a'
):
l
=
tf
.
concat
(
3
,
[
l
=
tf
.
concat
_v2
(
[
proj_kk
(
l
,
3
,
192
,
320
,
stride
=
2
),
Conv2D
(
'conv73'
,
proj_77
(
l
,
192
,
192
),
192
,
3
,
stride
=
2
,
padding
=
'VALID'
),
MaxPooling
(
'maxpool'
,
l
,
3
,
2
)
],
name
=
'concat'
)
],
3
,
name
=
'concat'
)
for
x
in
[
'a'
,
'b'
]:
with
tf
.
variable_scope
(
'incep-8-2048{}'
.
format
(
x
))
as
scope
:
br11
=
Conv2D
(
'conv11'
,
l
,
320
,
1
)
br33
=
Conv2D
(
'conv133r'
,
l
,
384
,
1
)
br33
=
tf
.
concat
(
3
,
[
br33
=
tf
.
concat
_v2
(
[
Conv2D
(
'conv133a'
,
br33
,
384
,
[
1
,
3
]),
Conv2D
(
'conv133b'
,
br33
,
384
,
[
3
,
1
])
],
name
=
'conv133'
)
],
3
,
name
=
'conv133'
)
br233
=
proj_kk
(
l
,
3
,
448
,
384
)
br233
=
tf
.
concat
(
3
,
[
br233
=
tf
.
concat
_v2
(
[
Conv2D
(
'conv233a'
,
br233
,
384
,
[
1
,
3
]),
Conv2D
(
'conv233b'
,
br233
,
384
,
[
3
,
1
]),
],
name
=
'conv233'
)
],
3
,
name
=
'conv233'
)
l
=
tf
.
concat
(
3
,
[
l
=
tf
.
concat
_v2
(
[
br11
,
br33
,
br233
,
pool_proj
(
l
,
192
,
'avg'
)
],
name
=
'concat'
)
],
3
,
name
=
'concat'
)
l
=
GlobalAvgPooling
(
'gap'
,
l
)
# 1x1x2048
...
...
examples/SpatialTransformer/mnist-addition.py
View file @
e94abf66
...
...
@@ -60,14 +60,14 @@ class Model(ModelDesc):
# For visualization in tensorboard
padded1
=
tf
.
pad
(
sampled1
,
[[
0
,
0
],
[
HALF_DIFF
,
HALF_DIFF
],
[
HALF_DIFF
,
HALF_DIFF
],
[
0
,
0
]])
padded2
=
tf
.
pad
(
sampled2
,
[[
0
,
0
],
[
HALF_DIFF
,
HALF_DIFF
],
[
HALF_DIFF
,
HALF_DIFF
],
[
0
,
0
]])
img_orig
=
tf
.
concat
(
1
,
[
image
[:,
:,
:,
0
],
image
[:,
:,
:,
1
]]
)
# b x 2h x w
transform1
=
tf
.
concat
(
1
,
[
padded1
[:,
:,
:,
0
],
padded1
[:,
:,
:,
1
]]
)
transform2
=
tf
.
concat
(
1
,
[
padded2
[:,
:,
:,
0
],
padded2
[:,
:,
:,
1
]]
)
stacked
=
tf
.
concat
(
2
,
[
img_orig
,
transform1
,
transform2
]
,
'viz'
)
img_orig
=
tf
.
concat
_v2
([
image
[:,
:,
:,
0
],
image
[:,
:,
:,
1
]],
1
)
# b x 2h x w
transform1
=
tf
.
concat
_v2
([
padded1
[:,
:,
:,
0
],
padded1
[:,
:,
:,
1
]],
1
)
transform2
=
tf
.
concat
_v2
([
padded2
[:,
:,
:,
0
],
padded2
[:,
:,
:,
1
]],
1
)
stacked
=
tf
.
concat
_v2
([
img_orig
,
transform1
,
transform2
],
2
,
'viz'
)
tf
.
summary
.
image
(
'visualize'
,
tf
.
expand_dims
(
stacked
,
-
1
),
max_images
=
30
)
sampled
=
tf
.
concat
(
3
,
[
sampled1
,
sampled2
]
,
'sampled_concat'
)
sampled
=
tf
.
concat
_v2
([
sampled1
,
sampled2
],
3
,
'sampled_concat'
)
logits
=
(
LinearWrap
(
sampled
)
.
apply
(
symbf
.
batch_flatten
)
.
FullyConnected
(
'fc1'
,
out_dim
=
256
,
nl
=
tf
.
nn
.
relu
)
...
...
examples/char-rnn/char-rnn.py
View file @
e94abf66
...
...
@@ -84,7 +84,7 @@ class Model(ModelDesc):
self
.
last_state
=
tf
.
identity
(
last_state
,
'last_state'
)
# seqlen x (Bxrnnsize)
output
=
tf
.
reshape
(
tf
.
concat
(
1
,
outputs
),
[
-
1
,
param
.
rnn_size
])
# (Bxseqlen) x rnnsize
output
=
tf
.
reshape
(
tf
.
concat
_v2
(
outputs
,
1
),
[
-
1
,
param
.
rnn_size
])
# (Bxseqlen) x rnnsize
logits
=
FullyConnected
(
'fc'
,
output
,
param
.
vocab_size
,
nl
=
tf
.
identity
)
self
.
prob
=
tf
.
nn
.
softmax
(
logits
/
param
.
softmax_temprature
)
...
...
tensorpack/dataflow/dataset/ptb.py
View file @
e94abf66
...
...
@@ -12,7 +12,7 @@ from ...utils.argtools import memoized_ignoreargs
try
:
from
tensorflow.models.rnn.ptb
import
reader
as
tfreader
except
ImportError
:
logger
.
warn_dependency
(
'PennTreeBank'
,
'tensorflow'
)
logger
.
warn_dependency
(
'PennTreeBank'
,
'tensorflow
.models.rnn.ptb.reader
'
)
__all__
=
[]
else
:
__all__
=
[
'get_PennTreeBank'
]
...
...
tensorpack/models/conv2d.py
View file @
e94abf66
...
...
@@ -53,11 +53,12 @@ def Conv2D(x, out_channel, kernel_shape,
if
split
==
1
:
conv
=
tf
.
nn
.
conv2d
(
x
,
W
,
stride
,
padding
)
else
:
inputs
=
tf
.
split
(
3
,
split
,
x
)
kernels
=
tf
.
split
(
3
,
split
,
W
)
# TODO rename to split later
inputs
=
tf
.
split_v
(
x
,
split
,
3
)
kernels
=
tf
.
split_v
(
W
,
split
,
3
)
outputs
=
[
tf
.
nn
.
conv2d
(
i
,
k
,
stride
,
padding
)
for
i
,
k
in
zip
(
inputs
,
kernels
)]
conv
=
tf
.
concat
(
3
,
outputs
)
conv
=
tf
.
concat
_v2
(
outputs
,
3
)
if
nl
is
None
:
logger
.
warn
(
"[DEPRECATED] Default ReLU nonlinearity for Conv2D and FullyConnected will be deprecated. "
...
...
@@ -130,7 +131,7 @@ def Deconv2D(x, out_shape, kernel_shape,
if
use_bias
:
b
=
tf
.
get_variable
(
'b'
,
[
out_channel
],
initializer
=
b_init
)
out_shape_dyn
=
tf
.
p
ack
([
tf
.
shape
(
x
)[
0
]]
+
shp3_dyn
)
out_shape_dyn
=
tf
.
st
ack
([
tf
.
shape
(
x
)[
0
]]
+
shp3_dyn
)
conv
=
tf
.
nn
.
conv2d_transpose
(
x
,
W
,
out_shape_dyn
,
stride4d
,
padding
=
padding
)
conv
.
set_shape
(
tf
.
TensorShape
([
None
]
+
shp3_static
))
return
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
)
if
use_bias
else
conv
,
name
=
'output'
)
tensorpack/models/image_sample.py
View file @
e94abf66
...
...
@@ -74,14 +74,14 @@ def ImageSample(inputs, borderMode='repeat'):
diff
=
mapping
-
lcoor
neg_diff
=
1.0
-
diff
# bxh2xw2x2
lcoory
,
lcoorx
=
tf
.
split
(
3
,
2
,
lcoor
)
ucoory
,
ucoorx
=
tf
.
split
(
3
,
2
,
ucoor
)
lcoory
,
lcoorx
=
tf
.
split
_v
(
lcoor
,
2
,
3
)
ucoory
,
ucoorx
=
tf
.
split
_v
(
ucoor
,
2
,
3
)
lyux
=
tf
.
concat
(
3
,
[
lcoory
,
ucoorx
]
)
uylx
=
tf
.
concat
(
3
,
[
ucoory
,
lcoorx
]
)
lyux
=
tf
.
concat
_v2
([
lcoory
,
ucoorx
],
3
)
uylx
=
tf
.
concat
_v2
([
ucoory
,
lcoorx
],
3
)
diffy
,
diffx
=
tf
.
split
(
3
,
2
,
diff
)
neg_diffy
,
neg_diffx
=
tf
.
split
(
3
,
2
,
neg_diff
)
diffy
,
diffx
=
tf
.
split
_v
(
diff
,
2
,
3
)
neg_diffy
,
neg_diffx
=
tf
.
split
_v
(
neg_diff
,
2
,
3
)
# prod = tf.reduce_prod(diff, 3, keep_dims=True)
# diff = tf.Print(diff, [tf.is_finite(tf.reduce_sum(diff)), tf.shape(prod),
...
...
tensorpack/models/pool.py
View file @
e94abf66
...
...
@@ -73,8 +73,8 @@ def GlobalAvgPooling(x):
def
UnPooling2x2ZeroFilled
(
x
):
out
=
tf
.
concat
(
3
,
[
x
,
tf
.
zeros_like
(
x
)]
)
out
=
tf
.
concat
(
2
,
[
out
,
tf
.
zeros_like
(
out
)]
)
out
=
tf
.
concat
_v2
([
x
,
tf
.
zeros_like
(
x
)],
3
)
out
=
tf
.
concat
_v2
([
out
,
tf
.
zeros_like
(
out
)],
2
)
sh
=
x
.
get_shape
()
.
as_list
()
if
None
not
in
sh
[
1
:]:
...
...
@@ -82,7 +82,7 @@ def UnPooling2x2ZeroFilled(x):
return
tf
.
reshape
(
out
,
out_size
)
else
:
shv
=
tf
.
shape
(
x
)
ret
=
tf
.
reshape
(
out
,
tf
.
p
ack
([
-
1
,
shv
[
1
]
*
2
,
shv
[
2
]
*
2
,
sh
[
3
]]))
ret
=
tf
.
reshape
(
out
,
tf
.
st
ack
([
-
1
,
shv
[
1
]
*
2
,
shv
[
2
]
*
2
,
sh
[
3
]]))
ret
.
set_shape
([
None
,
None
,
None
,
sh
[
3
]])
return
ret
...
...
@@ -118,10 +118,10 @@ def FixedUnPooling(x, shape, unpool_mat=None):
fx
=
tf
.
expand_dims
(
fx
,
-
1
)
# (bchw)x1
mat
=
tf
.
expand_dims
(
symbf
.
flatten
(
unpool_mat
),
0
)
# 1x(shxsw)
prod
=
tf
.
matmul
(
fx
,
mat
)
# (bchw) x(shxsw)
prod
=
tf
.
reshape
(
prod
,
tf
.
p
ack
(
prod
=
tf
.
reshape
(
prod
,
tf
.
st
ack
(
[
-
1
,
input_shape
[
3
],
input_shape
[
1
],
input_shape
[
2
],
shape
[
0
],
shape
[
1
]]))
prod
=
tf
.
transpose
(
prod
,
[
0
,
2
,
4
,
3
,
5
,
1
])
prod
=
tf
.
reshape
(
prod
,
tf
.
p
ack
(
prod
=
tf
.
reshape
(
prod
,
tf
.
st
ack
(
[
-
1
,
input_shape
[
1
]
*
shape
[
0
],
input_shape
[
2
]
*
shape
[
1
],
input_shape
[
3
]]))
return
prod
...
...
@@ -135,7 +135,7 @@ def BilinearUpSample(x, shape):
"""
# inp_shape = tf.shape(x)
# return tf.image.resize_bilinear(x,
# tf.
p
ack([inp_shape[1]*shape,inp_shape[2]*shape]),
# tf.
st
ack([inp_shape[1]*shape,inp_shape[2]*shape]),
# align_corners=True)
inp_shape
=
x
.
get_shape
()
.
as_list
()
...
...
tensorpack/models/shapes.py
View file @
e94abf66
...
...
@@ -12,13 +12,13 @@ __all__ = ['ConcatWith']
@
layer_register
(
use_scope
=
False
,
log_shape
=
False
)
def
ConcatWith
(
x
,
dim
,
tensor
):
"""
A wrapper around `tf.concat` to support `LinearWrap`
A wrapper around `tf.concat
_v2
` to support `LinearWrap`
:param x: the input tensor
:param dim: the dimension along which to concatenate
:param tensor: a tensor or list of tensor to concatenate with x.
x will be at the beginning
:return: tf.concat
(dim, [x] + [tensor]
)
:return: tf.concat
_v2([x] + [tensor], dim
)
"""
if
type
(
tensor
)
!=
list
:
tensor
=
[
tensor
]
return
tf
.
concat
(
dim
,
[
x
]
+
tensor
)
return
tf
.
concat
_v2
([
x
]
+
tensor
,
dim
)
tensorpack/tfutils/symbolic_functions.py
View file @
e94abf66
...
...
@@ -30,7 +30,7 @@ def batch_flatten(x):
shape
=
x
.
get_shape
()
.
as_list
()[
1
:]
if
None
not
in
shape
:
return
tf
.
reshape
(
x
,
[
-
1
,
int
(
np
.
prod
(
shape
))])
return
tf
.
reshape
(
x
,
tf
.
p
ack
([
tf
.
shape
(
x
)[
0
],
-
1
]))
return
tf
.
reshape
(
x
,
tf
.
st
ack
([
tf
.
shape
(
x
)[
0
],
-
1
]))
def
class_balanced_cross_entropy
(
pred
,
label
,
name
=
'cross_entropy_loss'
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment