Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d10cb1af
Commit
d10cb1af
authored
Nov 30, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
small updates
parent
34a5a809
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
87 additions
and
13 deletions
+87
-13
examples/GAN/InfoGAN-mnist.py
examples/GAN/InfoGAN-mnist.py
+4
-2
examples/GAN/README.md
examples/GAN/README.md
+1
-1
examples/OpenAIGym/README.md
examples/OpenAIGym/README.md
+2
-2
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+1
-0
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+1
-1
tensorpack/dataflow/imgaug/crop.py
tensorpack/dataflow/imgaug/crop.py
+31
-3
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+8
-4
tensorpack/utils/debug.py
tensorpack/utils/debug.py
+39
-0
No files found.
examples/GAN/InfoGAN-mnist.py
View file @
d10cb1af
...
@@ -117,8 +117,10 @@ def sample(model_path):
...
@@ -117,8 +117,10 @@ def sample(model_path):
input_names
=
[
'zc'
],
input_names
=
[
'zc'
],
output_names
=
[
'gen/gen'
]))
output_names
=
[
'gen/gen'
]))
eye
=
[
k
for
k
in
np
.
eye
(
10
)]
eye
=
[]
inputs
=
np
.
asarray
(
eye
*
10
)
for
k
in
np
.
eye
(
10
):
eye
=
eye
+
[
k
]
*
10
inputs
=
np
.
asarray
(
eye
)
while
True
:
while
True
:
o
=
pred
([
inputs
])
o
=
pred
([
inputs
])
o
=
(
o
[
0
]
+
1
)
*
128.0
o
=
(
o
[
0
]
+
1
)
*
128.0
...
...
examples/GAN/README.md
View file @
d10cb1af
...
@@ -33,4 +33,4 @@ It requires the datasets released by the original authors.
...
@@ -33,4 +33,4 @@ It requires the datasets released by the original authors.
Reproduce a mnist experiement in InfoGAN.
Reproduce a mnist experiement in InfoGAN.
By assuming 10 latent variables corresponding to a categorical distribution and maximizing mutual information,
By assuming 10 latent variables corresponding to a categorical distribution and maximizing mutual information,
the
GAN learns to map the 10 variables to 10 digits in an unsupervised fashion
.
the
network unsupervisedly learns to map the 10 variables to 10 digits
.
examples/OpenAIGym/README.md
View file @
d10cb1af
###
c
ode and models for my Gym submissions on Atari games
###
C
ode and models for my Gym submissions on Atari games
Use
A3C in
[
Asynchronous Methods for Deep Reinforcement Learning
](
http://arxiv.org/abs/1602.01783
)
.
Implemented
A3C in
[
Asynchronous Methods for Deep Reinforcement Learning
](
http://arxiv.org/abs/1602.01783
)
.
### To train on an Atari game:
### To train on an Atari game:
...
...
tensorpack/RL/simulator.py
View file @
d10cb1af
...
@@ -146,6 +146,7 @@ class SimulatorMaster(threading.Thread):
...
@@ -146,6 +146,7 @@ class SimulatorMaster(threading.Thread):
while
True
:
while
True
:
msg
=
loads
(
self
.
c2s_socket
.
recv
(
copy
=
False
)
.
bytes
)
msg
=
loads
(
self
.
c2s_socket
.
recv
(
copy
=
False
)
.
bytes
)
ident
,
state
,
reward
,
isOver
=
msg
ident
,
state
,
reward
,
isOver
=
msg
# TODO check history and warn about dead client
client
=
self
.
clients
[
ident
]
client
=
self
.
clients
[
ident
]
# check if reward&isOver is valid
# check if reward&isOver is valid
...
...
tensorpack/callbacks/param.py
View file @
d10cb1af
...
@@ -207,7 +207,7 @@ class HyperParamSetterWithFunc(HyperParamSetter):
...
@@ -207,7 +207,7 @@ class HyperParamSetterWithFunc(HyperParamSetter):
"""Set hyperparameter by a func
"""Set hyperparameter by a func
new_value = f(epoch_num, old_value)
new_value = f(epoch_num, old_value)
"""
"""
super
(
StatMonitorParamSetter
,
self
)
.
__init__
(
param
)
super
(
HyperParamSetterWithFunc
,
self
)
.
__init__
(
param
)
self
.
f
=
func
self
.
f
=
func
def
_get_value_to_set
(
self
):
def
_get_value_to_set
(
self
):
...
...
tensorpack/dataflow/imgaug/crop.py
View file @
d10cb1af
...
@@ -10,7 +10,7 @@ from six.moves import range
...
@@ -10,7 +10,7 @@ from six.moves import range
import
numpy
as
np
import
numpy
as
np
__all__
=
[
'RandomCrop'
,
'CenterCrop'
,
'FixedCrop'
,
__all__
=
[
'RandomCrop'
,
'CenterCrop'
,
'FixedCrop'
,
'RandomCropRandomShape'
,
'perturb_BB'
]
'RandomCropRandomShape'
,
'perturb_BB'
,
'RandomCropAroundBox'
]
class
RandomCrop
(
ImageAugmentor
):
class
RandomCrop
(
ImageAugmentor
):
""" Randomly crop the image into a smaller one """
""" Randomly crop the image into a smaller one """
...
@@ -109,7 +109,7 @@ def perturb_BB(image_shape, bb, max_pertub_pixel,
...
@@ -109,7 +109,7 @@ def perturb_BB(image_shape, bb, max_pertub_pixel,
return
bb
return
bb
class
RandomCrop
RandomShape
(
ImageAugmentor
):
class
RandomCrop
AroundBox
(
ImageAugmentor
):
"""
"""
Crop a box around a bounding box
Crop a box around a bounding box
"""
"""
...
@@ -118,7 +118,7 @@ class RandomCropRandomShape(ImageAugmentor):
...
@@ -118,7 +118,7 @@ class RandomCropRandomShape(ImageAugmentor):
:param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)]
:param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)]
:param max_aspect_ratio_diff: keep aspect ratio within the range
:param max_aspect_ratio_diff: keep aspect ratio within the range
"""
"""
super
(
RandomCrop
RandomShape
,
self
)
.
__init__
()
super
(
RandomCrop
AroundBox
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
def
_get_augment_params
(
self
,
img
):
def
_get_augment_params
(
self
,
img
):
...
@@ -135,5 +135,33 @@ class RandomCropRandomShape(ImageAugmentor):
...
@@ -135,5 +135,33 @@ class RandomCropRandomShape(ImageAugmentor):
def
_fprop_coord
(
self
,
coord
,
param
):
def
_fprop_coord
(
self
,
coord
,
param
):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
RandomCropRandomShape
(
ImageAugmentor
):
def
__init__
(
self
,
wmin
,
hmin
,
wmax
=
None
,
hmax
=
None
,
max_aspect_ratio
=
None
):
"""
Randomly crop a box of shape (h, w), sampled from [min, max](inclusive).
If max is None, will use the input image shape.
max_aspect_ratio is the upper bound of max(w,h)/min(w,h)
"""
if
max_aspect_ratio
is
None
:
max_aspect_ratio
=
9999999
self
.
_init
(
locals
())
def
_get_augment_params
(
self
,
img
):
hmax
=
self
.
hmax
or
img
.
shape
[
0
]
wmax
=
self
.
wmax
or
img
.
shape
[
1
]
h
=
self
.
rng
.
randint
(
self
.
hmin
,
hmax
+
1
)
w
=
self
.
rng
.
randint
(
self
.
wmin
,
wmax
+
1
)
diffh
=
img
.
shape
[
0
]
-
h
y0
=
0
if
diffh
==
0
else
self
.
rng
.
randint
(
diffh
)
diffw
=
img
.
shape
[
1
]
-
w
x0
=
0
if
diffw
==
0
else
self
.
rng
.
randint
(
diffw
)
return
(
y0
,
x0
,
h
,
w
)
def
_augment
(
self
,
img
,
param
):
y0
,
x0
,
h
,
w
=
param
return
img
[
y0
:
y0
+
h
,
x0
:
x0
+
w
]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
print
(
perturb_BB
([
100
,
100
],
Rect
(
3
,
3
,
50
,
50
),
50
))
print
(
perturb_BB
([
100
,
100
],
Rect
(
3
,
3
,
50
,
50
),
50
))
tensorpack/train/multigpu.py
View file @
d10cb1af
...
@@ -91,18 +91,22 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase,
...
@@ -91,18 +91,22 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase,
MultiGPUTrainer
,
MultiGPUTrainer
,
SingleCostFeedlessTrainer
,
SingleCostFeedlessTrainer
,
MultiPredictorTowerTrainer
):
MultiPredictorTowerTrainer
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
,
average_gradient
=
True
):
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_build_enque_thread
(
input_queue
)
self
.
_build_enque_thread
(
input_queue
)
self
.
average_gradient
=
average_gradient
def
_setup
(
self
):
def
_setup
(
self
):
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
gradprocs
=
self
.
model
.
get_gradient_processor
()
gradprocs
=
self
.
model
.
get_gradient_processor
()
# pretend to average the grads, in order to make async and
if
self
.
average_gradient
and
self
.
config
.
nr_tower
>
1
:
# sync have consistent effective learning rate
# pretend to average the grads, in order to make async and
if
self
.
config
.
nr_tower
>
1
:
# sync have consistent effective learning rate
gradprocs
.
insert
(
0
,
ScaleGradient
((
'.*'
,
1.0
/
self
.
config
.
nr_tower
),
log
=
False
))
gradprocs
.
insert
(
0
,
ScaleGradient
((
'.*'
,
1.0
/
self
.
config
.
nr_tower
),
log
=
False
))
grad_list
=
[
apply_grad_processors
(
g
,
gradprocs
)
for
g
in
grad_list
]
grad_list
=
[
apply_grad_processors
(
g
,
gradprocs
)
for
g
in
grad_list
]
...
...
tensorpack/utils/debug.py
0 → 100644
View file @
d10cb1af
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: debug.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
sys
__all__
=
[
'enable_call_trace'
]
def
enable_call_trace
():
def
tracer
(
frame
,
event
,
arg
):
if
event
==
'call'
:
co
=
frame
.
f_code
func_name
=
co
.
co_name
if
func_name
==
'write'
or
func_name
==
'print'
:
# ignore write() calls from print statements
return
func_line_no
=
frame
.
f_lineno
func_filename
=
co
.
co_filename
caller
=
frame
.
f_back
if
caller
:
caller_line_no
=
caller
.
f_lineno
caller_filename
=
caller
.
f_code
.
co_filename
print
'Call to `
%
s` on line
%
s:
%
s from
%
s:
%
s'
%
\
(
func_name
,
func_filename
,
func_line_no
,
caller_filename
,
caller_line_no
)
return
sys
.
settrace
(
tracer
)
if
__name__
==
'__main__'
:
enable_call_trace
()
def
b
(
a
):
print
2
def
a
():
print
1
b
(
1
)
a
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment