diff --git a/rust/cpp.rs b/rust/cpp.rs index ba6d66675e012..4ffbd4b8f39c0 100644 --- a/rust/cpp.rs +++ b/rust/cpp.rs @@ -10,7 +10,8 @@ use crate::ProtoStr; use crate::__internal::{Enum, Private, PtrAndLen, RawArena, RawMap, RawMessage, RawRepeatedField}; use crate::{ - Mut, Proxied, ProxiedInRepeated, Repeated, RepeatedMut, RepeatedView, SettableValue, View, + Map, Mut, ProxiedInMapValue, ProxiedInRepeated, Repeated, RepeatedMut, RepeatedView, + SettableValue, View, }; use core::fmt::Debug; use paste::paste; @@ -327,58 +328,15 @@ pub fn cast_enum_repeated_mut( } } -#[derive(Debug)] -pub struct MapInner<'msg, K: ?Sized, V: ?Sized> { - pub raw: RawMap, - pub _phantom_key: PhantomData<&'msg mut K>, - pub _phantom_value: PhantomData<&'msg mut V>, -} - -impl<'msg, K: ?Sized, V: ?Sized> Copy for MapInner<'msg, K, V> {} -impl<'msg, K: ?Sized, V: ?Sized> Clone for MapInner<'msg, K, V> { - fn clone(&self) -> MapInner<'msg, K, V> { - *self - } -} - -pub trait ProxiedInMapValue: Proxied -where - K: Proxied + ?Sized, -{ - fn new_map() -> RawMap; - fn clear(m: RawMap); - fn size(m: RawMap) -> usize; - fn insert(m: RawMap, key: View<'_, K>, value: View<'_, Self>) -> bool; - fn get<'msg>(m: RawMap, key: View<'_, K>) -> Option>; - fn remove(m: RawMap, key: View<'_, K>) -> bool; -} - -impl<'msg, K: Proxied + ?Sized, V: ProxiedInMapValue + ?Sized> Default for MapInner<'msg, K, V> { - fn default() -> Self { - MapInner { raw: V::new_map(), _phantom_key: PhantomData, _phantom_value: PhantomData } - } +#[derive(Clone, Copy, Debug)] +pub struct InnerMapMut<'msg> { + pub(crate) raw: RawMap, + _phantom: PhantomData<&'msg ()>, } -impl<'msg, K: Proxied + ?Sized, V: ProxiedInMapValue + ?Sized> MapInner<'msg, K, V> { - pub fn size(&self) -> usize { - V::size(self.raw) - } - - pub fn clear(&mut self) { - V::clear(self.raw) - } - - pub fn get<'a>(&self, key: View<'_, K>) -> Option> { - V::get(self.raw, key) - } - - pub fn remove(&mut self, key: View<'_, K>) -> bool { - V::remove(self.raw, key) - } - - pub fn insert(&mut self, key: View<'_, K>, value: View<'_, V>) -> bool { - V::insert(self.raw, key, value); - true +impl<'msg> InnerMapMut<'msg> { + pub fn new(_private: Private, raw: RawMap) -> Self { + InnerMapMut { raw, _phantom: PhantomData } } } @@ -387,6 +345,7 @@ macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types { paste! { $( extern "C" { fn [< __pb_rust_Map_ $key_t _ $t _new >]() -> RawMap; + fn [< __pb_rust_Map_ $key_t _ $t _free >](m: RawMap); fn [< __pb_rust_Map_ $key_t _ $t _clear >](m: RawMap); fn [< __pb_rust_Map_ $key_t _ $t _size >](m: RawMap) -> usize; fn [< __pb_rust_Map_ $key_t _ $t _insert >](m: RawMap, key: $ffi_key_t, value: $ffi_t); @@ -395,39 +354,55 @@ macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types { } impl ProxiedInMapValue<$key_t> for $t { - fn new_map() -> RawMap { - unsafe { [< __pb_rust_Map_ $key_t _ $t _new >]() } + fn map_new(_private: Private) -> Map<$key_t, Self> { + unsafe { + Map::from_inner( + Private, + InnerMapMut { + raw: [< __pb_rust_Map_ $key_t _ $t _new >](), + _phantom: PhantomData + } + ) + } + } + + unsafe fn map_free(_private: Private, map: &mut Map<$key_t, Self>) { + // SAFETY: + // - `map.inner.raw` is a live `RawMap` + // - This function is only called once for `map` in `Drop`. + unsafe { [< __pb_rust_Map_ $key_t _ $t _free >](map.inner.raw); } } - fn clear(m: RawMap) { - unsafe { [< __pb_rust_Map_ $key_t _ $t _clear >](m) } + + fn map_clear(map: Mut<'_, Map<$key_t, Self>>) { + unsafe { [< __pb_rust_Map_ $key_t _ $t _clear >](map.inner.raw); } } - fn size(m: RawMap) -> usize { - unsafe { [< __pb_rust_Map_ $key_t _ $t _size >](m) } + fn map_len(map: View<'_, Map<$key_t, Self>>) -> usize { + unsafe { [< __pb_rust_Map_ $key_t _ $t _size >](map.raw) } } - fn insert(m: RawMap, key: View<'_, $key_t>, value: View<'_, Self>) -> bool { + fn map_insert(map: Mut<'_, Map<$key_t, Self>>, key: View<'_, $key_t>, value: View<'_, Self>) -> bool { let ffi_key = $to_ffi_key(key); let ffi_value = $to_ffi_value(value); - unsafe { [< __pb_rust_Map_ $key_t _ $t _insert >](m, ffi_key, ffi_value) } + unsafe { [< __pb_rust_Map_ $key_t _ $t _insert >](map.inner.raw, ffi_key, ffi_value) } true } - fn get<'msg>(m: RawMap, key: View<'_, $key_t>) -> Option> { + fn map_get<'a>(map: View<'a, Map<$key_t, Self>>, key: View<'_, $key_t>) -> Option> { let ffi_key = $to_ffi_key(key); let mut ffi_value = $to_ffi_value($zero_val); - let found = unsafe { [< __pb_rust_Map_ $key_t _ $t _get >](m, ffi_key, &mut ffi_value) }; + let found = unsafe { [< __pb_rust_Map_ $key_t _ $t _get >](map.raw, ffi_key, &mut ffi_value) }; if !found { return None; } Some($from_ffi_value(ffi_value)) } - fn remove(m: RawMap, key: View<'_, $key_t>) -> bool { + fn map_remove(map: Mut<'_, Map<$key_t, Self>>, key: View<'_, $key_t>) -> bool { let ffi_key = $to_ffi_key(key); let mut ffi_value = $to_ffi_value($zero_val); - unsafe { [< __pb_rust_Map_ $key_t _ $t _remove >](m, ffi_key, &mut ffi_value) } + unsafe { [< __pb_rust_Map_ $key_t _ $t _remove >](map.inner.raw, ffi_key, &mut ffi_value) } } } )* } @@ -470,16 +445,6 @@ impl_ProxiedInMapValue_for_key_types!( ProtoStr, PtrAndLen, str_to_ptrlen; ); -#[cfg(test)] -pub(crate) fn new_map_i32_i64() -> MapInner<'static, i32, i64> { - Default::default() -} - -#[cfg(test)] -pub(crate) fn new_map_str_str() -> MapInner<'static, ProtoStr, ProtoStr> { - Default::default() -} - #[cfg(test)] mod tests { use super::*; @@ -500,92 +465,4 @@ mod tests { let serialized_data = SerializedData { data: NonNull::new(ptr).unwrap(), len }; assert_that!(&*serialized_data, eq(b"Hello world")); } - - #[test] - fn i32_i32_map() { - let mut map: MapInner<'_, i32, i32> = Default::default(); - assert_that!(map.size(), eq(0)); - - assert_that!(map.insert(1, 2), eq(true)); - assert_that!(map.get(1), eq(Some(2))); - assert_that!(map.get(3), eq(None)); - assert_that!(map.size(), eq(1)); - - assert_that!(map.remove(1), eq(true)); - assert_that!(map.size(), eq(0)); - assert_that!(map.remove(1), eq(false)); - - assert_that!(map.insert(4, 5), eq(true)); - assert_that!(map.insert(6, 7), eq(true)); - map.clear(); - assert_that!(map.size(), eq(0)); - } - - #[test] - fn i64_f64_map() { - let mut map: MapInner<'_, i64, f64> = Default::default(); - assert_that!(map.size(), eq(0)); - - assert_that!(map.insert(1, 2.5), eq(true)); - assert_that!(map.get(1), eq(Some(2.5))); - assert_that!(map.get(3), eq(None)); - assert_that!(map.size(), eq(1)); - - assert_that!(map.remove(1), eq(true)); - assert_that!(map.size(), eq(0)); - assert_that!(map.remove(1), eq(false)); - - assert_that!(map.insert(4, 5.1), eq(true)); - assert_that!(map.insert(6, 7.2), eq(true)); - map.clear(); - assert_that!(map.size(), eq(0)); - } - - #[test] - fn str_str_map() { - let mut map = MapInner::<'_, ProtoStr, ProtoStr>::default(); - assert_that!(map.size(), eq(0)); - - map.insert("fizz".into(), "buzz".into()); - assert_that!(map.size(), eq(1)); - assert_that!(map.remove("fizz".into()), eq(true)); - map.clear(); - assert_that!(map.size(), eq(0)); - } - - #[test] - fn u64_str_map() { - let mut map = MapInner::<'_, u64, ProtoStr>::default(); - assert_that!(map.size(), eq(0)); - - map.insert(1, "fizz".into()); - map.insert(2, "buzz".into()); - assert_that!(map.size(), eq(2)); - assert_that!(map.remove(1), eq(true)); - assert_that!(map.get(1), eq(None)); - map.clear(); - assert_that!(map.size(), eq(0)); - } - - #[test] - fn test_all_maps_can_be_constructed() { - macro_rules! gen_proto_values { - ($key_t:ty, $($value_t:ty),*) => { - $( - let map = MapInner::<'_, $key_t, $value_t>::default(); - assert_that!(map.size(), eq(0)); - )* - } - } - - macro_rules! gen_proto_keys { - ($($key_t:ty),*) => { - $( - gen_proto_values!($key_t, f32, f64, i32, u32, i64, bool, ProtoStr); - )* - } - } - - gen_proto_keys!(i32, u32, i64, u64, bool, ProtoStr); - } } diff --git a/rust/cpp_kernel/cpp_api.cc b/rust/cpp_kernel/cpp_api.cc index a71db58e65b6c..47e0b7a385673 100644 --- a/rust/cpp_kernel/cpp_api.cc +++ b/rust/cpp_kernel/cpp_api.cc @@ -59,6 +59,10 @@ expose_repeated_field_methods(int64_t, i64); __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_new() { \ return new google::protobuf::Map(); \ } \ + void __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_free( \ + google::protobuf::Map* m) { \ + delete m; \ + } \ void __pb_rust_Map_##rust_key_ty##_##rust_value_ty##_clear( \ google::protobuf::Map* m) { \ m->clear(); \ diff --git a/rust/map.rs b/rust/map.rs index b17ac0bdca71c..c40d82851be3b 100644 --- a/rust/map.rs +++ b/rust/map.rs @@ -7,20 +7,17 @@ use crate::{ Mut, MutProxy, Proxied, SettableValue, View, ViewProxy, - __internal::Private, - __runtime::{MapInner, ProxiedInMapValue}, + __internal::{Private, RawMap}, + __runtime::InnerMapMut, }; use std::marker::PhantomData; +use std::ops::Deref; #[repr(transparent)] pub struct MapView<'msg, K: ?Sized, V: ?Sized> { - inner: MapInner<'msg, K, V>, -} - -impl<'msg, K: ?Sized, V: ?Sized> MapView<'msg, K, V> { - pub fn from_inner(_private: Private, inner: MapInner<'msg, K, V>) -> Self { - Self { inner } - } + pub raw: RawMap, + _phantom_key: PhantomData<&'msg K>, + _phantom_value: PhantomData<&'msg V>, } impl<'msg, K: ?Sized, V: ?Sized> Copy for MapView<'msg, K, V> {} @@ -30,6 +27,15 @@ impl<'msg, K: ?Sized, V: ?Sized> Clone for MapView<'msg, K, V> { } } +impl<'msg, K: ?Sized, V: ?Sized> Deref for MapMut<'msg, K, V> { + type Target = MapView<'msg, K, V>; + fn deref(&self) -> &Self::Target { + // SAFETY: + // - `MapView<'msg, K, V>` is `#[repr(transparent)]` over `RawMap`. + unsafe { &*(&self.inner.raw as *const RawMap as *const MapView<'msg, K, V>) } + } +} + unsafe impl<'msg, K: ?Sized, V: ?Sized> Sync for MapView<'msg, K, V> {} unsafe impl<'msg, K: ?Sized, V: ?Sized> Send for MapView<'msg, K, V> {} @@ -42,15 +48,10 @@ impl<'msg, K: ?Sized, V: ?Sized> std::fmt::Debug for MapView<'msg, K, V> { } } -#[repr(transparent)] pub struct MapMut<'msg, K: ?Sized, V: ?Sized> { - inner: MapInner<'msg, K, V>, -} - -impl<'msg, K: ?Sized, V: ?Sized> MapMut<'msg, K, V> { - pub fn from_inner(_private: Private, inner: MapInner<'msg, K, V>) -> Self { - Self { inner } - } + pub(crate) inner: InnerMapMut<'msg>, + _phantom_key: PhantomData<&'msg K>, + _phantom_value: PhantomData<&'msg V>, } unsafe impl<'msg, K: ?Sized, V: ?Sized> Sync for MapMut<'msg, K, V> {} @@ -64,20 +65,37 @@ impl<'msg, K: ?Sized, V: ?Sized> std::fmt::Debug for MapMut<'msg, K, V> { } } -impl<'msg, K: ?Sized, V: ?Sized> std::ops::Deref for MapMut<'msg, K, V> { - type Target = MapView<'msg, K, V>; - fn deref(&self) -> &Self::Target { +pub struct Map> { + pub(crate) inner: InnerMapMut<'static>, + _phantom_key: PhantomData, + _phantom_value: PhantomData, +} + +impl> Drop for Map { + fn drop(&mut self) { // SAFETY: - // - `Map{View,Mut}<'msg, T>` are both `#[repr(transparent)]` over - // `MapInner<'msg, T>`. - // - `MapInner` is a type alias for `NonNull`. - unsafe { &*(self as *const Self as *const MapView<'msg, K, V>) } + // - `drop` is only called once. + // - 'map_free` is only called here. + unsafe { V::map_free(Private, self) } } } -// This is a ZST type so we can implement `Proxied`. Users will work with -// `MapView` (`View<'_, Map>>) and `MapMut` (Mut<'_, Map>). -pub struct Map(PhantomData, PhantomData); +pub trait ProxiedInMapValue: Proxied +where + K: Proxied + ?Sized, +{ + fn map_new(_private: Private) -> Map; + + /// # Safety + /// - After `map_free`, no other methods on the input are safe to call. + unsafe fn map_free(_private: Private, map: &mut Map); + + fn map_clear(map: Mut<'_, Map>); + fn map_len(map: View<'_, Map>) -> usize; + fn map_insert(map: Mut<'_, Map>, key: View<'_, K>, value: View<'_, Self>) -> bool; + fn map_get<'a>(map: View<'a, Map>, key: View<'_, K>) -> Option>; + fn map_remove(map: Mut<'_, Map>, key: View<'_, K>) -> bool; +} impl + ?Sized> Proxied for Map { type View<'msg> = MapView<'msg, K, V> where K: 'msg, V: 'msg; @@ -108,7 +126,7 @@ impl<'msg, K: Proxied + ?Sized, V: ProxiedInMapValue + ?Sized> ViewProxy<'msg where 'msg: 'shorter, { - MapView { inner: self.inner } + MapView { raw: self.raw, _phantom_key: PhantomData, _phantom_value: PhantomData } } } @@ -118,14 +136,14 @@ impl<'msg, K: Proxied + ?Sized, V: ProxiedInMapValue + ?Sized> ViewProxy<'msg type Proxied = Map; fn as_view(&self) -> View<'_, Self::Proxied> { - **self + MapView { raw: self.inner.raw, _phantom_key: PhantomData, _phantom_value: PhantomData } } fn into_view<'shorter>(self) -> View<'shorter, Self::Proxied> where 'msg: 'shorter, { - *self.into_mut::<'shorter>() + MapView { raw: self.inner.raw, _phantom_key: PhantomData, _phantom_value: PhantomData } } } @@ -133,14 +151,41 @@ impl<'msg, K: Proxied + ?Sized, V: ProxiedInMapValue + ?Sized> MutProxy<'msg> for MapMut<'msg, K, V> { fn as_mut(&mut self) -> Mut<'_, Self::Proxied> { - MapMut { inner: self.inner } + MapMut { inner: self.inner, _phantom_key: PhantomData, _phantom_value: PhantomData } } fn into_mut<'shorter>(self) -> Mut<'shorter, Self::Proxied> where 'msg: 'shorter, { - MapMut { inner: self.inner } + MapMut { inner: self.inner, _phantom_key: PhantomData, _phantom_value: PhantomData } + } +} + +impl Map +where + K: Proxied + ?Sized, + V: ProxiedInMapValue + ?Sized, +{ + #[allow(dead_code)] + pub(crate) fn new() -> Self { + V::map_new(Private) + } + + pub fn as_mut(&mut self) -> MapMut<'_, K, V> { + MapMut { inner: self.inner, _phantom_key: PhantomData, _phantom_value: PhantomData } + } + + pub fn as_view(&self) -> MapView<'_, K, V> { + MapView { raw: self.inner.raw, _phantom_key: PhantomData, _phantom_value: PhantomData } + } + + /// # Safety + /// - `inner` must be valid to read and write from for `'static`. + /// - There must be no aliasing references or mutations on the same + /// underlying object. + pub unsafe fn from_inner(_private: Private, inner: InnerMapMut<'static>) -> Self { + Self { inner, _phantom_key: PhantomData, _phantom_value: PhantomData } } } @@ -149,18 +194,30 @@ where K: Proxied + ?Sized + 'msg, V: ProxiedInMapValue + ?Sized + 'msg, { - pub fn get<'a>(&self, key: impl Into>) -> Option> + #[doc(hidden)] + pub fn as_raw(&self, _private: Private) -> RawMap { + self.raw + } + + /// # Safety + /// - `raw` must be valid to read from for `'msg`. + #[doc(hidden)] + pub unsafe fn from_raw(_private: Private, raw: RawMap) -> Self { + Self { raw, _phantom_key: PhantomData, _phantom_value: PhantomData } + } + + pub fn get<'a>(self, key: impl Into>) -> Option> where K: 'a, { - self.inner.get(key.into()) + V::map_get(self, key.into()) } - pub fn len(&self) -> usize { - self.inner.size() + pub fn len(self) -> usize { + V::map_len(self) } - pub fn is_empty(&self) -> bool { + pub fn is_empty(self) -> bool { self.len() == 0 } } @@ -170,6 +227,12 @@ where K: Proxied + ?Sized + 'msg, V: ProxiedInMapValue + ?Sized + 'msg, { + /// # Safety + /// - `inner` must be valid to read and write from for `'msg`. + pub unsafe fn from_inner(_private: Private, inner: InnerMapMut<'msg>) -> Self { + Self { inner, _phantom_key: PhantomData, _phantom_value: PhantomData } + } + pub fn insert<'a, 'b>( &mut self, key: impl Into>, @@ -179,25 +242,25 @@ where K: 'a, V: 'b, { - self.inner.insert(key.into(), value.into()) + V::map_insert(self.as_mut(), key.into(), value.into()) } pub fn remove<'a>(&mut self, key: impl Into>) -> bool where K: 'a, { - self.inner.remove(key.into()) + V::map_remove(self.as_mut(), key.into()) } pub fn clear(&mut self) { - self.inner.clear() + V::map_clear(self.as_mut()) } pub fn get<'a>(&self, key: impl Into>) -> Option> where K: 'a, { - self.as_view().get(key) + V::map_get(self.as_view(), key.into()) } pub fn copy_from(&mut self, _src: MapView<'_, K, V>) { @@ -208,12 +271,13 @@ where #[cfg(test)] mod tests { use super::*; - use crate::__runtime::{new_map_i32_i64, new_map_str_str}; + use crate::ProtoStr; use googletest::prelude::*; #[test] fn test_proxied_scalar() { - let mut map_mut = MapMut::from_inner(Private, new_map_i32_i64()); + let mut map: Map = Map::new(); + let mut map_mut = map.as_mut(); map_mut.insert(1, 2); assert_that!(map_mut.get(1), eq(Some(2))); @@ -238,7 +302,8 @@ mod tests { #[test] fn test_proxied_str() { - let mut map_mut = MapMut::from_inner(Private, new_map_str_str()); + let mut map: Map = Map::new(); + let mut map_mut = map.as_mut(); map_mut.insert("a", "b"); let map_view_1 = map_mut.as_view(); @@ -260,9 +325,32 @@ mod tests { assert_that!(map_view_4.is_empty(), eq(false)); } + #[test] + fn test_all_maps_can_be_constructed() { + macro_rules! gen_proto_values { + ($key_t:ty, $($value_t:ty),*) => { + $( + let map = Map::<$key_t, $value_t>::new(); + assert_that!(map.as_view().len(), eq(0)); + )* + } + } + + macro_rules! gen_proto_keys { + ($($key_t:ty),*) => { + $( + gen_proto_values!($key_t, f32, f64, i32, u32, i64, bool, ProtoStr); + )* + } + } + + gen_proto_keys!(i32, u32, i64, u64, bool, ProtoStr); + } + #[test] fn test_dbg() { - let map_view = MapView::from_inner(Private, new_map_i32_i64()); - assert_that!(format!("{:?}", map_view), eq("MapView(\"i32\", \"i64\")")); + let mut map = Map::::new(); + assert_that!(format!("{:?}", map.as_view()), eq("MapView(\"i32\", \"f64\")")); + assert_that!(format!("{:?}", map.as_mut()), eq("MapMut(\"i32\", \"f64\")")); } } diff --git a/rust/shared.rs b/rust/shared.rs index 67ece90872c7c..4d15b592afdc5 100644 --- a/rust/shared.rs +++ b/rust/shared.rs @@ -23,7 +23,7 @@ use std::fmt; #[doc(hidden)] pub mod __public { pub use crate::r#enum::UnknownEnumValue; - pub use crate::map::{Map, MapMut, MapView}; + pub use crate::map::{Map, MapMut, MapView, ProxiedInMapValue}; pub use crate::optional::{AbsentField, FieldEntry, Optional, PresentField}; pub use crate::primitive::PrimitiveMut; pub use crate::proxied::{ diff --git a/rust/upb.rs b/rust/upb.rs index 052e17786d565..e757481d0702f 100644 --- a/rust/upb.rs +++ b/rust/upb.rs @@ -9,8 +9,8 @@ use crate::__internal::{Enum, Private, PtrAndLen, RawArena, RawMap, RawMessage, RawRepeatedField}; use crate::{ - Mut, ProtoStr, Proxied, ProxiedInRepeated, Repeated, RepeatedMut, RepeatedView, SettableValue, - View, ViewProxy, + Map, MapView, Mut, ProtoStr, Proxied, ProxiedInMapValue, ProxiedInRepeated, Repeated, + RepeatedMut, RepeatedView, SettableValue, View, ViewProxy, }; use core::fmt::Debug; use paste::paste; @@ -536,92 +536,41 @@ pub fn empty_array() -> RepeatedView<'static, T> } } -/// Returns a static thread-local empty MapInner for use in a -/// MapView. -/// -/// # Safety -/// The returned map must never be mutated. -/// -/// TODO: Split MapInner into mut and const variants to -/// enforce safety. The returned array must never be mutated. -pub unsafe fn empty_map() -> MapInner<'static, K, V> { - fn new_map_inner() -> MapInner<'static, i32, i32> { - // TODO: Consider creating empty map in C. - let arena = Box::leak::<'static>(Box::new(Arena::new())); - // Provide `i32` as a placeholder type. - MapInner::<'static, i32, i32>::new(arena) - } - thread_local! { - static MAP: MapInner<'static, i32, i32> = new_map_inner(); - } - - MAP.with(|inner| MapInner { - raw: inner.raw, - arena: inner.arena, - _phantom_key: PhantomData, - _phantom_value: PhantomData, - }) -} - -#[derive(Debug)] -pub struct MapInner<'msg, K: ?Sized, V: ?Sized> { - pub raw: RawMap, - pub arena: &'msg Arena, - pub _phantom_key: PhantomData<&'msg mut K>, - pub _phantom_value: PhantomData<&'msg mut V>, -} - -impl<'msg, K: ?Sized, V: ?Sized> Copy for MapInner<'msg, K, V> {} -impl<'msg, K: ?Sized, V: ?Sized> Clone for MapInner<'msg, K, V> { - fn clone(&self) -> MapInner<'msg, K, V> { - *self - } -} - -pub trait ProxiedInMapValue: Proxied +/// Returns a static empty MapView. +pub fn empty_map() -> MapView<'static, K, V> where K: Proxied + ?Sized, + V: ProxiedInMapValue + ?Sized, { - fn new_map(a: RawArena) -> RawMap; - fn clear(m: RawMap) { - unsafe { upb_Map_Clear(m) } - } - fn size(m: RawMap) -> usize { - unsafe { upb_Map_Size(m) } - } - fn insert(m: RawMap, a: RawArena, key: View<'_, K>, value: View<'_, Self>) -> bool; - fn get<'a>(m: RawMap, key: View<'_, K>) -> Option>; - fn remove(m: RawMap, key: View<'_, K>) -> bool; -} - -impl<'msg, K: Proxied + ?Sized, V: ProxiedInMapValue + ?Sized> MapInner<'msg, K, V> { - pub fn new(arena: &'msg mut Arena) -> Self { - MapInner { - raw: V::new_map(arena.raw()), - arena, - _phantom_key: PhantomData, - _phantom_value: PhantomData, - } - } + // TODO: Consider creating a static empty map in C. - pub fn size(&self) -> usize { - V::size(self.raw) - } - - pub fn clear(&mut self) { - V::clear(self.raw) - } + // Use `i32` for a shared empty map for all map types. + static EMPTY_MAP_VIEW: OnceLock> = OnceLock::new(); - pub fn get<'a>(&self, key: View<'_, K>) -> Option> { - V::get(self.raw, key) + // SAFETY: + // - Because the map is never mutated, the map type is unused and therefore + // valid for `T`. + // - The view is leaked for `'static`. + unsafe { + MapView::from_raw( + Private, + EMPTY_MAP_VIEW + .get_or_init(|| Box::leak(Box::new(Map::new())).as_mut().into_view()) + .as_raw(Private), + ) } +} - pub fn remove(&mut self, key: View<'_, K>) -> bool { - V::remove(self.raw, key) - } +#[derive(Clone, Copy, Debug)] +pub struct InnerMapMut<'msg> { + pub(crate) raw: RawMap, + raw_arena: RawArena, + _phantom: PhantomData<&'msg Arena>, +} - pub fn insert(&mut self, key: View<'_, K>, value: View<'_, V>) -> bool { - V::insert(self.raw, self.arena.raw(), key, value) +impl<'msg> InnerMapMut<'msg> { + pub fn new(_private: Private, raw: RawMap, raw_arena: RawArena) -> Self { + InnerMapMut { raw, raw_arena, _phantom: PhantomData } } } @@ -629,25 +578,59 @@ macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types { ($key_t:ty, $key_msg_val:expr, $key_upb_tag:expr, for $($t:ty, $msg_val:expr, $from_msg_val:expr, $upb_tag:expr, $zero_val:literal;)*) => { $( impl ProxiedInMapValue<$key_t> for $t { - fn new_map(a: RawArena) -> RawMap { - unsafe { upb_Map_New(a, $key_upb_tag, $upb_tag) } + fn map_new(_private: Private) -> Map<$key_t, Self> { + let arena = Arena::new(); + let raw_arena = arena.raw(); + std::mem::forget(arena); + + unsafe { + Map::from_inner( + Private, + InnerMapMut { + raw: upb_Map_New(raw_arena, $key_upb_tag, $upb_tag), + raw_arena, + _phantom: PhantomData + } + ) + } + } + + unsafe fn map_free(_private: Private, map: &mut Map<$key_t, Self>) { + // SAFETY: + // - `map.inner.raw_arena` is a live `upb_Arena*` + // - This function is only called once for `map` in `Drop`. + unsafe { + upb_Arena_Free(map.inner.raw_arena); + } + } + + fn map_clear(map: Mut<'_, Map<$key_t, Self>>) { + unsafe { + upb_Map_Clear(map.inner.raw); + } + } + + fn map_len(map: View<'_, Map<$key_t, Self>>) -> usize { + unsafe { + upb_Map_Size(map.raw) + } } - fn insert(m: RawMap, a: RawArena, key: View<'_, $key_t>, value: View<'_, Self>) -> bool { + fn map_insert(map: Mut<'_, Map<$key_t, Self>>, key: View<'_, $key_t>, value: View<'_, Self>) -> bool { unsafe { upb_Map_Set( - m, + map.inner.raw, $key_msg_val(key), $msg_val(value), - a + map.inner.raw_arena ) } } - fn get<'a>(m: RawMap, key: View<'_, $key_t>) -> Option> { + fn map_get<'a>(map: View<'a, Map<$key_t, Self>>, key: View<'_, $key_t>) -> Option> { let mut val = $msg_val($zero_val); let found = unsafe { - upb_Map_Get(m, $key_msg_val(key), &mut val) + upb_Map_Get(map.raw, ($key_msg_val)(key), &mut val) }; if !found { return None; @@ -655,10 +638,10 @@ macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types { Some($from_msg_val(val)) } - fn remove(m: RawMap, key: View<'_, $key_t>) -> bool { + fn map_remove(map: Mut<'_, Map<$key_t, Self>>, key: View<'_, $key_t>) -> bool { let mut val = $msg_val($zero_val); unsafe { - upb_Map_Delete(m, $key_msg_val(key), &mut val) + upb_Map_Delete(map.inner.raw, $key_msg_val(key), &mut val) } } } @@ -732,18 +715,6 @@ extern "C" { fn upb_Map_Clear(map: RawMap); } -#[cfg(test)] -pub(crate) fn new_map_i32_i64() -> MapInner<'static, i32, i64> { - let arena = Box::leak::<'static>(Box::new(Arena::new())); - MapInner::<'static, i32, i64>::new(arena) -} - -#[cfg(test)] -pub(crate) fn new_map_str_str() -> MapInner<'static, ProtoStr, ProtoStr> { - let arena = Box::leak::<'static>(Box::new(Arena::new())); - MapInner::<'static, ProtoStr, ProtoStr>::new(arena) -} - #[cfg(test)] mod tests { use super::*; @@ -770,96 +741,4 @@ mod tests { }; assert_that!(&*serialized_data, eq(b"Hello world")); } - - #[test] - fn i32_i32_map() { - let mut arena = Arena::new(); - let mut map = MapInner::<'_, i32, i32>::new(&mut arena); - assert_that!(map.size(), eq(0)); - - assert_that!(map.insert(1, 2), eq(true)); - assert_that!(map.get(1), eq(Some(2))); - assert_that!(map.get(3), eq(None)); - assert_that!(map.size(), eq(1)); - - assert_that!(map.remove(1), eq(true)); - assert_that!(map.size(), eq(0)); - assert_that!(map.remove(1), eq(false)); - - assert_that!(map.insert(4, 5), eq(true)); - assert_that!(map.insert(6, 7), eq(true)); - map.clear(); - assert_that!(map.size(), eq(0)); - } - - #[test] - fn i64_f64_map() { - let mut arena = Arena::new(); - let mut map = MapInner::<'_, i64, f64>::new(&mut arena); - assert_that!(map.size(), eq(0)); - - assert_that!(map.insert(1, 2.5), eq(true)); - assert_that!(map.get(1), eq(Some(2.5))); - assert_that!(map.get(3), eq(None)); - assert_that!(map.size(), eq(1)); - - assert_that!(map.remove(1), eq(true)); - assert_that!(map.size(), eq(0)); - assert_that!(map.remove(1), eq(false)); - - assert_that!(map.insert(4, 5.1), eq(true)); - assert_that!(map.insert(6, 7.2), eq(true)); - map.clear(); - assert_that!(map.size(), eq(0)); - } - - #[test] - fn str_str_map() { - let mut arena = Arena::new(); - let mut map = MapInner::<'_, ProtoStr, ProtoStr>::new(&mut arena); - assert_that!(map.size(), eq(0)); - - map.insert("fizz".into(), "buzz".into()); - assert_that!(map.size(), eq(1)); - assert_that!(map.remove("fizz".into()), eq(true)); - map.clear(); - assert_that!(map.size(), eq(0)); - } - - #[test] - fn u64_str_map() { - let mut arena = Arena::new(); - let mut map = MapInner::<'_, u64, ProtoStr>::new(&mut arena); - assert_that!(map.size(), eq(0)); - - map.insert(1, "fizz".into()); - map.insert(2, "buzz".into()); - assert_that!(map.size(), eq(2)); - assert_that!(map.remove(1), eq(true)); - map.clear(); - assert_that!(map.size(), eq(0)); - } - - #[test] - fn test_all_maps_can_be_constructed() { - macro_rules! gen_proto_values { - ($key_t:ty, $($value_t:ty),*) => { - let mut arena = Arena::new(); - $( - let map = MapInner::<'_, $key_t, $value_t>::new(&mut arena); - assert_that!(map.size(), eq(0)); - )* - } - } - - macro_rules! gen_proto_keys { - ($($key_t:ty),*) => { - $( - gen_proto_values!($key_t, f32, f64, i32, u32, i64, bool, ProtoStr); - )* - } - } - - gen_proto_keys!(i32, u32, i64, u64, bool, ProtoStr); - } } diff --git a/src/google/protobuf/compiler/rust/accessors/map.cc b/src/google/protobuf/compiler/rust/accessors/map.cc index 5c103a195d4a3..bdf5ca5da88b6 100644 --- a/src/google/protobuf/compiler/rust/accessors/map.cc +++ b/src/google/protobuf/compiler/rust/accessors/map.cc @@ -31,29 +31,23 @@ void Map::InMsgImpl(Context& ctx, const FieldDescriptor& field) const { if (ctx.is_upb()) { ctx.Emit({}, R"rs( pub fn r#$field$(&self) - -> $pb$::View<'_, $pb$::Map<$Key$, $Value$>> { - let inner = unsafe { + -> $pb$::MapView<'_, $Key$, $Value$> { + unsafe { $getter_thunk$(self.inner.msg) - }.map_or_else(|| unsafe {$pbr$::empty_map()}, |raw| { - $pbr$::MapInner{ - raw, - arena: &self.inner.arena, - _phantom_key: std::marker::PhantomData, - _phantom_value: std::marker::PhantomData, - } - }); - $pb$::MapView::from_inner($pbi$::Private, inner) + .map_or_else( + $pbr$::empty_map::<$Key$, $Value$>, + |raw| $pb$::MapView::from_raw($pbi$::Private, raw) + ) + } })rs"); } else { ctx.Emit({}, R"rs( pub fn r#$field$(&self) - -> $pb$::View<'_, $pb$::Map<$Key$, $Value$>> { - let inner = $pbr$::MapInner { - raw: unsafe { $getter_thunk$(self.inner.msg) }, - _phantom_key: std::marker::PhantomData, - _phantom_value: std::marker::PhantomData, - }; - $pb$::MapView::from_inner($pbi$::Private, inner) + -> $pb$::MapView<'_, $Key$, $Value$> { + unsafe { + $pb$::MapView::from_raw($pbi$::Private, + $getter_thunk$(self.inner.msg)) + } })rs"); } }}, @@ -62,29 +56,22 @@ void Map::InMsgImpl(Context& ctx, const FieldDescriptor& field) const { if (ctx.is_upb()) { ctx.Emit({}, R"rs( pub fn r#$field$_mut(&mut self) - -> $pb$::Mut<'_, $pb$::Map<$Key$, $Value$>> { + -> $pb$::MapMut<'_, $Key$, $Value$> { let raw = unsafe { $getter_mut_thunk$(self.inner.msg, self.inner.arena.raw()) }; - let inner = $pbr$::MapInner{ - raw, - arena: &self.inner.arena, - _phantom_key: std::marker::PhantomData, - _phantom_value: std::marker::PhantomData, - }; - $pb$::MapMut::from_inner($pbi$::Private, inner) + let inner = $pbr$::InnerMapMut::new($pbi$::Private, + raw, self.inner.arena.raw()); + unsafe { $pb$::MapMut::from_inner($pbi$::Private, inner) } })rs"); } else { ctx.Emit({}, R"rs( pub fn r#$field$_mut(&mut self) - -> $pb$::Mut<'_, $pb$::Map<$Key$, $Value$>> { - let inner = $pbr$::MapInner { - raw: unsafe { $getter_mut_thunk$(self.inner.msg) }, - _phantom_key: std::marker::PhantomData, - _phantom_value: std::marker::PhantomData, - }; - $pb$::MapMut::from_inner($pbi$::Private, inner) + -> $pb$::MapMut<'_, $Key$, $Value$> { + let inner = $pbr$::InnerMapMut::new($pbi$::Private, + unsafe { $getter_mut_thunk$(self.inner.msg) }); + unsafe { $pb$::MapMut::from_inner($pbi$::Private, inner) } })rs"); } }}},