Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
Seminar-HFO
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
Seminar-HFO
Commits
52447fba
Commit
52447fba
authored
Feb 27, 2015
by
Matthew Hausknecht
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Cleaned up and commented trainer.
parent
09cc8de1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
94 additions
and
61 deletions
+94
-61
bin/Trainer.py
bin/Trainer.py
+94
-61
No files found.
bin/Trainer.py
View file @
52447fba
...
...
@@ -6,25 +6,25 @@ from signal import SIGINT
from
Communicator
import
ClientCommunicator
,
TimeoutError
class
DoneError
(
Exception
):
""" This exception is thrown when the Trainer is finished. """
def
__init__
(
self
,
msg
=
'unknown'
):
self
.
msg
=
msg
def
__str__
(
self
):
return
'Done due to
%
s'
%
self
.
msg
class
DummyPopen
(
object
):
def
__init__
(
self
,
pid
):
""" Emulates a Popen object. """
def
__init__
(
self
,
pid
):
self
.
pid
=
pid
def
poll
(
self
):
try
:
os
.
kill
(
self
.
pid
,
0
)
os
.
kill
(
self
.
pid
,
0
)
return
None
except
OSError
:
return
0
def
send_signal
(
self
,
sig
):
def
send_signal
(
self
,
sig
):
try
:
os
.
kill
(
self
.
pid
,
sig
)
os
.
kill
(
self
.
pid
,
sig
)
except
OSError
:
pass
...
...
@@ -32,50 +32,55 @@ class Trainer(object):
""" Trainer is responsible for setting up the players and game.
"""
def
__init__
(
self
,
args
,
rng
=
numpy
.
random
.
RandomState
()):
self
.
_args
=
args
self
.
_numOffense
=
self
.
_args
.
numOffense
self
.
_numDefense
=
self
.
_args
.
numDefense
self
.
_teams
=
[]
self
.
_lastTrialStart
=
-
1
self
.
_numFrames
=
0
self
.
_lastFrameBallTouched
=
-
1
self
.
_maxTrials
=
self
.
_args
.
numTrials
self
.
_maxFrames
=
self
.
_args
.
numFrames
self
.
_rng
=
rng
self
.
_playerPositions
=
numpy
.
zeros
((
11
,
2
,
2
))
self
.
_ballPosition
=
numpy
.
zeros
(
2
)
self
.
_ballHeld
=
numpy
.
zeros
((
11
,
2
))
self
.
_frame
=
0
self
.
_SP
=
{}
self
.
NUM_FRAMES_TO_HOLD
=
2
self
.
HOLD_FACTOR
=
1.5
self
.
PITCH_WIDTH
=
68.0
self
.
PITCH_LENGTH
=
105.0
# Trial will end if the ball is untouched for this many steps
self
.
UNTOUCHED_LENGTH
=
100
self
.
_rng
=
rng
# The Random Number Generator
self
.
_numOffense
=
args
.
numOffense
# Number offensive players
self
.
_numDefense
=
args
.
numDefense
# Number defensive players
self
.
_maxTrials
=
args
.
numTrials
# Maximum number of trials to play
self
.
_maxFrames
=
args
.
numFrames
# Maximum number of frames to play
# =============== FIELD DIMENSIONS =============== #
self
.
NUM_FRAMES_TO_HOLD
=
2
# Hold ball this many frames to capture
self
.
HOLD_FACTOR
=
1.5
# Gain to calculate ball control
self
.
PITCH_WIDTH
=
68.0
# Width of the field
self
.
PITCH_LENGTH
=
105.0
# Length of field in long-direction
self
.
UNTOUCHED_LENGTH
=
100
# Trial will end if ball untouched for this long
# allowedBallX, allowedBallY defines the usable area of the playfield
self
.
_allowedBallX
=
numpy
.
array
([
-
0.1
,
0.5
*
self
.
PITCH_LENGTH
])
self
.
_allowedBallY
=
numpy
.
array
([
-
0.5
*
self
.
PITCH_WIDTH
,
0.5
*
self
.
PITCH_WIDTH
])
self
.
_numTrials
=
0
self
.
_numGoals
=
0
self
.
_numBallsCaptured
=
0
self
.
_numBallsOOB
=
0
# Indicates if a learning agent is active
self
.
_agent
=
not
self
.
_args
.
no_agent
self
.
_agentTeam
=
''
self
.
_agentNumInt
=
-
1
self
.
_agentNumExt
=
-
1
self
.
_isPlaying
=
False
self
.
_agentPopen
=
None
self
.
_allowedBallX
=
numpy
.
array
([
-
0.1
,
0.5
*
self
.
PITCH_LENGTH
])
self
.
_allowedBallY
=
numpy
.
array
([
-
0.5
*
self
.
PITCH_WIDTH
,
0.5
*
self
.
PITCH_WIDTH
])
# =============== COUNTERS =============== #
self
.
_numFrames
=
0
# Number of frames seen in HFO trials
self
.
_frame
=
0
# Current frame id
self
.
_lastTrialStart
=
-
1
# Frame Id in which the last trial started
self
.
_lastFrameBallTouched
=
-
1
# Frame Id in which ball was last touched
# =============== TRIAL RESULTS =============== #
self
.
_numTrials
=
0
# Total number of HFO trials
self
.
_numGoals
=
0
# Trials in which the offense scored a goal
self
.
_numBallsCaptured
=
0
# Trials in which defense captured the ball
self
.
_numBallsOOB
=
0
# Trials in which ball went out of bounds
self
.
_numOutOfTime
=
0
# Trials that ran out of time
# =============== AGENT =============== #
self
.
_agent
=
not
args
.
no_agent
# Indicates if a learning agent is active
self
.
_agent_play_offense
=
args
.
play_offense
# Agent's role
self
.
_agentTeam
=
''
# Name of the team the agent is playing for
self
.
_agentNumInt
=
-
1
# Agent's internal team number
self
.
_agentNumExt
=
-
1
# Agent's external team number
# =============== MISC =============== #
self
.
_offenseTeam
=
''
# Name of the offensive team
self
.
_defenseTeam
=
''
# Name of the defensive team
self
.
_playerPositions
=
numpy
.
zeros
((
11
,
2
,
2
))
# Positions of the players
self
.
_ballPosition
=
numpy
.
zeros
(
2
)
# Position of the ball
self
.
_ballHeld
=
numpy
.
zeros
((
11
,
2
))
# Track player holding the ball
self
.
_teams
=
[]
# Team indexes for offensive and defensive teams
self
.
_SP
=
{}
# Sever Parameters. Recieved when connecting to the server.
self
.
_isPlaying
=
False
# Is a game being played?
self
.
_agentPopen
=
None
# Agent's process
self
.
initMsgHandlers
()
def
launch_agent
(
self
):
"""Launch the learning agent using the start.sh script and return a
DummyPopen for the process.
"""
print
'[Trainer] Launching Agent'
AGENT_DIR
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
AGENT_CMD
=
'start_agent.sh -t
%
s -u
%
i'
os
.
chdir
(
AGENT_DIR
)
if
self
.
_args
.
play_offense
:
if
self
.
_agent_play_offense
:
assert
self
.
_numOffense
>
0
self
.
_agentTeam
=
self
.
_offenseTeam
self
.
_agentNumInt
=
1
if
self
.
_numOffense
==
1
\
...
...
@@ -87,7 +92,7 @@ class Trainer(object):
else
self
.
_rng
.
randint
(
0
,
self
.
_numDefense
)
self
.
_agentNumExt
=
self
.
convertToExtPlayer
(
self
.
_agentTeam
,
self
.
_agentNumInt
)
agentCmd
=
AGENT_CMD
%
(
self
.
_agentTeam
,
self
.
_agentNumExt
)
agentCmd
=
'start_agent.sh -t
%
s -u
%
i'
%
(
self
.
_agentTeam
,
self
.
_agentNumExt
)
agentCmd
=
agentCmd
.
split
(
' '
)
p
=
subprocess
.
Popen
(
agentCmd
)
p
.
wait
()
...
...
@@ -137,12 +142,14 @@ class Trainer(object):
return
self
.
_teams
.
index
(
team_name
)
def
parseMsg
(
self
,
msg
):
""" Parse a message """
assert
(
msg
[
0
]
==
'('
)
res
,
ind
=
self
.
__parseMsg
(
msg
,
1
)
assert
(
ind
==
len
(
msg
)),
msg
res
,
ind
=
self
.
__parseMsg
(
msg
,
1
)
assert
(
ind
==
len
(
msg
)),
msg
return
res
def
__parseMsg
(
self
,
msg
,
ind
):
def
__parseMsg
(
self
,
msg
,
ind
):
""" Recursively parse a message. """
res
=
[]
while
True
:
if
msg
[
ind
]
==
'"'
:
...
...
@@ -173,7 +180,8 @@ class Trainer(object):
# self.send('(eye on)')
self
.
send
(
'(ear on)'
)
def
_hear
(
self
,
body
):
def
_hear
(
self
,
body
):
""" Handle a hear message. """
timestep
,
playerInfo
,
msg
=
body
if
len
(
playerInfo
)
!=
3
:
return
...
...
@@ -198,6 +206,7 @@ class Trainer(object):
print
'[Trainer] Unhandled message from agent:
%
s'
%
msg
def
initMsgHandlers
(
self
):
""" Create handlers for different messages. """
self
.
_msgHandlers
=
[]
self
.
ignoreMsg
(
'player_param'
)
self
.
ignoreMsg
(
'player_type'
)
...
...
@@ -209,13 +218,16 @@ class Trainer(object):
self
.
registerMsgHandler
(
self
.
_handleSP
,
'server_param'
)
self
.
registerMsgHandler
(
self
.
_hear
,
'hear'
)
def
recv
(
self
,
retryCount
=
None
):
def
recv
(
self
,
retryCount
=
None
):
""" Recieve a message. Retry a specified number of times. """
return
self
.
_comm
.
recvMsg
(
retryCount
=
retryCount
)
.
strip
()
def
send
(
self
,
msg
):
def
send
(
self
,
msg
):
""" Send a message. """
self
.
_comm
.
sendMsg
(
msg
)
def
checkMsg
(
self
,
expectedMsg
,
retryCount
=
None
):
def
checkMsg
(
self
,
expectedMsg
,
retryCount
=
None
):
""" Check that the next message is same as expected message. """
msg
=
self
.
recv
(
retryCount
)
if
msg
!=
expectedMsg
:
print
>>
sys
.
stderr
,
'[Trainer] Error with message'
...
...
@@ -224,7 +236,8 @@ class Trainer(object):
print
>>
sys
.
stderr
,
len
(
expectedMsg
),
len
(
msg
)
raise
ValueError
def
extractPoint
(
self
,
msg
):
def
extractPoint
(
self
,
msg
):
""" Extract a point from the provided message. """
return
numpy
.
array
(
map
(
float
,
msg
[:
2
]))
def
convertToExtPlayer
(
self
,
team
,
num
):
...
...
@@ -235,13 +248,18 @@ class Trainer(object):
else
:
return
self
.
_defenseOrder
[
num
]
def
convertFromExtPlayer
(
self
,
team
,
num
):
def
convertFromExtPlayer
(
self
,
team
,
num
):
""" Maps external player number to internal player number. """
if
team
==
self
.
_offenseTeam
:
return
self
.
_offenseOrder
.
index
(
num
)
else
:
return
self
.
_defenseOrder
.
index
(
num
)
def
seeGlobal
(
self
,
body
):
def
seeGlobal
(
self
,
body
):
"""Send a look message to extract global information on ball and
player positions.
"""
self
.
send
(
'(look)'
)
self
.
_frame
=
int
(
body
[
0
])
for
obj
in
body
[
1
:]:
...
...
@@ -272,12 +290,14 @@ class Trainer(object):
print
'[Trainer] Updating handler for
%
s'
%
(
' '
.
join
(
args
))
self
.
_msgHandlers
[
i
]
=
[
args
,
handler
]
def
unregisterMsgHandler
(
self
,
*
args
):
def
unregisterMsgHandler
(
self
,
*
args
):
""" Delete a message handler. """
i
,
_
,
_
=
self
.
_findHandlerInd
(
args
)
assert
(
i
>=
0
)
del
self
.
_msgHandlers
[
i
]
def
_findHandlerInd
(
self
,
msg
):
def
_findHandlerInd
(
self
,
msg
):
""" Find the handler for a particular message. """
msg
=
list
(
msg
)
for
i
,(
partial
,
handler
)
in
enumerate
(
self
.
_msgHandlers
):
recPartial
=
msg
[:
len
(
partial
)]
...
...
@@ -285,7 +305,8 @@ class Trainer(object):
return
i
,
len
(
partial
),
handler
return
-
1
,
None
,
None
def
handleMsg
(
self
,
msg
):
def
handleMsg
(
self
,
msg
):
""" Handle a message using the registered handlers. """
i
,
prefixLength
,
handler
=
self
.
_findHandlerInd
(
msg
)
if
i
<
0
:
print
'[Trainer] Unhandled message:'
,
msg
[
0
:
2
]
...
...
@@ -293,9 +314,11 @@ class Trainer(object):
handler
(
msg
[
prefixLength
:])
def
ignoreMsg
(
self
,
*
args
,
**
kwargs
):
""" Ignore a certain type of message. """
self
.
registerMsgHandler
(
lambda
x
:
None
,
*
args
,
**
kwargs
)
def
_handleSP
(
self
,
body
):
def
_handleSP
(
self
,
body
):
""" Handler for the sever params message. """
for
param
in
body
:
try
:
val
=
int
(
param
[
1
])
...
...
@@ -307,19 +330,25 @@ class Trainer(object):
self
.
_SP
[
param
[
0
]]
=
val
def
listenAndProcess
(
self
):
""" Gather messages and process them. """
msg
=
self
.
recv
()
assert
((
msg
[
0
]
==
'('
)
and
(
msg
[
-
1
]
==
')'
)),
'|
%
s|'
%
msg
msg
=
self
.
parseMsg
(
msg
)
self
.
handleMsg
(
msg
)
def
_readTeamNames
(
self
,
body
):
""" Read the names of each of the teams. """
self
.
_teams
=
[]
for
_
,
_
,
team
in
body
:
self
.
_teams
.
append
(
team
)
time
.
sleep
(
0.1
)
self
.
send
(
'(team_names)'
)
def
waitOnTeam
(
self
,
first
):
def
waitOnTeam
(
self
,
first
):
"""Wait on a given team. First indicates if this is the first team
connected or the second.
"""
self
.
send
(
'(team_names)'
)
partial
=
[
'ok'
,
'team_names'
]
self
.
registerMsgHandler
(
self
.
_readTeamNames
,
*
partial
,
quiet
=
True
)
...
...
@@ -329,6 +358,7 @@ class Trainer(object):
self
.
ignoreMsg
(
*
partial
,
quiet
=
True
)
def
checkIfAllPlayersConnected
(
self
):
""" Returns true if all players are connected. """
self
.
send
(
'(look)'
)
partial
=
[
'ok'
,
'look'
]
self
.
_numPlayers
=
0
...
...
@@ -363,6 +393,7 @@ class Trainer(object):
totalHeld
+=
1
else
:
self
.
_ballHeld
[
i
,
self
.
teamToInd
(
team
)]
=
0
# If multiple players are close to the ball, no-one is holding
if
totalHeld
>
1
:
self
.
_ballHeld
[:,:]
=
0
inds
=
numpy
.
transpose
((
self
.
_ballHeld
>=
self
.
NUM_FRAMES_TO_HOLD
)
.
nonzero
())
...
...
@@ -524,6 +555,7 @@ class Trainer(object):
result
=
'Defense Captured'
elif
self
.
_frame
-
self
.
_lastFrameBallTouched
>
self
.
UNTOUCHED_LENGTH
:
self
.
_lastFrameBallTouched
=
self
.
_frame
self
.
_numOutOfTime
+=
1
result
=
'Ball untouched for too long'
else
:
print
'[Trainer] Error: Unable to detect reason for End of Trial!'
...
...
@@ -558,6 +590,7 @@ class Trainer(object):
print
'Goals :
%
i'
%
self
.
_numGoals
print
'Defense Captured :
%
i'
%
self
.
_numBallsCaptured
print
'Balls Out of Bounds:
%
i'
%
self
.
_numBallsOOB
print
'Out of Time :
%
i'
%
self
.
_numOutOfTime
def
checkLive
(
self
,
necProcesses
):
"""Returns true if each of the necessary processes is still alive and
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment