Skip to content

Commit

Permalink
#rust #protobuf Refactor maps to make ProxiedInMapValue independent o…
Browse files Browse the repository at this point in the history
…f the runtime

 - ProxiedInMapValue is defined in maps.rs, and no longer in the runtime files {upb, cpp}.rs.
 - ProxiedInMapValue's methods accept and return Proxied types.
 - InnerMapMut no longer has any generic type parameters.
 - Through this refactoring the Map type is no longer a ZST. Creating a new map is now as simple as `Map::new()`.

PiperOrigin-RevId: 597765165
  • Loading branch information
buchgr authored and copybara-github committed Jan 12, 2024
1 parent baaaca8 commit 8d9e3e9
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 434 deletions.
199 changes: 38 additions & 161 deletions rust/cpp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
use crate::ProtoStr;
use crate::__internal::{Enum, Private, PtrAndLen, RawArena, RawMap, RawMessage, RawRepeatedField};
use crate::{
Mut, Proxied, ProxiedInRepeated, Repeated, RepeatedMut, RepeatedView, SettableValue, View,
Map, Mut, ProxiedInMapValue, ProxiedInRepeated, Repeated, RepeatedMut, RepeatedView,
SettableValue, View,
};
use core::fmt::Debug;
use paste::paste;
Expand Down Expand Up @@ -327,58 +328,15 @@ pub fn cast_enum_repeated_mut<E: Enum + ProxiedInRepeated>(
}
}

#[derive(Debug)]
pub struct MapInner<'msg, K: ?Sized, V: ?Sized> {
pub raw: RawMap,
pub _phantom_key: PhantomData<&'msg mut K>,
pub _phantom_value: PhantomData<&'msg mut V>,
}

impl<'msg, K: ?Sized, V: ?Sized> Copy for MapInner<'msg, K, V> {}
impl<'msg, K: ?Sized, V: ?Sized> Clone for MapInner<'msg, K, V> {
fn clone(&self) -> MapInner<'msg, K, V> {
*self
}
}

pub trait ProxiedInMapValue<K>: Proxied
where
K: Proxied + ?Sized,
{
fn new_map() -> RawMap;
fn clear(m: RawMap);
fn size(m: RawMap) -> usize;
fn insert(m: RawMap, key: View<'_, K>, value: View<'_, Self>) -> bool;
fn get<'msg>(m: RawMap, key: View<'_, K>) -> Option<View<'msg, Self>>;
fn remove(m: RawMap, key: View<'_, K>) -> bool;
}

impl<'msg, K: Proxied + ?Sized, V: ProxiedInMapValue<K> + ?Sized> Default for MapInner<'msg, K, V> {
fn default() -> Self {
MapInner { raw: V::new_map(), _phantom_key: PhantomData, _phantom_value: PhantomData }
}
#[derive(Clone, Copy, Debug)]
pub struct InnerMapMut<'msg> {
pub(crate) raw: RawMap,
_phantom: PhantomData<&'msg ()>,
}

impl<'msg, K: Proxied + ?Sized, V: ProxiedInMapValue<K> + ?Sized> MapInner<'msg, K, V> {
pub fn size(&self) -> usize {
V::size(self.raw)
}

pub fn clear(&mut self) {
V::clear(self.raw)
}

pub fn get<'a>(&self, key: View<'_, K>) -> Option<View<'a, V>> {
V::get(self.raw, key)
}

pub fn remove(&mut self, key: View<'_, K>) -> bool {
V::remove(self.raw, key)
}

pub fn insert(&mut self, key: View<'_, K>, value: View<'_, V>) -> bool {
V::insert(self.raw, key, value);
true
impl<'msg> InnerMapMut<'msg> {
pub fn new(_private: Private, raw: RawMap) -> Self {
InnerMapMut { raw, _phantom: PhantomData }
}
}

Expand All @@ -387,6 +345,7 @@ macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types {
paste! { $(
extern "C" {
fn [< __pb_rust_Map_ $key_t _ $t _new >]() -> RawMap;
fn [< __pb_rust_Map_ $key_t _ $t _free >](m: RawMap);
fn [< __pb_rust_Map_ $key_t _ $t _clear >](m: RawMap);
fn [< __pb_rust_Map_ $key_t _ $t _size >](m: RawMap) -> usize;
fn [< __pb_rust_Map_ $key_t _ $t _insert >](m: RawMap, key: $ffi_key_t, value: $ffi_t);
Expand All @@ -395,39 +354,55 @@ macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types {
}

impl ProxiedInMapValue<$key_t> for $t {
fn new_map() -> RawMap {
unsafe { [< __pb_rust_Map_ $key_t _ $t _new >]() }
fn map_new(_private: Private) -> Map<$key_t, Self> {
unsafe {
Map::from_inner(
Private,
InnerMapMut {
raw: [< __pb_rust_Map_ $key_t _ $t _new >](),
_phantom: PhantomData
}
)
}
}

unsafe fn map_free(_private: Private, map: &mut Map<$key_t, Self>) {
// SAFETY:
// - `map.inner.raw` is a live `RawMap`
// - This function is only called once for `map` in `Drop`.
unsafe { [< __pb_rust_Map_ $key_t _ $t _free >](map.inner.raw); }
}

fn clear(m: RawMap) {
unsafe { [< __pb_rust_Map_ $key_t _ $t _clear >](m) }

fn map_clear(map: Mut<'_, Map<$key_t, Self>>) {
unsafe { [< __pb_rust_Map_ $key_t _ $t _clear >](map.inner.raw); }
}

fn size(m: RawMap) -> usize {
unsafe { [< __pb_rust_Map_ $key_t _ $t _size >](m) }
fn map_len(map: View<'_, Map<$key_t, Self>>) -> usize {
unsafe { [< __pb_rust_Map_ $key_t _ $t _size >](map.raw) }
}

fn insert(m: RawMap, key: View<'_, $key_t>, value: View<'_, Self>) -> bool {
fn map_insert(map: Mut<'_, Map<$key_t, Self>>, key: View<'_, $key_t>, value: View<'_, Self>) -> bool {
let ffi_key = $to_ffi_key(key);
let ffi_value = $to_ffi_value(value);
unsafe { [< __pb_rust_Map_ $key_t _ $t _insert >](m, ffi_key, ffi_value) }
unsafe { [< __pb_rust_Map_ $key_t _ $t _insert >](map.inner.raw, ffi_key, ffi_value) }
true
}

fn get<'msg>(m: RawMap, key: View<'_, $key_t>) -> Option<View<'msg, Self>> {
fn map_get<'a>(map: View<'a, Map<$key_t, Self>>, key: View<'_, $key_t>) -> Option<View<'a, Self>> {
let ffi_key = $to_ffi_key(key);
let mut ffi_value = $to_ffi_value($zero_val);
let found = unsafe { [< __pb_rust_Map_ $key_t _ $t _get >](m, ffi_key, &mut ffi_value) };
let found = unsafe { [< __pb_rust_Map_ $key_t _ $t _get >](map.raw, ffi_key, &mut ffi_value) };
if !found {
return None;
}
Some($from_ffi_value(ffi_value))
}

fn remove(m: RawMap, key: View<'_, $key_t>) -> bool {
fn map_remove(map: Mut<'_, Map<$key_t, Self>>, key: View<'_, $key_t>) -> bool {
let ffi_key = $to_ffi_key(key);
let mut ffi_value = $to_ffi_value($zero_val);
unsafe { [< __pb_rust_Map_ $key_t _ $t _remove >](m, ffi_key, &mut ffi_value) }
unsafe { [< __pb_rust_Map_ $key_t _ $t _remove >](map.inner.raw, ffi_key, &mut ffi_value) }
}
}
)* }
Expand Down Expand Up @@ -470,16 +445,6 @@ impl_ProxiedInMapValue_for_key_types!(
ProtoStr, PtrAndLen, str_to_ptrlen;
);

#[cfg(test)]
pub(crate) fn new_map_i32_i64() -> MapInner<'static, i32, i64> {
Default::default()
}

#[cfg(test)]
pub(crate) fn new_map_str_str() -> MapInner<'static, ProtoStr, ProtoStr> {
Default::default()
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -500,92 +465,4 @@ mod tests {
let serialized_data = SerializedData { data: NonNull::new(ptr).unwrap(), len };
assert_that!(&*serialized_data, eq(b"Hello world"));
}

#[test]
fn i32_i32_map() {
let mut map: MapInner<'_, i32, i32> = Default::default();
assert_that!(map.size(), eq(0));

assert_that!(map.insert(1, 2), eq(true));
assert_that!(map.get(1), eq(Some(2)));
assert_that!(map.get(3), eq(None));
assert_that!(map.size(), eq(1));

assert_that!(map.remove(1), eq(true));
assert_that!(map.size(), eq(0));
assert_that!(map.remove(1), eq(false));

assert_that!(map.insert(4, 5), eq(true));
assert_that!(map.insert(6, 7), eq(true));
map.clear();
assert_that!(map.size(), eq(0));
}

#[test]
fn i64_f64_map() {
let mut map: MapInner<'_, i64, f64> = Default::default();
assert_that!(map.size(), eq(0));

assert_that!(map.insert(1, 2.5), eq(true));
assert_that!(map.get(1), eq(Some(2.5)));
assert_that!(map.get(3), eq(None));
assert_that!(map.size(), eq(1));

assert_that!(map.remove(1), eq(true));
assert_that!(map.size(), eq(0));
assert_that!(map.remove(1), eq(false));

assert_that!(map.insert(4, 5.1), eq(true));
assert_that!(map.insert(6, 7.2), eq(true));
map.clear();
assert_that!(map.size(), eq(0));
}

#[test]
fn str_str_map() {
let mut map = MapInner::<'_, ProtoStr, ProtoStr>::default();
assert_that!(map.size(), eq(0));

map.insert("fizz".into(), "buzz".into());
assert_that!(map.size(), eq(1));
assert_that!(map.remove("fizz".into()), eq(true));
map.clear();
assert_that!(map.size(), eq(0));
}

#[test]
fn u64_str_map() {
let mut map = MapInner::<'_, u64, ProtoStr>::default();
assert_that!(map.size(), eq(0));

map.insert(1, "fizz".into());
map.insert(2, "buzz".into());
assert_that!(map.size(), eq(2));
assert_that!(map.remove(1), eq(true));
assert_that!(map.get(1), eq(None));
map.clear();
assert_that!(map.size(), eq(0));
}

#[test]
fn test_all_maps_can_be_constructed() {
macro_rules! gen_proto_values {
($key_t:ty, $($value_t:ty),*) => {
$(
let map = MapInner::<'_, $key_t, $value_t>::default();
assert_that!(map.size(), eq(0));
)*
}
}

macro_rules! gen_proto_keys {
($($key_t:ty),*) => {
$(
gen_proto_values!($key_t, f32, f64, i32, u32, i64, bool, ProtoStr);
)*
}
}

gen_proto_keys!(i32, u32, i64, u64, bool, ProtoStr);
}
}
4 changes: 4 additions & 0 deletions rust/cpp_kernel/cpp_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ expose_repeated_field_methods(int64_t, i64);
__pb_rust_Map_##rust_key_ty##_##rust_value_ty##_new() { \
return new google::protobuf::Map<key_ty, value_ty>(); \
} \
void __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_free( \
google::protobuf::Map<key_ty, value_ty>* m) { \
delete m; \
} \
void __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_clear( \
google::protobuf::Map<key_ty, value_ty>* m) { \
m->clear(); \
Expand Down
Loading

0 comments on commit 8d9e3e9

Please sign in to comment.