Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b47c184e
Commit
b47c184e
authored
Aug 17, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
huber loss
parent
212cfd90
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
9 additions
and
12 deletions
+9
-12
examples/Atari2600/common.py
examples/Atari2600/common.py
+1
-1
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+0
-7
tensorpack/callbacks/stat.py
tensorpack/callbacks/stat.py
+1
-1
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+7
-3
No files found.
examples/Atari2600/common.py
View file @
b47c184e
...
...
@@ -90,7 +90,7 @@ class Evaluator(Callback):
self
.
output_names
=
output_names
def
_setup_graph
(
self
):
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
20
)
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict_func
(
self
.
input_names
,
self
.
output_names
)]
*
NR_PROC
...
...
tensorpack/RL/simulator.py
View file @
b47c184e
...
...
@@ -126,7 +126,6 @@ class SimulatorMaster(threading.Thread):
def
f
():
msg
=
self
.
send_queue
.
get
()
# slow
self
.
s2c_socket
.
send_multipart
(
msg
,
copy
=
False
)
self
.
send_thread
=
LoopThread
(
f
)
self
.
send_thread
.
daemon
=
True
...
...
@@ -142,11 +141,7 @@ class SimulatorMaster(threading.Thread):
def
run
(
self
):
self
.
clients
=
defaultdict
(
self
.
ClientState
)
#cnt = 0
while
True
:
#cnt += 1
#if cnt % 3000 == 0:
#print_total_timer()
msg
=
loads
(
self
.
c2s_socket
.
recv
(
copy
=
False
)
.
bytes
)
ident
,
state
,
reward
,
isOver
=
msg
client
=
self
.
clients
[
ident
]
...
...
@@ -179,7 +174,6 @@ class SimulatorMaster(threading.Thread):
def
__del__
(
self
):
self
.
context
.
destroy
(
linger
=
0
)
class
SimulatorProcessDF
(
SimulatorProcessBase
):
""" A simulator which contains a forward model itself, allowing
it to produce data points directly """
...
...
@@ -208,7 +202,6 @@ class SimulatorProcessDF(SimulatorProcessBase):
def
get_data
(
self
):
pass
class
SimulatorProcessSharedWeight
(
SimulatorProcessDF
):
""" A simulator process with an extra thread waiting for event,
and take shared weight from shm.
...
...
tensorpack/callbacks/stat.py
View file @
b47c184e
# -*- coding: utf-8 -*-
# File: s
ummary
.py
# File: s
tat
.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
...
...
tensorpack/tfutils/symbolic_functions.py
View file @
b47c184e
...
...
@@ -78,12 +78,16 @@ def rms(x, name=None):
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
def
clipped_l2_loss
(
x
,
name
=
None
):
def
huber_loss
(
x
,
delta
=
1
,
name
=
None
):
if
name
is
None
:
name
=
'
clipped_l2
_loss'
name
=
'
huber
_loss'
sqrcost
=
tf
.
square
(
x
)
abscost
=
tf
.
abs
(
x
)
return
tf
.
select
(
abscost
<
1
,
sqrcost
,
abscost
,
name
=
name
)
return
tf
.
reduce_sum
(
tf
.
select
(
abscost
<
delta
,
sqrcost
*
0.5
,
abscost
*
delta
-
0.5
*
delta
**
2
),
name
=
name
)
def
get_scalar_var
(
name
,
init_value
):
return
tf
.
get_variable
(
name
,
shape
=
[],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment