Commit 52447fba authored by Matthew Hausknecht's avatar Matthew Hausknecht

Cleaned up and commented trainer.

parent 09cc8de1
......@@ -6,25 +6,25 @@ from signal import SIGINT
from Communicator import ClientCommunicator, TimeoutError
class DoneError(Exception):
""" This exception is thrown when the Trainer is finished. """
def __init__(self,msg='unknown'):
self.msg = msg
def __str__(self):
return 'Done due to %s' % self.msg
class DummyPopen(object):
def __init__(self,pid):
""" Emulates a Popen object. """
def __init__(self, pid):
self.pid = pid
def poll(self):
try:
os.kill(self.pid,0)
os.kill(self.pid, 0)
return None
except OSError:
return 0
def send_signal(self,sig):
def send_signal(self, sig):
try:
os.kill(self.pid,sig)
os.kill(self.pid, sig)
except OSError:
pass
......@@ -32,50 +32,55 @@ class Trainer(object):
""" Trainer is responsible for setting up the players and game.
"""
def __init__(self, args, rng=numpy.random.RandomState()):
self._args = args
self._numOffense = self._args.numOffense
self._numDefense = self._args.numDefense
self._teams = []
self._lastTrialStart = -1
self._numFrames = 0
self._lastFrameBallTouched = -1
self._maxTrials = self._args.numTrials
self._maxFrames = self._args.numFrames
self._rng = rng
self._playerPositions = numpy.zeros((11,2,2))
self._ballPosition = numpy.zeros(2)
self._ballHeld = numpy.zeros((11,2))
self._frame = 0
self._SP = {}
self.NUM_FRAMES_TO_HOLD = 2
self.HOLD_FACTOR = 1.5
self.PITCH_WIDTH = 68.0
self.PITCH_LENGTH = 105.0
# Trial will end if the ball is untouched for this many steps
self.UNTOUCHED_LENGTH = 100
self._rng = rng # The Random Number Generator
self._numOffense = args.numOffense # Number offensive players
self._numDefense = args.numDefense # Number defensive players
self._maxTrials = args.numTrials # Maximum number of trials to play
self._maxFrames = args.numFrames # Maximum number of frames to play
# =============== FIELD DIMENSIONS =============== #
self.NUM_FRAMES_TO_HOLD = 2 # Hold ball this many frames to capture
self.HOLD_FACTOR = 1.5 # Gain to calculate ball control
self.PITCH_WIDTH = 68.0 # Width of the field
self.PITCH_LENGTH = 105.0 # Length of field in long-direction
self.UNTOUCHED_LENGTH = 100 # Trial will end if ball untouched for this long
# allowedBallX, allowedBallY defines the usable area of the playfield
self._allowedBallX = numpy.array([-0.1,0.5 * self.PITCH_LENGTH])
self._allowedBallY = numpy.array([-0.5 * self.PITCH_WIDTH,
0.5 * self.PITCH_WIDTH])
self._numTrials = 0
self._numGoals = 0
self._numBallsCaptured = 0
self._numBallsOOB = 0
# Indicates if a learning agent is active
self._agent = not self._args.no_agent
self._agentTeam = ''
self._agentNumInt = -1
self._agentNumExt = -1
self._isPlaying = False
self._agentPopen = None
self._allowedBallX = numpy.array([-0.1, 0.5 * self.PITCH_LENGTH])
self._allowedBallY = numpy.array([-0.5 * self.PITCH_WIDTH, 0.5 * self.PITCH_WIDTH])
# =============== COUNTERS =============== #
self._numFrames = 0 # Number of frames seen in HFO trials
self._frame = 0 # Current frame id
self._lastTrialStart = -1 # Frame Id in which the last trial started
self._lastFrameBallTouched = -1 # Frame Id in which ball was last touched
# =============== TRIAL RESULTS =============== #
self._numTrials = 0 # Total number of HFO trials
self._numGoals = 0 # Trials in which the offense scored a goal
self._numBallsCaptured = 0 # Trials in which defense captured the ball
self._numBallsOOB = 0 # Trials in which ball went out of bounds
self._numOutOfTime = 0 # Trials that ran out of time
# =============== AGENT =============== #
self._agent = not args.no_agent # Indicates if a learning agent is active
self._agent_play_offense = args.play_offense # Agent's role
self._agentTeam = '' # Name of the team the agent is playing for
self._agentNumInt = -1 # Agent's internal team number
self._agentNumExt = -1 # Agent's external team number
# =============== MISC =============== #
self._offenseTeam = '' # Name of the offensive team
self._defenseTeam = '' # Name of the defensive team
self._playerPositions = numpy.zeros((11,2,2)) # Positions of the players
self._ballPosition = numpy.zeros(2) # Position of the ball
self._ballHeld = numpy.zeros((11,2)) # Track player holding the ball
self._teams = [] # Team indexes for offensive and defensive teams
self._SP = {} # Sever Parameters. Recieved when connecting to the server.
self._isPlaying = False # Is a game being played?
self._agentPopen = None # Agent's process
self.initMsgHandlers()
def launch_agent(self):
"""Launch the learning agent using the start.sh script and return a
DummyPopen for the process.
"""
print '[Trainer] Launching Agent'
AGENT_DIR = os.path.dirname(os.path.realpath(__file__))
AGENT_CMD = 'start_agent.sh -t %s -u %i'
os.chdir(AGENT_DIR)
if self._args.play_offense:
if self._agent_play_offense:
assert self._numOffense > 0
self._agentTeam = self._offenseTeam
self._agentNumInt = 1 if self._numOffense == 1 \
......@@ -87,7 +92,7 @@ class Trainer(object):
else self._rng.randint(0, self._numDefense)
self._agentNumExt = self.convertToExtPlayer(self._agentTeam,
self._agentNumInt)
agentCmd = AGENT_CMD % (self._agentTeam, self._agentNumExt)
agentCmd = 'start_agent.sh -t %s -u %i'%(self._agentTeam, self._agentNumExt)
agentCmd = agentCmd.split(' ')
p = subprocess.Popen(agentCmd)
p.wait()
......@@ -137,12 +142,14 @@ class Trainer(object):
return self._teams.index(team_name)
def parseMsg(self, msg):
""" Parse a message """
assert(msg[0] == '(')
res,ind = self.__parseMsg(msg,1)
assert(ind == len(msg)),msg
res, ind = self.__parseMsg(msg,1)
assert(ind == len(msg)), msg
return res
def __parseMsg(self,msg,ind):
def __parseMsg(self, msg, ind):
""" Recursively parse a message. """
res = []
while True:
if msg[ind] == '"':
......@@ -173,7 +180,8 @@ class Trainer(object):
# self.send('(eye on)')
self.send('(ear on)')
def _hear(self,body):
def _hear(self, body):
""" Handle a hear message. """
timestep,playerInfo,msg = body
if len(playerInfo) != 3:
return
......@@ -198,6 +206,7 @@ class Trainer(object):
print '[Trainer] Unhandled message from agent: %s' % msg
def initMsgHandlers(self):
""" Create handlers for different messages. """
self._msgHandlers = []
self.ignoreMsg('player_param')
self.ignoreMsg('player_type')
......@@ -209,13 +218,16 @@ class Trainer(object):
self.registerMsgHandler(self._handleSP,'server_param')
self.registerMsgHandler(self._hear,'hear')
def recv(self,retryCount=None):
def recv(self, retryCount=None):
""" Recieve a message. Retry a specified number of times. """
return self._comm.recvMsg(retryCount=retryCount).strip()
def send(self,msg):
def send(self, msg):
""" Send a message. """
self._comm.sendMsg(msg)
def checkMsg(self,expectedMsg,retryCount=None):
def checkMsg(self, expectedMsg, retryCount=None):
""" Check that the next message is same as expected message. """
msg = self.recv(retryCount)
if msg != expectedMsg:
print >>sys.stderr,'[Trainer] Error with message'
......@@ -224,7 +236,8 @@ class Trainer(object):
print >>sys.stderr,len(expectedMsg),len(msg)
raise ValueError
def extractPoint(self,msg):
def extractPoint(self, msg):
""" Extract a point from the provided message. """
return numpy.array(map(float,msg[:2]))
def convertToExtPlayer(self, team, num):
......@@ -235,13 +248,18 @@ class Trainer(object):
else:
return self._defenseOrder[num]
def convertFromExtPlayer(self,team,num):
def convertFromExtPlayer(self, team, num):
""" Maps external player number to internal player number. """
if team == self._offenseTeam:
return self._offenseOrder.index(num)
else:
return self._defenseOrder.index(num)
def seeGlobal(self,body):
def seeGlobal(self, body):
"""Send a look message to extract global information on ball and
player positions.
"""
self.send('(look)')
self._frame = int(body[0])
for obj in body[1:]:
......@@ -272,12 +290,14 @@ class Trainer(object):
print '[Trainer] Updating handler for %s' % (' '.join(args))
self._msgHandlers[i] = [args,handler]
def unregisterMsgHandler(self,*args):
def unregisterMsgHandler(self, *args):
""" Delete a message handler. """
i,_,_ = self._findHandlerInd(args)
assert(i >= 0)
del self._msgHandlers[i]
def _findHandlerInd(self,msg):
def _findHandlerInd(self, msg):
""" Find the handler for a particular message. """
msg = list(msg)
for i,(partial,handler) in enumerate(self._msgHandlers):
recPartial = msg[:len(partial)]
......@@ -285,7 +305,8 @@ class Trainer(object):
return i,len(partial),handler
return -1,None,None
def handleMsg(self,msg):
def handleMsg(self, msg):
""" Handle a message using the registered handlers. """
i,prefixLength,handler = self._findHandlerInd(msg)
if i < 0:
print '[Trainer] Unhandled message:',msg[0:2]
......@@ -293,9 +314,11 @@ class Trainer(object):
handler(msg[prefixLength:])
def ignoreMsg(self,*args,**kwargs):
""" Ignore a certain type of message. """
self.registerMsgHandler(lambda x: None,*args,**kwargs)
def _handleSP(self,body):
def _handleSP(self, body):
""" Handler for the sever params message. """
for param in body:
try:
val = int(param[1])
......@@ -307,19 +330,25 @@ class Trainer(object):
self._SP[param[0]] = val
def listenAndProcess(self):
""" Gather messages and process them. """
msg = self.recv()
assert((msg[0] == '(') and (msg[-1] == ')')),'|%s|' % msg
msg = self.parseMsg(msg)
self.handleMsg(msg)
def _readTeamNames(self,body):
""" Read the names of each of the teams. """
self._teams = []
for _,_,team in body:
self._teams.append(team)
time.sleep(0.1)
self.send('(team_names)')
def waitOnTeam(self,first):
def waitOnTeam(self, first):
"""Wait on a given team. First indicates if this is the first team
connected or the second.
"""
self.send('(team_names)')
partial = ['ok','team_names']
self.registerMsgHandler(self._readTeamNames,*partial,quiet=True)
......@@ -329,6 +358,7 @@ class Trainer(object):
self.ignoreMsg(*partial,quiet=True)
def checkIfAllPlayersConnected(self):
""" Returns true if all players are connected. """
self.send('(look)')
partial = ['ok','look']
self._numPlayers = 0
......@@ -363,6 +393,7 @@ class Trainer(object):
totalHeld += 1
else:
self._ballHeld[i,self.teamToInd(team)] = 0
# If multiple players are close to the ball, no-one is holding
if totalHeld > 1:
self._ballHeld[:,:] = 0
inds = numpy.transpose((self._ballHeld >= self.NUM_FRAMES_TO_HOLD).nonzero())
......@@ -524,6 +555,7 @@ class Trainer(object):
result = 'Defense Captured'
elif self._frame - self._lastFrameBallTouched > self.UNTOUCHED_LENGTH:
self._lastFrameBallTouched = self._frame
self._numOutOfTime += 1
result = 'Ball untouched for too long'
else:
print '[Trainer] Error: Unable to detect reason for End of Trial!'
......@@ -558,6 +590,7 @@ class Trainer(object):
print 'Goals : %i' % self._numGoals
print 'Defense Captured : %i' % self._numBallsCaptured
print 'Balls Out of Bounds: %i' % self._numBallsOOB
print 'Out of Time : %i' % self._numOutOfTime
def checkLive(self, necProcesses):
"""Returns true if each of the necessary processes is still alive and
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment