1use std::sync::{Arc, Weak};
9
10use slotmap_careful::{Key as _, KeyData, SlotMap};
11use tor_rpcbase as rpc;
12
13pub(crate) mod methods;
14
15slotmap_careful::new_key_type! {
16 pub(crate) struct GenIdx;
17
18}
19
20#[derive(Clone, derive_more::From)]
27enum ObjectRef {
28 Strong(Arc<dyn rpc::Object>),
30 Weak(Weak<dyn rpc::Object>),
32}
33
34impl ObjectRef {
35 fn get(&self) -> Option<Arc<dyn rpc::Object>> {
37 match self {
38 ObjectRef::Strong(s) => Some(Arc::clone(s)),
39 ObjectRef::Weak(w) => w.upgrade(),
40 }
41 }
42}
43
44#[derive(Default)]
46pub(crate) struct ObjMap {
47 arena: SlotMap<GenIdx, ObjectRef>,
49}
50
51impl GenIdx {
65 pub(crate) const BYTE_LEN: usize = 16;
67
68 pub(crate) fn encode(self) -> rpc::ObjectId {
70 self.encode_with_rng(&mut rand::rng())
71 }
72
73 fn encode_with_rng<R: rand::RngCore>(self, rng: &mut R) -> rpc::ObjectId {
75 use base64ct::Encoding;
76 let bytes = self.to_bytes(rng);
77 rpc::ObjectId::from(base64ct::Base64UrlUnpadded::encode_string(&bytes[..]))
78 }
79
80 pub(crate) fn to_bytes<R: rand::RngCore>(self, rng: &mut R) -> [u8; Self::BYTE_LEN] {
82 use rand::Rng;
83 use tor_bytes::Writer;
84 let ffi_idx = self.data().as_ffi();
85 let x = rng.random::<u64>();
86 let mut bytes = Vec::with_capacity(Self::BYTE_LEN);
87 bytes.write_u64(x);
88 bytes.write_u64(ffi_idx.wrapping_add(x));
89
90 bytes.try_into().expect("Length was wrong!")
91 }
92
93 pub(crate) fn try_decode(id: &rpc::ObjectId) -> Result<Self, rpc::LookupError> {
95 use base64ct::Encoding;
96
97 let bytes = base64ct::Base64UrlUnpadded::decode_vec(id.as_ref())
98 .map_err(|_| rpc::LookupError::NoObject(id.clone()))?;
99 Self::from_bytes(&bytes).ok_or_else(|| rpc::LookupError::NoObject(id.clone()))
100 }
101
102 pub(crate) fn from_bytes(bytes: &[u8]) -> Option<Self> {
104 use tor_bytes::Reader;
105 let mut r = Reader::from_slice(bytes);
106 let x = r.take_u64().ok()?;
107 let ffi_idx = r.take_u64().ok()?;
108 r.should_be_exhausted().ok()?;
109
110 let ffi_idx = ffi_idx.wrapping_sub(x);
111 Some(GenIdx::from(KeyData::from_ffi(ffi_idx)))
112 }
113}
114
115impl ObjMap {
116 pub(crate) fn new() -> Self {
118 Self::default()
119 }
120
121 pub(crate) fn insert_strong(&mut self, value: Arc<dyn rpc::Object>) -> GenIdx {
123 self.arena.insert(ObjectRef::Strong(value))
124 }
125
126 pub(crate) fn insert_weak(&mut self, value: &Arc<dyn rpc::Object>) -> GenIdx {
128 self.arena.insert(ObjectRef::Weak(Arc::downgrade(value)))
129 }
130
131 pub(crate) fn lookup(&self, idx: GenIdx) -> Result<Arc<dyn rpc::Object>, LookupError> {
133 self.arena
134 .get(idx)
135 .ok_or(LookupError::NoObject)?
136 .get()
137 .ok_or(LookupError::Expired)
138 }
139
140 pub(crate) fn remove(&mut self, idx: GenIdx) -> bool {
144 self.arena.remove(idx).is_some()
145 }
146
147 #[cfg(test)]
149 fn assert_okay(&self) {}
150}
151
152#[derive(Clone, Debug, thiserror::Error)]
156pub(crate) enum LookupError {
157 #[error("Object not found")]
159 NoObject,
160
161 #[error("Object expired")]
163 Expired,
164}
165
166impl LookupError {
167 pub(crate) fn to_rpc_lookup_error(&self, id: rpc::ObjectId) -> rpc::LookupError {
169 match self {
170 LookupError::NoObject => rpc::LookupError::NoObject(id),
171 LookupError::Expired => rpc::LookupError::Expired(id),
172 }
173 }
174}
175
176#[cfg(test)]
177mod test {
178 #![allow(clippy::bool_assert_comparison)]
180 #![allow(clippy::clone_on_copy)]
181 #![allow(clippy::dbg_macro)]
182 #![allow(clippy::mixed_attributes_style)]
183 #![allow(clippy::print_stderr)]
184 #![allow(clippy::print_stdout)]
185 #![allow(clippy::single_char_pattern)]
186 #![allow(clippy::unwrap_used)]
187 #![allow(clippy::unchecked_time_subtraction)]
188 #![allow(clippy::useless_vec)]
189 #![allow(clippy::needless_pass_by_value)]
190 use super::*;
193 use derive_deftly::Deftly;
194 use tor_rpcbase::templates::*;
195
196 #[derive(Clone, Debug, Deftly)]
197 #[derive_deftly(Object)]
198 struct ExampleObject(#[allow(unused)] String);
199
200 #[test]
201 fn map_basics() {
202 let obj1 = Arc::new(ExampleObject("abcdef".to_string()));
204 let mut map = ObjMap::new();
205 map.assert_okay();
206 let id1 = map.insert_strong(obj1.clone());
207 let id2 = map.insert_strong(obj1.clone());
208 assert_ne!(id1, id2);
209 let obj1: Arc<dyn rpc::Object> = obj1;
210 let obj_out1 = map.lookup(id1).unwrap();
211 let obj_out2 = map.lookup(id2).unwrap();
212 assert!(Arc::ptr_eq(&obj1, &obj_out1));
213 assert!(Arc::ptr_eq(&obj1, &obj_out2));
214 map.assert_okay();
215
216 map.remove(id1);
217 assert!(map.lookup(id1).is_err());
218 let obj_out2b = map.lookup(id2).unwrap();
219 assert!(Arc::ptr_eq(&obj_out2, &obj_out2b));
220
221 map.assert_okay();
222 }
223
224 #[test]
225 fn strong_and_weak() {
226 let obj1: Arc<dyn rpc::Object> = Arc::new(ExampleObject("hello".to_string()));
229 let obj2: Arc<dyn rpc::Object> = Arc::new(ExampleObject("world".to_string()));
230 let mut map = ObjMap::new();
231 let id1 = map.insert_strong(obj1.clone());
232 let id2 = map.insert_weak(&obj2);
233
234 {
235 let out1 = map.lookup(id1).unwrap();
236 let out2 = map.lookup(id2).unwrap();
237 assert!(Arc::ptr_eq(&obj1, &out1));
238 assert!(Arc::ptr_eq(&obj2, &out2));
239 }
240 map.assert_okay();
241
242 drop(obj1);
244 drop(obj2);
245 {
246 let out1 = map.lookup(id1);
247 let out2 = map.lookup(id2);
248
249 assert!(out1.is_ok());
251
252 assert!(out2.is_err());
254 }
255 map.assert_okay();
256 }
257
258 #[test]
259 fn remove() {
260 let obj1: Arc<dyn rpc::Object> = Arc::new(ExampleObject("hello".to_string()));
262 let obj2: Arc<dyn rpc::Object> = Arc::new(ExampleObject("world".to_string()));
263 let mut map = ObjMap::new();
264 let id1 = map.insert_strong(obj1.clone());
265 let id2 = map.insert_weak(&obj2);
266 map.assert_okay();
267
268 map.remove(id1);
269 map.assert_okay();
270 assert!(map.lookup(id1).is_err());
271 assert!(map.lookup(id2).is_ok());
272
273 map.remove(id2);
274 map.assert_okay();
275 assert!(map.lookup(id1).is_err());
276 assert!(map.lookup(id2).is_err());
277 }
278
279 #[test]
280 fn duplicates() {
281 let obj1: Arc<dyn rpc::Object> = Arc::new(ExampleObject("hello".to_string()));
282 let obj2: Arc<dyn rpc::Object> = Arc::new(ExampleObject("world".to_string()));
283 let mut map = ObjMap::new();
284 let id1 = map.insert_strong(obj1.clone());
285 let id2 = map.insert_weak(&obj2);
286
287 {
288 assert_ne!(id2, map.insert_weak(&obj1));
289 assert_ne!(id2, map.insert_weak(&obj2));
290 }
291
292 {
293 assert_ne!(id1, map.insert_strong(obj1.clone()));
294 assert_ne!(id2, map.insert_strong(obj2.clone()));
295 }
296 }
297
298 #[test]
299 fn objid_encoding() {
300 use rand::Rng;
301 fn test_roundtrip(a: u32, b: u32, rng: &mut tor_basic_utils::test_rng::TestingRng) {
302 let a: u64 = a.into();
303 let b: u64 = b.into();
304 let data = KeyData::from_ffi((a << 33) | (1_u64 << 32) | b);
305 let idx = GenIdx::from(data);
306 let s1 = idx.encode_with_rng(rng);
307 let s2 = idx.encode_with_rng(rng);
308 assert_ne!(s1, s2);
309 assert_eq!(idx, GenIdx::try_decode(&s1).unwrap());
310 assert_eq!(idx, GenIdx::try_decode(&s2).unwrap());
311 }
312 let mut rng = tor_basic_utils::test_rng::testing_rng();
313
314 test_roundtrip(0, 1, &mut rng);
315 test_roundtrip(0, 2, &mut rng);
316 test_roundtrip(1, 1, &mut rng);
317 test_roundtrip(0xffffffff, 0xffffffff, &mut rng);
318
319 for _ in 0..256 {
320 test_roundtrip(rng.random(), rng.random(), &mut rng);
321 }
322 }
323}