Commit 5f87d268 authored by Sushant Mahajan's avatar Sushant Mahajan

fixed command scanning code. Added test case file and targeted tests for set command

parent a0cf72e5
...@@ -38,6 +38,7 @@ const ( ...@@ -38,6 +38,7 @@ const (
//constant //constant
MAX_CMD_ARGS = 6 MAX_CMD_ARGS = 6
MIN_CMD_ARGS = 2
READ_TIMEOUT = 5 READ_TIMEOUT = 5
) )
...@@ -142,93 +143,107 @@ func handleClient(conn net.Conn, table *KeyValueStore) { ...@@ -142,93 +143,107 @@ func handleClient(conn net.Conn, table *KeyValueStore) {
*return: integer representing error state *return: integer representing error state
*/ */
func isValid(cmd string, tokens []string, conn net.Conn) int { func isValid(cmd string, tokens []string, conn net.Conn) int {
var flag int
switch cmd { switch cmd {
case SET: case SET:
if len(tokens) > 5 || len(tokens) < 4 { if len(tokens) > 5 || len(tokens) < 4 {
flag = 1
logger.Println(cmd, ":Invalid no. of tokens") logger.Println(cmd, ":Invalid no. of tokens")
write(conn, ERR_CMD_ERR)
return 1
} }
if len([]byte(tokens[1])) > 250 { if len([]byte(tokens[1])) > 250 {
flag = 1
logger.Println(cmd, ":Invalid size of key") logger.Println(cmd, ":Invalid size of key")
write(conn, ERR_CMD_ERR)
return 1
} }
if len(tokens) == 5 && tokens[4] != NOREPLY { if len(tokens) == 5 && tokens[4] != NOREPLY {
logger.Println(cmd, ":optional arg incorrect") logger.Println(cmd, ":optional arg incorrect")
flag = 1 write(conn, ERR_CMD_ERR)
return 1
} }
if _, err := strconv.ParseUint(tokens[2], 10, 64); err != nil { if _, err := strconv.ParseUint(tokens[2], 10, 64); err != nil {
logger.Println(cmd, ":expiry time invalid") logger.Println(cmd, ":expiry time invalid")
flag = 1 write(conn, ERR_CMD_ERR)
return 1
} }
if _, err := strconv.ParseUint(tokens[3], 10, 64); err != nil { if _, err := strconv.ParseUint(tokens[3], 10, 64); err != nil {
logger.Println(cmd, ":numBytes invalid") logger.Println(cmd, ":numBytes invalid")
flag = 1 write(conn, ERR_CMD_ERR)
return 1
} }
case GET: case GET:
if len(tokens) != 2 { if len(tokens) != 2 {
flag = 1
logger.Println(cmd, ":Invalid number of arguments") logger.Println(cmd, ":Invalid number of arguments")
write(conn, ERR_CMD_ERR)
return 1
} }
if len(tokens[1]) > 250 { if len(tokens[1]) > 250 {
flag = 1
logger.Println(cmd, ":Invalid key size") logger.Println(cmd, ":Invalid key size")
write(conn, ERR_CMD_ERR)
return 1
} }
case GETM: case GETM:
if len(tokens) != 2 { if len(tokens) != 2 {
flag = 1 logger.Println(cmd, ":Invalid number of tokens")
write(conn, ERR_CMD_ERR)
return 1
} }
if len(tokens[1]) > 250 { if len(tokens[1]) > 250 {
flag = 1
logger.Println(cmd, ":Invalid key size") logger.Println(cmd, ":Invalid key size")
write(conn, ERR_CMD_ERR)
return 1
} }
case CAS: case CAS:
if len(tokens) > 6 || len(tokens) < 5 { if len(tokens) > 6 || len(tokens) < 5 {
flag = 1 logger.Println(cmd, ":Invalid number of tokens")
write(conn, ERR_CMD_ERR)
return 1
} }
if len([]byte(tokens[1])) > 250 { if len([]byte(tokens[1])) > 250 {
flag = 1
logger.Println(cmd, ":Invalid size of key") logger.Println(cmd, ":Invalid size of key")
write(conn, ERR_CMD_ERR)
return 1
} }
if len(tokens) == 6 && tokens[5] != NOREPLY { if len(tokens) == 6 && tokens[5] != NOREPLY {
logger.Println(cmd, ":optional arg incorrect") logger.Println(cmd, ":optional arg incorrect")
flag = 1 write(conn, ERR_CMD_ERR)
return 1
} }
if _, err := strconv.ParseUint(tokens[2], 10, 64); err != nil { if _, err := strconv.ParseUint(tokens[2], 10, 64); err != nil {
logger.Println(cmd, ":expiry time invalid") logger.Println(cmd, ":expiry time invalid")
flag = 1 write(conn, ERR_CMD_ERR)
return 1
} }
if _, err := strconv.ParseUint(tokens[3], 10, 64); err != nil { if _, err := strconv.ParseUint(tokens[3], 10, 64); err != nil {
logger.Println(cmd, ":version invalid") logger.Println(cmd, ":version invalid")
flag = 1 write(conn, ERR_CMD_ERR)
return 1
} }
if _, err := strconv.ParseUint(tokens[4], 10, 64); err != nil { if _, err := strconv.ParseUint(tokens[4], 10, 64); err != nil {
logger.Println(cmd, ":numbytes invalid") logger.Println(cmd, ":numbytes invalid")
flag = 1 write(conn, ERR_CMD_ERR)
return 1
} }
case DELETE: case DELETE:
if len(tokens) != 2 { if len(tokens) != 2 {
flag = 1 logger.Println(cmd, ":Invalid number of tokens")
write(conn, ERR_CMD_ERR)
return 1
} }
if len([]byte(tokens[1])) > 250 { if len([]byte(tokens[1])) > 250 {
flag = 1
logger.Println(cmd, ":Invalid size of key") logger.Println(cmd, ":Invalid size of key")
write(conn, ERR_CMD_ERR)
return 1
} }
default: default:
return 0 return 0
} }
//compiler is happy
switch flag { return 0
case 1:
write(conn, ERR_CMD_ERR)
}
return flag
} }
/*Function parses the command provided by the client and delegates further action to command specific functions. /*Function parses the command provided by the client and delegates further action to command specific functions.
...@@ -239,7 +254,7 @@ func isValid(cmd string, tokens []string, conn net.Conn) int { ...@@ -239,7 +254,7 @@ func isValid(cmd string, tokens []string, conn net.Conn) int {
func parseInput(conn net.Conn, msg string, table *KeyValueStore, ch chan []byte) { func parseInput(conn net.Conn, msg string, table *KeyValueStore, ch chan []byte) {
tokens := strings.Fields(msg) tokens := strings.Fields(msg)
//general error, don't check for commands, avoid the pain ;) //general error, don't check for commands, avoid the pain ;)
if len(tokens) > MAX_CMD_ARGS { if len(tokens) > MAX_CMD_ARGS || len(tokens) < MIN_CMD_ARGS {
write(conn, ERR_CMD_ERR) write(conn, ERR_CMD_ERR)
return return
} }
...@@ -385,6 +400,10 @@ func readValue(ch chan []byte, n uint64) ([]byte, bool) { ...@@ -385,6 +400,10 @@ func readValue(ch chan []byte, n uint64) ([]byte, bool) {
case temp := <-ch: case temp := <-ch:
logger.Println("Value chunk read!") logger.Println("Value chunk read!")
valReadLength += uint64(len(temp)) valReadLength += uint64(len(temp))
if valReadLength > n+2 {
err = true
break
}
v = append(v, temp...) v = append(v, temp...)
case <-up: case <-up:
...@@ -424,7 +443,7 @@ func performSet(conn net.Conn, tokens []string, table *KeyValueStore, ch chan [] ...@@ -424,7 +443,7 @@ func performSet(conn net.Conn, tokens []string, table *KeyValueStore, ch chan []
logger.Println(r) logger.Println(r)
if v, err := readValue(ch, n); err { if v, err := readValue(ch, n); err {
write(conn, ERR_INTERNAL) write(conn, ERR_CMD_ERR)
return 0, false, r return 0, false, r
} else { } else {
defer table.Unlock() defer table.Unlock()
...@@ -609,7 +628,7 @@ func CustomSplitter(data []byte, atEOF bool) (advance int, token []byte, err err ...@@ -609,7 +628,7 @@ func CustomSplitter(data []byte, atEOF bool) (advance int, token []byte, err err
//here we add omega as we are using the complete data array instead of the slice where we found '\n' //here we add omega as we are using the complete data array instead of the slice where we found '\n'
if data[omega+i-1] == '\r' { if data[omega+i-1] == '\r' {
//next byte begins at i+1 and data[0:i+1] returned //next byte begins at i+1 and data[0:i+1] returned
return i + 1, data[0 : i+1], nil return omega + i + 1, data[:omega+i+1], nil
} else { } else {
//move the omega index to the byte after \n //move the omega index to the byte after \n
omega = i + 1 omega = i + 1
...@@ -639,8 +658,9 @@ func main() { ...@@ -639,8 +658,9 @@ func main() {
toLog = os.Args[1] toLog = os.Args[1]
} }
//toLog = "s"
if toLog != "" { if toLog != "" {
logf, _ := os.OpenFile("serverlog.log", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666) logf, _ := os.OpenFile("serverlog.log", os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
defer logf.Close() defer logf.Close()
logger = log.New(logf, "SERVER: ", log.Ltime|log.Lshortfile) logger = log.New(logf, "SERVER: ", log.Ltime|log.Lshortfile)
} else { } else {
......
...@@ -2,63 +2,60 @@ package main ...@@ -2,63 +2,60 @@ package main
import ( import (
"bytes" "bytes"
"fmt"
"net" "net"
"strconv"
"testing" "testing"
"time" "time"
) )
func TestSet(t *testing.T) { type TestCasePair struct {
command []byte
expected []byte
}
//this test function will start tests for various commands, one client at a time
func TestSerial(t *testing.T) {
go main() go main()
conn, err := net.Dial("tcp", "localhost:5000") //give some time for server to initialize
if err != nil { time.Sleep(time.Second)
t.Errorf("Error connecting to server") testSetCommand(t)
} else {
time.Sleep(time.Second*2)
conn.Write([]byte("set xyz 200 10\r\n"))
time.Sleep(time.Millisecond)
conn.Write([]byte("abcd\r\n"))
buffer := make([]byte, 1024)
conn.Read(buffer)
msg := string(buffer)
if msg == ERR_CMD_ERR+"\r\n" {
t.Errorf("Expected OK <version>")
}
}
} }
func TestGet(t *testing.T) { func testSetCommand(t *testing.T) {
//go main() testSetReplyExpected(t)
conn, err := net.Dial("tcp", "localhost:5000") }
func testSetReplyExpected(t *testing.T) {
conn, err := net.Dial("tcp", ":5000")
defer conn.Close()
time.Sleep(time.Millisecond)
if err != nil { if err != nil {
t.Errorf("Error connecting to server") t.Errorf("Connection Error")
} else { }
time.Sleep(time.Second)
conn.Write([]byte("set xyz 200 10\r\n"))
time.Sleep(time.Millisecond)
conn.Write([]byte("abcdefg\r\n"))
buffer := make([]byte, 1024)
conn.Read(buffer)
msg := string(buffer)
if msg == ERR_CMD_ERR+"\r\n" {
t.Errorf("Expected OK <version>")
}
conn.Write([]byte("get xyz\r\n")) cases := []TestCasePair{
time.Sleep(time.Millisecond) {[]byte("set a 200 10\r\n1234567890\r\n"), []byte("OK 2\r\n")}, //single length
conn.Read(buffer) {[]byte("set b 200 10\r\n12345\r\n890\r\n"), []byte("OK 3\r\n")}, //\r\n in middle
msg = string(buffer) {[]byte("set c 0 10\r\n12345\r\n890\r\n"), []byte("OK 4\r\n")}, //perpetual key
if msg == ERR_CMD_ERR+"\r\n" { {[]byte("set \n 200 10\r\n1234567890\r\n"), []byte("ERR_CMD_ERR\r\nERR_CMD_ERR\r\n")}, //newline key (error)
t.Errorf("Expected key value") {[]byte("set d 200 10\r\n12345678901\r\n"), []byte("ERR_CMD_ERR\r\n")}, //value length greater (error)
} {[]byte("set e 200 10\r\n1234\r6789\r\n"), []byte("ERR_CMD_ERR\r\n")}, //value length less (error)
//key length high (250 bytes)
{[]byte("set 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 200 10\r\n1234\r67890\r\n"), []byte("OK 5\r\n")},
//key length high (251 bytes), error
{[]byte("set 12345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901 200 10\r\n1234\r67890\r\n"), []byte("ERR_CMD_ERR\r\nERR_CMD_ERR\r\n")},
{[]byte("set f -1 10\r\n1234\r6\r89\r\n"), []byte("ERR_CMD_ERR\r\nERR_CMD_ERR\r\n")}, //invalid expiry
{[]byte("set f 200 0\r\n1234\r6\r89\r\n"), []byte("ERR_CMD_ERR\r\nERR_CMD_ERR\r\n")}, //invalid value size
}
conn.Write([]byte("get tuv\r\n")) for i, e := range cases {
time.Sleep(time.Millisecond) buf := make([]byte, 2048)
buffer = make([]byte, 1024) conn.Write(e.command)
conn.Read(buffer) n, _ := conn.Read(buf)
n := bytes.Index(buffer, []byte{0}) if !bytes.Equal(buf[:n], e.expected) {
msg = string(buffer[:n]) fmt.Println(buf[:n], e.expected, string(buf[:n]), string(e.expected))
if msg != ERR_NOT_FOUND+"\r\n" { t.Errorf("Error occured for case:" + strconv.Itoa(i))
t.Errorf("Expected key value")
} }
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment