Skip to content

Commit

Permalink
Require Source methods be implemented in test (rust-lang#406)
Browse files Browse the repository at this point in the history
require Source methods be implemented in test

This just makes sure we don't forget to forward Source methods
on types that wrap other sources
  • Loading branch information
KodrAus authored Aug 4, 2020
1 parent 803a23b commit 469a441
Showing 1 changed file with 67 additions and 27 deletions.
94 changes: 67 additions & 27 deletions src/kv/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,14 @@ pub trait Source {
///
/// A source that can provide a more efficient implementation of this method
/// should override it.
#[cfg(not(test))]
fn get<'v>(&'v self, key: Key) -> Option<Value<'v>> {
struct Get<'k, 'v> {
key: Key<'k>,
found: Option<Value<'v>>,
}

impl<'k, 'kvs> Visitor<'kvs> for Get<'k, 'kvs> {
fn visit_pair(&mut self, key: Key<'kvs>, value: Value<'kvs>) -> Result<(), Error> {
if self.key == key {
self.found = Some(value);
}

Ok(())
}
}

let mut get = Get { key, found: None };

let _ = self.visit(&mut get);
get.found
get_default(self, key)
}

#[cfg(test)]
fn get<'v>(&'v self, key: Key) -> Option<Value<'v>>;

/// Count the number of key-value pairs that can be visited.
///
/// # Implementation notes
Expand All @@ -61,21 +47,53 @@ pub trait Source {
///
/// A subsequent call to `visit` should yield the same number of key-value pairs
/// to the visitor, unless that visitor fails part way through.
#[cfg(not(test))]
fn count(&self) -> usize {
struct Count(usize);
count_default(self)
}

impl<'kvs> Visitor<'kvs> for Count {
fn visit_pair(&mut self, _: Key<'kvs>, _: Value<'kvs>) -> Result<(), Error> {
self.0 += 1;
#[cfg(test)]
fn count(&self) -> usize;
}

Ok(())
/// The default implemention of `Source::get`
pub(crate) fn get_default<'v>(source: &'v (impl Source + ?Sized), key: Key) -> Option<Value<'v>> {
struct Get<'k, 'v> {
key: Key<'k>,
found: Option<Value<'v>>,
}

impl<'k, 'kvs> Visitor<'kvs> for Get<'k, 'kvs> {
fn visit_pair(&mut self, key: Key<'kvs>, value: Value<'kvs>) -> Result<(), Error> {
if self.key == key {
self.found = Some(value);
}

Ok(())
}
}

let mut get = Get { key, found: None };

let _ = source.visit(&mut get);
get.found
}

/// The default implementation of `Source::count`.
pub(crate) fn count_default(source: impl Source) -> usize {
struct Count(usize);

impl<'kvs> Visitor<'kvs> for Count {
fn visit_pair(&mut self, _: Key<'kvs>, _: Value<'kvs>) -> Result<(), Error> {
self.0 += 1;

let mut count = Count(0);
let _ = self.visit(&mut count);
count.0
Ok(())
}
}

let mut count = Count(0);
let _ = source.visit(&mut count);
count.0
}

impl<'a, T> Source for &'a T
Expand Down Expand Up @@ -129,6 +147,16 @@ where
Ok(())
}

fn get<'v>(&'v self, key: Key) -> Option<Value<'v>> {
for source in self {
if let Some(found) = source.get(key.clone()) {
return Some(found);
}
}

None
}

fn count(&self) -> usize {
self.len()
}
Expand All @@ -146,6 +174,10 @@ where
Ok(())
}

fn get<'v>(&'v self, key: Key) -> Option<Value<'v>> {
self.as_ref().and_then(|s| s.get(key))
}

fn count(&self) -> usize {
self.as_ref().map(Source::count).unwrap_or(0)
}
Expand Down Expand Up @@ -366,6 +398,14 @@ mod tests {
fn visit<'kvs>(&'kvs self, visitor: &mut dyn Visitor<'kvs>) -> Result<(), Error> {
visitor.visit_pair(self.key.to_key(), self.value.to_value())
}

fn get<'v>(&'v self, key: Key) -> Option<Value<'v>> {
get_default(self, key)
}

fn count(&self) -> usize {
count_default(self)
}
}

assert_eq!(1, Source::count(&("a", 1)));
Expand Down

0 comments on commit 469a441

Please sign in to comment.