Skip to content

Commit

Permalink
refactor: allow the same key for different data types
Browse files Browse the repository at this point in the history
Because Redis uses a global key namespace, it's possible to
create a key of one type (e.g., a string) and then try to
work with is as if it were of another type (e.g., a hash).
I call this a "key type mismatch" situation.

Redis' handling of key type mismatches is a mess. Sometimes
it allows them (SET), sometimes it ignores them (MGET),
sometimes it forbids them (HSET).

Starting with this commit, Redka takes a more consistent
approach. Now you can use the same key for different data
types, and everything will work fine. With a small caveat:
the `type` column in the `rkey` table will store the last
modified type (it doesn't affect any operations though).

Having said that. Please don't use the same key for
different data types. It's a VERY bad idea.
  • Loading branch information
nalgeon committed Apr 28, 2024
1 parent 2a0dfa9 commit c6bdaf0
Show file tree
Hide file tree
Showing 24 changed files with 141 additions and 144 deletions.
3 changes: 0 additions & 3 deletions internal/command/a.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ var (
ErrInvalidExpireTime = errors.New("ERR invalid expire time")
ErrInvalidFloat = errors.New("ERR value is not a float")
ErrInvalidInt = errors.New("ERR value is not an integer")
ErrKeyType = errors.New("WRONGTYPE Operation against a key holding the wrong kind of value")
ErrNestedMulti = errors.New("ERR MULTI calls can not be nested")
ErrNotFound = errors.New("ERR no such key")
ErrNotInMulti = errors.New("ERR EXEC without MULTI")
Expand Down Expand Up @@ -200,8 +199,6 @@ func (cmd baseCmd) Error(err error) string {
switch err {
case core.ErrNotFound:
err = ErrNotFound
case core.ErrKeyType:
err = ErrKeyType
}
return fmt.Sprintf("%s (%s)", err, cmd.Name())
}
Expand Down
7 changes: 3 additions & 4 deletions internal/command/zadd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package command
import (
"testing"

"github.com/nalgeon/redka/internal/core"
"github.com/nalgeon/redka/internal/testx"
)

Expand Down Expand Up @@ -145,8 +144,8 @@ func TestZAddExec(t *testing.T) {
cmd := mustParse[*ZAdd]("zadd key 11 one")
conn := new(fakeConn)
res, err := cmd.Run(conn, red)
testx.AssertErr(t, err, core.ErrKeyType)
testx.AssertEqual(t, res, nil)
testx.AssertEqual(t, conn.out(), ErrKeyType.Error()+" (zadd)")
testx.AssertNoErr(t, err)
testx.AssertEqual(t, res, 1)
testx.AssertEqual(t, conn.out(), "1")
})
}
7 changes: 3 additions & 4 deletions internal/command/zincrby_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package command
import (
"testing"

"github.com/nalgeon/redka/internal/core"
"github.com/nalgeon/redka/internal/testx"
)

