Commit 01094c3f authored by Sushant Mahajan's avatar Sushant Mahajan

completed code for cas functionality

parent da13f06b
......@@ -30,8 +30,9 @@ const (
// DELETED = "DELETED"
//errors
ERR_CMD_ERR = "ERR_CMD_ERR"
ERR_CMD_ERR = "ERR_CMD_ERR"
ERR_NOT_FOUND = "ERR_NOT_FOUND"
ERR_VERSION = "ERR_VERSION"
//logging
LOG = true
......@@ -229,7 +230,7 @@ func parseInput(conn net.Conn, msg string, table *KeyValueStore) {
buffer.WriteString(" ")
buffer.WriteString(strconv.FormatUint(data.version, 10))
buffer.WriteString(" ")
buffer.WriteString(strconv.FormatUint(data.expTime - uint64(time.Now().Unix()), 10))
buffer.WriteString(strconv.FormatUint(data.expTime-uint64(time.Now().Unix()), 10))
buffer.WriteString(" ")
buffer.WriteString(strconv.FormatUint(data.numBytes, 10))
write(conn, buffer.String())
......@@ -242,7 +243,35 @@ func parseInput(conn net.Conn, msg string, table *KeyValueStore) {
write(conn, buffer.String())
}
//case CAS: performCas(tokens[1:len(tokens)])
case CAS:
if isValid(CAS, tokens, conn) != 0 {
return
}
if ver, ok, r := performCas(conn, tokens[1:len(tokens)], table); r {
if r {
switch ok {
case 0:
buffer.Reset()
buffer.WriteString(OK)
buffer.WriteString(" ")
buffer.WriteString(strconv.FormatUint(ver, 10))
logger.Println(buffer.String())
write(conn, buffer.String())
case 1:
buffer.Reset()
buffer.WriteString(ERR_CMD_ERR)
write(conn, buffer.String())
case 2:
buffer.Reset()
buffer.WriteString(ERR_VERSION)
write(conn, buffer.String())
case 3:
buffer.Reset()
buffer.WriteString(ERR_NOT_FOUND)
write(conn, buffer.String())
}
}
}
//case DELETE: performDelete(tokens[1:len(tokens)])
default:
logger.Println("Command not found")
......@@ -265,7 +294,7 @@ func performSet(conn net.Conn, tokens []string, table *KeyValueStore) (uint64, b
v, ok := read(conn, n+2)
if !ok {
//error here
return 0, false, false
return 0, false, r
}
table.Lock()
......@@ -282,7 +311,7 @@ func performSet(conn net.Conn, tokens []string, table *KeyValueStore) (uint64, b
val.numBytes = n
val.version = ver
if e == 0 {
val.expTime = e
val.expTime = e
} else {
val.expTime = e + uint64(time.Now().Unix())
}
......@@ -296,12 +325,12 @@ func performSet(conn net.Conn, tokens []string, table *KeyValueStore) (uint64, b
func performGet(conn net.Conn, tokens []string, table *KeyValueStore) (*Data, bool) {
k := tokens[0]
defer table.RUnlock()
table.RUnlock()
table.RLock()
//critical section begin
if v, ok := table.dictionary[k]; ok {
if v.expTime != 0 && v.expTime < uint64(time.Now().Unix()) {
table.RUnlock()
return nil, false
}
data := new(Data)
......@@ -327,12 +356,52 @@ func performGetm(conn net.Conn, tokens []string, table *KeyValueStore) (*Data, b
data.expTime = v.expTime
data.numBytes = v.numBytes
data.value = v.value[:]
return data, true
} else {
return nil, false
}
}
func performCas(conn net.Conn, tokens []string, table *KeyValueStore) (uint64, int, bool) {
k := tokens[0]
e, _ := strconv.ParseUint(tokens[1], 10, 64)
ve, _ := strconv.ParseUint(tokens[2], 10, 64)
n, _ := strconv.ParseUint(tokens[3], 10, 64)
r := true
logger.Println(k, e, ve, n, r)
if len(tokens) == 5 && tokens[4] == NOREPLY {
r = false
}
//read value
v, ok := read(conn, n+2)
if !ok {
//error here
return 0, 1, r //malformed
}
defer table.Unlock()
table.Lock()
if val, ok := table.dictionary[k]; ok {
if val.version == ve {
if e == 0 {
val.expTime = e
} else {
val.expTime = e + uint64(time.Now().Unix())
}
val.numBytes = n
ver++
val.version = ver
val.value = v
return val.version, 0, r //key found and changed
}
return 0, 2, r //version mismatch
}
return 0, 3, r //key not found
}
func debug(table *KeyValueStore) {
logger.Println("----start debug----")
for key, val := range (*table).dictionary {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment