Commit b4502c1f authored by Sushant Mahajan's avatar Sushant Mahajan

changed the server io to use channels for synchronized communication. Also...

changed the server io to use channels for synchronized communication. Also added timeout for idle connection
parent 14b0a125
package main package main
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
...@@ -33,9 +33,7 @@ const ( ...@@ -33,9 +33,7 @@ const (
ERR_CMD_ERR = "ERR_CMD_ERR" ERR_CMD_ERR = "ERR_CMD_ERR"
ERR_NOT_FOUND = "ERR_NOT_FOUND" ERR_NOT_FOUND = "ERR_NOT_FOUND"
ERR_VERSION = "ERR_VERSION" ERR_VERSION = "ERR_VERSION"
ERR_INTERNAL = "ERR_INTERNAL"
//logging
LOG = true
) )
type Data struct { type Data struct {
...@@ -75,24 +73,17 @@ func startServer() { ...@@ -75,24 +73,17 @@ func startServer() {
} }
} }
func read(conn net.Conn, toRead uint64) ([]byte, bool) { func myRead(ch chan string, conn net.Conn) {
buf := make([]byte, toRead) scanner := bufio.NewScanner(conn)
_, err := conn.Read(buf) for {
if ok := scanner.Scan(); !ok {
if err != nil { break
if err == io.EOF { } else {
logger.Println("Client disconnected!") temp := scanner.Text()
return []byte{0}, false ch <- temp
logger.Println(temp, "$$")
} }
} }
n := bytes.Index(buf, []byte{0})
if n != 0 {
logger.Println("Received: ", buf[:n], string(buf[:n]))
return buf[:n-2], true
}
return []byte{0}, false
} }
func write(conn net.Conn, msg string) { func write(conn net.Conn, msg string) {
...@@ -104,15 +95,17 @@ func write(conn net.Conn, msg string) { ...@@ -104,15 +95,17 @@ func write(conn net.Conn, msg string) {
func handleClient(conn net.Conn, table *KeyValueStore) { func handleClient(conn net.Conn, table *KeyValueStore) {
defer conn.Close() defer conn.Close()
//channel for every connection for every client
ch := make(chan string)
go myRead(ch, conn)
for { for {
if msg, ok := read(conn, 1024); ok { msg := <-ch
if len(msg) == 0 { logger.Println("Channel: ", msg)
continue if len(msg) == 0 {
} continue
parseInput(conn, string(msg), table)
} else {
break
} }
parseInput(conn, string(msg), table, ch)
} }
} }
...@@ -207,8 +200,14 @@ func isValid(cmd string, tokens []string, conn net.Conn) int { ...@@ -207,8 +200,14 @@ func isValid(cmd string, tokens []string, conn net.Conn) int {
return flag return flag
} }
func parseInput(conn net.Conn, msg string, table *KeyValueStore) { func parseInput(conn net.Conn, msg string, table *KeyValueStore, ch chan string) {
tokens := strings.Fields(msg) tokens := strings.Fields(msg)
//general error, don't check for commands, avoid the pain
if len(tokens) > 6 {
write(conn, ERR_CMD_ERR)
return
}
var buffer bytes.Buffer var buffer bytes.Buffer
//logger.Println(tokens) //logger.Println(tokens)
switch tokens[0] { switch tokens[0] {
...@@ -216,7 +215,8 @@ func parseInput(conn net.Conn, msg string, table *KeyValueStore) { ...@@ -216,7 +215,8 @@ func parseInput(conn net.Conn, msg string, table *KeyValueStore) {
if isValid(SET, tokens, conn) != 0 { if isValid(SET, tokens, conn) != 0 {
return return
} }
if ver, ok, r := performSet(conn, tokens[1:len(tokens)], table); ok { if ver, ok, r := performSet(conn, tokens[1:len(tokens)], table, ch); ok {
debug(table)
logger.Println(ver) logger.Println(ver)
if r { if r {
buffer.Reset() buffer.Reset()
...@@ -246,6 +246,7 @@ func parseInput(conn net.Conn, msg string, table *KeyValueStore) { ...@@ -246,6 +246,7 @@ func parseInput(conn net.Conn, msg string, table *KeyValueStore) {
buffer.WriteString(ERR_NOT_FOUND) buffer.WriteString(ERR_NOT_FOUND)
write(conn, buffer.String()) write(conn, buffer.String())
} }
debug(table)
case GETM: case GETM:
if isValid(GETM, tokens, conn) != 0 { if isValid(GETM, tokens, conn) != 0 {
...@@ -270,12 +271,13 @@ func parseInput(conn net.Conn, msg string, table *KeyValueStore) { ...@@ -270,12 +271,13 @@ func parseInput(conn net.Conn, msg string, table *KeyValueStore) {
buffer.WriteString(ERR_NOT_FOUND) buffer.WriteString(ERR_NOT_FOUND)
write(conn, buffer.String()) write(conn, buffer.String())
} }
debug(table)
case CAS: case CAS:
if isValid(CAS, tokens, conn) != 0 { if isValid(CAS, tokens, conn) != 0 {
return return
} }
if ver, ok, r := performCas(conn, tokens[1:len(tokens)], table); r { if ver, ok, r := performCas(conn, tokens[1:len(tokens)], table, ch); r {
if r { if r {
switch ok { switch ok {
case 0: case 0:
...@@ -300,24 +302,74 @@ func parseInput(conn net.Conn, msg string, table *KeyValueStore) { ...@@ -300,24 +302,74 @@ func parseInput(conn net.Conn, msg string, table *KeyValueStore) {
} }
} }
} }
debug(table)
case DELETE: case DELETE:
if isValid(DELETE, tokens, conn) != 0 { if isValid(DELETE, tokens, conn) != 0 {
return return
} }
if ok := performDelete(conn, tokens[1:len(tokens)], table); ok { if ok := performDelete(conn, tokens[1:len(tokens)], table); ok == 0 {
buffer.Reset() write(conn, DELETED)
buffer.WriteString(DELETED) } else {
write(conn, buffer.String()) write(conn, ERR_NOT_FOUND)
} }
debug(table)
default: default:
logger.Println("Command not found") buffer.Reset()
buffer.WriteString(ERR_CMD_ERR)
write(conn, buffer.String())
}
}
/*
*Helper function to read value or cause timeout after 5 seconds
*parameters: channel to read data from, threshold number of bytes to read
*returns: the value string and error state
*/
func readValue(ch chan string, n uint64) ([]byte, bool) {
//now we need to read the value which should have been sent
valReadLength := uint64(0)
var v string
err := false
up := make(chan bool, 1)
//after 5 seconds passed reading value, we'll just send err to client
go func() {
time.Sleep(5 * time.Second)
up <- true
}()
//use select for the data channel and the timeout channel
for valReadLength < n {
select {
case temp := <-ch:
logger.Println("Value chunk read!")
valReadLength += uint64(len(temp))
v += temp
case <-up:
err = true
logger.Println("Oh, Oh timeout")
//write(conn, ERR_INTERNAL)
break
}
//will be true if timeout occurs
if err {
break
}
}
if err {
return []byte{0}, err
} }
return []byte(v), err
} }
func performSet(conn net.Conn, tokens []string, table *KeyValueStore) (uint64, bool, bool) { func performSet(conn net.Conn, tokens []string, table *KeyValueStore, ch chan string) (uint64, bool, bool) {
k := tokens[0] k := tokens[0]
e, _ := strconv.ParseUint(tokens[1], 10, 64) e, _ := strconv.ParseUint(tokens[1], 10, 64) //expiry time offset
n, _ := strconv.ParseUint(tokens[2], 10, 64) n, _ := strconv.ParseUint(tokens[2], 10, 64) //numbytes
r := true r := true
if len(tokens) == 4 && tokens[3] == NOREPLY { if len(tokens) == 4 && tokens[3] == NOREPLY {
...@@ -326,42 +378,36 @@ func performSet(conn net.Conn, tokens []string, table *KeyValueStore) (uint64, b ...@@ -326,42 +378,36 @@ func performSet(conn net.Conn, tokens []string, table *KeyValueStore) (uint64, b
logger.Println(r) logger.Println(r)
//read value if v, err := readValue(ch, n); err {
v, ok := read(conn, n+2) write(conn, ERR_INTERNAL)
if !ok {
//error here
return 0, false, r return 0, false, r
}
table.Lock()
logger.Println("Table locked")
//critical section start
var val *Data
if _, ok := table.dictionary[k]; ok {
val = table.dictionary[k]
} else {
val = new(Data)
table.dictionary[k] = val
}
ver++
val.numBytes = n
val.version = ver
if e == 0 {
val.expTime = e
} else { } else {
val.expTime = e + uint64(time.Now().Unix()) defer table.Unlock()
table.Lock()
//critical section start
var val *Data
if _, ok := table.dictionary[k]; ok {
val = table.dictionary[k]
} else {
val = new(Data)
table.dictionary[k] = val
}
ver++
val.numBytes = n
val.version = ver
if e == 0 {
val.expTime = e
} else {
val.expTime = e + uint64(time.Now().Unix())
}
val.value = v
return val.version, true, r
} }
val.value = v
table.Unlock()
logger.Println("Table unlocked")
debug(table)
return val.version, true, r
} }
func performGet(conn net.Conn, tokens []string, table *KeyValueStore) (*Data, bool) { func performGet(conn net.Conn, tokens []string, table *KeyValueStore) (*Data, bool) {
k := tokens[0] k := tokens[0]
table.RUnlock() defer table.RUnlock()
table.RLock() table.RLock()
//critical section begin //critical section begin
if v, ok := table.dictionary[k]; ok { if v, ok := table.dictionary[k]; ok {
...@@ -399,7 +445,7 @@ func performGetm(conn net.Conn, tokens []string, table *KeyValueStore) (*Data, b ...@@ -399,7 +445,7 @@ func performGetm(conn net.Conn, tokens []string, table *KeyValueStore) (*Data, b
} }
} }
func performCas(conn net.Conn, tokens []string, table *KeyValueStore) (uint64, int, bool) { func performCas(conn net.Conn, tokens []string, table *KeyValueStore, ch chan string) (uint64, int, bool) {
k := tokens[0] k := tokens[0]
e, _ := strconv.ParseUint(tokens[1], 10, 64) e, _ := strconv.ParseUint(tokens[1], 10, 64)
ve, _ := strconv.ParseUint(tokens[2], 10, 64) ve, _ := strconv.ParseUint(tokens[2], 10, 64)
...@@ -412,43 +458,47 @@ func performCas(conn net.Conn, tokens []string, table *KeyValueStore) (uint64, i ...@@ -412,43 +458,47 @@ func performCas(conn net.Conn, tokens []string, table *KeyValueStore) (uint64, i
} }
//read value //read value
v, ok := read(conn, n+2) if v, err := readValue(ch, n); err {
if !ok { return 0, 1, r
//error here } else {
return 0, 1, r //malformed defer table.Unlock()
} table.Lock()
if val, ok := table.dictionary[k]; ok {
defer table.Unlock() if val.version == ve {
table.Lock() if e == 0 {
if val, ok := table.dictionary[k]; ok { val.expTime = e
if val.version == ve { } else {
if e == 0 { val.expTime = e + uint64(time.Now().Unix())
val.expTime = e }
} else { val.numBytes = n
val.expTime = e + uint64(time.Now().Unix()) ver++
val.version = ver
val.value = v
return val.version, 0, r //key found and changed
} }
val.numBytes = n return 0, 2, r //version mismatch
ver++
val.version = ver
val.value = v
return val.version, 0, r //key found and changed
} }
return 0, 2, r //version mismatch return 0, 3, r //key not found
} }
return 0, 3, r //key not found
} }
func performDelete(conn net.Conn, tokens []string, table *KeyValueStore) bool { func performDelete(conn net.Conn, tokens []string, table *KeyValueStore) int {
k := tokens[0] k := tokens[0]
logger.Println(tokens)
flag := 1
defer table.Unlock() defer table.Unlock()
table.Lock() table.Lock()
//begin critical section //begin critical section
if _, ok := table.dictionary[k]; ok { if v, ok := table.dictionary[k]; ok {
delete(table.dictionary, k) if v.expTime < uint64(time.Now().Unix()) {
return true flag = 1 //found but expired
} else {
flag = 0 //found not expired
}
delete(table.dictionary, k) //delete anyway as expired or needs to be deleted
return flag
} }
return false return 2 //key not found
} }
func debug(table *KeyValueStore) { func debug(table *KeyValueStore) {
...@@ -462,7 +512,12 @@ func debug(table *KeyValueStore) { ...@@ -462,7 +512,12 @@ func debug(table *KeyValueStore) {
func main() { func main() {
ver = 1 ver = 1
if LOG { toLog := ""
if len(os.Args) > 1 {
toLog = os.Args[1]
}
if toLog != "" {
logf, _ := os.OpenFile("serverlog.log", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666) logf, _ := os.OpenFile("serverlog.log", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
defer logf.Close() defer logf.Close()
logger = log.New(logf, "SERVER: ", log.Ltime|log.Lshortfile) logger = log.New(logf, "SERVER: ", log.Ltime|log.Lshortfile)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment