From bd8cd5b52a0e2555b235356bc335fcae530f6e71 Mon Sep 17 00:00:00 2001 From: Elias Rohrer Date: Mon, 21 Aug 2023 16:17:35 +0200 Subject: [PATCH] Add `TestStore` implementation of `KVStore` --- lightning/src/util/test_utils.rs | 93 +++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 65c0483a5..48b0fc017 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -32,6 +32,7 @@ use crate::util::config::UserConfig; use crate::util::enforcing_trait_impls::{EnforcingSigner, EnforcementState}; use crate::util::logger::{Logger, Level, Record}; use crate::util::ser::{Readable, ReadableArgs, Writer, Writeable}; +use crate::util::persist::KVStore; use bitcoin::EcdsaSighashType; use bitcoin::blockdata::constants::ChainHash; @@ -56,7 +57,7 @@ use crate::prelude::*; use core::cell::RefCell; use core::ops::DerefMut; use core::time::Duration; -use crate::sync::{Mutex, Arc}; +use crate::sync::{Mutex, Arc, RwLock}; use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use core::mem; use bitcoin::bech32::u5; @@ -316,6 +317,96 @@ impl chainmonitor::Persist fo } } +pub(crate) struct TestStore { + persisted_bytes: RwLock>>>>>, + did_persist: Arc, +} + +impl TestStore { + pub fn new() -> Self { + let persisted_bytes = RwLock::new(HashMap::new()); + let did_persist = Arc::new(AtomicBool::new(false)); + Self { persisted_bytes, did_persist } + } + + pub fn get_persisted_bytes(&self, namespace: &str, key: &str) -> Option> { + if let Some(outer_ref) = self.persisted_bytes.read().unwrap().get(namespace) { + if let Some(inner_ref) = outer_ref.get(key) { + let locked = inner_ref.read().unwrap(); + return Some((*locked).clone()); + } + } + None + } + + pub fn get_and_clear_did_persist(&self) -> bool { + self.did_persist.swap(false, Ordering::Relaxed) + } +} + +impl KVStore for TestStore { + type Reader = TestReader; + + fn read(&self, namespace: &str, key: &str) -> io::Result { + if let Some(outer_ref) = self.persisted_bytes.read().unwrap().get(namespace) { + if let Some(inner_ref) = outer_ref.get(key) { + Ok(TestReader::new(Arc::clone(inner_ref))) + } else { + Err(io::Error::new(io::ErrorKind::NotFound, "Key not found")) + } + } else { + Err(io::Error::new(io::ErrorKind::NotFound, "Namespace not found")) + } + } + + fn write(&self, namespace: &str, key: &str, buf: &[u8]) -> io::Result<()> { + let mut guard = self.persisted_bytes.write().unwrap(); + let outer_e = guard.entry(namespace.to_string()).or_insert(HashMap::new()); + let inner_e = outer_e.entry(key.to_string()).or_insert(Arc::new(RwLock::new(Vec::new()))); + + let mut guard = inner_e.write().unwrap(); + guard.write_all(buf)?; + self.did_persist.store(true, Ordering::SeqCst); + Ok(()) + } + + fn remove(&self, namespace: &str, key: &str) -> io::Result<()> { + match self.persisted_bytes.write().unwrap().entry(namespace.to_string()) { + hash_map::Entry::Occupied(mut e) => { + self.did_persist.store(true, Ordering::SeqCst); + e.get_mut().remove(&key.to_string()); + Ok(()) + } + hash_map::Entry::Vacant(_) => Ok(()), + } + } + + fn list(&self, namespace: &str) -> io::Result> { + match self.persisted_bytes.write().unwrap().entry(namespace.to_string()) { + hash_map::Entry::Occupied(e) => Ok(e.get().keys().cloned().collect()), + hash_map::Entry::Vacant(_) => Ok(Vec::new()), + } + } +} + +pub struct TestReader { + entry_ref: Arc>>, +} + +impl TestReader { + pub fn new(entry_ref: Arc>>) -> Self { + Self { entry_ref } + } +} + +impl io::Read for TestReader { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let bytes = self.entry_ref.read().unwrap().clone(); + let mut reader = io::Cursor::new(bytes); + reader.read(buf) + } +} + pub struct TestBroadcaster { pub txn_broadcasted: Mutex>, pub blocks: Arc>>, -- 2.39.5