Skip to content

Commit

Permalink
Add ZINTERSTORE command
Browse files Browse the repository at this point in the history
  • Loading branch information
seppo0010 committed Jul 13, 2015
1 parent 9e8a859 commit a2104b2
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 5 deletions.
2 changes: 1 addition & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
- [x] zremrangebyrank
- [ ] zremrangebylex
- [x] zunionstore
- [ ] zinterstore
- [x] zinterstore
- [x] zrange
- [x] zrangebyscore
- [x] zrevrangebyscore
Expand Down
52 changes: 50 additions & 2 deletions command/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1280,7 +1280,7 @@ fn zrevrank(parser: &Parser, db: &mut Database, dbindex: usize) -> Response {
generic_zrank(db, dbindex, &key, member, true)
}

fn zunionstore(parser: &Parser, db: &mut Database, dbindex: usize) -> Response {
fn zinter_union_store(parser: &Parser, db: &mut Database, dbindex: usize, union: bool) -> Response {
validate!(parser.argv.len() >= 4, "Wrong number of parameters");
let key = try_validate!(parser.get_vec(1), "Invalid key");
let value = {
Expand Down Expand Up @@ -1324,7 +1324,7 @@ fn zunionstore(parser: &Parser, db: &mut Database, dbindex: usize) -> Response {
};
validate!(pos == parser.argv.len(), "Syntax error");
let n = Value::Nil;
match n.zunion(&zsets, weights, aggregate) {
match if union { n.zunion(&zsets, weights, aggregate) } else { n.zinter(&zsets, weights, aggregate) } {
Ok(v) => v,
Err(err) => return Response::Error(err.to_string()),
}
Expand All @@ -1338,6 +1338,14 @@ fn zunionstore(parser: &Parser, db: &mut Database, dbindex: usize) -> Response {
r
}

fn zunionstore(parser: &Parser, db: &mut Database, dbindex: usize) -> Response {
zinter_union_store(parser, db, dbindex, true)
}

fn zinterstore(parser: &Parser, db: &mut Database, dbindex: usize) -> Response {
zinter_union_store(parser, db, dbindex, false)
}

fn ping(parser: &Parser, db: &mut Database, dbindex: usize) -> Response {
#![allow(unused_variables)]
validate!(parser.argv.len() <= 2, "Wrong number of parameters");
Expand Down Expand Up @@ -1558,6 +1566,7 @@ pub fn command(
"zrank" => zrank(parser, db, dbindex),
"zrevrank" => zrevrank(parser, db, dbindex),
"zunionstore" => zunionstore(parser, db, dbindex),
"zinterstore" => zinterstore(parser, db, dbindex),
"dump" => dump(parser, db, dbindex),
"subscribe" => return subscribe(parser, db, subscriptions.unwrap(), pattern_subscriptions.unwrap().len(), sender.unwrap()),
"unsubscribe" => return unsubscribe(parser, db, subscriptions.unwrap(), pattern_subscriptions.unwrap().len(), sender.unwrap()),
Expand Down Expand Up @@ -2628,6 +2637,45 @@ mod test_command {
assert_eq!(command(&parser!(b"zscore key c"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Data(b"6".to_vec()));
}

#[test]
fn zinterstore_command_short() {
let mut db = Database::new(Config::new(Logger::new(Level::Warning)));
assert_eq!(command(&parser!(b"zadd key1 1 a 2 b 3 c"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Integer(3));
assert_eq!(command(&parser!(b"zadd key2 3 c 4 d 5 e"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Integer(3));
assert_eq!(command(&parser!(b"zinterstore key 3 key1 key2 key3"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Integer(1));
}

#[test]
fn zinterstore_command_weights() {
let mut db = Database::new(Config::new(Logger::new(Level::Warning)));
assert_eq!(command(&parser!(b"zadd key1 1 a 2 b 3 c 4 d"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Integer(4));
assert_eq!(command(&parser!(b"zadd key2 4 d 5 e"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Integer(2));
assert_eq!(command(&parser!(b"zadd key3 0 d"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Integer(1));
assert_eq!(command(&parser!(b"zinterstore key 3 key1 key2 key3 Weights 1 2 3"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Integer(1));
assert_eq!(command(&parser!(b"zscore key d"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Data(b"12".to_vec()));
}

#[test]
fn zinterstore_command_aggregate() {
let mut db = Database::new(Config::new(Logger::new(Level::Warning)));
assert_eq!(command(&parser!(b"zadd key1 1 a 2 b 3 c"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Integer(3));
assert_eq!(command(&parser!(b"zadd key2 9 c 4 d 5 e"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Integer(3));
assert_eq!(command(&parser!(b"zinterstore key 3 key1 key2 key3 aggregate max"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Integer(1));
assert_eq!(command(&parser!(b"zscore key c"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Data(b"9".to_vec()));
assert_eq!(command(&parser!(b"zinterstore key 3 key1 key2 key3 aggregate min"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Integer(1));
assert_eq!(command(&parser!(b"zscore key c"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Data(b"3".to_vec()));
}

#[test]
fn zinterstore_command_weights_aggregate() {
let mut db = Database::new(Config::new(Logger::new(Level::Warning)));
assert_eq!(command(&parser!(b"zadd key1 1 a 2 b 3 c"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Integer(3));
assert_eq!(command(&parser!(b"zadd key2 3 c 4 d 5 e"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Integer(3));
assert_eq!(command(&parser!(b"zinterstore key 3 key1 key2 key3 weights 1 2 3 aggregate max"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Integer(1));
assert_eq!(command(&parser!(b"zscore key c"), &mut db, &mut 0, &mut true, None, None, None).unwrap(), Response::Data(b"6".to_vec()));
}


#[test]
fn select_command() {
let mut db = Database::new(Config::new(Logger::new(Level::Warning)));
Expand Down
131 changes: 129 additions & 2 deletions database/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1239,8 +1239,7 @@ impl Value {
}
}

/// Add multiple sorted sets and store the resulting sorted set.
/// Returns the number of resulting items.
/// Creates a sorted set by merging existing sorted sets.
///
/// # Examples
/// ```
Expand Down Expand Up @@ -1292,6 +1291,54 @@ impl Value {
Ok(Value::SortedSet(value))
}

/// Creates a new sorted set with the intersection of existing sorted sets.
///
/// # Examples
/// ```
/// use database::Value;
/// use database::zset;
///
/// let mut val1 = Value::Nil;
/// val1.zadd(1.1, vec![1], false, false, false, false).unwrap();
/// let mut val2 = Value::Nil;
/// val2.zadd(1.2, vec![1], false, false, false, false).unwrap();
/// val2.zadd(2.2, vec![2], false, false, false, false).unwrap();
/// let mut val3 = Value::Nil;
/// val3.zadd(1.3, vec![1], false, false, false, false).unwrap();
/// val3.zadd(3.3, vec![3], false, false, false, false).unwrap();
/// let mut val4 = Value::Nil;
/// val4 = val4.zinter(&vec![&val1, &val2, &val3], None, zset::Aggregate::Sum).unwrap();
/// assert_eq!(val4.zcard().unwrap(), 1);
/// assert!(((val4.zscore(vec![1]).unwrap().unwrap() - 3.6)).abs() < 0.01);
/// ```
///
/// # Examples
/// ```
/// use database::Value;
/// use database::zset;
///
/// let mut val1 = Value::Nil;
/// val1.zadd(1.1, vec![1], false, false, false, false).unwrap();
/// let mut val2 = Value::Nil;
/// val2.zadd(1.2, vec![1], false, false, false, false).unwrap();
/// val2.zadd(2.2, vec![2], false, false, false, false).unwrap();
/// let mut val3 = Value::Nil;
/// val3.zadd(1.3, vec![1], false, false, false, false).unwrap();
/// val3.zadd(3.3, vec![3], false, false, false, false).unwrap();
/// let mut val4 = Value::Nil;
/// val4 = val4.zinter(&vec![&val1, &val2, &val3], Some(vec![1.0, 2.0, 3.0]), zset::Aggregate::Min).unwrap();
/// assert_eq!(val4.zcard().unwrap(), 1);
/// assert!(((val4.zscore(vec![1]).unwrap().unwrap() - 1.1)).abs() < 0.01);
/// ```
pub fn zinter(&self, zset_values: &Vec<&Value>, weights: Option<Vec<f64>>, aggregate: zset::Aggregate) -> Result<Value, OperationError> {
let emptyzset = ValueSortedSet::new();
let zsets = try!(get_zset_list(zset_values, &emptyzset));

let mut value = ValueSortedSet::new();
value.zinter(zsets, weights, aggregate);
Ok(Value::SortedSet(value))
}

/// Serializes and writes into `writer` the object current value.
/// The serialized version also includes the type, the version and a crc.
///
Expand Down Expand Up @@ -2775,6 +2822,86 @@ mod test_command {
assert!((value3.zscore(v3.clone()).unwrap().unwrap() - 640.0).abs() < 0.01);
}

#[test]
fn zinterstore_sum() {
let mut value1 = Value::Nil;
let mut value2 = Value::Nil;
let mut value3 = Value::Nil;
let v1 = vec![1, 2, 3, 4];
let v2 = vec![5, 6, 7, 8];
let v3 = vec![0, 9, 1, 2];

assert_eq!(value1.zadd(1.1, v1.clone(), false, false, false, false).unwrap(), true);
assert_eq!(value1.zadd(2.1, v2.clone(), false, false, false, false).unwrap(), true);

assert_eq!(value2.zadd(1.2, v1.clone(), false, false, false, false).unwrap(), true);
assert_eq!(value2.zadd(3.2, v3.clone(), false, false, false, false).unwrap(), true);

value3 = value3.zinter(&vec![&value1, &value2], None, zset::Aggregate::Sum).unwrap();
assert_eq!(value3.zcard().unwrap(), 1);
assert!((value3.zscore(v1.clone()).unwrap().unwrap() - 2.3).abs() < 0.01);
}

#[test]
fn zinterstore_min() {
let mut value1 = Value::Nil;
let mut value2 = Value::Nil;
let mut value3 = Value::Nil;
let v1 = vec![1, 2, 3, 4];
let v2 = vec![5, 6, 7, 8];
let v3 = vec![0, 9, 1, 2];

assert_eq!(value1.zadd(1.1, v1.clone(), false, false, false, false).unwrap(), true);
assert_eq!(value1.zadd(2.1, v2.clone(), false, false, false, false).unwrap(), true);

assert_eq!(value2.zadd(1.2, v1.clone(), false, false, false, false).unwrap(), true);
assert_eq!(value2.zadd(3.2, v3.clone(), false, false, false, false).unwrap(), true);

value3 = value3.zinter(&vec![&value1, &value2], None, zset::Aggregate::Min).unwrap();
assert_eq!(value3.zcard().unwrap(), 1);
assert!((value3.zscore(v1.clone()).unwrap().unwrap() - 1.1).abs() < 0.01);
}

#[test]
fn zinterstore_max() {
let mut value1 = Value::Nil;
let mut value2 = Value::Nil;
let mut value3 = Value::Nil;
let v1 = vec![1, 2, 3, 4];
let v2 = vec![5, 6, 7, 8];
let v3 = vec![0, 9, 1, 2];

assert_eq!(value1.zadd(1.1, v1.clone(), false, false, false, false).unwrap(), true);
assert_eq!(value1.zadd(2.1, v2.clone(), false, false, false, false).unwrap(), true);

assert_eq!(value2.zadd(1.2, v1.clone(), false, false, false, false).unwrap(), true);
assert_eq!(value2.zadd(3.2, v3.clone(), false, false, false, false).unwrap(), true);

value3 = value3.zinter(&vec![&value1, &value2], None, zset::Aggregate::Max).unwrap();
assert_eq!(value3.zcard().unwrap(), 1);
assert!((value3.zscore(v1.clone()).unwrap().unwrap() - 1.2).abs() < 0.01);
}

#[test]
fn zinterstore_weights() {
let mut value1 = Value::Nil;
let mut value2 = Value::Nil;
let mut value3 = Value::Nil;
let v1 = vec![1, 2, 3, 4];
let v2 = vec![5, 6, 7, 8];
let v3 = vec![0, 9, 1, 2];

assert_eq!(value1.zadd(1.1, v1.clone(), false, false, false, false).unwrap(), true);
assert_eq!(value1.zadd(2.1, v2.clone(), false, false, false, false).unwrap(), true);

assert_eq!(value2.zadd(1.2, v1.clone(), false, false, false, false).unwrap(), true);
assert_eq!(value2.zadd(3.2, v3.clone(), false, false, false, false).unwrap(), true);

value3 = value3.zinter(&vec![&value1, &value2], Some(vec![100.0, 200.0]), zset::Aggregate::Max).unwrap();
assert_eq!(value3.zcard().unwrap(), 1);
assert!((value3.zscore(v1.clone()).unwrap().unwrap() - 240.0).abs() < 0.01);
}

#[test]
fn pubsub_basic() {
let config = Config::new(Logger::new(Level::Warning));
Expand Down
46 changes: 46 additions & 0 deletions database/src/zset.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::cmp::Ordering;
use std::collections::Bound;
use std::collections::HashMap;
use std::collections::HashSet;
use std::io;
use std::io::Write;
use std::f64::{INFINITY, NEG_INFINITY, NAN};
Expand Down Expand Up @@ -397,6 +398,51 @@ impl ValueSortedSet {
}
}

pub fn zinter(&mut self, zsets: Vec<&ValueSortedSet>, weights: Option<Vec<f64>>, aggregate: Aggregate) {
if zsets.len() == 0 {
return;
}
let mut intersected_keys = {
match *zsets[0] {
ValueSortedSet::Data(_, ref hm) => hm.keys().collect::<HashSet<_>>(),
}
};

for i in 1..zsets.len() {
let zset = zsets[i];
let keys = match *zset {
ValueSortedSet::Data(_, ref hm) => hm.keys().collect::<HashSet<_>>(),
};
intersected_keys = intersected_keys.intersection(&keys).cloned().collect::<HashSet<_>>();
}

for k in intersected_keys {
let hm = match *zsets[0] {
ValueSortedSet::Data(_, ref hm) => hm,
};
let mut score = hm.get(k).unwrap() * (match weights {
Some(ref ws) => ws[0],
None => 1.0,
});
for i in 1..zsets.len() {
let hm = match *zsets[i] {
ValueSortedSet::Data(_, ref hm) => hm,
};
let s2 = hm.get(k).unwrap() * (match weights {
Some(ref ws) => ws[i],
None => 1.0,
});
match aggregate {
Aggregate::Sum => score += s2,
Aggregate::Min => if score > s2 { score = s2; },
Aggregate::Max => if score < s2 { score = s2; },
}
}

self.zadd(score, k.clone(), false, false, false, false);
}
}

pub fn dump<T: Write>(&self, writer: &mut T) -> io::Result<usize> {
let mut v = vec![];
let settype;
Expand Down

0 comments on commit a2104b2

Please sign in to comment.