Expand Down Expand Up @@ -122,8 +121,8 @@ func TestZIncrByExec(t *testing.T) {
cmd := mustParse[*ZIncrBy]("zincrby key 25.5 one")
conn := new(fakeConn)
res, err := cmd.Run(conn, red)
testx.AssertErr(t, err, core.ErrKeyType)
testx.AssertEqual(t, res, nil)
testx.AssertEqual(t, conn.out(), ErrKeyType.Error()+" (zincrby)")
testx.AssertNoErr(t, err)
testx.AssertEqual(t, res, 25.5)
testx.AssertEqual(t, conn.out(), "25.5")
})
}
10 changes: 5 additions & 5 deletions internal/command/zinterstore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,11 @@ func TestZInterStoreExec(t *testing.T) {
cmd := mustParse[*ZInterStore]("zinterstore dest 1 key")
conn := new(fakeConn)
res, err := cmd.Run(conn, red)
testx.AssertErr(t, err, core.ErrKeyType)
testx.AssertEqual(t, res, nil)
testx.AssertEqual(t, conn.out(), ErrKeyType.Error()+" (zinterstore)")
testx.AssertNoErr(t, err)
testx.AssertEqual(t, res, 1)
testx.AssertEqual(t, conn.out(), "1")

dest, _ := db.Str().Get("dest")
testx.AssertEqual(t, dest.String(), "value")
count, _ := db.ZSet().Len("dest")
testx.AssertEqual(t, count, 1)
})
}
10 changes: 5 additions & 5 deletions internal/command/zunionstore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,11 @@ func TestZUnionStoreExec(t *testing.T) {
cmd := mustParse[*ZUnionStore]("zunionstore dest 1 key")
conn := new(fakeConn)
res, err := cmd.Run(conn, red)
testx.AssertErr(t, err, core.ErrKeyType)
testx.AssertEqual(t, res, nil)
testx.AssertEqual(t, conn.out(), ErrKeyType.Error()+" (zunionstore)")
testx.AssertNoErr(t, err)
testx.AssertEqual(t, res, 1)
testx.AssertEqual(t, conn.out(), "1")

dest, _ := db.Str().Get("dest")
testx.AssertEqual(t, dest.String(), "value")
count, _ := db.ZSet().Len("dest")
testx.AssertEqual(t, count, 1)
})
}
1 change: 0 additions & 1 deletion internal/core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ const InitialVersion = 1
// Common errors returned by data structure methods.
var (
ErrNotFound = errors.New("key not found")
ErrKeyType = errors.New("key type mismatch") // the key already exists with a different type.
ErrValueType = errors.New("invalid value type")
ErrNotAllowed = errors.New("operation not allowed")
)
Expand Down
5 changes: 0 additions & 5 deletions internal/rhash/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ func (d *DB) GetMany(key string, fields ...string) (map[string]core.Value, error
// If the field does not exist, sets it to 0 before the increment.
// If the field value is not an integer, returns ErrValueType.
// If the key does not exist, creates it.
// If the key exists but is not a hash, returns ErrKeyType.
func (d *DB) Incr(key, field string, delta int) (int, error) {
var val int
err := d.Update(func(tx *Tx) error {
Expand All @@ -90,7 +89,6 @@ func (d *DB) Incr(key, field string, delta int) (int, error) {
// If the field does not exist, sets it to 0 before the increment.
// If the field value is not a float, returns ErrValueType.
// If the key does not exist, creates it.
// If the key exists but is not a hash, returns ErrKeyType.
func (d *DB) IncrFloat(key, field string, delta float64) (float64, error) {
var val float64
err := d.Update(func(tx *Tx) error {
Expand Down Expand Up @@ -139,7 +137,6 @@ func (d *DB) Scanner(key, pattern string, pageSize int) *Scanner {
// Set creates or updates the value of a field in a hash.
// Returns true if the field was created, false if it was updated.
// If the key does not exist, creates it.
// If the key exists but is not a hash, returns ErrKeyType.
func (d *DB) Set(key, field string, value any) (bool, error) {
var created bool
err := d.Update(func(tx *Tx) error {
Expand All @@ -153,7 +150,6 @@ func (d *DB) Set(key, field string, value any) (bool, error) {
// SetMany creates or updates the values of multiple fields in a hash.
// Returns the number of fields created (as opposed to updated).
// If the key does not exist, creates it.
// If the key exists but is not a hash, returns ErrKeyType.
func (d *DB) SetMany(key string, items map[string]any) (int, error) {
var count int
err := d.Update(func(tx *Tx) error {
Expand All @@ -167,7 +163,6 @@ func (d *DB) SetMany(key string, items map[string]any) (int, error) {
// SetNotExists creates the value of a field in a hash if it does not exist.
// Returns true if the field was created, false if it already exists.
// If the key does not exist, creates it.
// If the key exists but is not a hash, returns ErrKeyType.
func (d *DB) SetNotExists(key, field string, value any) (bool, error) {
var created bool
err := d.Update(func(tx *Tx) error {
Expand Down
38 changes: 28 additions & 10 deletions internal/rhash/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ func TestIncr(t *testing.T) {
defer red.Close()
_ = red.Str().Set("person", "alice")
val, err := db.Incr("person", "age", 25)
testx.AssertErr(t, err, core.ErrKeyType)
testx.AssertEqual(t, val, 0)
testx.AssertNoErr(t, err)
testx.AssertEqual(t, val, 25)
})
}

Expand Down Expand Up @@ -323,8 +323,8 @@ func TestIncrFloat(t *testing.T) {
defer red.Close()
_ = red.Str().Set("person", "alice")
val, err := db.IncrFloat("person", "age", 25.0)
testx.AssertErr(t, err, core.ErrKeyType)
testx.AssertEqual(t, val, 0.0)
testx.AssertNoErr(t, err)
testx.AssertEqual(t, val, 25.0)
})
}

Expand Down Expand Up @@ -576,8 +576,14 @@ func TestSet(t *testing.T) {
defer red.Close()
_ = red.Str().Set("person", "alice")
ok, err := db.Set("person", "name", "alice")
testx.AssertErr(t, err, core.ErrKeyType)
testx.AssertEqual(t, ok, false)
testx.AssertNoErr(t, err)
testx.AssertEqual(t, ok, true)

val, _ := db.Get("person", "name")
testx.AssertEqual(t, val.String(), "alice")

sval, _ := red.Str().Get("person")
testx.AssertEqual(t, sval.String(), "alice")
})
}

Expand Down Expand Up @@ -639,11 +645,20 @@ func TestSetMany(t *testing.T) {
red, db := getDB(t)
defer red.Close()
_ = red.Str().Set("person", "alice")

count, err := db.SetMany("person", map[string]any{
"name": "alice", "age": 25,
})
testx.AssertErr(t, err, core.ErrKeyType)
testx.AssertEqual(t, count, 0)
testx.AssertNoErr(t, err)
testx.AssertEqual(t, count, 2)

name, _ := db.Get("person", "name")
testx.AssertEqual(t, name.String(), "alice")
age, _ := db.Get("person", "age")
testx.AssertEqual(t, age.String(), "25")

sval, _ := red.Str().Get("person")
testx.AssertEqual(t, sval.String(), "alice")
})
}

Expand Down Expand Up @@ -683,10 +698,13 @@ func TestSetNotExists(t *testing.T) {
t.Run("key type mismatch", func(t *testing.T) {
red, db := getDB(t)
defer red.Close()

_ = red.Str().Set("person", "alice")
ok, err := db.SetNotExists("person", "name", "alice")
testx.AssertErr(t, err, core.ErrKeyType)
testx.AssertEqual(t, ok, false)
testx.AssertNoErr(t, err)
testx.AssertEqual(t, ok, true)
val, _ := db.Get("person", "name")
testx.AssertEqual(t, val.String(), "alice")
})
}

Expand Down
5 changes: 0 additions & 5 deletions internal/rhash/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ func (tx *Tx) GetMany(key string, fields ...string) (map[string]core.Value, erro
// If the field does not exist, sets it to 0 before the increment.
// If the field value is not an integer, returns ErrValueType.
// If the key does not exist, creates it.
// If the key exists but is not a hash, returns ErrKeyType.
func (tx *Tx) Incr(key, field string, delta int) (int, error) {
// get the current value
val, err := tx.Get(key, field)
Expand Down Expand Up @@ -224,7 +223,6 @@ func (tx *Tx) Incr(key, field string, delta int) (int, error) {
// If the field does not exist, sets it to 0 before the increment.
// If the field value is not a float, returns ErrValueType.
// If the key does not exist, creates it.
// If the key exists but is not a hash, returns ErrKeyType.
func (tx *Tx) IncrFloat(key, field string, delta float64) (float64, error) {
// get the current value
val, err := tx.Get(key, field)
Expand Down Expand Up @@ -337,7 +335,6 @@ func (tx *Tx) Scanner(key, pattern string, pageSize int) *Scanner {
// Set creates or updates the value of a field in a hash.
// Returns true if the field was created, false if it was updated.
// If the key does not exist, creates it.
// If the key exists but is not a hash, returns ErrKeyType.
func (tx *Tx) Set(key string, field string, value any) (bool, error) {
if !core.IsValueType(value) {
return false, core.ErrValueType
Expand All @@ -356,7 +353,6 @@ func (tx *Tx) Set(key string, field string, value any) (bool, error) {
// SetMany creates or updates the values of multiple fields in a hash.
// Returns the number of fields created (as opposed to updated).
// If the key does not exist, creates it.
// If the key exists but is not a hash, returns ErrKeyType.
func (tx *Tx) SetMany(key string, items map[string]any) (int, error) {
for _, val := range items {
if !core.IsValueType(val) {
Expand Down Expand Up @@ -388,7 +384,6 @@ func (tx *Tx) SetMany(key string, items map[string]any) (int, error) {
// SetNotExists creates the value of a field in a hash if it does not exist.
// Returns true if the field was created, false if it already exists.
// If the key does not exist, creates it.
// If the key exists but is not a hash, returns ErrKeyType.
func (tx *Tx) SetNotExists(key, field string, value any) (bool, error) {
if !core.IsValueType(value) {
return false, core.ErrValueType
Expand Down
1 change: 0 additions & 1 deletion internal/rkey/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ func (db *DB) Random() (core.Key, error) {
// Rename changes the key name.
// If there is an existing key with the new name, it is replaced.
// If the old key does not exist, returns ErrNotFound.
// If the new key has a different type, returns ErrKeyType.
func (db *DB) Rename(key, newKey string) error {
err := db.Update(func(tx *Tx) error {
err := tx.Rename(key, newKey)
Expand Down
4 changes: 2 additions & 2 deletions internal/rkey/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,10 @@ func TestRename(t *testing.T) {
_, _ = red.Hash().Set("hash", "field", "value")

err := db.Rename("str", "hash")
testx.AssertEqual(t, err, core.ErrKeyType)
testx.AssertNoErr(t, err)

exists, _ := db.Exists("str")
testx.AssertEqual(t, exists, true)
testx.AssertEqual(t, exists, false)

exists, _ = db.Exists("hash")
testx.AssertEqual(t, exists, true)
Expand Down
10 changes: 0 additions & 10 deletions internal/rkey/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ func (tx *Tx) Random() (core.Key, error) {
// Rename changes the key name.
// If there is an existing key with the new name, it is replaced.
// If the old key does not exist, returns ErrNotFound.
// If the new key has a different type, returns ErrKeyType.
func (tx *Tx) Rename(key, newKey string) error {
// Make sure the old key exists.
oldK, err := tx.Get(key)
Expand All @@ -234,15 +233,6 @@ func (tx *Tx) Rename(key, newKey string) error {
return nil
}

// Make sure the new key does not exist or has the same type.
newK, err := tx.Get(newKey)
if err != nil && err != core.ErrNotFound {
return err
}
if newK.Exists() && newK.Type != oldK.Type {
return core.ErrKeyType
}

// Rename the old key to the new key.
now := time.Now().UnixMilli()
args := []any{
Expand Down
5 changes: 0 additions & 5 deletions internal/rstring/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ func (d *DB) GetMany(keys ...string) (map[string]core.Value, error) {
// Returns the value after the increment.
// If the key does not exist, sets it to 0 before the increment.
// If the key value is not an integer, returns ErrValueType.
// If the key exists but is not a string, returns ErrKeyType.
func (d *DB) Incr(key string, delta int) (int, error) {
var val int
err := d.Update(func(tx *Tx) error {
Expand All @@ -58,7 +57,6 @@ func (d *DB) Incr(key string, delta int) (int, error) {
// Returns the value after the increment.
// If the key does not exist, sets it to 0 before the increment.
// If the key value is not an float, returns ErrValueType.
// If the key exists but is not a string, returns ErrKeyType.
func (d *DB) IncrFloat(key string, delta float64) (float64, error) {
var val float64
err := d.Update(func(tx *Tx) error {
Expand All @@ -71,7 +69,6 @@ func (d *DB) IncrFloat(key string, delta float64) (float64, error) {

// Set sets the key value that will not expire.
// Overwrites the value if the key already exists.
// If the key exists but is not a string, returns ErrKeyType.
func (d *DB) Set(key string, value any) error {
err := d.Update(func(tx *Tx) error {
return tx.Set(key, value)
Expand All @@ -81,7 +78,6 @@ func (d *DB) Set(key string, value any) error {

// SetExpires sets the key value with an optional expiration time (if ttl > 0).
// Overwrites the value and ttl if the key already exists.
// If the key exists but is not a string, returns ErrKeyType.
func (d *DB) SetExpires(key string, value any, ttl time.Duration) error {
err := d.Update(func(tx *Tx) error {
return tx.SetExpires(key, value, ttl)
Expand All @@ -93,7 +89,6 @@ func (d *DB) SetExpires(key string, value any, ttl time.Duration) error {
// Overwrites values for keys that already exist and
// creates new keys/values for keys that do not exist.
// Removes the TTL for existing keys.
// If any of the keys exists but is not a string, returns ErrKeyType.
func (d *DB) SetMany(items map[string]any) error {
err := d.Update(func(tx *Tx) error {
return tx.SetMany(items)
Expand Down
Loading

0 comments on commit c6bdaf0

Please sign in to comment.