Commit 9555df12 authored by Anshul's avatar Anshul

Added working device plugin

parent 64c132ae
...@@ -7,5 +7,7 @@ COPY mps-device-plugin.go . ...@@ -7,5 +7,7 @@ COPY mps-device-plugin.go .
RUN go mod init mps-device-plugin RUN go mod init mps-device-plugin
RUN go get google.golang.org/grpc RUN go get google.golang.org/grpc
RUN go get k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1 RUN go get k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1
RUN go get github.com/golang/glog
RUN go build mps-device-plugin.go
ENTRYPOINT ["./mps-device_plugin"] ENTRYPOINT ["./mps-device-plugin"]
\ No newline at end of file \ No newline at end of file
...@@ -5,6 +5,8 @@ import ( ...@@ -5,6 +5,8 @@ import (
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
"sync"
"time"
"github.com/golang/glog" "github.com/golang/glog"
"google.golang.org/grpc" "google.golang.org/grpc"
...@@ -16,7 +18,8 @@ const ( ...@@ -16,7 +18,8 @@ const (
mpsActiveThreadCmd = "get_default_active_thread_percentage" mpsActiveThreadCmd = "get_default_active_thread_percentage"
mpsMemLimitEnv = "CUDA_MPS_PINNED_DEVICE_MEM_LIMIT" mpsMemLimitEnv = "CUDA_MPS_PINNED_DEVICE_MEM_LIMIT"
mpsThreadLimitEnv = "CUDA_MPS_ACTIVE_THREAD_PERCENTAGE" mpsThreadLimitEnv = "CUDA_MPS_ACTIVE_THREAD_PERCENTAGE"
socketPlugin = "/var/lib/kubelet/device-plugins/mps-device-plugin.sock" pluginEndpoint = "/device-plugin/mps-device-plugin.sock"
kubeletEndpoint = "/device-plugin/kubelet.sock"
) )
var ( var (
...@@ -27,9 +30,10 @@ var ( ...@@ -27,9 +30,10 @@ var (
type mpsGPUManager struct { type mpsGPUManager struct {
grpcServer *grpc.Server grpcServer *grpc.Server
devices map[string]pluginapi.Device devices map[string]pluginapi.Device
socket string
computePartitionSize int // In utilization (eg. 1%) computePartitionSize int // In utilization (eg. 1%)
memPartitionSize int // In MB (eg. 256MB) memPartitionSize int // In MB (eg. 256MB)
socket string
stop chan bool
} }
func NewMpsGPUManager(computePartitionSize, memPartitionSize int) *mpsGPUManager { func NewMpsGPUManager(computePartitionSize, memPartitionSize int) *mpsGPUManager {
...@@ -37,32 +41,12 @@ func NewMpsGPUManager(computePartitionSize, memPartitionSize int) *mpsGPUManager ...@@ -37,32 +41,12 @@ func NewMpsGPUManager(computePartitionSize, memPartitionSize int) *mpsGPUManager
devices: make(map[string]pluginapi.Device), devices: make(map[string]pluginapi.Device),
computePartitionSize: computePartitionSize, computePartitionSize: computePartitionSize,
memPartitionSize: memPartitionSize, memPartitionSize: memPartitionSize,
socket: socketPlugin, stop: make(chan bool),
} }
} }
func (mgm *mpsGPUManager) ListDevices() []*pluginapi.Device { type pluginService struct {
gpuMemoryAvailable := 16384 // Using static value for now mgm *mpsGPUManager
computeDevicesCount := 100 / mgm.computePartitionSize
memoryDevicesCount := gpuMemoryAvailable / mgm.memPartitionSize
virtualDevices := make([]*pluginapi.Device, computeDevicesCount+memoryDevicesCount)
for i := 0; i < computeDevicesCount; i++ {
virtualDeviceID := fmt.Sprintf("%s-%d", computeResourceName, i)
virtualDevices[i] = &pluginapi.Device{
ID: virtualDeviceID,
Health: pluginapi.Healthy,
}
}
for i := 0; i < memoryDevicesCount; i++ {
virtualDeviceID := fmt.Sprintf("%s-%d", memResourceName, i)
virtualDevices[computeDevicesCount+i] = &pluginapi.Device{
ID: virtualDeviceID,
Health: pluginapi.Healthy,
}
}
return virtualDevices
} }
// func (mgm *mpsGPUManager) isMpsHealthy() error { // func (mgm *mpsGPUManager) isMpsHealthy() error {
...@@ -95,28 +79,119 @@ func (mgm *mpsGPUManager) ListDevices() []*pluginapi.Device { ...@@ -95,28 +79,119 @@ func (mgm *mpsGPUManager) ListDevices() []*pluginapi.Device {
// return nil // return nil
// } // }
func (mgm *mpsGPUManager) Start() error { func (mgm *mpsGPUManager) Serve() {
glog.Infof("Starting MPS GPU Manager") glog.Infof("Starting MPS GPU Manager")
lis, err := net.Listen("unix", mgm.socket) lis, err := net.Listen("unix", pluginEndpoint)
if err != nil { if err != nil {
return err glog.Fatal("starting device plugin server failed : %v", err)
} }
mgm.socket = pluginEndpoint
mgm.grpcServer = grpc.NewServer() mgm.grpcServer = grpc.NewServer()
pluginapi.RegisterDevicePluginServer(mgm.grpcServer, mgm)
glog.Infof("MPS GPU Manager registered with the kubelet") pluginbeta := &pluginService{mgm: mgm}
return mgm.grpcServer.Serve(lis) pluginbeta.RegisterService()
registeredWithKubelet := false
for {
select {
case <-mgm.stop:
close(mgm.stop)
return
default:
{
var wg sync.WaitGroup
wg.Add(1)
// Starts device plugin service.
go func() {
defer wg.Done()
// Blocking call to accept incoming connections.
err := mgm.grpcServer.Serve(lis)
glog.Errorf("device-plugin server stopped serving: %v", err)
}()
if !registeredWithKubelet {
for len(mgm.grpcServer.GetServiceInfo()) <= 0 {
time.Sleep(1 * time.Second)
}
glog.Infoln("device-plugin server started serving")
err = RegisterToKubelet()
if err != nil {
mgm.grpcServer.Stop()
wg.Wait()
glog.Fatal(err)
}
glog.Infoln("device plugin registered with kubelet")
registeredWithKubelet = true
}
}
}
}
}
func RegisterToKubelet() error {
conn, err := grpc.Dial(kubeletEndpoint, grpc.WithInsecure(),
grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("unix", addr, timeout)
}))
if err != nil {
return fmt.Errorf("device plugin cannot connect to the kubelet service: %v", err)
}
defer conn.Close()
client := pluginapi.NewRegistrationClient(conn)
request := &pluginapi.RegisterRequest{
Version: pluginapi.Version,
Endpoint: "mps-device-plugin.sock",
ResourceName: "nvidia.com/gpu",
}
if _, err = client.Register(context.Background(), request); err != nil {
return fmt.Errorf("device plugin cannot register to kubelet service: %v", err)
}
return nil
}
func (ps *pluginService) RegisterService() {
pluginapi.RegisterDevicePluginServer(ps.mgm.grpcServer, ps)
} }
func (mgm *mpsGPUManager) Stop() { func (mgm *mpsGPUManager) Stop() {
if mgm.grpcServer != nil { if mgm.grpcServer != nil {
mgm.grpcServer.Stop() mgm.grpcServer.Stop()
} }
mgm.stop <- true
<-mgm.stop
glog.Infof("MPS GPU Manager stopped") glog.Infof("MPS GPU Manager stopped")
} }
func (mgm *mpsGPUManager) ListAndWatch(empty *pluginapi.Empty, stream pluginapi.DevicePlugin_ListAndWatchServer) error { func (ps *pluginService) ListDevices() []*pluginapi.Device {
gpuMemoryAvailable := 16384 // Using static value for now
computeDevicesCount := 100 / ps.mgm.computePartitionSize
memoryDevicesCount := gpuMemoryAvailable / ps.mgm.memPartitionSize
virtualDevices := make([]*pluginapi.Device, computeDevicesCount+memoryDevicesCount)
for i := 0; i < computeDevicesCount; i++ {
virtualDeviceID := fmt.Sprintf("%s-%d", computeResourceName, i)
virtualDevices[i] = &pluginapi.Device{
ID: virtualDeviceID,
Health: pluginapi.Healthy,
}
}
for i := 0; i < memoryDevicesCount; i++ {
virtualDeviceID := fmt.Sprintf("%s-%d", memResourceName, i)
virtualDevices[computeDevicesCount+i] = &pluginapi.Device{
ID: virtualDeviceID,
Health: pluginapi.Healthy,
}
}
return virtualDevices
}
func (ps *pluginService) ListAndWatch(empty *pluginapi.Empty, stream pluginapi.DevicePlugin_ListAndWatchServer) error {
resp := new(pluginapi.ListAndWatchResponse) resp := new(pluginapi.ListAndWatchResponse)
resp.Devices = mgm.ListDevices() resp.Devices = ps.ListDevices()
if err := stream.Send(resp); err != nil { if err := stream.Send(resp); err != nil {
glog.Infof("Error sending device list : %v", err) glog.Infof("Error sending device list : %v", err)
return err return err
...@@ -125,7 +200,7 @@ func (mgm *mpsGPUManager) ListAndWatch(empty *pluginapi.Empty, stream pluginapi. ...@@ -125,7 +200,7 @@ func (mgm *mpsGPUManager) ListAndWatch(empty *pluginapi.Empty, stream pluginapi.
select {} select {}
} }
func (mgm *mpsGPUManager) Allocate(ctx context.Context, rqt *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) { func (ps *pluginService) Allocate(ctx context.Context, rqt *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) {
allocateResponse := &pluginapi.AllocateResponse{} allocateResponse := &pluginapi.AllocateResponse{}
for _, req := range rqt.ContainerRequests { for _, req := range rqt.ContainerRequests {
...@@ -148,14 +223,14 @@ func (mgm *mpsGPUManager) Allocate(ctx context.Context, rqt *pluginapi.AllocateR ...@@ -148,14 +223,14 @@ func (mgm *mpsGPUManager) Allocate(ctx context.Context, rqt *pluginapi.AllocateR
return allocateResponse, nil return allocateResponse, nil
} }
func (mgm *mpsGPUManager) GetDevicePluginOptions(context.Context, *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) { func (ps *pluginService) GetDevicePluginOptions(context.Context, *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) {
return &pluginapi.DevicePluginOptions{ return &pluginapi.DevicePluginOptions{
PreStartRequired: false, PreStartRequired: false,
GetPreferredAllocationAvailable: false, GetPreferredAllocationAvailable: false,
}, nil }, nil
} }
func (mgm *mpsGPUManager) GetPreferredAllocation(ctx context.Context, rqt *pluginapi.PreferredAllocationRequest) (*pluginapi.PreferredAllocationResponse, error) { func (ps *pluginService) GetPreferredAllocation(ctx context.Context, rqt *pluginapi.PreferredAllocationRequest) (*pluginapi.PreferredAllocationResponse, error) {
preferredAllocateResponse := &pluginapi.PreferredAllocationResponse{} preferredAllocateResponse := &pluginapi.PreferredAllocationResponse{}
for _, req := range rqt.ContainerRequests { for _, req := range rqt.ContainerRequests {
...@@ -170,7 +245,7 @@ func (mgm *mpsGPUManager) GetPreferredAllocation(ctx context.Context, rqt *plugi ...@@ -170,7 +245,7 @@ func (mgm *mpsGPUManager) GetPreferredAllocation(ctx context.Context, rqt *plugi
return preferredAllocateResponse, nil return preferredAllocateResponse, nil
} }
func (mgm *mpsGPUManager) PreStartContainer(context.Context, *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) { func (ps *pluginService) PreStartContainer(context.Context, *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) {
preStartContainerResponse := pluginapi.PreStartContainerResponse{} preStartContainerResponse := pluginapi.PreStartContainerResponse{}
return &preStartContainerResponse, nil return &preStartContainerResponse, nil
} }
...@@ -178,8 +253,8 @@ func (mgm *mpsGPUManager) PreStartContainer(context.Context, *pluginapi.PreStart ...@@ -178,8 +253,8 @@ func (mgm *mpsGPUManager) PreStartContainer(context.Context, *pluginapi.PreStart
func main() { func main() {
mgm := NewMpsGPUManager(1, 256) mgm := NewMpsGPUManager(1, 256)
defer mgm.Stop() defer mgm.Stop()
mgm.Serve()
if err := mgm.Start(); err != nil { // if err := mgm.Serve(); err != nil {
glog.Fatalf("Error starting the MPS GPU Manager : %v", err) // glog.Fatalf("Error starting the MPS GPU Manager : %v", err)
} // }
} }
apiVersion: apps/v1
kind: DaemonSet
metadata:
name: mps-device-plugin
namespace: kube-system
spec:
selector:
matchLabels:
app: mps-device-plugin
template:
metadata:
labels:
app: mps-device-plugin
spec:
hostPID: true
hostIPC: true
hostNetwork: true
serviceAccount: mps-device-plugin-manager
nodeSelector:
mps-gpu-enabled: "true"
containers:
- name: mps-device-plugin
image: xzaviourr/mps-device-plugin:v4
securityContext:
privileged: true
volumeMounts:
- name: device-plugin
mountPath: /device-plugin
volumes:
- name: device-plugin
hostPath:
path: /var/lib/kubelet/device-plugins
...@@ -19,6 +19,6 @@ kubectl label node ub-10 mps-gpu-enabled=true # Add device plugin label ...@@ -19,6 +19,6 @@ kubectl label node ub-10 mps-gpu-enabled=true # Add device plugin label
# Attach daemonset again # Attach daemonset again
# kubectl create namespace gpu-device-plugin-namespace # kubectl create namespace gpu-device-plugin-namespace
kubectl create sa gpu-device-plugin-manager -n kube-system kubectl create sa mps-device-plugin-manager -n kube-system
kubectl create clusterrolebinding gpu-device-plugin-manager-role --clusterrole=cluster-admin --serviceaccount=kube-system:gpu-device-plugin-manager kubectl create clusterrolebinding mps-device-plugin-manager-role --clusterrole=cluster-admin --serviceaccount=kube-system:mps-device-plugin-manager
kubectl apply -f gpu_device_plugin/mps-manager.yaml kubectl apply -f mps-manager.yaml
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment