Commit 52447fba authored by Matthew Hausknecht's avatar Matthew Hausknecht

Cleaned up and commented trainer.

parent 09cc8de1
...@@ -6,25 +6,25 @@ from signal import SIGINT ...@@ -6,25 +6,25 @@ from signal import SIGINT
from Communicator import ClientCommunicator, TimeoutError from Communicator import ClientCommunicator, TimeoutError
class DoneError(Exception): class DoneError(Exception):
""" This exception is thrown when the Trainer is finished. """
def __init__(self,msg='unknown'): def __init__(self,msg='unknown'):
self.msg = msg self.msg = msg
def __str__(self): def __str__(self):
return 'Done due to %s' % self.msg return 'Done due to %s' % self.msg
class DummyPopen(object): class DummyPopen(object):
def __init__(self,pid): """ Emulates a Popen object. """
def __init__(self, pid):
self.pid = pid self.pid = pid
def poll(self): def poll(self):
try: try:
os.kill(self.pid,0) os.kill(self.pid, 0)
return None return None
except OSError: except OSError:
return 0 return 0
def send_signal(self, sig):
def send_signal(self,sig):
try: try:
os.kill(self.pid,sig) os.kill(self.pid, sig)
except OSError: except OSError:
pass pass
...@@ -32,50 +32,55 @@ class Trainer(object): ...@@ -32,50 +32,55 @@ class Trainer(object):
""" Trainer is responsible for setting up the players and game. """ Trainer is responsible for setting up the players and game.
""" """
def __init__(self, args, rng=numpy.random.RandomState()): def __init__(self, args, rng=numpy.random.RandomState()):
self._args = args self._rng = rng # The Random Number Generator
self._numOffense = self._args.numOffense self._numOffense = args.numOffense # Number offensive players
self._numDefense = self._args.numDefense self._numDefense = args.numDefense # Number defensive players
self._teams = [] self._maxTrials = args.numTrials # Maximum number of trials to play
self._lastTrialStart = -1 self._maxFrames = args.numFrames # Maximum number of frames to play
self._numFrames = 0 # =============== FIELD DIMENSIONS =============== #
self._lastFrameBallTouched = -1 self.NUM_FRAMES_TO_HOLD = 2 # Hold ball this many frames to capture
self._maxTrials = self._args.numTrials self.HOLD_FACTOR = 1.5 # Gain to calculate ball control
self._maxFrames = self._args.numFrames self.PITCH_WIDTH = 68.0 # Width of the field
self._rng = rng self.PITCH_LENGTH = 105.0 # Length of field in long-direction
self._playerPositions = numpy.zeros((11,2,2)) self.UNTOUCHED_LENGTH = 100 # Trial will end if ball untouched for this long
self._ballPosition = numpy.zeros(2)
self._ballHeld = numpy.zeros((11,2))
self._frame = 0
self._SP = {}
self.NUM_FRAMES_TO_HOLD = 2
self.HOLD_FACTOR = 1.5
self.PITCH_WIDTH = 68.0
self.PITCH_LENGTH = 105.0
# Trial will end if the ball is untouched for this many steps
self.UNTOUCHED_LENGTH = 100
# allowedBallX, allowedBallY defines the usable area of the playfield # allowedBallX, allowedBallY defines the usable area of the playfield
self._allowedBallX = numpy.array([-0.1,0.5 * self.PITCH_LENGTH]) self._allowedBallX = numpy.array([-0.1, 0.5 * self.PITCH_LENGTH])
self._allowedBallY = numpy.array([-0.5 * self.PITCH_WIDTH, self._allowedBallY = numpy.array([-0.5 * self.PITCH_WIDTH, 0.5 * self.PITCH_WIDTH])
0.5 * self.PITCH_WIDTH]) # =============== COUNTERS =============== #
self._numTrials = 0 self._numFrames = 0 # Number of frames seen in HFO trials
self._numGoals = 0 self._frame = 0 # Current frame id
self._numBallsCaptured = 0 self._lastTrialStart = -1 # Frame Id in which the last trial started
self._numBallsOOB = 0 self._lastFrameBallTouched = -1 # Frame Id in which ball was last touched
# Indicates if a learning agent is active # =============== TRIAL RESULTS =============== #
self._agent = not self._args.no_agent self._numTrials = 0 # Total number of HFO trials
self._agentTeam = '' self._numGoals = 0 # Trials in which the offense scored a goal
self._agentNumInt = -1 self._numBallsCaptured = 0 # Trials in which defense captured the ball
self._agentNumExt = -1 self._numBallsOOB = 0 # Trials in which ball went out of bounds
self._isPlaying = False self._numOutOfTime = 0 # Trials that ran out of time
self._agentPopen = None # =============== AGENT =============== #
self._agent = not args.no_agent # Indicates if a learning agent is active
self._agent_play_offense = args.play_offense # Agent's role
self._agentTeam = '' # Name of the team the agent is playing for
self._agentNumInt = -1 # Agent's internal team number
self._agentNumExt = -1 # Agent's external team number
# =============== MISC =============== #
self._offenseTeam = '' # Name of the offensive team
self._defenseTeam = '' # Name of the defensive team
self._playerPositions = numpy.zeros((11,2,2)) # Positions of the players
self._ballPosition = numpy.zeros(2) # Position of the ball
self._ballHeld = numpy.zeros((11,2)) # Track player holding the ball
self._teams = [] # Team indexes for offensive and defensive teams
self._SP = {} # Sever Parameters. Recieved when connecting to the server.
self._isPlaying = False # Is a game being played?
self._agentPopen = None # Agent's process
self.initMsgHandlers() self.initMsgHandlers()
def launch_agent(self): def launch_agent(self):
"""Launch the learning agent using the start.sh script and return a
DummyPopen for the process.
"""
print '[Trainer] Launching Agent' print '[Trainer] Launching Agent'
AGENT_DIR = os.path.dirname(os.path.realpath(__file__)) if self._agent_play_offense:
AGENT_CMD = 'start_agent.sh -t %s -u %i'
os.chdir(AGENT_DIR)
if self._args.play_offense:
assert self._numOffense > 0 assert self._numOffense > 0
self._agentTeam = self._offenseTeam self._agentTeam = self._offenseTeam
self._agentNumInt = 1 if self._numOffense == 1 \ self._agentNumInt = 1 if self._numOffense == 1 \
...@@ -87,7 +92,7 @@ class Trainer(object): ...@@ -87,7 +92,7 @@ class Trainer(object):
else self._rng.randint(0, self._numDefense) else self._rng.randint(0, self._numDefense)
self._agentNumExt = self.convertToExtPlayer(self._agentTeam, self._agentNumExt = self.convertToExtPlayer(self._agentTeam,
self._agentNumInt) self._agentNumInt)
agentCmd = AGENT_CMD % (self._agentTeam, self._agentNumExt) agentCmd = 'start_agent.sh -t %s -u %i'%(self._agentTeam, self._agentNumExt)
agentCmd = agentCmd.split(' ') agentCmd = agentCmd.split(' ')
p = subprocess.Popen(agentCmd) p = subprocess.Popen(agentCmd)
p.wait() p.wait()
...@@ -137,12 +142,14 @@ class Trainer(object): ...@@ -137,12 +142,14 @@ class Trainer(object):
return self._teams.index(team_name) return self._teams.index(team_name)
def parseMsg(self, msg): def parseMsg(self, msg):
""" Parse a message """
assert(msg[0] == '(') assert(msg[0] == '(')
res,ind = self.__parseMsg(msg,1) res, ind = self.__parseMsg(msg,1)
assert(ind == len(msg)),msg assert(ind == len(msg)), msg
return res return res
def __parseMsg(self,msg,ind): def __parseMsg(self, msg, ind):
""" Recursively parse a message. """
res = [] res = []
while True: while True:
if msg[ind] == '"': if msg[ind] == '"':
...@@ -173,7 +180,8 @@ class Trainer(object): ...@@ -173,7 +180,8 @@ class Trainer(object):
# self.send('(eye on)') # self.send('(eye on)')
self.send('(ear on)') self.send('(ear on)')
def _hear(self,body): def _hear(self, body):
""" Handle a hear message. """
timestep,playerInfo,msg = body timestep,playerInfo,msg = body
if len(playerInfo) != 3: if len(playerInfo) != 3:
return return
...@@ -198,6 +206,7 @@ class Trainer(object): ...@@ -198,6 +206,7 @@ class Trainer(object):
print '[Trainer] Unhandled message from agent: %s' % msg print '[Trainer] Unhandled message from agent: %s' % msg
def initMsgHandlers(self): def initMsgHandlers(self):
""" Create handlers for different messages. """
self._msgHandlers = [] self._msgHandlers = []
self.ignoreMsg('player_param') self.ignoreMsg('player_param')
self.ignoreMsg('player_type') self.ignoreMsg('player_type')
...@@ -209,13 +218,16 @@ class Trainer(object): ...@@ -209,13 +218,16 @@ class Trainer(object):
self.registerMsgHandler(self._handleSP,'server_param') self.registerMsgHandler(self._handleSP,'server_param')
self.registerMsgHandler(self._hear,'hear') self.registerMsgHandler(self._hear,'hear')
def recv(self,retryCount=None): def recv(self, retryCount=None):
""" Recieve a message. Retry a specified number of times. """
return self._comm.recvMsg(retryCount=retryCount).strip() return self._comm.recvMsg(retryCount=retryCount).strip()
def send(self,msg): def send(self, msg):
""" Send a message. """
self._comm.sendMsg(msg) self._comm.sendMsg(msg)
def checkMsg(self,expectedMsg,retryCount=None): def checkMsg(self, expectedMsg, retryCount=None):
""" Check that the next message is same as expected message. """
msg = self.recv(retryCount) msg = self.recv(retryCount)
if msg != expectedMsg: if msg != expectedMsg:
print >>sys.stderr,'[Trainer] Error with message' print >>sys.stderr,'[Trainer] Error with message'
...@@ -224,7 +236,8 @@ class Trainer(object): ...@@ -224,7 +236,8 @@ class Trainer(object):
print >>sys.stderr,len(expectedMsg),len(msg) print >>sys.stderr,len(expectedMsg),len(msg)
raise ValueError raise ValueError
def extractPoint(self,msg): def extractPoint(self, msg):
""" Extract a point from the provided message. """
return numpy.array(map(float,msg[:2])) return numpy.array(map(float,msg[:2]))
def convertToExtPlayer(self, team, num): def convertToExtPlayer(self, team, num):
...@@ -235,13 +248,18 @@ class Trainer(object): ...@@ -235,13 +248,18 @@ class Trainer(object):
else: else:
return self._defenseOrder[num] return self._defenseOrder[num]
def convertFromExtPlayer(self,team,num): def convertFromExtPlayer(self, team, num):
""" Maps external player number to internal player number. """
if team == self._offenseTeam: if team == self._offenseTeam:
return self._offenseOrder.index(num) return self._offenseOrder.index(num)
else: else:
return self._defenseOrder.index(num) return self._defenseOrder.index(num)
def seeGlobal(self,body): def seeGlobal(self, body):
"""Send a look message to extract global information on ball and
player positions.
"""
self.send('(look)') self.send('(look)')
self._frame = int(body[0]) self._frame = int(body[0])
for obj in body[1:]: for obj in body[1:]:
...@@ -272,12 +290,14 @@ class Trainer(object): ...@@ -272,12 +290,14 @@ class Trainer(object):
print '[Trainer] Updating handler for %s' % (' '.join(args)) print '[Trainer] Updating handler for %s' % (' '.join(args))
self._msgHandlers[i] = [args,handler] self._msgHandlers[i] = [args,handler]
def unregisterMsgHandler(self,*args): def unregisterMsgHandler(self, *args):
""" Delete a message handler. """
i,_,_ = self._findHandlerInd(args) i,_,_ = self._findHandlerInd(args)
assert(i >= 0) assert(i >= 0)
del self._msgHandlers[i] del self._msgHandlers[i]
def _findHandlerInd(self,msg): def _findHandlerInd(self, msg):
""" Find the handler for a particular message. """
msg = list(msg) msg = list(msg)
for i,(partial,handler) in enumerate(self._msgHandlers): for i,(partial,handler) in enumerate(self._msgHandlers):
recPartial = msg[:len(partial)] recPartial = msg[:len(partial)]
...@@ -285,7 +305,8 @@ class Trainer(object): ...@@ -285,7 +305,8 @@ class Trainer(object):
return i,len(partial),handler return i,len(partial),handler
return -1,None,None return -1,None,None
def handleMsg(self,msg): def handleMsg(self, msg):
""" Handle a message using the registered handlers. """
i,prefixLength,handler = self._findHandlerInd(msg) i,prefixLength,handler = self._findHandlerInd(msg)
if i < 0: if i < 0:
print '[Trainer] Unhandled message:',msg[0:2] print '[Trainer] Unhandled message:',msg[0:2]
...@@ -293,9 +314,11 @@ class Trainer(object): ...@@ -293,9 +314,11 @@ class Trainer(object):
handler(msg[prefixLength:]) handler(msg[prefixLength:])
def ignoreMsg(self,*args,**kwargs): def ignoreMsg(self,*args,**kwargs):
""" Ignore a certain type of message. """
self.registerMsgHandler(lambda x: None,*args,**kwargs) self.registerMsgHandler(lambda x: None,*args,**kwargs)
def _handleSP(self,body): def _handleSP(self, body):
""" Handler for the sever params message. """
for param in body: for param in body:
try: try:
val = int(param[1]) val = int(param[1])
...@@ -307,19 +330,25 @@ class Trainer(object): ...@@ -307,19 +330,25 @@ class Trainer(object):
self._SP[param[0]] = val self._SP[param[0]] = val
def listenAndProcess(self): def listenAndProcess(self):
""" Gather messages and process them. """
msg = self.recv() msg = self.recv()
assert((msg[0] == '(') and (msg[-1] == ')')),'|%s|' % msg assert((msg[0] == '(') and (msg[-1] == ')')),'|%s|' % msg
msg = self.parseMsg(msg) msg = self.parseMsg(msg)
self.handleMsg(msg) self.handleMsg(msg)
def _readTeamNames(self,body): def _readTeamNames(self,body):
""" Read the names of each of the teams. """
self._teams = [] self._teams = []
for _,_,team in body: for _,_,team in body:
self._teams.append(team) self._teams.append(team)
time.sleep(0.1) time.sleep(0.1)
self.send('(team_names)') self.send('(team_names)')
def waitOnTeam(self,first): def waitOnTeam(self, first):
"""Wait on a given team. First indicates if this is the first team
connected or the second.
"""
self.send('(team_names)') self.send('(team_names)')
partial = ['ok','team_names'] partial = ['ok','team_names']
self.registerMsgHandler(self._readTeamNames,*partial,quiet=True) self.registerMsgHandler(self._readTeamNames,*partial,quiet=True)
...@@ -329,6 +358,7 @@ class Trainer(object): ...@@ -329,6 +358,7 @@ class Trainer(object):
self.ignoreMsg(*partial,quiet=True) self.ignoreMsg(*partial,quiet=True)
def checkIfAllPlayersConnected(self): def checkIfAllPlayersConnected(self):
""" Returns true if all players are connected. """
self.send('(look)') self.send('(look)')
partial = ['ok','look'] partial = ['ok','look']
self._numPlayers = 0 self._numPlayers = 0
...@@ -363,6 +393,7 @@ class Trainer(object): ...@@ -363,6 +393,7 @@ class Trainer(object):
totalHeld += 1 totalHeld += 1
else: else:
self._ballHeld[i,self.teamToInd(team)] = 0 self._ballHeld[i,self.teamToInd(team)] = 0
# If multiple players are close to the ball, no-one is holding
if totalHeld > 1: if totalHeld > 1:
self._ballHeld[:,:] = 0 self._ballHeld[:,:] = 0
inds = numpy.transpose((self._ballHeld >= self.NUM_FRAMES_TO_HOLD).nonzero()) inds = numpy.transpose((self._ballHeld >= self.NUM_FRAMES_TO_HOLD).nonzero())
...@@ -524,6 +555,7 @@ class Trainer(object): ...@@ -524,6 +555,7 @@ class Trainer(object):
result = 'Defense Captured' result = 'Defense Captured'
elif self._frame - self._lastFrameBallTouched > self.UNTOUCHED_LENGTH: elif self._frame - self._lastFrameBallTouched > self.UNTOUCHED_LENGTH:
self._lastFrameBallTouched = self._frame self._lastFrameBallTouched = self._frame
self._numOutOfTime += 1
result = 'Ball untouched for too long' result = 'Ball untouched for too long'
else: else:
print '[Trainer] Error: Unable to detect reason for End of Trial!' print '[Trainer] Error: Unable to detect reason for End of Trial!'
...@@ -558,6 +590,7 @@ class Trainer(object): ...@@ -558,6 +590,7 @@ class Trainer(object):
print 'Goals : %i' % self._numGoals print 'Goals : %i' % self._numGoals
print 'Defense Captured : %i' % self._numBallsCaptured print 'Defense Captured : %i' % self._numBallsCaptured
print 'Balls Out of Bounds: %i' % self._numBallsOOB print 'Balls Out of Bounds: %i' % self._numBallsOOB
print 'Out of Time : %i' % self._numOutOfTime
def checkLive(self, necProcesses): def checkLive(self, necProcesses):
"""Returns true if each of the necessary processes is still alive and """Returns true if each of the necessary processes is still alive and
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment