Commit b7aa50ea authored by Matthew Hausknecht's avatar Matthew Hausknecht Committed by GitHub

Merge pull request #51 from DurgeshSamant/sarsa_py_wrapper

python wrapper over sarsa libraries
parents 6ce247b5 f7b0086a
......@@ -11,6 +11,7 @@ addons:
- libboost-filesystem-dev
install:
- if [ "${TRAVIS_OS_NAME}" = "osx" ]; then
brew update ;
brew install cartr/qt4/qt
;
fi
......
#Flags
CXXFLAGS = -g -O3 -Wall
CXXFLAGS = -shared -g -O3 -Wall -fPIC
#Compiler
CXX = g++
#Sources
SRCS = FuncApprox.cpp tiles2.cpp CMAC.cpp
SRCS = FuncApprox.cpp tiles2.cpp CMAC.cpp
#Objects
OBJS = $(SRCS:.cpp=.o)
......@@ -18,11 +18,11 @@ TARGET = libfuncapprox.a
all: $(TARGET)
.cpp.o:
$(CXX) $(CXXFLAGS) -c -o $@ $(@F:%.o=%.cpp)
$(CXX) $(CXXFLAGS) -c -o $@ $(@F:%.o=%.cpp)
$(TARGET): $(OBJS)
ar cq $@ $(OBJS)
ar cq $@ $(OBJS);
clean:
rm -f $(TARGET) $(OBJS) *~
rm -f $(TARGET) $(OBJS) *~;
......@@ -5,7 +5,7 @@ FA_DIR = ../funcapprox
INCLUDES = -I$(FA_DIR)
#Flags
CXXFLAGS = -g -O3 -Wall
CXXFLAGS = -shared -g -O3 -Wall -fPIC
#Compiler
CXX = g++
......
#ifndef __FA_C_WRAPPER_H__
#define __FA_C_WRAPPER_H__
#include "CMAC.h"
#include<iostream>
extern "C" {
void* CMAC_new(int numF, int numA, double r[], double m[], double res[])
{
// std::cout<<"FA_C_WRAPPER: CMAC_new"<<std::endl;
CMAC *ca = new CMAC(numF, numA, r, m, res);
void *ptr = reinterpret_cast<void *>(ca);
return ptr;
}
}
#endif
#Directories
FA_DIR = ../../sarsa_libraries/funcapprox
POLICY_DIR = ../../sarsa_libraries/policy
HFO_SRC_DIR = ../../../src
HFO_LIB_DIR = ../../../lib
#Includes
INCLUDES = -I$(FA_DIR) -I$(POLICY_DIR) -I$(HFO_SRC_DIR)
#Libs
FA_LIB = funcapprox
POLICY_LIB = policyagent
#Flags
CXXFLAGS = -shared -g -Wall -std=c++11 -fPIC
LDFLAGS = -l$(FA_LIB) -l$(POLICY_LIB) -lhfo -pthread
LDLIBS = -L$(FA_DIR) -L$(POLICY_DIR) -L$(HFO_LIB_DIR)
LINKEROPTIONS = -Wl,-rpath,$(HFO_LIB_DIR)
#Compiler
CXX = g++
#Sources
SRC = FA_C_wrapper.cpp Policy_C_wrapper.cpp
#Objects
OBJ = $(SRC:.cpp=.o)
#Target
TARGET = C_wrappers.so
#Rules
.PHONY: $(FA_LIB)
all: $(TARGET)
.cpp.o:
$(CXX) $(CXXFLAGS) $(INCLUDES) -c -o $@ $(@F:%.o=%.cpp)
$(FA_LIB):
$(MAKE) -C $(FA_DIR)
$(POLICY_LIB):
$(MAKE) -C $(POLICY_DIR)
$(TARGET): $(FA_LIB) $(POLICY_LIB) $(OBJ)
$(CXX) $(OBJ) $(CXXFLAGS) $(LDLIBS) $(LDFLAGS) -o $(TARGET) $(LINKEROPTIONS)
cleanfa:
$(MAKE) clean -C $(FA_DIR)
cleanpolicy:
$(MAKE) clean -C $(POLICY_DIR)
clean: cleanfa cleanpolicy
rm -f $(TARGET) $(OBJ) *~
#ifndef __POLICY_C_WRAPPER_H__
#define __POLICY_C_WRAPPER_H__
#include "SarsaAgent.h"
#include "FuncApprox.h"
#include "CMAC.h"
#include<iostream>
extern "C" {
void* SarsaAgent_new(int numFeatures, int numActions, double learningRate, double epsilon, double lambda, void *FA, char *loadWeightsFile, char *saveWeightsFile)
{
CMAC *fa = reinterpret_cast<CMAC *>(FA);
SarsaAgent *sa=new SarsaAgent(numFeatures, numActions, learningRate, epsilon, lambda, fa, loadWeightsFile, saveWeightsFile);
void *ptr = reinterpret_cast<void *>(sa);
return ptr;
}
void SarsaAgent_update(void *ptr, double state[], int action, double reward, double discountFactor)
{
SarsaAgent *p = reinterpret_cast<SarsaAgent *>(ptr);
p->update(state,action,reward,discountFactor);
}
int SarsaAgent_selectAction(void *ptr, double state[])
{
SarsaAgent *p = reinterpret_cast<SarsaAgent *>(ptr);
int action=p->selectAction(state);
return action;
}
void SarsaAgent_endEpisode(void *ptr)
{
SarsaAgent *p = reinterpret_cast<SarsaAgent *>(ptr);
p->endEpisode();
}
}
#endif
from ctypes import *
import numpy as np
import getpass
import sys, os
isPy3=False
if sys.version_info[0] == 3:
isPy3=True
username=getpass.getuser()
libs = cdll.LoadLibrary(os.path.join(os.path.dirname(__file__),'C_wrappers.so'))
libs.CMAC_new.argtypes=[c_int,c_int,POINTER(c_double),POINTER(c_double),POINTER(c_double)]
libs.CMAC_new.restype=c_void_p
libs.SarsaAgent_new.argtypes=[c_int,c_int,c_double,c_double,c_double,c_void_p,c_char_p,c_char_p]
libs.SarsaAgent_new.restype=c_void_p
libs.SarsaAgent_update.argtypes=[c_void_p,POINTER(c_double),c_int,c_double,c_double]
libs.SarsaAgent_update.restype=None
libs.SarsaAgent_selectAction.argtypes=[c_void_p,POINTER(c_double)]
libs.SarsaAgent_selectAction.restype=c_int
libs.SarsaAgent_endEpisode.argtypes=[c_void_p]
libs.SarsaAgent_endEpisode.restype=None
class CMAC(object):
def __init__(self,numF,numA,r,m,res):
arr1 = (c_double * len(r))(*r)
arr2 = (c_double * len(m))(*m)
arr3 = (c_double * len(res))(*res)
self.obj = libs.CMAC_new(c_int(numF),c_int(numA),arr1,arr2,arr3)
#print(self.obj)
class SarsaAgent(object):
def __init__(self,numFeatures, numActions, learningRate, epsilon, Lambda, FA, loadWeightsFile, saveWeightsFile):
p1=c_int(numFeatures)
p2=c_int(numActions)
p3=c_double(learningRate)
p4=c_double(epsilon)
p5=c_double(Lambda)
p6=c_void_p(FA.obj)
if isPy3:
#utf-8 encoding required for python3
p7=c_char_p(loadWeightsFile.encode('utf-8'))
p8=c_char_p(saveWeightsFile.encode('utf-8'))
else:
#non encoded will do for python2
p7=c_char_p(loadWeightsFile)
p8=c_char_p(saveWeightsFile)
self.obj = libs.SarsaAgent_new(p1,p2,p3,p4,p5,p6,p7,p8)
#print(format(self.obj,'02x'))
def update(self,state,action,reward,discountFactor):
s = (c_double * len(state))(*state)
a = c_int(action)
r = c_double(reward)
df = c_double(discountFactor)
libs.SarsaAgent_update(c_void_p(self.obj),s,a,r,df)
#print(format(self.obj,'02x'))
def selectAction(self,state):
s = (c_double * len(state))(*state)
action=libs.SarsaAgent_selectAction(c_void_p(self.obj),s)
#print(action)
#print(format(self.obj,'02x'))
return int(action)
def endEpisode(self):
libs.SarsaAgent_endEpisode(c_void_p(self.obj))
#print(format(self.obj,'02x'))
#!/usr/bin/env python3
# encoding: utf-8
from hfo import *
import argparse
import numpy as np
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'sarsa_libraries','python_wrapper'))
from py_wrapper import *
NA=0 #Number of actions
NOT=0 #Number of teammates
NF=0 #Number of features
def getReward(s):
reward=0
#---------------------------
if s==GOAL:
reward=1
#---------------------------
elif s==CAPTURED_BY_DEFENSE:
reward=-1
#---------------------------
elif s==OUT_OF_BOUNDS:
reward=-1
#---------------------------
#Cause Unknown Do Nothing
elif s==OUT_OF_TIME:
reward=0
#---------------------------
elif s==IN_GAME:
reward=0
#---------------------------
elif s==SERVER_DOWN:
reward=0
#---------------------------
else:
print("Error: Unknown GameState", s)
return reward
def purge_features(state):
st=np.empty(NF,dtype=np.float64)
stateIndex=0
tmpIndex= 9 + 3*NOT
for i in range(len(state)):
# Ignore first six features and teammate proximity to opponent(when opponent is absent)and opponent features
if(i < 6 or i>9+6*NOT or (args.numOpponents==0 and ((i>9+numTMates and i<=9+2*numTMates) or i==9)) ):
continue;
#Ignore Angle and Uniform Number of Teammates
temp = i-tmpIndex;
if(temp > 0 and (temp % 3 == 2 or temp % 3 == 0)):
continue;
if (i > 9+6*NOT):
continue;
st[stateIndex] = state[i];
stateIndex+=1;
return st
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--port', type=int, default=6000)
parser.add_argument('--numTeammates', type=int, default=0)
parser.add_argument('--numOpponents', type=int, default=1)
parser.add_argument('--numEpisodes', type=int, default=1)
parser.add_argument('--learnRate', type=float, default=0.1)
parser.add_argument('--suffix', type=int, default=0)
args=parser.parse_args()
# Create the HFO Environment
hfo = HFOEnvironment()
#now connect to the server
hfo.connectToServer(HIGH_LEVEL_FEATURE_SET,'bin/teams/base/config/formations-dt',args.port,'localhost','base_left',False)
global NF,NA, NOT
if args.numOpponents >0:
NF=4+4*args.numTeammates
else:
NF=3+3*args.numTeammates
NOT=args.numTeammates
NA=NOT+2 #PASS to each teammate, SHOOT, DRIBBLE
learnR=args.learnRate
#CMAC parameters
resolution=0.1
Range=[2]*NF
Min=[-1]*NF
Res=[resolution]*NF
#Sarsa Agent Parameters
wt_filename="weights_"+str(NOT+1)+"v"+str(args.numOpponents)+'_'+str(args.suffix)
discFac=1
Lambda=0
eps=0.01
#initialize the function approximator and the sarsa agent
FA=CMAC(NF, NA, Range, Min, Res)
SA=SarsaAgent(NF, NA, learnR, eps, Lambda, FA, wt_filename, wt_filename)
#episode rollouts
st = np.empty(NF,dtype=np.float64)
action = -1
reward = 0
for episode in range(1,args.numEpisodes+1):
count=0
status=IN_GAME
action=-1
while status==IN_GAME:
count=count+1
# Grab the state features from the environment
state = hfo.getState()
if int(state[5])==1:
if action != -1:
#print(st)
reward=getReward(status)
#fb.SA.update(state,action,reward,discFac)
SA.update(st,action,reward,discFac)
st=purge_features(state)
#take an action
#action = fb.SA.selectAction(state)
action = SA.selectAction(st)
#print("Action:", action)
if action == 0:
hfo.act(SHOOT)
elif action == 1:
hfo.act(DRIBBLE)
else:
hfo.act(PASS,state[(9+6*NOT)-(action-2)*3])
else:
hfo.act(MOVE)
status = hfo.step()
#--------------- end of while loop ------------------------------------------------------
############# EPISODE ENDS ###################################################################################
# Check the outcome of the episode
if action != -1:
reward=getReward(status)
SA.update(st, action, reward, discFac)
SA.endEpisode()
############################################################################################################
# Quit if the server goes down
if status == SERVER_DOWN:
hfo.act(QUIT)
break
#! /bin/sh
#This script calls the python implementation of the high_level_sarsa_agent
#In essence it calls the relevant functions from a thin python wrapper written over the C++ sarsa_libraries
# HOW TO RUN
#takes in the number of trails as first argument
#takes in the number of offense agents as second argument
#takes in the number of defense agents as the third argument
# eg. if one needs to run 200 episodes of 2v2 then execute
# ./simulate_python_sarsa_agents.sh 200 2 2
port=6000
trials=10000
oa=2 #number of offense agents
da=1 #number of defense agents
if [ "$#" -lt 1 ]
then
:
else
trials=$1
oa=$2
da=$3
fi
#kill any other simulations that may be running
killall -9 rcssserver
sleep 2
cd .. #cd to HFO directory
rm weights* #remove weights from old runs
python="/usr/bin/python3" #which python?
agent_path="./example/sarsa_offense"
log_dir="."
output_path=$agent_path
agent_filename="high_level_sarsa_agent.py"
#start the server
stdbuf -oL ./bin/HFO --port=$port --no-logging --offense-agents=$oa --defense-npcs=$da --trials=$trials --defense-team=base --headless --fullstate > $log_dir/"$oa"v"$da""_sarsa_py_agents.log" &
#each agent is a seperate process
for n in $(seq 1 $oa)
do
sleep 5
fname="agent"
fname+=$n
fname+=".txt"
logfile=$log_dir/$fname
rm $logfile
$python $agent_path/$agent_filename --port=$port --numTeammates=`expr $oa - 1` --numOpponents=$da --numEpisodes=$trials &> $log_dir/$fname &
done
# The magic line
# $$ holds the PID for this script
# Negation means kill by process group id instead of PID
trap "kill -TERM -$$" SIGINT
wait
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment