Commit f8ec967c authored by Matthew Hausknecht's avatar Matthew Hausknecht

removed most of trainer code.

parent 23696f8a
......@@ -62,7 +62,7 @@ class Communicator(object):
raise TimeoutError
else:
retryCount -= 1
print '[Trainer] waiting for message'
print '[Trainer] waiting for message, retry =', retryCount
time.sleep(0.3)
#raise ValueError('Error while receiving message')
(msg,sep,rest) = msg.partition('\0')
......
......@@ -49,11 +49,12 @@ def main(args, team1='left', team2='right', rng=numpy.random.RandomState()):
'server::game_log_dir=%s server::text_log_dir=%s '\
'server::synch_mode=%i server::hfo=1 ' \
'server::fullstate_l=%i server::fullstate_r=%i ' \
'server::coach_w_referee=on server::record_messages=on' \
'server::coach_w_referee=1' \
%(server_port, coach_port, olcoach_port,
args.logging, args.logging, args.logging,
args.logDir, args.logDir, args.logDir,
args.sync, args.fullstate, args.fullstate)
# server::record_messages=on -- useful for debug
try:
# Launch the Server
server = launch(serverCommand + serverOptions, name='server')
......@@ -67,8 +68,7 @@ def main(args, team1='left', team2='right', rng=numpy.random.RandomState()):
launch(monitorCommand + monitorOptions, name='monitor')
# Launch the Trainer
from Trainer import Trainer
trainer = Trainer(args=args, rng=rng, server_port=server_port,
coach_port=coach_port)
trainer = Trainer(args=args, server_port=server_port, coach_port=coach_port)
trainer.initComm()
# Add Team1
trainer.addTeam(team1)
......
......@@ -14,33 +14,18 @@ class DoneError(Exception):
class Trainer(object):
""" Trainer is responsible for setting up the players and game.
"""
def __init__(self, args, rng=numpy.random.RandomState(), server_port=6001,
coach_port=6002):
self._rng = rng # The Random Number Generator
def __init__(self, args, server_port=6001, coach_port=6002):
self._serverPort = server_port # The port the server is listening on
self._coachPort = coach_port # The coach port to talk with the server
self._logDir = args.logDir # Directory to store logs
self._record = args.record # Record states + actions
self._numOffense = args.offenseAgents + args.offenseNPCs # Number offensive players
self._numDefense = args.defenseAgents + args.defenseNPCs # Number defensive players
self._maxTrials = args.numTrials # Maximum number of trials to play
self._maxFrames = args.numFrames # Maximum number of frames to play
self._maxFramesPerTrial = args.maxFramesPerTrial
# =============== FIELD DIMENSIONS =============== #
self.NUM_FRAMES_TO_HOLD = 2 # Hold ball this many frames to capture
self.HOLD_FACTOR = 1.5 # Gain to calculate ball control
self.PITCH_WIDTH = 68.0 # Width of the field
self.PITCH_LENGTH = 105.0 # Length of field in long-direction
self.UNTOUCHED_LENGTH = 100 # Trial will end if ball untouched for this long
# allowedBallX, allowedBallY defines the usable area of the playfield
self._allowedBallX = numpy.array([-0.1, 0.5 * self.PITCH_LENGTH])
self._allowedBallY = numpy.array([-0.5 * self.PITCH_WIDTH, 0.5 * self.PITCH_WIDTH])
# =============== COUNTERS =============== #
self._numFrames = 0 # Number of frames seen in HFO trials
self._numGoalFrames = 0 # Number of frames in goal-scoring HFO trials
self._frame = 0 # Current frame id
self._lastTrialStart = -1 # Frame Id in which the last trial started
self._lastFrameBallTouched = -1 # Frame Id in which ball was last touched
# =============== TRIAL RESULTS =============== #
self._numTrials = 0 # Total number of HFO trials
self._numGoals = 0 # Trials in which the offense scored a goal
......@@ -59,20 +44,14 @@ class Trainer(object):
# =============== MISC =============== #
self._offenseTeamName = '' # Name of the offensive team
self._defenseTeamName = '' # Name of the defensive team
self._playerPositions = numpy.zeros((11,2,2)) # Positions of the players
self._ballPosition = numpy.zeros(2) # Position of the ball
self._ballHeld = numpy.zeros((11,2)) # Track player holding the ball
self._teams = [] # Team indexes for offensive and defensive teams
self._SP = {} # Sever Parameters. Recieved when connecting to the server.
self._isPlaying = False # Is a game being played?
self._teamHoldingBall = None # Team currently in control of the ball
self._playerHoldingBall = None # Player current in control of ball
self._agentPopen = [] # Agent's processes
self._npcPopen = [] # NPC's processes
self._connectedPlayers = []
self.initMsgHandlers()
def launch_player(self, player_num, play_offense):
def launch_npc(self, player_num, play_offense, wait_until_join=True):
"""Launches a player using sample_player binary
Returns a Popen process object
......@@ -96,9 +75,11 @@ class Trainer(object):
kwargs = {'stdout':open('/dev/null', 'w'),
'stderr':open('/dev/null', 'w')}
p = subprocess.Popen(player_cmd.split(' '), shell = False, **kwargs)
if wait_until_join:
self.waitOnPlayer(player_num, play_offense)
return p
def launch_agent(self, agent_num, play_offense, port):
def launch_agent(self, agent_num, agent_ext_num, play_offense, port, wait_until_join=True):
"""Launches a learning agent using the agent binary
Returns a Popen process object
......@@ -122,7 +103,6 @@ class Trainer(object):
self._agentNumInt.append(internal_player_num)
numTeammates = self._numDefense - 1
numOpponents = self._numOffense
ext_num = self.convertToExtPlayer(team_name, internal_player_num)
binary_dir = os.path.dirname(os.path.realpath(__file__))
config_dir = os.path.join(binary_dir, '../config/formations-dt')
player_conf = os.path.join(binary_dir, '../config/player.conf')
......@@ -133,13 +113,15 @@ class Trainer(object):
%(team_name, self._serverPort, numTeammates,
numOpponents, play_offense, port, self._logDir,
player_conf, config_dir)
if ext_num == 1:
if agent_ext_num == 1:
agent_cmd += ' -g'
if self._record:
agent_cmd += ' --record'
kwargs = {'stdout':open('/dev/null', 'w'),
'stderr':open('/dev/null', 'w')}
p = subprocess.Popen(agent_cmd.split(' '), shell = False, **kwargs)
if wait_until_join:
self.waitOnPlayer(agent_ext_num, play_offense)
return p
def getDefensiveRoster(self, team_name):
......@@ -225,8 +207,32 @@ class Trainer(object):
# self.send('(eye on)')
self.send('(ear on)')
def _hearRef(self, body):
""" Handles hear messages from referee. """
assert body[0] == 'referee', 'Expected referee message.'
_,ts,event = body
self._frame = int(ts)
if event == 'GOAL':
self._numGoals += 1
self._numGoalFrames += self._frame - self._lastTrialStart
elif event == 'OUT_OF_BOUNDS':
self._numBallsOOB += 1
elif event == 'CAPTURED_BY_DEFENSE':
self._numBallsCaptured += 1
elif event == 'OUT_OF_TIME':
self._numOutOfTime += 1
if event in {'GOAL','OUT_OF_BOUNDS','CAPTURED_BY_DEFENSE','OUT_OF_TIME'}:
self._numTrials += 1
print '[Trainer] EndOfTrial: %d / %d %d %s'%\
(self._numGoals, self._numTrials, self._frame, event)
self._numFrames += self._frame - self._lastTrialStart
self._lastTrialStart = self._frame
def _hear(self, body):
""" Handle a hear message. """
if body[0] == 'referee':
self._hearRef(body)
return
timestep,playerInfo,msg = body
try:
_,team,player = playerInfo[:3]
......@@ -242,7 +248,7 @@ class Trainer(object):
elif msg == 'DONE':
raise DoneError
elif msg == 'ready':
print '[Trainer] Agent Ready:', team, player
print '[Trainer] Agent Connected:', team, player
self._agentReady.add((team, player))
else:
print '[Trainer] Unhandled message from agent: %s' % msg
......@@ -257,7 +263,7 @@ class Trainer(object):
self.ignoreMsg('ok','move')
self.ignoreMsg('ok','recover')
self.ignoreMsg('ok','say')
self.registerMsgHandler(self._handleSP,'server_param')
self.ignoreMsg('server_param')
self.registerMsgHandler(self._hear,'hear')
def recv(self, retryCount=None):
......@@ -278,10 +284,6 @@ class Trainer(object):
# print >>sys.stderr,len(expectedMsg),len(msg)
raise ValueError
def extractPoint(self, msg):
""" Extract a point from the provided message. """
return numpy.array(map(float,msg[:2]))
def convertToExtPlayer(self, team, num):
""" Returns the external player number for a given player. """
assert team == self._offenseTeamName or team == self._defenseTeamName,\
......@@ -292,33 +294,6 @@ class Trainer(object):
else:
return self._defenseOrder[num]
def convertFromExtPlayer(self, team, num):
""" Maps external player number to internal player number. """
if team == self._offenseTeamName:
return self._offenseOrder.index(num)
else:
return self._defenseOrder.index(num)
def seeGlobal(self, body):
"""Send a look message to extract global information on ball and
player positions.
"""
self.send('(look)')
self._frame = int(body[0])
for obj in body[1:]:
objType = obj[0]
objData = obj[1:]
if objType[0] == 'g':
continue
elif objType[0] == 'b':
self._ballPosition = self.extractPoint(objData)
elif objType[0] == 'p':
teamName = objType[1]
team = self.teamToInd(teamName)
playerNum = self.convertFromExtPlayer(teamName,int(objType[2]))
self._playerPositions[playerNum,:,team] = self.extractPoint(objData)
def registerMsgHandler(self,handler,*args,**kwargs):
'''Register a message handler.
......@@ -361,18 +336,6 @@ class Trainer(object):
""" Ignore a certain type of message. """
self.registerMsgHandler(lambda x: None,*args,**kwargs)
def _handleSP(self, body):
""" Handler for the sever params message. """
for param in body:
try:
val = int(param[1])
except:
try:
val = float(param[1])
except:
val = param[1]
self._SP[param[0]] = val
def listenAndProcess(self, retry_count=None):
""" Gather messages and process them. """
msg = self.recv(retry_count)
......@@ -380,25 +343,21 @@ class Trainer(object):
msg = self.parseMsg(msg)
self.handleMsg(msg)
def _readTeamNames(self,body):
""" Read the names of each of the teams. """
self._teams = []
for _,_,team in body:
self._teams.append(team)
time.sleep(0.1)
self.send('(team_names)')
def waitOnTeam(self, first):
"""Wait on a given team. First indicates if this is the first team
connected or the second.
"""
self.send('(team_names)')
partial = ['ok','team_names']
self.registerMsgHandler(self._readTeamNames,*partial,quiet=True)
while len(self._teams) < (1 if first else 2):
def waitForDisconnect(self, player_num, on_offense):
"""Wait on a launched player to disconnect from the server. """
self.send('(look)')
partial = ['ok','look']
self._numPlayers = 0
def f(body):
for i in xrange(4, len(body)):
_,team,num = body[i][0][:3]
if (team, num) in self._connectedPlayers:
self._connectedPlayers.remove((team,num))
self.registerMsgHandler(f,*partial,quiet=True)
team_name = self._offenseTeamName if on_offense else self._defenseTeamName
while (team_name, str(player_num)) in self._connectedPlayers:
self.listenAndProcess()
#self.unregisterMsgHandler(*partial)
self.send('(look)')
self.ignoreMsg(*partial,quiet=True)
def waitOnPlayer(self, player_num, on_offense):
......@@ -406,6 +365,7 @@ class Trainer(object):
server.
"""
print 'Wait on player', player_num, on_offense
self.send('(look)')
partial = ['ok','look']
self._numPlayers = 0
......@@ -436,210 +396,8 @@ class Trainer(object):
def startGame(self):
""" Starts a game of HFO. """
self.reset()
self.registerMsgHandler(self.seeGlobal, 'see_global')
self.registerMsgHandler(self.seeGlobal, 'ok', 'look', quiet=True)
#self.registerMsgHandler(self.checkBall,'ok','check_ball')
self.send('(look)')
self._isPlaying = True
def calcBallHolder(self):
'''Calculates the ball holder, returns results in teamInd, playerInd. '''
totalHeld = 0
for team in self._teams:
for i in range(11):
pos = self._playerPositions[i,:,self.teamToInd(team)]
distBound = self._SP['kickable_margin'] + self._SP['player_size'] \
+ self._SP['ball_size']
distBound *= self.HOLD_FACTOR
if numpy.linalg.norm(self._ballPosition - pos) < distBound:
self._ballHeld[i,self.teamToInd(team)] += 1
totalHeld += 1
else:
self._ballHeld[i,self.teamToInd(team)] = 0
# If multiple players are close to the ball, no-one is holding
if totalHeld > 1:
self._ballHeld[:,:] = 0
inds = numpy.transpose((self._ballHeld >= self.NUM_FRAMES_TO_HOLD).nonzero())
assert(len(inds) <= 1)
if len(inds) == 1:
return inds[0,1],inds[0,0]
else:
return None,None
def isGoal(self):
""" Returns true if a goal has been scored. """
return (self._ballPosition[0] > self._allowedBallX[1]) \
and (numpy.abs(self._ballPosition[1]) <= 0.5 * self._SP['goal_width'])
def isOOB(self):
""" Returns true if the ball is out of bounds. """
return self._ballPosition[0] < self._allowedBallX[0] \
or self._ballPosition[0] > self._allowedBallX[1] \
or self._ballPosition[1] < self._allowedBallY[0] \
or self._ballPosition[1] > self._allowedBallY[1]
def isCaptured(self):
""" Returns true if the ball is captured by defense. """
return self._teamHoldingBall not in [None,self._offenseTeamInd]
def isOOT(self):
""" Returns true if the trial has run out of time. """
return self._frame - self._lastFrameBallTouched > self.UNTOUCHED_LENGTH \
or (self._maxFramesPerTrial > 0 and self._frame -
self._lastTrialStart > self._maxFramesPerTrial)
def movePlayer(self, team, internal_num, pos, convertToExt=True):
""" Move a player to a specified position.
Args:
team: the team name of the player
interal_num: the player's internal number
pos: position to move player to
convertToExt: convert interal player num to external
"""
num = self.convertToExtPlayer(team, internal_num) if convertToExt \
else internal_num
self.send('(move (player %s %i) %f %f)' % (team, num, pos[0], pos[1]))
def moveBall(self, pos):
""" Moves the ball to a specified x,y position. """
self.send('(move (ball) %f %f 0.0 0.0 0.0)' % tuple(pos))
def randomPointInBounds(self, xBounds=None, yBounds=None):
"""Returns a random point inside of the box defined by xBounds,
yBounds. Where xBounds=[x_min, x_max] and yBounds=[y_min,
y_max]. Defaults to the xy-bounds of the playable HFO area.
"""
if xBounds is None:
xBounds = self.allowedBallX
if yBounds is None:
yBounds = self.allowedBallY
pos = numpy.zeros(2)
bounds = [xBounds, yBounds]
for i in range(2):
pos[i] = self._rng.rand() * (bounds[i][1] - bounds[i][0]) + bounds[i][0]
return pos
def boundPoint(self, pos):
"""Ensures a point is within the minimum and maximum bounds of the
HFO playing area.
"""
pos[0] = min(max(pos[0], self._allowedBallX[0]), self._allowedBallX[1])
pos[1] = min(max(pos[1], self._allowedBallY[0]), self._allowedBallY[1])
return pos
def reset(self):
""" Resets the HFO domain by moving the ball and players. """
self.resetBallPosition()
self.resetPlayerPositions()
self.send('(recover)')
self.send('(change_mode play_on)')
# self.send('(say RESET)')
def resetBallPosition(self):
"""Reset the position of the ball for a new HFO trial. """
self._ballPosition = self.boundPoint(self.randomPointInBounds(
.2*self._allowedBallX+.05*self.PITCH_LENGTH, .8*self._allowedBallY))
self.moveBall(self._ballPosition)
def getOffensiveResetPosition(self):
""" Returns a random position for an offensive player. """
offsets = [
[-1,-1],
[-1,1],
[1,1],
[1,-1],
[0,2],
[0,-2],
[-2,-2],
[-2,2],
[2,2],
[2,-2],
]
offset = offsets[self._rng.randint(len(offsets))]
offset_from_ball = 0.1 * self.PITCH_LENGTH * self._rng.rand(2) + \
0.1 * self.PITCH_LENGTH * numpy.array(offset)
return self.boundPoint(self._ballPosition + offset_from_ball)
# return self._ballPosition
def getDefensiveResetPosition(self):
""" Returns a random position for a defensive player. """
return self.boundPoint(self.randomPointInBounds(
[0.5 * 0.5 * self.PITCH_LENGTH, 0.75 * 0.5 * self.PITCH_LENGTH],
0.8 * self._allowedBallY))
def resetPlayerPositions(self):
"""Reset the positions of the players. This is called after a trial
ends to setup for the next trial.
"""
# Move the offense
for i in xrange(1, self._numOffense + 1):
self.movePlayer(self._offenseTeamName, i, self.getOffensiveResetPosition())
# Move the agent to the ball
if self._agentOnBall and self._offenseAgents > 0:
self.movePlayer(self._offenseTeamName, self._agentNumInt[0], self._ballPosition)
# Move the defensive goalie
if self._numDefense > 0:
self.movePlayer(self._defenseTeamName, 0, [0.5 * self.PITCH_LENGTH,0])
# Move the rest of the defense
for i in xrange(1, self._numDefense):
self.movePlayer(self._defenseTeamName, i, self.getDefensiveResetPosition())
def step(self):
""" Takes a simulated step. """
self._teamHoldingBall, self._playerHoldingBall = self.calcBallHolder()
if self._teamHoldingBall is not None:
self._lastFrameBallTouched = self._frame
if self.trialOver():
self.updateResults()
self._lastFrameBallTouched = self._frame
self.reset()
def updateResults(self):
""" Updates the various members after a trial has ended. """
if self.isGoal():
self._numGoals += 1
self._numGoalFrames += self._frame - self._lastTrialStart
result = 'Goal'
self.send('(say GOAL)')
elif self.isOOB():
self._numBallsOOB += 1
result = 'Out of Bounds'
self.send('(say OUT_OF_BOUNDS)')
elif self.isCaptured():
self._numBallsCaptured += 1
result = 'Defense Captured'
self.send('(say CAPTURED_BY_DEFENSE)')
elif self.isOOT():
self._numOutOfTime += 1
result = 'Ball untouched for too long'
self.send('(say OUT_OF_TIME)')
else:
print '[Trainer] Error: Unable to detect reason for End of Trial!'
sys.exit(1)
self._numTrials += 1
print '[Trainer] EndOfTrial: %d / %d %d %s'%\
(self._numGoals, self._numTrials, self._frame, result)
self._numFrames += self._frame - self._lastTrialStart
self._lastTrialStart = self._frame
if (self._maxTrials > 0) and (self._numTrials >= self._maxTrials):
raise DoneError
if (self._maxFrames > 0) and (self._numFrames >= self._maxFrames):
raise DoneError
def trialOver(self):
"""Returns true if the trial has ended for one of the following
reasons: Goal scored, out of bounds, captured by defense, or
untouched for too long.
"""
# The trial is still being setup, it cannot be over.
if self._frame - self._lastTrialStart < 5:
return False
return self.isGoal() or self.isOOB() or self.isCaptured() or self.isOOT()
self._isPlaying = True
def printStats(self):
print '[Trainer] TotalFrames = %i, AvgFramesPerTrial = %.1f, AvgFramesPerGoal = %.1f'\
......@@ -675,56 +433,54 @@ class Trainer(object):
# Launch offense
agent_num = 0
for player_num in xrange(1, 12):
if agent_num < self._offenseAgents and \
sorted_offense_agent_unums[agent_num] == player_num:
port = self._agentServerPort + agent_num
agent = self.launch_agent(agent_num, play_offense=True, port=port)
self._agentPopen.append(agent)
necProcesses.append([agent, 'offense_agent_' + str(agent_num)])
agent_num += 1
if agent_num < self._offenseAgents:
agent_ext_num = sorted_offense_agent_unums[agent_num]
if agent_ext_num == player_num:
port = self._agentServerPort + agent_num
agent = self.launch_agent(agent_num, agent_ext_num,
play_offense=True, port=port)
self._agentPopen.append(agent)
necProcesses.append([agent, 'offense_agent_' + str(agent_num)])
agent_num += 1
else:
player = self.launch_player(player_num, play_offense = True)
time.sleep(0.15)
player = self.launch_npc(player_num, play_offense=True)
if player_num in offense_unums:
self._npcPopen.append(player)
necProcesses.append([player, 'offense_npc_' + str(player_num)])
else:
player.terminate()
time.sleep(0.1)
continue
self.waitOnPlayer(player_num, on_offense=True)
self.waitOnTeam(first = False)
player.kill()
self.waitForDisconnect(player_num, on_offense=True)
# Launch defense
agent_num = 0
for player_num in xrange(1, 12):
if agent_num < self._defenseAgents and \
sorted_defense_agent_unums[agent_num] == player_num:
port = self._agentServerPort + agent_num + self._offenseAgents
agent = self.launch_agent(agent_num, play_offense=False, port=port)
self._agentPopen.append(agent)
necProcesses.append([agent, 'defense_agent_' + str(agent_num)])
agent_num += 1
if agent_num < self._defenseAgents:
agent_ext_num = sorted_offense_agent_unums[agent_num]
if agent_ext_num == player_num:
port = self._agentServerPort + agent_num + self._offenseAgents
agent = self.launch_agent(agent_num, agent_ext_num,
play_offense=False, port=port)
self._agentPopen.append(agent)
necProcesses.append([agent, 'defense_agent_' + str(agent_num)])
agent_num += 1
else:
player = self.launch_player(player_num, play_offense = False)
time.sleep(0.15)
player = self.launch_npc(player_num, play_offense=False)
if player_num in defense_unums:
self._npcPopen.append(player)
necProcesses.append([player, 'defense_npc_' + str(player_num)])
else:
player.terminate()
time.sleep(0.1)
continue
self.waitOnPlayer(player_num, on_offense=False)
self.waitOnTeam(first = False)
player.kill()
self.waitForDisconnect(player_num, on_offense=False)
self.checkIfAllPlayersConnected()
print '[Trainer] Agents awaiting your connections'
necOff = set([(self._offenseTeamName,str(x)) for x in sorted_offense_agent_unums])
necDef = set([(self._defenseTeamName,str(x)) for x in sorted_defense_agent_unums])
necAgents = necOff.union(necDef)
while self.checkLive(necProcesses) and self._agentReady != necAgents:
self.listenAndProcess(1000)
if self._numAgents > 0:
print '[Trainer] Agents awaiting your connections'
necOff = set([(self._offenseTeamName,str(x)) for x in sorted_offense_agent_unums])
necDef = set([(self._defenseTeamName,str(x)) for x in sorted_defense_agent_unums])
necAgents = necOff.union(necDef)
while self.checkLive(necProcesses) and self._agentReady != necAgents:
self.listenAndProcess(10)
# Broadcast the HFO configuration
offense_nums = ' '.join([str(self.convertToExtPlayer(self._offenseTeamName, i))
......@@ -736,13 +492,10 @@ class Trainer(object):
%(self._offenseTeamName, self._defenseTeamName,
self._numOffense, self._numDefense,
offense_nums, defense_nums))
print '[Trainer] Starting game'
self.startGame()
while self.checkLive(necProcesses):
prevFrame = self._frame
self.listenAndProcess()
# if self._frame != prevFrame:
# self.step()
except TimeoutError:
print '[Trainer] Haven\'t heard from the server for too long, Exiting'
except (KeyboardInterrupt, DoneError):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment