Commit 5f87d268 authored by Sushant Mahajan's avatar Sushant Mahajan

fixed command scanning code. Added test case file and targeted tests for set command

parent a0cf72e5
......@@ -38,6 +38,7 @@ const (
//constant
MAX_CMD_ARGS = 6
MIN_CMD_ARGS = 2
READ_TIMEOUT = 5
)
......@@ -142,93 +143,107 @@ func handleClient(conn net.Conn, table *KeyValueStore) {
*return: integer representing error state
*/
func isValid(cmd string, tokens []string, conn net.Conn) int {
var flag int
switch cmd {
case SET:
if len(tokens) > 5 || len(tokens) < 4 {
flag = 1
logger.Println(cmd, ":Invalid no. of tokens")
write(conn, ERR_CMD_ERR)
return 1
}
if len([]byte(tokens[1])) > 250 {
flag = 1
logger.Println(cmd, ":Invalid size of key")
write(conn, ERR_CMD_ERR)
return 1
}
if len(tokens) == 5 && tokens[4] != NOREPLY {
logger.Println(cmd, ":optional arg incorrect")
flag = 1
write(conn, ERR_CMD_ERR)
return 1
}
if _, err := strconv.ParseUint(tokens[2], 10, 64); err != nil {
logger.Println(cmd, ":expiry time invalid")
flag = 1
write(conn, ERR_CMD_ERR)
return 1
}
if _, err := strconv.ParseUint(tokens[3], 10, 64); err != nil {
logger.Println(cmd, ":numBytes invalid")
flag = 1
write(conn, ERR_CMD_ERR)
return 1
}
case GET:
if len(tokens) != 2 {
flag = 1
logger.Println(cmd, ":Invalid number of arguments")
write(conn, ERR_CMD_ERR)
return 1
}
if len(tokens[1]) > 250 {
flag = 1
logger.Println(cmd, ":Invalid key size")
write(conn, ERR_CMD_ERR)
return 1
}
case GETM:
if len(tokens) != 2 {
flag = 1
logger.Println(cmd, ":Invalid number of tokens")
write(conn, ERR_CMD_ERR)
return 1
}
if len(tokens[1]) > 250 {
flag = 1
logger.Println(cmd, ":Invalid key size")
write(conn, ERR_CMD_ERR)
return 1
}
case CAS:
if len(tokens) > 6 || len(tokens) < 5 {
flag = 1
logger.Println(cmd, ":Invalid number of tokens")
write(conn, ERR_CMD_ERR)
return 1
}
if len([]byte(tokens[1])) > 250 {
flag = 1
logger.Println(cmd, ":Invalid size of key")
write(conn, ERR_CMD_ERR)
return 1
}
if len(tokens) == 6 && tokens[5] != NOREPLY {
logger.Println(cmd, ":optional arg incorrect")
flag = 1
write(conn, ERR_CMD_ERR)
return 1
}
if _, err := strconv.ParseUint(tokens[2], 10, 64); err != nil {
logger.Println(cmd, ":expiry time invalid")
flag = 1
write(conn, ERR_CMD_ERR)
return 1
}
if _, err := strconv.ParseUint(tokens[3], 10, 64); err != nil {
logger.Println(cmd, ":version invalid")
flag = 1
write(conn, ERR_CMD_ERR)
return 1
}
if _, err := strconv.ParseUint(tokens[4], 10, 64); err != nil {
logger.Println(cmd, ":numbytes invalid")
flag = 1
write(conn, ERR_CMD_ERR)
return 1
}
case DELETE:
if len(tokens) != 2 {
flag = 1
logger.Println(cmd, ":Invalid number of tokens")
write(conn, ERR_CMD_ERR)
return 1
}
if len([]byte(tokens[1])) > 250 {
flag = 1
logger.Println(cmd, ":Invalid size of key")
write(conn, ERR_CMD_ERR)
return 1
}
default:
return 0
}
switch flag {
case 1:
write(conn, ERR_CMD_ERR)
}
return flag
//compiler is happy
return 0
}
/*Function parses the command provided by the client and delegates further action to command specific functions.
......@@ -239,7 +254,7 @@ func isValid(cmd string, tokens []string, conn net.Conn) int {
func parseInput(conn net.Conn, msg string, table *KeyValueStore, ch chan []byte) {
tokens := strings.Fields(msg)
//general error, don't check for commands, avoid the pain ;)
if len(tokens) > MAX_CMD_ARGS {
if len(tokens) > MAX_CMD_ARGS || len(tokens) < MIN_CMD_ARGS {
write(conn, ERR_CMD_ERR)
return
}
......@@ -385,6 +400,10 @@ func readValue(ch chan []byte, n uint64) ([]byte, bool) {
case temp := <-ch:
logger.Println("Value chunk read!")
valReadLength += uint64(len(temp))
if valReadLength > n+2 {
err = true
break
}
v = append(v, temp...)
case <-up:
......@@ -424,7 +443,7 @@ func performSet(conn net.Conn, tokens []string, table *KeyValueStore, ch chan []
logger.Println(r)
if v, err := readValue(ch, n); err {
write(conn, ERR_INTERNAL)
write(conn, ERR_CMD_ERR)
return 0, false, r
} else {
defer table.Unlock()
......@@ -609,7 +628,7 @@ func CustomSplitter(data []byte, atEOF bool) (advance int, token []byte, err err
//here we add omega as we are using the complete data array instead of the slice where we found '\n'
if data[omega+i-1] == '\r' {
//next byte begins at i+1 and data[0:i+1] returned
return i + 1, data[0 : i+1], nil
return omega + i + 1, data[:omega+i+1], nil
} else {
//move the omega index to the byte after \n
omega = i + 1
......@@ -639,8 +658,9 @@ func main() {
toLog = os.Args[1]
}
//toLog = "s"
if toLog != "" {
logf, _ := os.OpenFile("serverlog.log", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
logf, _ := os.OpenFile("serverlog.log", os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
defer logf.Close()
logger = log.New(logf, "SERVER: ", log.Ltime|log.Lshortfile)
} else {
......
......@@ -2,63 +2,60 @@ package main
import (
"bytes"
"fmt"
"net"
"strconv"
"testing"
"time"
)
func TestSet(t *testing.T) {
type TestCasePair struct {
command []byte
expected []byte
}
//this test function will start tests for various commands, one client at a time
func TestSerial(t *testing.T) {
go main()
conn, err := net.Dial("tcp", "localhost:5000")
if err != nil {
t.Errorf("Error connecting to server")
} else {
time.Sleep(time.Second*2)
conn.Write([]byte("set xyz 200 10\r\n"))
time.Sleep(time.Millisecond)
conn.Write([]byte("abcd\r\n"))
buffer := make([]byte, 1024)
conn.Read(buffer)
msg := string(buffer)
if msg == ERR_CMD_ERR+"\r\n" {
t.Errorf("Expected OK <version>")
}
}
//give some time for server to initialize
time.Sleep(time.Second)
testSetCommand(t)
}
func TestGet(t *testing.T) {
//go main()
conn, err := net.Dial("tcp", "localhost:5000")
func testSetCommand(t *testing.T) {
testSetReplyExpected(t)
}
func testSetReplyExpected(t *testing.T) {
conn, err := net.Dial("tcp", ":5000")
defer conn.Close()
time.Sleep(time.Millisecond)
if err != nil {
t.Errorf("Error connecting to server")
} else {
time.Sleep(time.Second)
conn.Write([]byte("set xyz 200 10\r\n"))
time.Sleep(time.Millisecond)
conn.Write([]byte("abcdefg\r\n"))
buffer := make([]byte, 1024)
conn.Read(buffer)
msg := string(buffer)
if msg == ERR_CMD_ERR+"\r\n" {
t.Errorf("Expected OK <version>")
}
t.Errorf("Connection Error")
}
conn.Write([]byte("get xyz\r\n"))
time.Sleep(time.Millisecond)
conn.Read(buffer)
msg = string(buffer)
if msg == ERR_CMD_ERR+"\r\n" {
t.Errorf("Expected key value")
}
cases := []TestCasePair{
{[]byte("set a 200 10\r\n1234567890\r\n"), []byte("OK 2\r\n")}, //single length
{[]byte("set b 200 10\r\n12345\r\n890\r\n"), []byte("OK 3\r\n")}, //\r\n in middle
{[]byte("set c 0 10\r\n12345\r\n890\r\n"), []byte("OK 4\r\n")}, //perpetual key
{[]byte("set \n 200 10\r\n1234567890\r\n"), []byte("ERR_CMD_ERR\r\nERR_CMD_ERR\r\n")}, //newline key (error)
{[]byte("set d 200 10\r\n12345678901\r\n"), []byte("ERR_CMD_ERR\r\n")}, //value length greater (error)
{[]byte("set e 200 10\r\n1234\r6789\r\n"), []byte("ERR_CMD_ERR\r\n")}, //value length less (error)
//key length high (250 bytes)
{[]byte("set 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 200 10\r\n1234\r67890\r\n"), []byte("OK 5\r\n")},
//key length high (251 bytes), error
{[]byte("set 12345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901 200 10\r\n1234\r67890\r\n"), []byte("ERR_CMD_ERR\r\nERR_CMD_ERR\r\n")},
{[]byte("set f -1 10\r\n1234\r6\r89\r\n"), []byte("ERR_CMD_ERR\r\nERR_CMD_ERR\r\n")}, //invalid expiry
{[]byte("set f 200 0\r\n1234\r6\r89\r\n"), []byte("ERR_CMD_ERR\r\nERR_CMD_ERR\r\n")}, //invalid value size
}
conn.Write([]byte("get tuv\r\n"))
time.Sleep(time.Millisecond)
buffer = make([]byte, 1024)
conn.Read(buffer)
n := bytes.Index(buffer, []byte{0})
msg = string(buffer[:n])
if msg != ERR_NOT_FOUND+"\r\n" {
t.Errorf("Expected key value")
for i, e := range cases {
buf := make([]byte, 2048)
conn.Write(e.command)
n, _ := conn.Read(buf)
if !bytes.Equal(buf[:n], e.expected) {
fmt.Println(buf[:n], e.expected, string(buf[:n]), string(e.expected))
t.Errorf("Error occured for case:" + strconv.Itoa(i))
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment