]> git.bitcoin.ninja Git - rust-lightning/commitdiff
f Simplify and fix `TestStore`
authorElias Rohrer <dev@tnull.de>
Tue, 22 Aug 2023 09:09:33 +0000 (11:09 +0200)
committerElias Rohrer <dev@tnull.de>
Wed, 23 Aug 2023 13:17:09 +0000 (15:17 +0200)
.. as we don't require all that logic anymore now that we don't return
an `FilesystemWriter` anymore etc.

lightning/src/util/test_utils.rs

index 48b0fc017a9cf4d3abd3bcfaeacc3b24a1dd64e3..c692f2d03e515a9450bfe3d3f48a3b2d9a5b5e56 100644 (file)
@@ -57,7 +57,7 @@ use crate::prelude::*;
 use core::cell::RefCell;
 use core::ops::DerefMut;
 use core::time::Duration;
-use crate::sync::{Mutex, Arc, RwLock};
+use crate::sync::{Mutex, Arc};
 use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
 use core::mem;
 use bitcoin::bech32::u5;
@@ -318,25 +318,21 @@ impl<Signer: sign::WriteableEcdsaChannelSigner> chainmonitor::Persist<Signer> fo
 }
 
 pub(crate) struct TestStore {
-       persisted_bytes: RwLock<HashMap<String, HashMap<String, Arc<RwLock<Vec<u8>>>>>>,
+       persisted_bytes: Mutex<HashMap<String, HashMap<String, Vec<u8>>>>,
        did_persist: Arc<AtomicBool>,
+       read_only: bool,
 }
 
 impl TestStore {
-       pub fn new() -> Self {
-               let persisted_bytes = RwLock::new(HashMap::new());
+       pub fn new(read_only: bool) -> Self {
+               let persisted_bytes = Mutex::new(HashMap::new());
                let did_persist = Arc::new(AtomicBool::new(false));
-               Self { persisted_bytes, did_persist }
+               Self { persisted_bytes, did_persist, read_only }
        }
 
        pub fn get_persisted_bytes(&self, namespace: &str, key: &str) -> Option<Vec<u8>> {
-               if let Some(outer_ref) = self.persisted_bytes.read().unwrap().get(namespace) {
-                       if let Some(inner_ref) = outer_ref.get(key) {
-                               let locked = inner_ref.read().unwrap();
-                               return Some((*locked).clone());
-                       }
-               }
-               None
+               let persisted_lock = self.persisted_bytes.lock().unwrap();
+               persisted_lock.get(namespace).and_then(|e| e.get(key).cloned())
        }
 
        pub fn get_and_clear_did_persist(&self) -> bool {
@@ -345,12 +341,14 @@ impl TestStore {
 }
 
 impl KVStore for TestStore {
-       type Reader = TestReader;
+       type Reader = io::Cursor<Vec<u8>>;
 
        fn read(&self, namespace: &str, key: &str) -> io::Result<Self::Reader> {
-               if let Some(outer_ref) = self.persisted_bytes.read().unwrap().get(namespace) {
+               let persisted_lock = self.persisted_bytes.lock().unwrap();
+               if let Some(outer_ref) = persisted_lock.get(namespace) {
                        if let Some(inner_ref) = outer_ref.get(key) {
-                               Ok(TestReader::new(Arc::clone(inner_ref)))
+                               let bytes = inner_ref.clone();
+                               Ok(io::Cursor::new(bytes))
                        } else {
                                Err(io::Error::new(io::ErrorKind::NotFound, "Key not found"))
                        }
@@ -360,53 +358,47 @@ impl KVStore for TestStore {
        }
 
        fn write(&self, namespace: &str, key: &str, buf: &[u8]) -> io::Result<()> {
-               let mut guard = self.persisted_bytes.write().unwrap();
-               let outer_e = guard.entry(namespace.to_string()).or_insert(HashMap::new());
-               let inner_e = outer_e.entry(key.to_string()).or_insert(Arc::new(RwLock::new(Vec::new())));
-
-               let mut guard = inner_e.write().unwrap();
-               guard.write_all(buf)?;
+               if self.read_only {
+                       return Err(io::Error::new(
+                               io::ErrorKind::PermissionDenied,
+                               "read only",
+                       ));
+               }
+               let mut persisted_lock = self.persisted_bytes.lock().unwrap();
+               let outer_e = persisted_lock.entry(namespace.to_string()).or_insert(HashMap::new());
+               let mut bytes = Vec::new();
+               bytes.write_all(buf)?;
+               outer_e.insert(key.to_string(), bytes);
                self.did_persist.store(true, Ordering::SeqCst);
                Ok(())
        }
 
        fn remove(&self, namespace: &str, key: &str) -> io::Result<()> {
-               match self.persisted_bytes.write().unwrap().entry(namespace.to_string()) {
-                       hash_map::Entry::Occupied(mut e) => {
+               if self.read_only {
+                       return Err(io::Error::new(
+                               io::ErrorKind::PermissionDenied,
+                               "read only",
+                       ));
+               }
+
+               let mut persisted_lock = self.persisted_bytes.lock().unwrap();
+               if let Some(outer_ref) = persisted_lock.get_mut(namespace) {
+                               outer_ref.remove(&key.to_string());
                                self.did_persist.store(true, Ordering::SeqCst);
-                               e.get_mut().remove(&key.to_string());
-                               Ok(())
-                       }
-                       hash_map::Entry::Vacant(_) => Ok(()),
                }
+
+               Ok(())
        }
 
        fn list(&self, namespace: &str) -> io::Result<Vec<String>> {
-               match self.persisted_bytes.write().unwrap().entry(namespace.to_string()) {
+               let mut persisted_lock = self.persisted_bytes.lock().unwrap();
+               match persisted_lock.entry(namespace.to_string()) {
                        hash_map::Entry::Occupied(e) => Ok(e.get().keys().cloned().collect()),
                        hash_map::Entry::Vacant(_) => Ok(Vec::new()),
                }
        }
 }
 
-pub struct TestReader {
-       entry_ref: Arc<RwLock<Vec<u8>>>,
-}
-
-impl TestReader {
-       pub fn new(entry_ref: Arc<RwLock<Vec<u8>>>) -> Self {
-               Self { entry_ref }
-       }
-}
-
-impl io::Read for TestReader {
-       fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
-               let bytes = self.entry_ref.read().unwrap().clone();
-               let mut reader = io::Cursor::new(bytes);
-               reader.read(buf)
-       }
-}
-
 pub struct TestBroadcaster {
        pub txn_broadcasted: Mutex<Vec<Transaction>>,
        pub blocks: Arc<Mutex<Vec<(Block, u32)>>>,