From c376a6a6532053af9131a2e69c78fc946614dce7 Mon Sep 17 00:00:00 2001 From: Elias Rohrer Date: Tue, 22 Aug 2023 11:09:33 +0200 Subject: [PATCH] f Simplify and fix `TestStore` .. as we don't require all that logic anymore now that we don't return an `FilesystemWriter` anymore etc. --- lightning/src/util/test_utils.rs | 84 +++++++++++++++----------------- 1 file changed, 38 insertions(+), 46 deletions(-) diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 48b0fc017..c692f2d03 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -57,7 +57,7 @@ use crate::prelude::*; use core::cell::RefCell; use core::ops::DerefMut; use core::time::Duration; -use crate::sync::{Mutex, Arc, RwLock}; +use crate::sync::{Mutex, Arc}; use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use core::mem; use bitcoin::bech32::u5; @@ -318,25 +318,21 @@ impl chainmonitor::Persist fo } pub(crate) struct TestStore { - persisted_bytes: RwLock>>>>>, + persisted_bytes: Mutex>>>, did_persist: Arc, + read_only: bool, } impl TestStore { - pub fn new() -> Self { - let persisted_bytes = RwLock::new(HashMap::new()); + pub fn new(read_only: bool) -> Self { + let persisted_bytes = Mutex::new(HashMap::new()); let did_persist = Arc::new(AtomicBool::new(false)); - Self { persisted_bytes, did_persist } + Self { persisted_bytes, did_persist, read_only } } pub fn get_persisted_bytes(&self, namespace: &str, key: &str) -> Option> { - if let Some(outer_ref) = self.persisted_bytes.read().unwrap().get(namespace) { - if let Some(inner_ref) = outer_ref.get(key) { - let locked = inner_ref.read().unwrap(); - return Some((*locked).clone()); - } - } - None + let persisted_lock = self.persisted_bytes.lock().unwrap(); + persisted_lock.get(namespace).and_then(|e| e.get(key).cloned()) } pub fn get_and_clear_did_persist(&self) -> bool { @@ -345,12 +341,14 @@ impl TestStore { } impl KVStore for TestStore { - type Reader = TestReader; + type Reader = io::Cursor>; fn read(&self, namespace: &str, key: &str) -> io::Result { - if let Some(outer_ref) = self.persisted_bytes.read().unwrap().get(namespace) { + let persisted_lock = self.persisted_bytes.lock().unwrap(); + if let Some(outer_ref) = persisted_lock.get(namespace) { if let Some(inner_ref) = outer_ref.get(key) { - Ok(TestReader::new(Arc::clone(inner_ref))) + let bytes = inner_ref.clone(); + Ok(io::Cursor::new(bytes)) } else { Err(io::Error::new(io::ErrorKind::NotFound, "Key not found")) } @@ -360,53 +358,47 @@ impl KVStore for TestStore { } fn write(&self, namespace: &str, key: &str, buf: &[u8]) -> io::Result<()> { - let mut guard = self.persisted_bytes.write().unwrap(); - let outer_e = guard.entry(namespace.to_string()).or_insert(HashMap::new()); - let inner_e = outer_e.entry(key.to_string()).or_insert(Arc::new(RwLock::new(Vec::new()))); - - let mut guard = inner_e.write().unwrap(); - guard.write_all(buf)?; + if self.read_only { + return Err(io::Error::new( + io::ErrorKind::PermissionDenied, + "read only", + )); + } + let mut persisted_lock = self.persisted_bytes.lock().unwrap(); + let outer_e = persisted_lock.entry(namespace.to_string()).or_insert(HashMap::new()); + let mut bytes = Vec::new(); + bytes.write_all(buf)?; + outer_e.insert(key.to_string(), bytes); self.did_persist.store(true, Ordering::SeqCst); Ok(()) } fn remove(&self, namespace: &str, key: &str) -> io::Result<()> { - match self.persisted_bytes.write().unwrap().entry(namespace.to_string()) { - hash_map::Entry::Occupied(mut e) => { + if self.read_only { + return Err(io::Error::new( + io::ErrorKind::PermissionDenied, + "read only", + )); + } + + let mut persisted_lock = self.persisted_bytes.lock().unwrap(); + if let Some(outer_ref) = persisted_lock.get_mut(namespace) { + outer_ref.remove(&key.to_string()); self.did_persist.store(true, Ordering::SeqCst); - e.get_mut().remove(&key.to_string()); - Ok(()) - } - hash_map::Entry::Vacant(_) => Ok(()), } + + Ok(()) } fn list(&self, namespace: &str) -> io::Result> { - match self.persisted_bytes.write().unwrap().entry(namespace.to_string()) { + let mut persisted_lock = self.persisted_bytes.lock().unwrap(); + match persisted_lock.entry(namespace.to_string()) { hash_map::Entry::Occupied(e) => Ok(e.get().keys().cloned().collect()), hash_map::Entry::Vacant(_) => Ok(Vec::new()), } } } -pub struct TestReader { - entry_ref: Arc>>, -} - -impl TestReader { - pub fn new(entry_ref: Arc>>) -> Self { - Self { entry_ref } - } -} - -impl io::Read for TestReader { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let bytes = self.entry_ref.read().unwrap().clone(); - let mut reader = io::Cursor::new(bytes); - reader.read(buf) - } -} - pub struct TestBroadcaster { pub txn_broadcasted: Mutex>, pub blocks: Arc>>, -- 2.39.5