]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Add `TestStore` implementation of `KVStore`
authorElias Rohrer <dev@tnull.de>
Mon, 21 Aug 2023 14:17:35 +0000 (16:17 +0200)
committerElias Rohrer <dev@tnull.de>
Wed, 23 Aug 2023 13:17:09 +0000 (15:17 +0200)
lightning/src/util/test_utils.rs

index 65c0483a59c9906e36b4a42587698362bd77702c..48b0fc017a9cf4d3abd3bcfaeacc3b24a1dd64e3 100644 (file)
@@ -32,6 +32,7 @@ use crate::util::config::UserConfig;
 use crate::util::enforcing_trait_impls::{EnforcingSigner, EnforcementState};
 use crate::util::logger::{Logger, Level, Record};
 use crate::util::ser::{Readable, ReadableArgs, Writer, Writeable};
+use crate::util::persist::KVStore;
 
 use bitcoin::EcdsaSighashType;
 use bitcoin::blockdata::constants::ChainHash;
@@ -56,7 +57,7 @@ use crate::prelude::*;
 use core::cell::RefCell;
 use core::ops::DerefMut;
 use core::time::Duration;
-use crate::sync::{Mutex, Arc};
+use crate::sync::{Mutex, Arc, RwLock};
 use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
 use core::mem;
 use bitcoin::bech32::u5;
@@ -316,6 +317,96 @@ impl<Signer: sign::WriteableEcdsaChannelSigner> chainmonitor::Persist<Signer> fo
        }
 }
 
+pub(crate) struct TestStore {
+       persisted_bytes: RwLock<HashMap<String, HashMap<String, Arc<RwLock<Vec<u8>>>>>>,
+       did_persist: Arc<AtomicBool>,
+}
+
+impl TestStore {
+       pub fn new() -> Self {
+               let persisted_bytes = RwLock::new(HashMap::new());
+               let did_persist = Arc::new(AtomicBool::new(false));
+               Self { persisted_bytes, did_persist }
+       }
+
+       pub fn get_persisted_bytes(&self, namespace: &str, key: &str) -> Option<Vec<u8>> {
+               if let Some(outer_ref) = self.persisted_bytes.read().unwrap().get(namespace) {
+                       if let Some(inner_ref) = outer_ref.get(key) {
+                               let locked = inner_ref.read().unwrap();
+                               return Some((*locked).clone());
+                       }
+               }
+               None
+       }
+
+       pub fn get_and_clear_did_persist(&self) -> bool {
+               self.did_persist.swap(false, Ordering::Relaxed)
+       }
+}
+
+impl KVStore for TestStore {
+       type Reader = TestReader;
+
+       fn read(&self, namespace: &str, key: &str) -> io::Result<Self::Reader> {
+               if let Some(outer_ref) = self.persisted_bytes.read().unwrap().get(namespace) {
+                       if let Some(inner_ref) = outer_ref.get(key) {
+                               Ok(TestReader::new(Arc::clone(inner_ref)))
+                       } else {
+                               Err(io::Error::new(io::ErrorKind::NotFound, "Key not found"))
+                       }
+               } else {
+                       Err(io::Error::new(io::ErrorKind::NotFound, "Namespace not found"))
+               }
+       }
+
+       fn write(&self, namespace: &str, key: &str, buf: &[u8]) -> io::Result<()> {
+               let mut guard = self.persisted_bytes.write().unwrap();
+               let outer_e = guard.entry(namespace.to_string()).or_insert(HashMap::new());
+               let inner_e = outer_e.entry(key.to_string()).or_insert(Arc::new(RwLock::new(Vec::new())));
+
+               let mut guard = inner_e.write().unwrap();
+               guard.write_all(buf)?;
+               self.did_persist.store(true, Ordering::SeqCst);
+               Ok(())
+       }
+
+       fn remove(&self, namespace: &str, key: &str) -> io::Result<()> {
+               match self.persisted_bytes.write().unwrap().entry(namespace.to_string()) {
+                       hash_map::Entry::Occupied(mut e) => {
+                               self.did_persist.store(true, Ordering::SeqCst);
+                               e.get_mut().remove(&key.to_string());
+                               Ok(())
+                       }
+                       hash_map::Entry::Vacant(_) => Ok(()),
+               }
+       }
+
+       fn list(&self, namespace: &str) -> io::Result<Vec<String>> {
+               match self.persisted_bytes.write().unwrap().entry(namespace.to_string()) {
+                       hash_map::Entry::Occupied(e) => Ok(e.get().keys().cloned().collect()),
+                       hash_map::Entry::Vacant(_) => Ok(Vec::new()),
+               }
+       }
+}
+
+pub struct TestReader {
+       entry_ref: Arc<RwLock<Vec<u8>>>,
+}
+
+impl TestReader {
+       pub fn new(entry_ref: Arc<RwLock<Vec<u8>>>) -> Self {
+               Self { entry_ref }
+       }
+}
+
+impl io::Read for TestReader {
+       fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+               let bytes = self.entry_ref.read().unwrap().clone();
+               let mut reader = io::Cursor::new(bytes);
+               reader.read(buf)
+       }
+}
+
 pub struct TestBroadcaster {
        pub txn_broadcasted: Mutex<Vec<Transaction>>,
        pub blocks: Arc<Mutex<Vec<(Block, u32)>>>,