Commit 5394bdce authored by Matthew Hausknecht's avatar Matthew Hausknecht Committed by GitHub

Merge pull request #57 from DurgeshSamant/master

HLFS + Sarsa Agent
parents fe899441 bca264cc
...@@ -8,10 +8,6 @@ import sys, os ...@@ -8,10 +8,6 @@ import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'sarsa_libraries','python_wrapper')) sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'sarsa_libraries','python_wrapper'))
from py_wrapper import * from py_wrapper import *
NA=0 #Number of actions
NOT=0 #Number of teammates
NF=0 #Number of features
def getReward(s): def getReward(s):
reward=0 reward=0
#--------------------------- #---------------------------
...@@ -44,7 +40,7 @@ def purge_features(state): ...@@ -44,7 +40,7 @@ def purge_features(state):
tmpIndex= 9 + 3*NOT tmpIndex= 9 + 3*NOT
for i in range(len(state)): for i in range(len(state)):
# Ignore first six features and teammate proximity to opponent(when opponent is absent)and opponent features # Ignore first six features and teammate proximity to opponent(when opponent is absent)and opponent features
if(i < 6 or i>9+6*NOT or (args.numOpponents==0 and ((i>9+numTMates and i<=9+2*numTMates) or i==9)) ): if(i < 6 or i>9+6*NOT or (NOO==0 and ((i>9+NOT and i<=9+2*NOT) or i==9)) ):
continue; continue;
#Ignore Angle and Uniform Number of Teammates #Ignore Angle and Uniform Number of Teammates
temp = i-tmpIndex; temp = i-tmpIndex;
...@@ -71,7 +67,8 @@ if __name__ == '__main__': ...@@ -71,7 +67,8 @@ if __name__ == '__main__':
hfo = HFOEnvironment() hfo = HFOEnvironment()
#now connect to the server #now connect to the server
hfo.connectToServer(HIGH_LEVEL_FEATURE_SET,'bin/teams/base/config/formations-dt',args.port,'localhost','base_left',False) hfo.connectToServer(HIGH_LEVEL_FEATURE_SET,'bin/teams/base/config/formations-dt',args.port,'localhost','base_left',False)
global NF,NA, NOT global NF,NA,NOT,NOO
NOO=args.numOpponents
if args.numOpponents >0: if args.numOpponents >0:
NF=4+4*args.numTeammates NF=4+4*args.numTeammates
else: else:
...@@ -85,7 +82,7 @@ if __name__ == '__main__': ...@@ -85,7 +82,7 @@ if __name__ == '__main__':
Min=[-1]*NF Min=[-1]*NF
Res=[resolution]*NF Res=[resolution]*NF
#Sarsa Agent Parameters #Sarsa Agent Parameters
wt_filename="weights_"+str(NOT+1)+"v"+str(args.numOpponents)+'_'+str(args.suffix) wt_filename="weights_"+str(NOT+1)+"v"+str(NOO)+'_'+str(args.suffix)
discFac=1 discFac=1
Lambda=0 Lambda=0
eps=0.01 eps=0.01
......
...@@ -37,7 +37,7 @@ output_path=$agent_path ...@@ -37,7 +37,7 @@ output_path=$agent_path
agent_filename="high_level_sarsa_agent.py" agent_filename="high_level_sarsa_agent.py"
#start the server #start the server
stdbuf -oL ./bin/HFO --port=$port --no-logging --offense-agents=$oa --defense-npcs=$da --trials=$trials --defense-team=base --headless --fullstate > $log_dir/"$oa"v"$da""_sarsa_py_agents.log" & stdbuf -oL ./bin/HFO --port=$port --no-logging --offense-agents=$oa --defense-npcs=$da --trials=$trials --defense-team=base --fullstate --headless > $log_dir/"$oa"v"$da""_sarsa_py_agents.log" &
#each agent is a seperate process #each agent is a seperate process
for n in $(seq 1 $oa) for n in $(seq 1 $oa)
...@@ -48,7 +48,7 @@ do ...@@ -48,7 +48,7 @@ do
fname+=".txt" fname+=".txt"
logfile=$log_dir/$fname logfile=$log_dir/$fname
rm $logfile rm $logfile
$python $agent_path/$agent_filename --port=$port --numTeammates=`expr $oa - 1` --numOpponents=$da --numEpisodes=$trials &> $log_dir/$fname & $python $agent_path/$agent_filename --port=$port --numTeammates=`expr $oa - 1` --numOpponents=$da --numEpisodes=$trials --suffix=$n &> $log_dir/$fname &
done done
# The magic line # The magic line
......
...@@ -30,8 +30,8 @@ HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm, ...@@ -30,8 +30,8 @@ HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm,
const SelfObject& self = wm.self(); const SelfObject& self = wm.self();
const Vector2D& self_pos = self.pos(); const Vector2D& self_pos = self.pos();
const float self_ang = self.body().radian(); const float self_ang = self.body().radian();
const PlayerCont& teammates = wm.teammates(); const PlayerPtrCont& teammates = wm.teammatesFromSelf();
const PlayerCont& opponents = wm.opponents(); const PlayerPtrCont& opponents = wm.opponentsFromSelf();
float maxR = sqrtf(SP.pitchHalfLength() * SP.pitchHalfLength() float maxR = sqrtf(SP.pitchHalfLength() * SP.pitchHalfLength()
+ SP.pitchHalfWidth() * SP.pitchHalfWidth()); + SP.pitchHalfWidth() * SP.pitchHalfWidth());
// features about self pos // features about self pos
...@@ -89,10 +89,10 @@ HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm, ...@@ -89,10 +89,10 @@ HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm,
// Features[9 - 9+T]: teammate's open angle to goal // Features[9 - 9+T]: teammate's open angle to goal
int detected_teammates = 0; int detected_teammates = 0;
for (PlayerCont::const_iterator it=teammates.begin(); it != teammates.end(); ++it) { for (PlayerPtrCont::const_iterator it=teammates.begin(); it != teammates.end(); ++it) {
const PlayerObject& teammate = *it; const PlayerObject* teammate = *it;
if (valid(teammate) && detected_teammates < numTeammates) { if (valid(teammate) && detected_teammates < numTeammates) {
addNormFeature(calcLargestGoalAngle(wm, teammate.pos()), 0, M_PI); addNormFeature(calcLargestGoalAngle(wm, teammate->pos()), 0, M_PI);
detected_teammates++; detected_teammates++;
} }
} }
...@@ -104,10 +104,10 @@ HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm, ...@@ -104,10 +104,10 @@ HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm,
// Features[9+T - 9+2T]: teammates' dists to closest opps // Features[9+T - 9+2T]: teammates' dists to closest opps
if (numOpponents > 0) { if (numOpponents > 0) {
detected_teammates = 0; detected_teammates = 0;
for (PlayerCont::const_iterator it=teammates.begin(); it != teammates.end(); ++it) { for (PlayerPtrCont::const_iterator it=teammates.begin(); it != teammates.end(); ++it) {
const PlayerObject& teammate = *it; const PlayerObject* teammate = *it;
if (valid(teammate) && detected_teammates < numTeammates) { if (valid(teammate) && detected_teammates < numTeammates) {
calcClosestOpp(wm, teammate.pos(), th, r); calcClosestOpp(wm, teammate->pos(), th, r);
addNormFeature(r, 0, maxR); addNormFeature(r, 0, maxR);
detected_teammates++; detected_teammates++;
} }
...@@ -124,10 +124,10 @@ HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm, ...@@ -124,10 +124,10 @@ HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm,
// Features [9+2T - 9+3T]: open angle to teammates // Features [9+2T - 9+3T]: open angle to teammates
detected_teammates = 0; detected_teammates = 0;
for (PlayerCont::const_iterator it=teammates.begin(); it != teammates.end(); ++it) { for (PlayerPtrCont::const_iterator it=teammates.begin(); it != teammates.end(); ++it) {
const PlayerObject& teammate = *it; const PlayerObject* teammate = *it;
if (valid(teammate) && detected_teammates < numTeammates) { if (valid(teammate) && detected_teammates < numTeammates) {
addNormFeature(calcLargestTeammateAngle(wm, self_pos, teammate.pos()),0,M_PI); addNormFeature(calcLargestTeammateAngle(wm, self_pos, teammate->pos()),0,M_PI);
detected_teammates++; detected_teammates++;
} }
} }
...@@ -138,16 +138,16 @@ HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm, ...@@ -138,16 +138,16 @@ HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm,
// Features [9+3T - 9+6T]: x, y, unum of teammates // Features [9+3T - 9+6T]: x, y, unum of teammates
detected_teammates = 0; detected_teammates = 0;
for (PlayerCont::const_iterator it=teammates.begin(); it != teammates.end(); ++it) { for (PlayerPtrCont::const_iterator it=teammates.begin(); it != teammates.end(); ++it) {
const PlayerObject& teammate = *it; const PlayerObject* teammate = *it;
if (valid(teammate) && detected_teammates < numTeammates) { if (valid(teammate) && detected_teammates < numTeammates) {
if (playingOffense) { if (playingOffense) {
addNormFeature(teammate.pos().x, -tolerance_x, SP.pitchHalfLength() + tolerance_x); addNormFeature(teammate->pos().x, -tolerance_x, SP.pitchHalfLength() + tolerance_x);
} else { } else {
addNormFeature(teammate.pos().x, -SP.pitchHalfLength()-tolerance_x, tolerance_x); addNormFeature(teammate->pos().x, -SP.pitchHalfLength()-tolerance_x, tolerance_x);
} }
addNormFeature(teammate.pos().y, -tolerance_y - SP.pitchHalfWidth(), SP.pitchHalfWidth() + tolerance_y); addNormFeature(teammate->pos().y, -tolerance_y - SP.pitchHalfWidth(), SP.pitchHalfWidth() + tolerance_y);
addFeature(teammate.unum()); addFeature(teammate->unum());
detected_teammates++; detected_teammates++;
} }
} }
...@@ -160,16 +160,16 @@ HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm, ...@@ -160,16 +160,16 @@ HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm,
// Features [9+6T - 9+6T+3O]: x, y, unum of opponents // Features [9+6T - 9+6T+3O]: x, y, unum of opponents
int detected_opponents = 0; int detected_opponents = 0;
for (PlayerCont::const_iterator it = opponents.begin(); it != opponents.end(); ++it) { for (PlayerPtrCont::const_iterator it = opponents.begin(); it != opponents.end(); ++it) {
const PlayerObject& opponent = *it; const PlayerObject* opponent = *it;
if (valid(opponent) && detected_opponents < numOpponents) { if (valid(opponent) && detected_opponents < numOpponents) {
if (playingOffense) { if (playingOffense) {
addNormFeature(opponent.pos().x, -tolerance_x, SP.pitchHalfLength() + tolerance_x); addNormFeature(opponent->pos().x, -tolerance_x, SP.pitchHalfLength() + tolerance_x);
} else { } else {
addNormFeature(opponent.pos().x, -SP.pitchHalfLength()-tolerance_x, tolerance_x); addNormFeature(opponent->pos().x, -SP.pitchHalfLength()-tolerance_x, tolerance_x);
} }
addNormFeature(opponent.pos().y, -tolerance_y - SP.pitchHalfWidth(), SP.pitchHalfWidth() + tolerance_y); addNormFeature(opponent->pos().y, -tolerance_y - SP.pitchHalfWidth(), SP.pitchHalfWidth() + tolerance_y);
addFeature(opponent.unum()); addFeature(opponent->unum());
detected_opponents++; detected_opponents++;
} }
} }
...@@ -190,3 +190,13 @@ HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm, ...@@ -190,3 +190,13 @@ HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm,
// checkFeatures(); // checkFeatures();
return feature_vec; return feature_vec;
} }
bool HighLevelFeatureExtractor::valid(const rcsc::PlayerObject* player) {
if (!player) {return false;} //avoid segfaults
const rcsc::Vector2D& pos = player->pos();
if (!player->posValid()) {
return false;
}
return pos.isValid();
}
...@@ -23,6 +23,10 @@ public: ...@@ -23,6 +23,10 @@ public:
virtual const std::vector<float>& ExtractFeatures(const rcsc::WorldModel& wm, virtual const std::vector<float>& ExtractFeatures(const rcsc::WorldModel& wm,
bool last_action_status); bool last_action_status);
//override FeatureExtractor::valid
//this method takes a pointer instead of a reference
static bool valid(const rcsc::PlayerObject* player);
protected: protected:
// Number of features for non-player objects. // Number of features for non-player objects.
const static int num_basic_features = 10; const static int num_basic_features = 10;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment