]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Merge pull request #861 from lightning-signer/degenerify
authorMatt Corallo <649246+TheBlueMatt@users.noreply.github.com>
Fri, 9 Apr 2021 23:57:20 +0000 (23:57 +0000)
committerGitHub <noreply@github.com>
Fri, 9 Apr 2021 23:57:20 +0000 (23:57 +0000)
De-generify Sign methods

1  2 
lightning-persister/src/lib.rs

index 3138087997e7e2295ca4d3cdcef3264a16a5c75f,47789f2ef2ca6250aeea05679316ce07f900a449..368945d1cfb489db52752a1e4ec6b4afedb0497d
@@@ -12,8 -12,7 +12,8 @@@ extern crate lightning
  extern crate bitcoin;
  extern crate libc;
  
 -use bitcoin::hashes::hex::ToHex;
 +use bitcoin::{BlockHash, Txid};
 +use bitcoin::hashes::hex::{FromHex, ToHex};
  use crate::util::DiskWriteable;
  use lightning::chain;
  use lightning::chain::chaininterface::{BroadcasterInterface, FeeEstimator};
@@@ -23,14 -22,21 +23,14 @@@ use lightning::chain::keysinterface::{S
  use lightning::chain::transaction::OutPoint;
  use lightning::ln::channelmanager::ChannelManager;
  use lightning::util::logger::Logger;
 -use lightning::util::ser::Writeable;
 +use lightning::util::ser::{ReadableArgs, Writeable};
 +use std::collections::HashMap;
  use std::fs;
 -use std::io::Error;
 -use std::path::PathBuf;
 +use std::io::{Cursor, Error};
 +use std::ops::Deref;
 +use std::path::{Path, PathBuf};
  use std::sync::Arc;
  
 -#[cfg(test)]
 -use {
 -      lightning::util::ser::ReadableArgs,
 -      bitcoin::{BlockHash, Txid},
 -      bitcoin::hashes::hex::FromHex,
 -      std::collections::HashMap,
 -      std::io::Cursor
 -};
 -
  /// FilesystemPersister persists channel data on disk, where each channel's
  /// data is stored in a file named after its funding outpoint.
  ///
@@@ -102,64 -108,42 +102,64 @@@ impl FilesystemPersister 
                util::write_to_file(path, "manager".to_string(), manager)
        }
  
 -      #[cfg(test)]
 -      fn load_channel_data<Keys: KeysInterface>(&self, keys: &Keys) ->
 -              Result<HashMap<OutPoint, ChannelMonitor<Keys::Signer>>, ChannelMonitorUpdateErr> {
 -                      if let Err(_) = fs::create_dir_all(self.path_to_monitor_data()) {
 -                              return Err(ChannelMonitorUpdateErr::PermanentFailure);
 +      /// Read `ChannelMonitor`s from disk.
 +      pub fn read_channelmonitors<Signer: Sign, K: Deref> (
 +              &self, keys_manager: K
 +      ) -> Result<HashMap<OutPoint, (BlockHash, ChannelMonitor<Signer>)>, std::io::Error>
 +           where K::Target: KeysInterface<Signer=Signer> + Sized
 +      {
 +              let path = self.path_to_monitor_data();
 +              if !Path::new(&path).exists() {
 +                      return Ok(HashMap::new());
 +              }
 +              let mut outpoint_to_channelmonitor = HashMap::new();
 +              for file_option in fs::read_dir(path).unwrap() {
 +                      let file = file_option.unwrap();
 +                      let owned_file_name = file.file_name();
 +                      let filename = owned_file_name.to_str();
 +                      if !filename.is_some() || !filename.unwrap().is_ascii() || filename.unwrap().len() < 65 {
 +                              return Err(std::io::Error::new(
 +                                      std::io::ErrorKind::InvalidData,
 +                                      "Invalid ChannelMonitor file name",
 +                              ));
                        }
 -                      let mut res = HashMap::new();
 -                      for file_option in fs::read_dir(self.path_to_monitor_data()).unwrap() {
 -                              let file = file_option.unwrap();
 -                              let owned_file_name = file.file_name();
 -                              let filename = owned_file_name.to_str();
 -                              if !filename.is_some() || !filename.unwrap().is_ascii() || filename.unwrap().len() < 65 {
 -                                      return Err(ChannelMonitorUpdateErr::PermanentFailure);
 -                              }
  
 -                              let txid = Txid::from_hex(filename.unwrap().split_at(64).0);
 -                              if txid.is_err() { return Err(ChannelMonitorUpdateErr::PermanentFailure); }
 -
 -                              let index = filename.unwrap().split_at(65).1.split('.').next().unwrap().parse();
 -                              if index.is_err() { return Err(ChannelMonitorUpdateErr::PermanentFailure); }
 +                      let txid = Txid::from_hex(filename.unwrap().split_at(64).0);
 +                      if txid.is_err() {
 +                              return Err(std::io::Error::new(
 +                                      std::io::ErrorKind::InvalidData,
 +                                      "Invalid tx ID in filename",
 +                              ));
 +                      }
  
 -                              let contents = fs::read(&file.path());
 -                              if contents.is_err() { return Err(ChannelMonitorUpdateErr::PermanentFailure); }
 +                      let index = filename.unwrap().split_at(65).1.parse();
 +                      if index.is_err() {
 +                              return Err(std::io::Error::new(
 +                                      std::io::ErrorKind::InvalidData,
 +                                      "Invalid tx index in filename",
 +                              ));
 +                      }
  
 -                              if let Ok((_, loaded_monitor)) =
 -                                      <(BlockHash, ChannelMonitor<Keys::Signer>)>::read(&mut Cursor::new(&contents.unwrap()), keys) {
 -                                              res.insert(OutPoint { txid: txid.unwrap(), index: index.unwrap() }, loaded_monitor);
 -                                      } else {
 -                                              return Err(ChannelMonitorUpdateErr::PermanentFailure);
 -                                      }
 +                      let contents = fs::read(&file.path())?;
 +                      let mut buffer = Cursor::new(&contents);
 +                      match <(BlockHash, ChannelMonitor<Signer>)>::read(&mut buffer, &*keys_manager) {
 +                              Ok((blockhash, channel_monitor)) => {
 +                                      outpoint_to_channelmonitor.insert(
 +                                              OutPoint { txid: txid.unwrap(), index: index.unwrap() },
 +                                              (blockhash, channel_monitor),
 +                                      );
 +                              }
 +                              Err(e) => return Err(std::io::Error::new(
 +                                      std::io::ErrorKind::InvalidData,
 +                                      format!("Failed to deserialize ChannelMonitor: {}", e),
 +                              ))
                        }
 -                      Ok(res)
                }
 +              Ok(outpoint_to_channelmonitor)
 +      }
  }
  
- impl<ChannelSigner: Sign + Send + Sync> channelmonitor::Persist<ChannelSigner> for FilesystemPersister {
+ impl<ChannelSigner: Sign> channelmonitor::Persist<ChannelSigner> for FilesystemPersister {
        fn persist_new_channel(&self, funding_txo: OutPoint, monitor: &ChannelMonitor<ChannelSigner>) -> Result<(), ChannelMonitorUpdateErr> {
                let filename = format!("{}_{}", funding_txo.txid.to_hex(), funding_txo.index);
                util::write_to_file(self.path_to_monitor_data(), filename, monitor)
@@@ -226,22 -210,22 +226,22 @@@ mod tests 
  
                // Check that the persisted channel data is empty before any channels are
                // open.
 -              let mut persisted_chan_data_0 = persister_0.load_channel_data(nodes[0].keys_manager).unwrap();
 +              let mut persisted_chan_data_0 = persister_0.read_channelmonitors(nodes[0].keys_manager).unwrap();
                assert_eq!(persisted_chan_data_0.keys().len(), 0);
 -              let mut persisted_chan_data_1 = persister_1.load_channel_data(nodes[1].keys_manager).unwrap();
 +              let mut persisted_chan_data_1 = persister_1.read_channelmonitors(nodes[1].keys_manager).unwrap();
                assert_eq!(persisted_chan_data_1.keys().len(), 0);
  
                // Helper to make sure the channel is on the expected update ID.
                macro_rules! check_persisted_data {
                        ($expected_update_id: expr) => {
 -                              persisted_chan_data_0 = persister_0.load_channel_data(nodes[0].keys_manager).unwrap();
 +                              persisted_chan_data_0 = persister_0.read_channelmonitors(nodes[0].keys_manager).unwrap();
                                assert_eq!(persisted_chan_data_0.keys().len(), 1);
 -                              for mon in persisted_chan_data_0.values() {
 +                              for (_, mon) in persisted_chan_data_0.values() {
                                        assert_eq!(mon.get_latest_update_id(), $expected_update_id);
                                }
 -                              persisted_chan_data_1 = persister_1.load_channel_data(nodes[1].keys_manager).unwrap();
 +                              persisted_chan_data_1 = persister_1.read_channelmonitors(nodes[1].keys_manager).unwrap();
                                assert_eq!(persisted_chan_data_1.keys().len(), 1);
 -                              for mon in persisted_chan_data_1.values() {
 +                              for (_, mon) in persisted_chan_data_1.values() {
                                        assert_eq!(mon.get_latest_update_id(), $expected_update_id);
                                }
                        }