**/*.rs.bk
Cargo.lock
.idea
-
+lightning/target
use lightning::ln::features::{ChannelFeatures, InitFeatures, NodeFeatures};
use lightning::ln::msgs::{CommitmentUpdate, ChannelMessageHandler, DecodeError, UpdateAddHTLC, Init};
use lightning::ln::script::ShutdownScript;
-use lightning::util::enforcing_trait_impls::{EnforcingSigner, INITIAL_REVOKED_COMMITMENT_NUMBER};
+use lightning::util::enforcing_trait_impls::{EnforcingSigner, EnforcementState};
use lightning::util::errors::APIError;
use lightning::util::events;
use lightning::util::logger::Logger;
struct KeyProvider {
node_id: u8,
rand_bytes_id: atomic::AtomicU32,
- revoked_commitments: Mutex<HashMap<[u8;32], Arc<Mutex<u64>>>>,
+ enforcement_states: Mutex<HashMap<[u8;32], Arc<Mutex<EnforcementState>>>>,
}
impl KeysInterface for KeyProvider {
type Signer = EnforcingSigner;
channel_value_satoshis,
[0; 32],
);
- let revoked_commitment = self.make_revoked_commitment_cell(keys.commitment_seed);
+ let revoked_commitment = self.make_enforcement_state_cell(keys.commitment_seed);
EnforcingSigner::new_with_revoked(keys, revoked_commitment, false)
}
let mut reader = std::io::Cursor::new(buffer);
let inner: InMemorySigner = Readable::read(&mut reader)?;
- let revoked_commitment = self.make_revoked_commitment_cell(inner.commitment_seed);
-
- let last_commitment_number = Readable::read(&mut reader)?;
+ let state = self.make_enforcement_state_cell(inner.commitment_seed);
Ok(EnforcingSigner {
inner,
- last_commitment_number: Arc::new(Mutex::new(last_commitment_number)),
- revoked_commitment,
+ state,
disable_revocation_policy_check: false,
})
}
}
impl KeyProvider {
- fn make_revoked_commitment_cell(&self, commitment_seed: [u8; 32]) -> Arc<Mutex<u64>> {
- let mut revoked_commitments = self.revoked_commitments.lock().unwrap();
+ fn make_enforcement_state_cell(&self, commitment_seed: [u8; 32]) -> Arc<Mutex<EnforcementState>> {
+ let mut revoked_commitments = self.enforcement_states.lock().unwrap();
if !revoked_commitments.contains_key(&commitment_seed) {
- revoked_commitments.insert(commitment_seed, Arc::new(Mutex::new(INITIAL_REVOKED_COMMITMENT_NUMBER)));
+ revoked_commitments.insert(commitment_seed, Arc::new(Mutex::new(EnforcementState::new())));
}
let cell = revoked_commitments.get(&commitment_seed).unwrap();
Arc::clone(cell)
macro_rules! make_node {
($node_id: expr, $fee_estimator: expr) => { {
let logger: Arc<dyn Logger> = Arc::new(test_logger::TestLogger::new($node_id.to_string(), out.clone()));
- let keys_manager = Arc::new(KeyProvider { node_id: $node_id, rand_bytes_id: atomic::AtomicU32::new(0), revoked_commitments: Mutex::new(HashMap::new()) });
+ let keys_manager = Arc::new(KeyProvider { node_id: $node_id, rand_bytes_id: atomic::AtomicU32::new(0), enforcement_states: Mutex::new(HashMap::new()) });
let monitor = Arc::new(TestChainMonitor::new(broadcast.clone(), logger.clone(), $fee_estimator.clone(), Arc::new(TestPersister{}), Arc::clone(&keys_manager)));
let mut config = UserConfig::default();
use lightning::chain::keysinterface::{InMemorySigner, KeysInterface};
use lightning::ln::{PaymentHash, PaymentPreimage, PaymentSecret};
use lightning::ln::channelmanager::{ChainParameters, ChannelManager};
-use lightning::ln::peer_handler::{MessageHandler,PeerManager,SocketDescriptor};
+use lightning::ln::peer_handler::{MessageHandler,PeerManager,SocketDescriptor,IgnoringMessageHandler};
use lightning::ln::msgs::DecodeError;
use lightning::ln::script::ShutdownScript;
use lightning::routing::router::get_route;
use lightning::util::config::UserConfig;
use lightning::util::errors::APIError;
use lightning::util::events::Event;
-use lightning::util::enforcing_trait_impls::EnforcingSigner;
+use lightning::util::enforcing_trait_impls::{EnforcingSigner, EnforcementState};
use lightning::util::logger::Logger;
use lightning::util::ser::Readable;
EnforcingSigner,
Arc<chainmonitor::ChainMonitor<EnforcingSigner, Arc<dyn chain::Filter>, Arc<TestBroadcaster>, Arc<FuzzEstimator>, Arc<dyn Logger>, Arc<TestPersister>>>,
Arc<TestBroadcaster>, Arc<KeyProvider>, Arc<FuzzEstimator>, Arc<dyn Logger>>;
-type PeerMan<'a> = PeerManager<Peer<'a>, Arc<ChannelMan>, Arc<NetGraphMsgHandler<Arc<dyn chain::Access>, Arc<dyn Logger>>>, Arc<dyn Logger>>;
+type PeerMan<'a> = PeerManager<Peer<'a>, Arc<ChannelMan>, Arc<NetGraphMsgHandler<Arc<dyn chain::Access>, Arc<dyn Logger>>>, Arc<dyn Logger>, IgnoringMessageHandler>;
struct MoneyLossDetector<'a> {
manager: Arc<ChannelMan>,
(ctr >> 8*7) as u8, (ctr >> 8*6) as u8, (ctr >> 8*5) as u8, (ctr >> 8*4) as u8, (ctr >> 8*3) as u8, (ctr >> 8*2) as u8, (ctr >> 8*1) as u8, 14, (ctr >> 8*0) as u8]
}
- fn read_chan_signer(&self, data: &[u8]) -> Result<EnforcingSigner, DecodeError> {
- EnforcingSigner::read(&mut std::io::Cursor::new(data))
+ fn read_chan_signer(&self, mut data: &[u8]) -> Result<EnforcingSigner, DecodeError> {
+ let inner: InMemorySigner = Readable::read(&mut data)?;
+ let state = Arc::new(Mutex::new(EnforcementState::new()));
+
+ Ok(EnforcingSigner::new_with_revoked(
+ inner,
+ state,
+ false
+ ))
}
fn sign_invoice(&self, _invoice_preimage: Vec<u8>) -> Result<RecoverableSignature, ()> {
let mut loss_detector = MoneyLossDetector::new(&peers, channelmanager.clone(), monitor.clone(), PeerManager::new(MessageHandler {
chan_handler: channelmanager.clone(),
route_handler: net_graph_msg_handler.clone(),
- }, our_network_key, &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0], Arc::clone(&logger)));
+ }, our_network_key, &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0], Arc::clone(&logger), IgnoringMessageHandler{}));
let mut should_forward = false;
let mut payments_received: Vec<PaymentHash> = Vec::new();
use lightning::ln::channelmanager::ChannelManager;
use lightning::ln::msgs::{ChannelMessageHandler, RoutingMessageHandler};
use lightning::ln::peer_handler::{PeerManager, SocketDescriptor};
+use lightning::ln::peer_handler::CustomMessageHandler;
use lightning::util::events::{EventHandler, EventsProvider};
use lightning::util::logger::Logger;
use std::sync::Arc;
CMP: 'static + Send + ChannelManagerPersister<Signer, CW, T, K, F, L>,
M: 'static + Deref<Target = ChainMonitor<Signer, CF, T, F, L, P>> + Send + Sync,
CM: 'static + Deref<Target = ChannelManager<Signer, CW, T, K, F, L>> + Send + Sync,
- PM: 'static + Deref<Target = PeerManager<Descriptor, CMH, RMH, L>> + Send + Sync,
+ UMH: 'static + Deref + Send + Sync,
+ PM: 'static + Deref<Target = PeerManager<Descriptor, CMH, RMH, L, UMH>> + Send + Sync,
>
(persister: CMP, event_handler: EH, chain_monitor: M, channel_manager: CM, peer_manager: PM, logger: L) -> Self
where
P::Target: 'static + channelmonitor::Persist<Signer>,
CMH::Target: 'static + ChannelMessageHandler,
RMH::Target: 'static + RoutingMessageHandler,
+ UMH::Target: 'static + CustomMessageHandler,
{
let stop_thread = Arc::new(AtomicBool::new(false));
let stop_thread_clone = stop_thread.clone();
use lightning::ln::channelmanager::{BREAKDOWN_TIMEOUT, ChainParameters, ChannelManager, SimpleArcChannelManager};
use lightning::ln::features::InitFeatures;
use lightning::ln::msgs::{ChannelMessageHandler, Init};
- use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor};
+ use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler};
use lightning::util::config::UserConfig;
use lightning::util::events::{Event, MessageSendEventsProvider, MessageSendEvent};
use lightning::util::ser::Writeable;
struct Node {
node: Arc<SimpleArcChannelManager<ChainMonitor, test_utils::TestBroadcaster, test_utils::TestFeeEstimator, test_utils::TestLogger>>,
- peer_manager: Arc<PeerManager<TestDescriptor, Arc<test_utils::TestChannelMessageHandler>, Arc<test_utils::TestRoutingMessageHandler>, Arc<test_utils::TestLogger>>>,
+ peer_manager: Arc<PeerManager<TestDescriptor, Arc<test_utils::TestChannelMessageHandler>, Arc<test_utils::TestRoutingMessageHandler>, Arc<test_utils::TestLogger>, IgnoringMessageHandler>>,
chain_monitor: Arc<ChainMonitor>,
persister: Arc<FilesystemPersister>,
tx_broadcaster: Arc<test_utils::TestBroadcaster>,
let params = ChainParameters { network, best_block };
let manager = Arc::new(ChannelManager::new(fee_estimator.clone(), chain_monitor.clone(), tx_broadcaster.clone(), logger.clone(), keys_manager.clone(), UserConfig::default(), params));
let msg_handler = MessageHandler { chan_handler: Arc::new(test_utils::TestChannelMessageHandler::new()), route_handler: Arc::new(test_utils::TestRoutingMessageHandler::new() )};
- let peer_manager = Arc::new(PeerManager::new(msg_handler, keys_manager.get_node_secret(), &seed, logger.clone()));
+ let peer_manager = Arc::new(PeerManager::new(msg_handler, keys_manager.get_node_secret(), &seed, logger.clone(), IgnoringMessageHandler{}));
let node = Node { node: manager, peer_manager, chain_monitor, persister, tx_broadcaster, logger, best_block };
nodes.push(node);
}
use bitcoin::blockdata::block::{Block, BlockHeader};
use bitcoin::consensus::encode;
-use bitcoin::hash_types::{BlockHash, TxMerkleNode};
+use bitcoin::hash_types::{BlockHash, TxMerkleNode, Txid};
use bitcoin::hashes::hex::{ToHex, FromHex};
use serde::Deserialize;
}
}
+impl TryInto<Txid> for JsonResponse {
+ type Error = std::io::Error;
+ fn try_into(self) -> std::io::Result<Txid> {
+ match self.0.as_str() {
+ None => Err(std::io::Error::new(
+ std::io::ErrorKind::InvalidData,
+ "expected JSON string",
+ )),
+ Some(hex_data) => match Vec::<u8>::from_hex(hex_data) {
+ Err(_) => Err(std::io::Error::new(
+ std::io::ErrorKind::InvalidData,
+ "invalid hex data",
+ )),
+ Ok(txid_data) => match encode::deserialize(&txid_data) {
+ Err(_) => Err(std::io::Error::new(
+ std::io::ErrorKind::InvalidData,
+ "invalid txid",
+ )),
+ Ok(txid) => Ok(txid),
+ },
+ },
+ }
+ }
+}
+
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use bitcoin::blockdata::constants::genesis_block;
use bitcoin::consensus::encode;
+ use bitcoin::hashes::Hash;
use bitcoin::network::constants::Network;
/// Converts from `BlockHeaderData` into a `GetHeaderResponse` JSON value.
},
}
}
+
+ #[test]
+ fn into_txid_from_json_response_with_unexpected_type() {
+ let response = JsonResponse(serde_json::json!({ "result": "foo" }));
+ match TryInto::<Txid>::try_into(response) {
+ Err(e) => {
+ assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
+ assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON string");
+ }
+ Ok(_) => panic!("Expected error"),
+ }
+ }
+
+ #[test]
+ fn into_txid_from_json_response_with_invalid_hex_data() {
+ let response = JsonResponse(serde_json::json!("foobar"));
+ match TryInto::<Txid>::try_into(response) {
+ Err(e) => {
+ assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
+ assert_eq!(e.get_ref().unwrap().to_string(), "invalid hex data");
+ }
+ Ok(_) => panic!("Expected error"),
+ }
+ }
+
+ #[test]
+ fn into_txid_from_json_response_with_invalid_txid_data() {
+ let response = JsonResponse(serde_json::json!("abcd"));
+ match TryInto::<Txid>::try_into(response) {
+ Err(e) => {
+ assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
+ assert_eq!(e.get_ref().unwrap().to_string(), "invalid txid");
+ }
+ Ok(_) => panic!("Expected error"),
+ }
+ }
+
+ #[test]
+ fn into_txid_from_json_response_with_valid_txid_data() {
+ let target_txid = Txid::from_slice(&[1; 32]).unwrap();
+ let response = JsonResponse(serde_json::json!(encode::serialize_hex(&target_txid)));
+ match TryInto::<Txid>::try_into(response) {
+ Err(e) => panic!("Unexpected error: {:?}", e),
+ Ok(txid) => assert_eq!(txid, target_txid),
+ }
+ }
}
use lightning::ln::peer_handler;
use lightning::ln::peer_handler::SocketDescriptor as LnSocketTrait;
+use lightning::ln::peer_handler::CustomMessageHandler;
use lightning::ln::msgs::{ChannelMessageHandler, RoutingMessageHandler};
use lightning::util::logger::Logger;
id: u64,
}
impl Connection {
- async fn schedule_read<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, us: Arc<Mutex<Self>>, mut reader: io::ReadHalf<TcpStream>, mut read_wake_receiver: mpsc::Receiver<()>, mut write_avail_receiver: mpsc::Receiver<()>) where
+ async fn schedule_read<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>, Arc<UMH>>>, us: Arc<Mutex<Self>>, mut reader: io::ReadHalf<TcpStream>, mut read_wake_receiver: mpsc::Receiver<()>, mut write_avail_receiver: mpsc::Receiver<()>) where
CMH: ChannelMessageHandler + 'static,
RMH: RoutingMessageHandler + 'static,
- L: Logger + 'static + ?Sized {
+ L: Logger + 'static + ?Sized,
+ UMH: CustomMessageHandler + 'static {
// 8KB is nice and big but also should never cause any issues with stack overflowing.
let mut buf = [0; 8192];
/// The returned future will complete when the peer is disconnected and associated handling
/// futures are freed, though, because all processing futures are spawned with tokio::spawn, you do
/// not need to poll the provided future in order to make progress.
-pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
+pub fn setup_inbound<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>, Arc<UMH>>>, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
CMH: ChannelMessageHandler + 'static + Send + Sync,
RMH: RoutingMessageHandler + 'static + Send + Sync,
- L: Logger + 'static + ?Sized + Send + Sync {
+ L: Logger + 'static + ?Sized + Send + Sync,
+ UMH: CustomMessageHandler + 'static + Send + Sync {
let (reader, write_receiver, read_receiver, us) = Connection::new(stream);
#[cfg(debug_assertions)]
let last_us = Arc::clone(&us);
/// The returned future will complete when the peer is disconnected and associated handling
/// futures are freed, though, because all processing futures are spawned with tokio::spawn, you do
/// not need to poll the provided future in order to make progress.
-pub fn setup_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, their_node_id: PublicKey, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
+pub fn setup_outbound<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>, Arc<UMH>>>, their_node_id: PublicKey, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
CMH: ChannelMessageHandler + 'static + Send + Sync,
RMH: RoutingMessageHandler + 'static + Send + Sync,
- L: Logger + 'static + ?Sized + Send + Sync {
+ L: Logger + 'static + ?Sized + Send + Sync,
+ UMH: CustomMessageHandler + 'static + Send + Sync {
let (reader, mut write_receiver, read_receiver, us) = Connection::new(stream);
#[cfg(debug_assertions)]
let last_us = Arc::clone(&us);
/// disconnected and associated handling futures are freed, though, because all processing in said
/// futures are spawned with tokio::spawn, you do not need to poll the second future in order to
/// make progress.
-pub async fn connect_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, their_node_id: PublicKey, addr: SocketAddr) -> Option<impl std::future::Future<Output=()>> where
+pub async fn connect_outbound<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>, Arc<UMH>>>, their_node_id: PublicKey, addr: SocketAddr) -> Option<impl std::future::Future<Output=()>> where
CMH: ChannelMessageHandler + 'static + Send + Sync,
RMH: RoutingMessageHandler + 'static + Send + Sync,
- L: Logger + 'static + ?Sized + Send + Sync {
+ L: Logger + 'static + ?Sized + Send + Sync,
+ UMH: CustomMessageHandler + 'static + Send + Sync {
if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), async { TcpStream::connect(&addr).await.map(|s| s.into_std().unwrap()) }).await {
Some(setup_outbound(peer_manager, their_node_id, stream))
} else { None }
let a_manager = Arc::new(PeerManager::new(MessageHandler {
chan_handler: Arc::clone(&a_handler),
route_handler: Arc::clone(&a_handler),
- }, a_key.clone(), &[1; 32], Arc::new(TestLogger())));
+ }, a_key.clone(), &[1; 32], Arc::new(TestLogger()), Arc::new(lightning::ln::peer_handler::IgnoringMessageHandler{})));
let (b_connected_sender, mut b_connected) = mpsc::channel(1);
let (b_disconnected_sender, mut b_disconnected) = mpsc::channel(1);
let b_manager = Arc::new(PeerManager::new(MessageHandler {
chan_handler: Arc::clone(&b_handler),
route_handler: Arc::clone(&b_handler),
- }, b_key.clone(), &[2; 32], Arc::new(TestLogger())));
+ }, b_key.clone(), &[2; 32], Arc::new(TestLogger()), Arc::new(lightning::ln::peer_handler::IgnoringMessageHandler{})));
// We bind on localhost, hoping the environment is properly configured with a local
// address. This may not always be the case in containers and the like, so if this test is
/// Note that the commitment number starts at (1 << 48) - 1 and counts backwards.
// TODO: return a Result so we can signal a validation error
fn release_commitment_secret(&self, idx: u64) -> [u8; 32];
+ /// Validate the counterparty's signatures on the holder commitment transaction and HTLCs.
+ ///
+ /// This is required in order for the signer to make sure that releasing a commitment
+ /// secret won't leave us without a broadcastable holder transaction.
+ /// Policy checks should be implemented in this function, including checking the amount
+ /// sent to us and checking the HTLCs.
+ fn validate_holder_commitment(&self, holder_tx: &HolderCommitmentTransaction) -> Result<(), ()>;
/// Gets the holder's channel public keys and basepoints
fn pubkeys(&self) -> &ChannelPublicKeys;
/// Gets an arbitrary identifier describing the set of keys which are provided back to you in
/// Create a signature for a counterparty's commitment transaction and associated HTLC transactions.
///
/// Note that if signing fails or is rejected, the channel will be force-closed.
+ ///
+ /// Policy checks should be implemented in this function, including checking the amount
+ /// sent to us and checking the HTLCs.
//
// TODO: Document the things someone using this interface should enforce before signing.
fn sign_counterparty_commitment(&self, commitment_tx: &CommitmentTransaction, secp_ctx: &Secp256k1<secp256k1::All>) -> Result<(Signature, Vec<Signature>), ()>;
+ /// Validate the counterparty's revocation.
+ ///
+ /// This is required in order for the signer to make sure that the state has moved
+ /// forward and it is safe to sign the next counterparty commitment.
+ fn validate_counterparty_revocation(&self, idx: u64, secret: &SecretKey) -> Result<(), ()>;
/// Create a signatures for a holder's commitment transaction and its claiming HTLC transactions.
/// This will only ever be called with a non-revoked commitment_tx. This will be called with the
chan_utils::build_commitment_secret(&self.commitment_seed, idx)
}
+ fn validate_holder_commitment(&self, _holder_tx: &HolderCommitmentTransaction) -> Result<(), ()> {
+ Ok(())
+ }
+
fn pubkeys(&self) -> &ChannelPublicKeys { &self.holder_channel_pubkeys }
fn channel_keys_id(&self) -> [u8; 32] { self.channel_keys_id }
Ok((commitment_sig, htlc_sigs))
}
+ fn validate_counterparty_revocation(&self, _idx: u64, _secret: &SecretKey) -> Result<(), ()> {
+ Ok(())
+ }
+
fn sign_holder_commitment_and_htlcs(&self, commitment_tx: &HolderCommitmentTransaction, secp_ctx: &Secp256k1<secp256k1::All>) -> Result<(Signature, Vec<Signature>), ()> {
let funding_pubkey = PublicKey::from_secret_key(secp_ctx, &self.funding_key);
let funding_redeemscript = make_funding_redeemscript(&funding_pubkey, &self.counterparty_pubkeys().funding_pubkey);
self.counterparty_funding_pubkey()
);
+ self.holder_signer.validate_holder_commitment(&holder_commitment_tx)
+ .map_err(|_| ChannelError::Close("Failed to validate our commitment".to_owned()))?;
+
// Now that we're past error-generating stuff, update our local state:
let funding_redeemscript = self.get_funding_redeemscript();
self.counterparty_funding_pubkey()
);
+ self.holder_signer.validate_holder_commitment(&holder_commitment_tx)
+ .map_err(|_| ChannelError::Close("Failed to validate our commitment".to_owned()))?;
+
let funding_redeemscript = self.get_funding_redeemscript();
let funding_txo = self.get_funding_txo().unwrap();
);
let next_per_commitment_point = self.holder_signer.get_per_commitment_point(self.cur_holder_commitment_transaction_number - 1, &self.secp_ctx);
+ self.holder_signer.validate_holder_commitment(&holder_commitment_tx)
+ .map_err(|_| (None, ChannelError::Close("Failed to validate our commitment".to_owned())))?;
let per_commitment_secret = self.holder_signer.release_commitment_secret(self.cur_holder_commitment_transaction_number + 1);
// Update state now that we've passed all the can-fail calls...
return Err(ChannelError::Close("Peer sent revoke_and_ack after we'd started exchanging closing_signeds".to_owned()));
}
+ let secret = secp_check!(SecretKey::from_slice(&msg.per_commitment_secret), "Peer provided an invalid per_commitment_secret".to_owned());
+
if let Some(counterparty_prev_commitment_point) = self.counterparty_prev_commitment_point {
- if PublicKey::from_secret_key(&self.secp_ctx, &secp_check!(SecretKey::from_slice(&msg.per_commitment_secret), "Peer provided an invalid per_commitment_secret".to_owned())) != counterparty_prev_commitment_point {
+ if PublicKey::from_secret_key(&self.secp_ctx, &secret) != counterparty_prev_commitment_point {
return Err(ChannelError::Close("Got a revoke commitment secret which didn't correspond to their current pubkey".to_owned()));
}
}
*self.next_remote_commitment_tx_fee_info_cached.lock().unwrap() = None;
}
+ self.holder_signer.validate_counterparty_revocation(
+ self.cur_counterparty_commitment_transaction_number + 1,
+ &secret
+ ).map_err(|_| ChannelError::Close("Failed to validate revocation from peer".to_owned()))?;
+
self.commitment_secrets.provide_secret(self.cur_counterparty_commitment_transaction_number + 1, msg.per_commitment_secret)
.map_err(|_| ChannelError::Close("Previous secrets did not match new one".to_owned()))?;
self.latest_monitor_update_id += 1;
txn_broadcasted: Mutex::new(self.tx_broadcaster.txn_broadcasted.lock().unwrap().clone()),
blocks: Arc::new(Mutex::new(self.tx_broadcaster.blocks.lock().unwrap().clone())),
},
- logger: &test_utils::TestLogger::new(),
+ logger: &self.logger,
channel_monitors,
}).unwrap();
}
pub fn route_payment<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], recv_value: u64) -> (PaymentPreimage, PaymentHash, PaymentSecret) {
let net_graph_msg_handler = &origin_node.net_graph_msg_handler;
- let logger = test_utils::TestLogger::new();
let route = get_route(&origin_node.node.get_our_node_id(), &net_graph_msg_handler.network_graph.read().unwrap(),
&expected_route.last().unwrap().node.get_our_node_id(), Some(InvoiceFeatures::known()),
Some(&origin_node.node.list_usable_channels().iter().collect::<Vec<_>>()), &[],
- recv_value, TEST_FINAL_CLTV, &logger).unwrap();
+ recv_value, TEST_FINAL_CLTV, origin_node.logger).unwrap();
assert_eq!(route.paths.len(), 1);
assert_eq!(route.paths[0].len(), expected_route.len());
for (node, hop) in expected_route.iter().zip(route.paths[0].iter()) {
}
pub fn route_over_limit<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], recv_value: u64) {
- let logger = test_utils::TestLogger::new();
let net_graph_msg_handler = &origin_node.net_graph_msg_handler;
- let route = get_route(&origin_node.node.get_our_node_id(), &net_graph_msg_handler.network_graph.read().unwrap(), &expected_route.last().unwrap().node.get_our_node_id(), Some(InvoiceFeatures::known()), None, &Vec::new(), recv_value, TEST_FINAL_CLTV, &logger).unwrap();
+ let route = get_route(&origin_node.node.get_our_node_id(), &net_graph_msg_handler.network_graph.read().unwrap(), &expected_route.last().unwrap().node.get_our_node_id(), Some(InvoiceFeatures::known()), None, &Vec::new(), recv_value, TEST_FINAL_CLTV, origin_node.logger).unwrap();
assert_eq!(route.paths.len(), 1);
assert_eq!(route.paths[0].len(), expected_route.len());
for (node, hop) in expected_route.iter().zip(route.paths[0].iter()) {
let chan_lock = nodes[0].node.channel_state.lock().unwrap();
let local_chan = chan_lock.by_id.get(&chan.2).unwrap();
let chan_signer = local_chan.get_signer();
+ // Make the signer believe we validated another commitment, so we can release the secret
+ chan_signer.get_enforcement_state().last_holder_commitment -= 1;
+
let pubkeys = chan_signer.pubkeys();
(pubkeys.revocation_basepoint, pubkeys.htlc_basepoint,
chan_signer.release_commitment_secret(INITIAL_COMMITMENT_NUMBER),
// commitment transaction, we would have happily carried on and provided them the next
// commitment transaction based on one RAA forward. This would probably eventually have led to
// channel closure, but it would not have resulted in funds loss. Still, our
- // EnforcingSigner would have paniced as it doesn't like jumps into the future. Here, we
+ // EnforcingSigner would have panicked as it doesn't like jumps into the future. Here, we
// check simply that the channel is closed in response to such an RAA, but don't check whether
// we decide to punish our counterparty for revoking their funds (as we don't currently
// implement that).
let channel_id = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known()).2;
let mut guard = nodes[0].node.channel_state.lock().unwrap();
- let keys = &guard.by_id.get_mut(&channel_id).unwrap().get_signer();
+ let keys = guard.by_id.get_mut(&channel_id).unwrap().get_signer();
+
const INITIAL_COMMITMENT_NUMBER: u64 = (1 << 48) - 1;
+
+ // Make signer believe we got a counterparty signature, so that it allows the revocation
+ keys.get_enforcement_state().last_holder_commitment -= 1;
let per_commitment_secret = keys.release_commitment_secret(INITIAL_COMMITMENT_NUMBER);
+
// Must revoke without gaps
+ keys.get_enforcement_state().last_holder_commitment -= 1;
keys.release_commitment_secret(INITIAL_COMMITMENT_NUMBER - 1);
+
+ keys.get_enforcement_state().last_holder_commitment -= 1;
let next_per_commitment_point = PublicKey::from_secret_key(&Secp256k1::new(),
&SecretKey::from_slice(&keys.release_commitment_secret(INITIAL_COMMITMENT_NUMBER - 2)).unwrap());
mod channel;
mod onion_utils;
-mod wire;
+pub mod wire;
// Older rustc (which we support) refuses to let us call the get_payment_preimage_hash!() macro
// without the node parameter being mut. This is incorrect, and thus newer rustcs will complain
pub funding_txid: Txid,
/// The specific output index funding this channel
pub funding_output_index: u16,
- /// The signature of the channel initiator (funder) on the funding transaction
+ /// The signature of the channel initiator (funder) on the initial commitment transaction
pub signature: Signature,
}
pub struct FundingSigned {
/// The channel ID
pub channel_id: [u8; 32],
- /// The signature of the channel acceptor (fundee) on the funding transaction
+ /// The signature of the channel acceptor (fundee) on the initial commitment transaction
pub signature: Signature,
}
use ln::msgs;
use ln::msgs::{ChannelMessageHandler, LightningError, RoutingMessageHandler};
use ln::channelmanager::{SimpleArcChannelManager, SimpleRefChannelManager};
-use util::ser::{VecWriter, Writeable};
+use util::ser::{VecWriter, Writeable, Writer};
use ln::peer_channel_encryptor::{PeerChannelEncryptor,NextNoiseStep};
use ln::wire;
-use ln::wire::Encode;
use util::byte_utils;
use util::events::{MessageSendEvent, MessageSendEventsProvider};
use util::logger::Logger;
use bitcoin::hashes::sha256::HashEngine as Sha256Engine;
use bitcoin::hashes::{HashEngine, Hash};
+/// Handler for BOLT1-compliant messages.
+pub trait CustomMessageHandler: wire::CustomMessageReader {
+ /// Called with the message type that was received and the buffer to be read.
+ /// Can return a `MessageHandlingError` if the message could not be handled.
+ fn handle_custom_message(&self, msg: Self::CustomMessage) -> Result<(), LightningError>;
+
+ /// Gets the list of pending messages which were generated by the custom message
+ /// handler, clearing the list in the process. The first tuple element must
+ /// correspond to the intended recipients node ids. If no connection to one of the
+ /// specified node does not exist, the message is simply not sent to it.
+ fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)>;
+}
+
/// A dummy struct which implements `RoutingMessageHandler` without storing any routing information
/// or doing any processing. You can provide one of these as the route_handler in a MessageHandler.
pub struct IgnoringMessageHandler{}
fn deref(&self) -> &Self { self }
}
+impl wire::Type for () {
+ fn type_id(&self) -> u16 {
+ // We should never call this for `DummyCustomType`
+ unreachable!();
+ }
+}
+
+impl Writeable for () {
+ fn write<W: Writer>(&self, _: &mut W) -> Result<(), io::Error> {
+ unreachable!();
+ }
+}
+
+impl wire::CustomMessageReader for IgnoringMessageHandler {
+ type CustomMessage = ();
+ fn read<R: io::Read>(&self, _message_type: u16, _buffer: &mut R) -> Result<Option<Self::CustomMessage>, msgs::DecodeError> {
+ Ok(None)
+ }
+}
+
+impl CustomMessageHandler for IgnoringMessageHandler {
+ fn handle_custom_message(&self, _msg: Self::CustomMessage) -> Result<(), LightningError> {
+ // Since we always return `None` in the read the handle method should never be called.
+ unreachable!();
+ }
+
+ fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { Vec::new() }
+}
+
/// A dummy struct which implements `ChannelMessageHandler` without having any channels.
/// You can provide one of these as the route_handler in a MessageHandler.
pub struct ErroringMessageHandler {
/// lifetimes). Other times you can afford a reference, which is more efficient, in which case
/// SimpleRefPeerManager is the more appropriate type. Defining these type aliases prevents
/// issues such as overly long function definitions.
-pub type SimpleArcPeerManager<SD, M, T, F, C, L> = PeerManager<SD, Arc<SimpleArcChannelManager<M, T, F, L>>, Arc<NetGraphMsgHandler<Arc<C>, Arc<L>>>, Arc<L>>;
+pub type SimpleArcPeerManager<SD, M, T, F, C, L> = PeerManager<SD, Arc<SimpleArcChannelManager<M, T, F, L>>, Arc<NetGraphMsgHandler<Arc<C>, Arc<L>>>, Arc<L>, Arc<IgnoringMessageHandler>>;
/// SimpleRefPeerManager is a type alias for a PeerManager reference, and is the reference
/// counterpart to the SimpleArcPeerManager type alias. Use this type by default when you don't
/// usage of lightning-net-tokio (since tokio::spawn requires parameters with static lifetimes).
/// But if this is not necessary, using a reference is more efficient. Defining these type aliases
/// helps with issues such as long function definitions.
-pub type SimpleRefPeerManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, SD, M, T, F, C, L> = PeerManager<SD, SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, M, T, F, L>, &'e NetGraphMsgHandler<&'g C, &'f L>, &'f L>;
+pub type SimpleRefPeerManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, SD, M, T, F, C, L> = PeerManager<SD, SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, M, T, F, L>, &'e NetGraphMsgHandler<&'g C, &'f L>, &'f L, IgnoringMessageHandler>;
/// A PeerManager manages a set of peers, described by their [`SocketDescriptor`] and marshalls
/// socket events into messages which it passes on to its [`MessageHandler`].
/// you're using lightning-net-tokio.
///
/// [`read_event`]: PeerManager::read_event
-pub struct PeerManager<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> where
+pub struct PeerManager<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> where
CM::Target: ChannelMessageHandler,
RM::Target: RoutingMessageHandler,
- L::Target: Logger {
+ L::Target: Logger,
+ CMH::Target: CustomMessageHandler {
message_handler: MessageHandler<CM, RM>,
peers: Mutex<PeerHolder<Descriptor>>,
our_node_secret: SecretKey,
ephemeral_key_midstate: Sha256Engine,
+ custom_message_handler: CMH,
// Usize needs to be at least 32 bits to avoid overflowing both low and high. If usize is 64
// bits we will never realistically count into high:
}}
}
-impl<Descriptor: SocketDescriptor, CM: Deref, L: Deref> PeerManager<Descriptor, CM, IgnoringMessageHandler, L> where
+impl<Descriptor: SocketDescriptor, CM: Deref, L: Deref> PeerManager<Descriptor, CM, IgnoringMessageHandler, L, IgnoringMessageHandler> where
CM::Target: ChannelMessageHandler,
L::Target: Logger {
/// Constructs a new PeerManager with the given ChannelMessageHandler. No routing message
Self::new(MessageHandler {
chan_handler: channel_message_handler,
route_handler: IgnoringMessageHandler{},
- }, our_node_secret, ephemeral_random_data, logger)
+ }, our_node_secret, ephemeral_random_data, logger, IgnoringMessageHandler{})
}
}
-impl<Descriptor: SocketDescriptor, RM: Deref, L: Deref> PeerManager<Descriptor, ErroringMessageHandler, RM, L> where
+impl<Descriptor: SocketDescriptor, RM: Deref, L: Deref> PeerManager<Descriptor, ErroringMessageHandler, RM, L, IgnoringMessageHandler> where
RM::Target: RoutingMessageHandler,
L::Target: Logger {
/// Constructs a new PeerManager with the given RoutingMessageHandler. No channel message
Self::new(MessageHandler {
chan_handler: ErroringMessageHandler::new(),
route_handler: routing_message_handler,
- }, our_node_secret, ephemeral_random_data, logger)
+ }, our_node_secret, ephemeral_random_data, logger, IgnoringMessageHandler{})
}
}
-impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<Descriptor, CM, RM, L> where
+impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> PeerManager<Descriptor, CM, RM, L, CMH> where
CM::Target: ChannelMessageHandler,
RM::Target: RoutingMessageHandler,
- L::Target: Logger {
+ L::Target: Logger,
+ CMH::Target: CustomMessageHandler + wire::CustomMessageReader {
/// Constructs a new PeerManager with the given message handlers and node_id secret key
/// ephemeral_random_data is used to derive per-connection ephemeral keys and must be
/// cryptographically secure random bytes.
- pub fn new(message_handler: MessageHandler<CM, RM>, our_node_secret: SecretKey, ephemeral_random_data: &[u8; 32], logger: L) -> Self {
+ pub fn new(message_handler: MessageHandler<CM, RM>, our_node_secret: SecretKey, ephemeral_random_data: &[u8; 32], logger: L, custom_message_handler: CMH) -> Self {
let mut ephemeral_key_midstate = Sha256::engine();
ephemeral_key_midstate.input(ephemeral_random_data);
peer_counter_low: AtomicUsize::new(0),
peer_counter_high: AtomicUsize::new(0),
logger,
+ custom_message_handler,
}
}
}
/// Append a message to a peer's pending outbound/write buffer, and update the map of peers needing sends accordingly.
- fn enqueue_message<M: Encode + Writeable + Debug>(&self, peer: &mut Peer, message: &M) {
+ fn enqueue_message<M: wire::Type + Writeable + Debug>(&self, peer: &mut Peer, message: &M) {
let mut buffer = VecWriter(Vec::new());
wire::write(message, &mut buffer).unwrap(); // crash if the write failed
let encoded_message = buffer.0;
peer.pending_read_is_header = true;
let mut reader = io::Cursor::new(&msg_data[..]);
- let message_result = wire::read(&mut reader);
+ let message_result = wire::read(&mut reader, &*self.custom_message_handler);
let message = match message_result {
Ok(x) => x,
Err(e) => {
/// Process an incoming message and return a decision (ok, lightning error, peer handling error) regarding the next action with the peer
/// Returns the message back if it needs to be broadcasted to all other peers.
- fn handle_message(&self, peer: &mut Peer, message: wire::Message) -> Result<Option<wire::Message>, MessageHandlingError> {
+ fn handle_message(
+ &self,
+ peer: &mut Peer,
+ message: wire::Message<<<CMH as core::ops::Deref>::Target as wire::CustomMessageReader>::CustomMessage>
+ ) -> Result<Option<wire::Message<<<CMH as core::ops::Deref>::Target as wire::CustomMessageReader>::CustomMessage>>, MessageHandlingError> {
log_trace!(self.logger, "Received message {:?} from {}", message, log_pubkey!(peer.their_node_id.unwrap()));
// Need an Init as first message
},
// Unknown messages:
- wire::Message::Unknown(msg_type) if msg_type.is_even() => {
- log_debug!(self.logger, "Received unknown even message of type {}, disconnecting peer!", msg_type);
+ wire::Message::Unknown(type_id) if message.is_even() => {
+ log_debug!(self.logger, "Received unknown even message of type {}, disconnecting peer!", type_id);
// Fail the channel if message is an even, unknown type as per BOLT #1.
return Err(PeerHandleError{ no_connection_possible: true }.into());
},
- wire::Message::Unknown(msg_type) => {
- log_trace!(self.logger, "Received unknown odd message of type {}, ignoring", msg_type);
- }
+ wire::Message::Unknown(type_id) => {
+ log_trace!(self.logger, "Received unknown odd message of type {}, ignoring", type_id);
+ },
+ wire::Message::Custom(custom) => {
+ self.custom_message_handler.handle_custom_message(custom)?;
+ },
};
Ok(should_forward)
}
- fn forward_broadcast_msg(&self, peers: &mut PeerHolder<Descriptor>, msg: &wire::Message, except_node: Option<&PublicKey>) {
+ fn forward_broadcast_msg(&self, peers: &mut PeerHolder<Descriptor>, msg: &wire::Message<<<CMH as core::ops::Deref>::Target as wire::CustomMessageReader>::CustomMessage>, except_node: Option<&PublicKey>) {
match msg {
wire::Message::ChannelAnnouncement(ref msg) => {
log_trace!(self.logger, "Sending message to all peers except {:?} or the announced channel's counterparties: {:?}", except_node, msg);
let mut events_generated = self.message_handler.chan_handler.get_and_clear_pending_msg_events();
events_generated.append(&mut self.message_handler.route_handler.get_and_clear_pending_msg_events());
let peers = &mut *peers_lock;
- for event in events_generated.drain(..) {
- macro_rules! get_peer_for_forwarding {
- ($node_id: expr) => {
- {
- match peers.node_id_to_descriptor.get($node_id) {
- Some(descriptor) => match peers.peers.get_mut(&descriptor) {
- Some(peer) => {
- if peer.their_features.is_none() {
- continue;
- }
- peer
- },
- None => panic!("Inconsistent peers set state!"),
- },
- None => {
- continue;
+ macro_rules! get_peer_for_forwarding {
+ ($node_id: expr) => {
+ {
+ match peers.node_id_to_descriptor.get($node_id) {
+ Some(descriptor) => match peers.peers.get_mut(&descriptor) {
+ Some(peer) => {
+ if peer.their_features.is_none() {
+ continue;
+ }
+ peer
},
- }
+ None => panic!("Inconsistent peers set state!"),
+ },
+ None => {
+ continue;
+ },
}
}
}
+ }
+ for event in events_generated.drain(..) {
match event {
MessageSendEvent::SendAcceptChannel { ref node_id, ref msg } => {
log_debug!(self.logger, "Handling SendAcceptChannel event in peer_handler for node {} for channel {}",
}
}
+ for (node_id, msg) in self.custom_message_handler.get_and_clear_pending_msg() {
+ self.enqueue_message(get_peer_for_forwarding!(&node_id), &msg);
+ }
+
for (descriptor, peer) in peers.peers.iter_mut() {
self.do_attempt_write_data(&mut (*descriptor).clone(), peer);
}
#[cfg(test)]
mod tests {
- use ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor};
+ use ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler};
use ln::msgs;
use util::events;
use util::test_utils;
cfgs
}
- fn create_network<'a>(peer_count: usize, cfgs: &'a Vec<PeerManagerCfg>) -> Vec<PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, &'a test_utils::TestLogger>> {
+ fn create_network<'a>(peer_count: usize, cfgs: &'a Vec<PeerManagerCfg>) -> Vec<PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, &'a test_utils::TestLogger, IgnoringMessageHandler>> {
let mut peers = Vec::new();
for i in 0..peer_count {
let node_secret = SecretKey::from_slice(&[42 + i as u8; 32]).unwrap();
let ephemeral_bytes = [i as u8; 32];
let msg_handler = MessageHandler { chan_handler: &cfgs[i].chan_handler, route_handler: &cfgs[i].routing_handler };
- let peer = PeerManager::new(msg_handler, node_secret, &ephemeral_bytes, &cfgs[i].logger);
+ let peer = PeerManager::new(msg_handler, node_secret, &ephemeral_bytes, &cfgs[i].logger, IgnoringMessageHandler {});
peers.push(peer);
}
peers
}
- fn establish_connection<'a>(peer_a: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, &'a test_utils::TestLogger>, peer_b: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, &'a test_utils::TestLogger>) -> (FileDescriptor, FileDescriptor) {
+ fn establish_connection<'a>(peer_a: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, &'a test_utils::TestLogger, IgnoringMessageHandler>, peer_b: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler, &'a test_utils::TestRoutingMessageHandler, &'a test_utils::TestLogger, IgnoringMessageHandler>) -> (FileDescriptor, FileDescriptor) {
let secp_ctx = Secp256k1::new();
let a_id = PublicKey::from_secret_key(&secp_ctx, &peer_a.our_node_secret);
let mut fd_a = FileDescriptor { fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())) };
// You may not use this file except in accordance with one or both of these
// licenses.
-//! Wire encoding/decoding for Lightning messages according to [BOLT #1].
-//!
-//! Messages known by this module can be read from the wire using [`read()`].
-//! The [`Message`] enum returned by [`read()`] wraps the decoded message or the message type (if
-//! unknown) to use with pattern matching.
-//!
-//! Messages implementing the [`Encode`] trait define a message type and can be sent over the wire
-//! using [`write()`].
-//!
+//! Wire encoding/decoding for Lightning messages according to [BOLT #1], and for
+//! custom message through the [`CustomMessageReader`] trait.
+//!
//! [BOLT #1]: https://github.com/lightningnetwork/lightning-rfc/blob/master/01-messaging.md
use io;
use ln::msgs;
use util::ser::{Readable, Writeable, Writer};
+/// Trait to be implemented by custom message (unrelated to the channel/gossip LN layers)
+/// decoders.
+pub trait CustomMessageReader {
+ /// The type of the message decoded by the implementation.
+ type CustomMessage: core::fmt::Debug + Type + Writeable;
+ /// Decodes a custom message to `CustomMessageType`. If the given message type is known to the
+ /// implementation and the message could be decoded, must return `Ok(Some(message))`. If the
+ /// message type is unknown to the implementation, must return `Ok(None)`. If a decoding error
+ /// occur, must return `Err(DecodeError::X)` where `X` details the encountered error.
+ fn read<R: io::Read>(&self, message_type: u16, buffer: &mut R) -> Result<Option<Self::CustomMessage>, msgs::DecodeError>;
+}
+
/// A Lightning message returned by [`read()`] when decoding bytes received over the wire. Each
/// variant contains a message from [`msgs`] or otherwise the message type if unknown.
#[allow(missing_docs)]
#[derive(Debug)]
-pub enum Message {
+pub(crate) enum Message<T> where T: core::fmt::Debug + Type {
Init(msgs::Init),
Error(msgs::ErrorMessage),
Ping(msgs::Ping),
ReplyChannelRange(msgs::ReplyChannelRange),
GossipTimestampFilter(msgs::GossipTimestampFilter),
/// A message that could not be decoded because its type is unknown.
- Unknown(MessageType),
+ Unknown(u16),
+ /// A message that was produced by a [`CustomMessageReader`] and is to be handled by a
+ /// [`::ln::peer_handler::CustomMessageHandler`].
+ Custom(T),
}
-/// A number identifying a message to determine how it is encoded on the wire.
-#[derive(Clone, Copy, Debug)]
-pub struct MessageType(u16);
-
-impl Message {
- #[allow(dead_code)] // This method is only used in tests
+impl<T> Message<T> where T: core::fmt::Debug + Type {
/// Returns the type that was used to decode the message payload.
- pub fn type_id(&self) -> MessageType {
+ pub fn type_id(&self) -> u16 {
match self {
&Message::Init(ref msg) => msg.type_id(),
&Message::Error(ref msg) => msg.type_id(),
&Message::ReplyChannelRange(ref msg) => msg.type_id(),
&Message::GossipTimestampFilter(ref msg) => msg.type_id(),
&Message::Unknown(type_id) => type_id,
+ &Message::Custom(ref msg) => msg.type_id(),
}
}
-}
-impl MessageType {
- /// Returns whether the message type is even, indicating both endpoints must support it.
+ /// Returns whether the message's type is even, indicating both endpoints must support it.
pub fn is_even(&self) -> bool {
- (self.0 & 1) == 0
- }
-}
-
-impl ::core::fmt::Display for MessageType {
- fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
- write!(f, "{}", self.0)
+ (self.type_id() & 1) == 0
}
}
/// # Errors
///
/// Returns an error if the message payload code not be decoded as the specified type.
-pub fn read<R: io::Read>(buffer: &mut R) -> Result<Message, msgs::DecodeError> {
+pub(crate) fn read<R: io::Read, T, H: core::ops::Deref>(
+ buffer: &mut R,
+ custom_reader: H,
+) -> Result<Message<T>, msgs::DecodeError>
+where
+ T: core::fmt::Debug + Type + Writeable,
+ H::Target: CustomMessageReader<CustomMessage = T>,
+{
let message_type = <u16 as Readable>::read(buffer)?;
match message_type {
msgs::Init::TYPE => {
Ok(Message::GossipTimestampFilter(Readable::read(buffer)?))
},
_ => {
- Ok(Message::Unknown(MessageType(message_type)))
+ if let Some(custom) = custom_reader.read(message_type, buffer)? {
+ Ok(Message::Custom(custom))
+ } else {
+ Ok(Message::Unknown(message_type))
+ }
},
}
}
/// # Errors
///
/// Returns an I/O error if the write could not be completed.
-pub fn write<M: Encode + Writeable, W: Writer>(message: &M, buffer: &mut W) -> Result<(), io::Error> {
- M::TYPE.write(buffer)?;
+pub(crate) fn write<M: Type + Writeable, W: Writer>(message: &M, buffer: &mut W) -> Result<(), io::Error> {
+ message.type_id().write(buffer)?;
message.write(buffer)
}
-/// Defines a type-identified encoding for sending messages over the wire.
+mod encode {
+ /// Defines a constant type identifier for reading messages from the wire.
+ pub trait Encode {
+ /// The type identifying the message payload.
+ const TYPE: u16;
+ }
+}
+
+pub(crate) use self::encode::Encode;
+
+/// Defines a type identifier for sending messages over the wire.
///
-/// Messages implementing this trait specify a type and must be [`Writeable`] to use with [`write()`].
-pub trait Encode {
- /// The type identifying the message payload.
- const TYPE: u16;
-
- /// Returns the type identifying the message payload. Convenience method for accessing
- /// [`Self::TYPE`].
- fn type_id(&self) -> MessageType {
- MessageType(Self::TYPE)
+/// Messages implementing this trait specify a type and must be [`Writeable`].
+pub trait Type {
+ /// Returns the type identifying the message payload.
+ fn type_id(&self) -> u16;
+}
+
+impl<T> Type for T where T: Encode {
+ fn type_id(&self) -> u16 {
+ T::TYPE
}
}
use super::*;
use prelude::*;
use core::convert::TryInto;
+ use ::ln::peer_handler::IgnoringMessageHandler;
// Big-endian wire encoding of Pong message (type = 19, byteslen = 2).
const ENCODED_PONG: [u8; 6] = [0u8, 19u8, 0u8, 2u8, 0u8, 0u8];
fn read_empty_buffer() {
let buffer = [];
let mut reader = io::Cursor::new(buffer);
- assert!(read(&mut reader).is_err());
+ assert!(read(&mut reader, &IgnoringMessageHandler{}).is_err());
}
#[test]
fn read_incomplete_type() {
let buffer = &ENCODED_PONG[..1];
let mut reader = io::Cursor::new(buffer);
- assert!(read(&mut reader).is_err());
+ assert!(read(&mut reader, &IgnoringMessageHandler{}).is_err());
}
#[test]
fn read_empty_payload() {
let buffer = &ENCODED_PONG[..2];
let mut reader = io::Cursor::new(buffer);
- assert!(read(&mut reader).is_err());
+ assert!(read(&mut reader, &IgnoringMessageHandler{}).is_err());
}
#[test]
fn read_invalid_message() {
let buffer = &ENCODED_PONG[..4];
let mut reader = io::Cursor::new(buffer);
- assert!(read(&mut reader).is_err());
+ assert!(read(&mut reader, &IgnoringMessageHandler{}).is_err());
}
#[test]
fn read_known_message() {
let buffer = &ENCODED_PONG[..];
let mut reader = io::Cursor::new(buffer);
- let message = read(&mut reader).unwrap();
+ let message = read(&mut reader, &IgnoringMessageHandler{}).unwrap();
match message {
Message::Pong(_) => (),
_ => panic!("Expected pong message; found message type: {}", message.type_id()),
fn read_unknown_message() {
let buffer = &::core::u16::MAX.to_be_bytes();
let mut reader = io::Cursor::new(buffer);
- let message = read(&mut reader).unwrap();
+ let message = read(&mut reader, &IgnoringMessageHandler{}).unwrap();
match message {
- Message::Unknown(MessageType(::core::u16::MAX)) => (),
+ Message::Unknown(::core::u16::MAX) => (),
_ => panic!("Expected message type {}; found: {}", ::core::u16::MAX, message.type_id()),
}
}
assert!(write(&message, &mut buffer).is_ok());
let mut reader = io::Cursor::new(buffer);
- let decoded_message = read(&mut reader).unwrap();
+ let decoded_message = read(&mut reader, &IgnoringMessageHandler{}).unwrap();
match decoded_message {
Message::Pong(msgs::Pong { byteslen: 2u16 }) => (),
Message::Pong(msgs::Pong { byteslen }) => {
#[test]
fn is_even_message_type() {
- let message = Message::Unknown(MessageType(42));
- assert!(message.type_id().is_even());
+ let message = Message::<()>::Unknown(42);
+ assert!(message.is_even());
}
#[test]
fn is_odd_message_type() {
- let message = Message::Unknown(MessageType(43));
- assert!(!message.type_id().is_even());
+ let message = Message::<()>::Unknown(43);
+ assert!(!message.is_even());
}
#[test]
fn check_init_msg(buffer: Vec<u8>, expect_unknown: bool) {
let mut reader = io::Cursor::new(buffer);
- let decoded_msg = read(&mut reader).unwrap();
+ let decoded_msg = read(&mut reader, &IgnoringMessageHandler{}).unwrap();
match decoded_msg {
Message::Init(msgs::Init { features }) => {
assert!(features.supports_variable_length_onion());
// Taken from lnd v0.9.0-beta.
let buffer = vec![1, 1, 91, 164, 146, 213, 213, 165, 21, 227, 102, 33, 105, 179, 214, 21, 221, 175, 228, 93, 57, 177, 191, 127, 107, 229, 31, 50, 21, 81, 179, 71, 39, 18, 35, 2, 89, 224, 110, 123, 66, 39, 148, 246, 177, 85, 12, 19, 70, 226, 173, 132, 156, 26, 122, 146, 71, 213, 247, 48, 93, 190, 185, 177, 12, 172, 0, 3, 2, 162, 161, 94, 103, 195, 37, 2, 37, 242, 97, 140, 2, 111, 69, 85, 39, 118, 30, 221, 99, 254, 120, 49, 103, 22, 170, 227, 111, 172, 164, 160, 49, 68, 138, 116, 16, 22, 206, 107, 51, 153, 255, 97, 108, 105, 99, 101, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 1, 172, 21, 0, 2, 38, 7];
let mut reader = io::Cursor::new(buffer);
- let decoded_msg = read(&mut reader).unwrap();
+ let decoded_msg = read(&mut reader, &IgnoringMessageHandler{}).unwrap();
match decoded_msg {
Message::NodeAnnouncement(msgs::NodeAnnouncement { contents: msgs::UnsignedNodeAnnouncement { features, ..}, ..}) => {
assert!(features.supports_variable_length_onion());
// Taken from lnd v0.9.0-beta.
let buffer = vec![1, 0, 82, 238, 153, 33, 128, 87, 215, 2, 28, 241, 140, 250, 98, 255, 56, 5, 79, 240, 214, 231, 172, 35, 240, 171, 44, 9, 78, 91, 8, 193, 102, 5, 17, 178, 142, 106, 180, 183, 46, 38, 217, 212, 25, 236, 69, 47, 92, 217, 181, 221, 161, 205, 121, 201, 99, 38, 158, 216, 186, 193, 230, 86, 222, 6, 206, 67, 22, 255, 137, 212, 141, 161, 62, 134, 76, 48, 241, 54, 50, 167, 187, 247, 73, 27, 74, 1, 129, 185, 197, 153, 38, 90, 255, 138, 39, 161, 102, 172, 213, 74, 107, 88, 150, 90, 0, 49, 104, 7, 182, 184, 194, 219, 181, 172, 8, 245, 65, 226, 19, 228, 101, 145, 25, 159, 52, 31, 58, 93, 53, 59, 218, 91, 37, 84, 103, 17, 74, 133, 33, 35, 2, 203, 101, 73, 19, 94, 175, 122, 46, 224, 47, 168, 128, 128, 25, 26, 25, 214, 52, 247, 43, 241, 117, 52, 206, 94, 135, 156, 52, 164, 143, 234, 58, 185, 50, 185, 140, 198, 174, 71, 65, 18, 105, 70, 131, 172, 137, 0, 164, 51, 215, 143, 117, 119, 217, 241, 197, 177, 227, 227, 170, 199, 114, 7, 218, 12, 107, 30, 191, 236, 203, 21, 61, 242, 48, 192, 90, 233, 200, 199, 111, 162, 68, 234, 54, 219, 1, 233, 66, 5, 82, 74, 84, 211, 95, 199, 245, 202, 89, 223, 102, 124, 62, 166, 253, 253, 90, 180, 118, 21, 61, 110, 37, 5, 96, 167, 0, 0, 6, 34, 110, 70, 17, 26, 11, 89, 202, 175, 18, 96, 67, 235, 91, 191, 40, 195, 79, 58, 94, 51, 42, 31, 199, 178, 183, 60, 241, 136, 145, 15, 0, 2, 65, 0, 0, 1, 0, 0, 2, 37, 242, 97, 140, 2, 111, 69, 85, 39, 118, 30, 221, 99, 254, 120, 49, 103, 22, 170, 227, 111, 172, 164, 160, 49, 68, 138, 116, 16, 22, 206, 107, 3, 54, 61, 144, 88, 171, 247, 136, 208, 99, 9, 135, 37, 201, 178, 253, 136, 0, 185, 235, 68, 160, 106, 110, 12, 46, 21, 125, 204, 18, 75, 234, 16, 3, 42, 171, 28, 52, 224, 11, 30, 30, 253, 156, 148, 175, 203, 121, 250, 111, 122, 195, 84, 122, 77, 183, 56, 135, 101, 88, 41, 60, 191, 99, 232, 85, 2, 36, 17, 156, 11, 8, 12, 189, 177, 68, 88, 28, 15, 207, 21, 179, 151, 56, 226, 158, 148, 3, 120, 113, 177, 243, 184, 17, 173, 37, 46, 222, 16];
let mut reader = io::Cursor::new(buffer);
- let decoded_msg = read(&mut reader).unwrap();
+ let decoded_msg = read(&mut reader, &IgnoringMessageHandler{}).unwrap();
match decoded_msg {
Message::ChannelAnnouncement(msgs::ChannelAnnouncement { contents: msgs::UnsignedChannelAnnouncement { features, ..}, ..}) => {
assert!(!features.requires_unknown_bits());
_ => panic!("Expected node announcement, found message type: {}", decoded_msg.type_id())
}
}
+
+ #[derive(Eq, PartialEq, Debug)]
+ struct TestCustomMessage {}
+
+ const CUSTOM_MESSAGE_TYPE : u16 = 9000;
+
+ impl Type for TestCustomMessage {
+ fn type_id(&self) -> u16 {
+ CUSTOM_MESSAGE_TYPE
+ }
+ }
+
+ impl Writeable for TestCustomMessage {
+ fn write<W: Writer>(&self, _: &mut W) -> Result<(), io::Error> {
+ Ok(())
+ }
+ }
+
+ struct TestCustomMessageReader {}
+
+ impl CustomMessageReader for TestCustomMessageReader {
+ type CustomMessage = TestCustomMessage;
+ fn read<R: io::Read>(
+ &self,
+ message_type: u16,
+ _: &mut R
+ ) -> Result<Option<Self::CustomMessage>, msgs::DecodeError> {
+ if message_type == CUSTOM_MESSAGE_TYPE {
+ return Ok(Some(TestCustomMessage{}));
+ }
+
+ Ok(None)
+ }
+ }
+
+ #[test]
+ fn read_custom_message() {
+ let buffer = vec![35, 40];
+ let mut reader = io::Cursor::new(buffer);
+ let decoded_msg = read(&mut reader, &TestCustomMessageReader{}).unwrap();
+ match decoded_msg {
+ Message::Custom(custom) => {
+ assert_eq!(custom.type_id(), CUSTOM_MESSAGE_TYPE);
+ assert_eq!(custom, TestCustomMessage {});
+ },
+ _ => panic!("Expected custom message, found message type: {}", decoded_msg.type_id()),
+ }
+ }
+
+ #[test]
+ fn read_with_custom_reader_unknown_message_type() {
+ let buffer = vec![35, 42];
+ let mut reader = io::Cursor::new(buffer);
+ let decoded_msg = read(&mut reader, &TestCustomMessageReader{}).unwrap();
+ match decoded_msg {
+ Message::Unknown(_) => {},
+ _ => panic!("Expected unknown message, found message type: {}", decoded_msg.type_id()),
+ }
+ }
+
+ #[test]
+ fn custom_reader_unknown_message_type() {
+ let buffer = Vec::new();
+ let mut reader = io::Cursor::new(buffer);
+ let res = TestCustomMessageReader{}.read(CUSTOM_MESSAGE_TYPE + 1, &mut reader).unwrap();
+ assert!(res.is_none());
+ }
}
return Err(LightningError{err: "Cannot send a payment of 0 msat".to_owned(), action: ErrorAction::IgnoreError});
}
- let last_hops = last_hops.iter().filter_map(|hops| hops.0.last()).collect::<Vec<_>>();
- for last_hop in last_hops.iter() {
- if last_hop.src_node_id == *payee {
- return Err(LightningError{err: "Last hop cannot have a payee as a source.".to_owned(), action: ErrorAction::IgnoreError});
+ for route in last_hops.iter() {
+ for hop in &route.0 {
+ if hop.src_node_id == *payee {
+ return Err(LightningError{err: "Last hop cannot have a payee as a source.".to_owned(), action: ErrorAction::IgnoreError});
+ }
}
}
// If a caller provided us with last hops, add them to routing targets. Since this happens
// earlier than general path finding, they will be somewhat prioritized, although currently
// it matters only if the fees are exactly the same.
- for hop in last_hops.iter() {
+ for route in last_hops.iter().filter(|route| !route.0.is_empty()) {
+ let first_hop_in_route = &(route.0)[0];
let have_hop_src_in_graph =
- // Only add the last hop to our candidate set if either we have a direct channel or
- // they are in the regular network graph.
- first_hop_targets.get(&hop.src_node_id).is_some() ||
- network.get_nodes().get(&hop.src_node_id).is_some();
+ // Only add the hops in this route to our candidate set if either
+ // we have a direct channel to the first hop or the first hop is
+ // in the regular network graph.
+ first_hop_targets.get(&first_hop_in_route.src_node_id).is_some() ||
+ network.get_nodes().get(&first_hop_in_route.src_node_id).is_some();
if have_hop_src_in_graph {
- // BOLT 11 doesn't allow inclusion of features for the last hop hints, which
- // really sucks, cause we're gonna need that eventually.
- let last_hop_htlc_minimum_msat: u64 = match hop.htlc_minimum_msat {
- Some(htlc_minimum_msat) => htlc_minimum_msat,
- None => 0
- };
- let directional_info = DummyDirectionalChannelInfo {
- cltv_expiry_delta: hop.cltv_expiry_delta as u32,
- htlc_minimum_msat: last_hop_htlc_minimum_msat,
- htlc_maximum_msat: hop.htlc_maximum_msat,
- fees: hop.fees,
- };
- // We assume that the recipient only included route hints for routes which had
- // sufficient value to route `final_value_msat`. Note that in the case of "0-value"
- // invoices where the invoice does not specify value this may not be the case, but
- // better to include the hints than not.
- if add_entry!(hop.short_channel_id, hop.src_node_id, payee, directional_info, Some((final_value_msat + 999) / 1000), &empty_channel_features, 0, path_value_msat, 0) {
- // If this hop connects to a node with which we have a direct channel,
- // ignore the network graph and, if the last hop was added, add our
- // direct channel to the candidate set.
- //
- // Note that we *must* check if the last hop was added as `add_entry`
- // always assumes that the third argument is a node to which we have a
- // path.
- if let Some(&(ref first_hop, ref features, ref outbound_capacity_msat, _)) = first_hop_targets.get(&hop.src_node_id) {
- add_entry!(first_hop, *our_node_id , hop.src_node_id, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, 0, path_value_msat, 0);
+ // We start building the path from reverse, i.e., from payee
+ // to the first RouteHintHop in the path.
+ let hop_iter = route.0.iter().rev();
+ let prev_hop_iter = core::iter::once(payee).chain(
+ route.0.iter().skip(1).rev().map(|hop| &hop.src_node_id));
+ let mut hop_used = true;
+ let mut aggregate_next_hops_fee_msat: u64 = 0;
+ let mut aggregate_next_hops_path_htlc_minimum_msat: u64 = 0;
+
+ for (idx, (hop, prev_hop_id)) in hop_iter.zip(prev_hop_iter).enumerate() {
+ // BOLT 11 doesn't allow inclusion of features for the last hop hints, which
+ // really sucks, cause we're gonna need that eventually.
+ let hop_htlc_minimum_msat: u64 = hop.htlc_minimum_msat.unwrap_or(0);
+
+ let directional_info = DummyDirectionalChannelInfo {
+ cltv_expiry_delta: hop.cltv_expiry_delta as u32,
+ htlc_minimum_msat: hop_htlc_minimum_msat,
+ htlc_maximum_msat: hop.htlc_maximum_msat,
+ fees: hop.fees,
+ };
+
+ let reqd_channel_cap = if let Some (val) = final_value_msat.checked_add(match idx {
+ 0 => 999,
+ _ => aggregate_next_hops_fee_msat.checked_add(999).unwrap_or(u64::max_value())
+ }) { Some( val / 1000 ) } else { break; }; // converting from msat or breaking if max ~ infinity
+
+
+ // We assume that the recipient only included route hints for routes which had
+ // sufficient value to route `final_value_msat`. Note that in the case of "0-value"
+ // invoices where the invoice does not specify value this may not be the case, but
+ // better to include the hints than not.
+ if !add_entry!(hop.short_channel_id, hop.src_node_id, prev_hop_id, directional_info, reqd_channel_cap, &empty_channel_features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat) {
+ // If this hop was not used then there is no use checking the preceding hops
+ // in the RouteHint. We can break by just searching for a direct channel between
+ // last checked hop and first_hop_targets
+ hop_used = false;
+ }
+
+ // Searching for a direct channel between last checked hop and first_hop_targets
+ if let Some(&(ref first_hop, ref features, ref outbound_capacity_msat, _)) = first_hop_targets.get(&prev_hop_id) {
+ add_entry!(first_hop, *our_node_id , prev_hop_id, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat);
+ }
+
+ if !hop_used {
+ break;
+ }
+
+ // In the next values of the iterator, the aggregate fees already reflects
+ // the sum of value sent from payer (final_value_msat) and routing fees
+ // for the last node in the RouteHint. We need to just add the fees to
+ // route through the current node so that the preceeding node (next iteration)
+ // can use it.
+ let hops_fee = compute_fees(aggregate_next_hops_fee_msat + final_value_msat, hop.fees)
+ .map_or(None, |inc| inc.checked_add(aggregate_next_hops_fee_msat));
+ aggregate_next_hops_fee_msat = if let Some(val) = hops_fee { val } else { break; };
+
+ let hop_htlc_minimum_msat_inc = if let Some(val) = compute_fees(aggregate_next_hops_path_htlc_minimum_msat, hop.fees) { val } else { break; };
+ let hops_path_htlc_minimum = aggregate_next_hops_path_htlc_minimum_msat
+ .checked_add(hop_htlc_minimum_msat_inc);
+ aggregate_next_hops_path_htlc_minimum_msat = if let Some(val) = hops_path_htlc_minimum { cmp::max(hop_htlc_minimum_msat, val) } else { break; };
+
+ if idx == route.0.len() - 1 {
+ // The last hop in this iterator is the first hop in
+ // overall RouteHint.
+ // If this hop connects to a node with which we have a direct channel,
+ // ignore the network graph and, if the last hop was added, add our
+ // direct channel to the candidate set.
+ //
+ // Note that we *must* check if the last hop was added as `add_entry`
+ // always assumes that the third argument is a node to which we have a
+ // path.
+ if let Some(&(ref first_hop, ref features, ref outbound_capacity_msat, _)) = first_hop_targets.get(&hop.src_node_id) {
+ add_entry!(first_hop, *our_node_id , hop.src_node_id, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat);
+ }
}
}
}
let logger = Arc::new(test_utils::TestLogger::new());
let chain_monitor = Arc::new(test_utils::TestChainSource::new(Network::Testnet));
let net_graph_msg_handler = NetGraphMsgHandler::new(genesis_block(Network::Testnet).header.block_hash(), None, Arc::clone(&logger));
- // Build network from our_id to node7:
+ // Build network from our_id to node6:
//
// -1(1)2- node0 -1(3)2-
// / \
// \ /
// -1(7)2- node5 -1(10)2-
//
+ // Channels 5, 8, 9 and 10 are private channels.
+ //
// chan5 1-to-2: enabled, 100 msat fee
// chan5 2-to-1: enabled, 0 fee
//
cltv_expiry_delta: (8 << 8) | 1,
htlc_minimum_msat: None,
htlc_maximum_msat: None,
+ }
+ ]), RouteHint(vec![RouteHintHop {
+ src_node_id: nodes[4].clone(),
+ short_channel_id: 9,
+ fees: RoutingFees {
+ base_msat: 1001,
+ proportional_millionths: 0,
+ },
+ cltv_expiry_delta: (9 << 8) | 1,
+ htlc_minimum_msat: None,
+ htlc_maximum_msat: None,
}]), RouteHint(vec![RouteHintHop {
+ src_node_id: nodes[5].clone(),
+ short_channel_id: 10,
+ fees: zero_fees,
+ cltv_expiry_delta: (10 << 8) | 1,
+ htlc_minimum_msat: None,
+ htlc_maximum_msat: None,
+ }])]
+ }
+
+ fn last_hops_multi_private_channels(nodes: &Vec<PublicKey>) -> Vec<RouteHint> {
+ let zero_fees = RoutingFees {
+ base_msat: 0,
+ proportional_millionths: 0,
+ };
+ vec![RouteHint(vec![RouteHintHop {
+ src_node_id: nodes[2].clone(),
+ short_channel_id: 5,
+ fees: RoutingFees {
+ base_msat: 100,
+ proportional_millionths: 0,
+ },
+ cltv_expiry_delta: (5 << 8) | 1,
+ htlc_minimum_msat: None,
+ htlc_maximum_msat: None,
+ }, RouteHintHop {
+ src_node_id: nodes[3].clone(),
+ short_channel_id: 8,
+ fees: zero_fees,
+ cltv_expiry_delta: (8 << 8) | 1,
+ htlc_minimum_msat: None,
+ htlc_maximum_msat: None,
+ }
+ ]), RouteHint(vec![RouteHintHop {
src_node_id: nodes[4].clone(),
short_channel_id: 9,
fees: RoutingFees {
}
#[test]
- fn last_hops_test() {
+ fn partial_route_hint_test() {
let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
// Simple test across 2, 3, 5, and 4 via a last_hop channel
+ // Tests the behaviour when the RouteHint contains a suboptimal hop.
+ // RouteHint may be partially used by the algo to build the best path.
// First check that last hop can't have its source as the payee.
let invalid_last_hop = RouteHint(vec![RouteHintHop {
htlc_maximum_msat: None,
}]);
- let mut invalid_last_hops = last_hops(&nodes);
+ let mut invalid_last_hops = last_hops_multi_private_channels(&nodes);
invalid_last_hops.push(invalid_last_hop);
{
if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph.read().unwrap(), &nodes[6], None, None, &invalid_last_hops.iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger)) {
} else { panic!(); }
}
- let route = get_route(&our_id, &net_graph_msg_handler.network_graph.read().unwrap(), &nodes[6], None, None, &last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger)).unwrap();
+ let route = get_route(&our_id, &net_graph_msg_handler.network_graph.read().unwrap(), &nodes[6], None, None, &last_hops_multi_private_channels(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger)).unwrap();
+ assert_eq!(route.paths[0].len(), 5);
+
+ assert_eq!(route.paths[0][0].pubkey, nodes[1]);
+ assert_eq!(route.paths[0][0].short_channel_id, 2);
+ assert_eq!(route.paths[0][0].fee_msat, 100);
+ assert_eq!(route.paths[0][0].cltv_expiry_delta, (4 << 8) | 1);
+ assert_eq!(route.paths[0][0].node_features.le_flags(), &id_to_feature_flags(2));
+ assert_eq!(route.paths[0][0].channel_features.le_flags(), &id_to_feature_flags(2));
+
+ assert_eq!(route.paths[0][1].pubkey, nodes[2]);
+ assert_eq!(route.paths[0][1].short_channel_id, 4);
+ assert_eq!(route.paths[0][1].fee_msat, 0);
+ assert_eq!(route.paths[0][1].cltv_expiry_delta, (6 << 8) | 1);
+ assert_eq!(route.paths[0][1].node_features.le_flags(), &id_to_feature_flags(3));
+ assert_eq!(route.paths[0][1].channel_features.le_flags(), &id_to_feature_flags(4));
+
+ assert_eq!(route.paths[0][2].pubkey, nodes[4]);
+ assert_eq!(route.paths[0][2].short_channel_id, 6);
+ assert_eq!(route.paths[0][2].fee_msat, 0);
+ assert_eq!(route.paths[0][2].cltv_expiry_delta, (11 << 8) | 1);
+ assert_eq!(route.paths[0][2].node_features.le_flags(), &id_to_feature_flags(5));
+ assert_eq!(route.paths[0][2].channel_features.le_flags(), &id_to_feature_flags(6));
+
+ assert_eq!(route.paths[0][3].pubkey, nodes[3]);
+ assert_eq!(route.paths[0][3].short_channel_id, 11);
+ assert_eq!(route.paths[0][3].fee_msat, 0);
+ assert_eq!(route.paths[0][3].cltv_expiry_delta, (8 << 8) | 1);
+ // If we have a peer in the node map, we'll use their features here since we don't have
+ // a way of figuring out their features from the invoice:
+ assert_eq!(route.paths[0][3].node_features.le_flags(), &id_to_feature_flags(4));
+ assert_eq!(route.paths[0][3].channel_features.le_flags(), &id_to_feature_flags(11));
+
+ assert_eq!(route.paths[0][4].pubkey, nodes[6]);
+ assert_eq!(route.paths[0][4].short_channel_id, 8);
+ assert_eq!(route.paths[0][4].fee_msat, 100);
+ assert_eq!(route.paths[0][4].cltv_expiry_delta, 42);
+ assert_eq!(route.paths[0][4].node_features.le_flags(), &Vec::<u8>::new()); // We dont pass flags in from invoices yet
+ assert_eq!(route.paths[0][4].channel_features.le_flags(), &Vec::<u8>::new()); // We can't learn any flags from invoices, sadly
+ }
+
+ fn empty_last_hop(nodes: &Vec<PublicKey>) -> Vec<RouteHint> {
+ let zero_fees = RoutingFees {
+ base_msat: 0,
+ proportional_millionths: 0,
+ };
+ vec![RouteHint(vec![RouteHintHop {
+ src_node_id: nodes[3].clone(),
+ short_channel_id: 8,
+ fees: zero_fees,
+ cltv_expiry_delta: (8 << 8) | 1,
+ htlc_minimum_msat: None,
+ htlc_maximum_msat: None,
+ }]), RouteHint(vec![
+
+ ]), RouteHint(vec![RouteHintHop {
+ src_node_id: nodes[5].clone(),
+ short_channel_id: 10,
+ fees: zero_fees,
+ cltv_expiry_delta: (10 << 8) | 1,
+ htlc_minimum_msat: None,
+ htlc_maximum_msat: None,
+ }])]
+ }
+
+ #[test]
+ fn ignores_empty_last_hops_test() {
+ let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+ let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
+
+ // Test handling of an empty RouteHint passed in Invoice.
+
+ let route = get_route(&our_id, &net_graph_msg_handler.network_graph.read().unwrap(), &nodes[6], None, None, &empty_last_hop(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger)).unwrap();
assert_eq!(route.paths[0].len(), 5);
assert_eq!(route.paths[0][0].pubkey, nodes[1]);
assert_eq!(route.paths[0][4].channel_features.le_flags(), &Vec::<u8>::new()); // We can't learn any flags from invoices, sadly
}
+ fn multi_hint_last_hops(nodes: &Vec<PublicKey>) -> Vec<RouteHint> {
+ let zero_fees = RoutingFees {
+ base_msat: 0,
+ proportional_millionths: 0,
+ };
+ vec![RouteHint(vec![RouteHintHop {
+ src_node_id: nodes[2].clone(),
+ short_channel_id: 5,
+ fees: RoutingFees {
+ base_msat: 100,
+ proportional_millionths: 0,
+ },
+ cltv_expiry_delta: (5 << 8) | 1,
+ htlc_minimum_msat: None,
+ htlc_maximum_msat: None,
+ }, RouteHintHop {
+ src_node_id: nodes[3].clone(),
+ short_channel_id: 8,
+ fees: zero_fees,
+ cltv_expiry_delta: (8 << 8) | 1,
+ htlc_minimum_msat: None,
+ htlc_maximum_msat: None,
+ }]), RouteHint(vec![RouteHintHop {
+ src_node_id: nodes[5].clone(),
+ short_channel_id: 10,
+ fees: zero_fees,
+ cltv_expiry_delta: (10 << 8) | 1,
+ htlc_minimum_msat: None,
+ htlc_maximum_msat: None,
+ }])]
+ }
+
+ #[test]
+ fn multi_hint_last_hops_test() {
+ let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+ let (_, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
+ // Test through channels 2, 3, 5, 8.
+ // Test shows that multiple hop hints are considered.
+
+ // Disabling channels 6 & 7 by flags=2
+ update_channel(&net_graph_msg_handler, &secp_ctx, &privkeys[2], UnsignedChannelUpdate {
+ chain_hash: genesis_block(Network::Testnet).header.block_hash(),
+ short_channel_id: 6,
+ timestamp: 2,
+ flags: 2, // to disable
+ cltv_expiry_delta: 0,
+ htlc_minimum_msat: 0,
+ htlc_maximum_msat: OptionalField::Absent,
+ fee_base_msat: 0,
+ fee_proportional_millionths: 0,
+ excess_data: Vec::new()
+ });
+ update_channel(&net_graph_msg_handler, &secp_ctx, &privkeys[2], UnsignedChannelUpdate {
+ chain_hash: genesis_block(Network::Testnet).header.block_hash(),
+ short_channel_id: 7,
+ timestamp: 2,
+ flags: 2, // to disable
+ cltv_expiry_delta: 0,
+ htlc_minimum_msat: 0,
+ htlc_maximum_msat: OptionalField::Absent,
+ fee_base_msat: 0,
+ fee_proportional_millionths: 0,
+ excess_data: Vec::new()
+ });
+
+ let route = get_route(&our_id, &net_graph_msg_handler.network_graph.read().unwrap(), &nodes[6], None, None, &multi_hint_last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger)).unwrap();
+ assert_eq!(route.paths[0].len(), 4);
+
+ assert_eq!(route.paths[0][0].pubkey, nodes[1]);
+ assert_eq!(route.paths[0][0].short_channel_id, 2);
+ assert_eq!(route.paths[0][0].fee_msat, 200);
+ assert_eq!(route.paths[0][0].cltv_expiry_delta, 1025);
+ assert_eq!(route.paths[0][0].node_features.le_flags(), &id_to_feature_flags(2));
+ assert_eq!(route.paths[0][0].channel_features.le_flags(), &id_to_feature_flags(2));
+
+ assert_eq!(route.paths[0][1].pubkey, nodes[2]);
+ assert_eq!(route.paths[0][1].short_channel_id, 4);
+ assert_eq!(route.paths[0][1].fee_msat, 100);
+ assert_eq!(route.paths[0][1].cltv_expiry_delta, 1281);
+ assert_eq!(route.paths[0][1].node_features.le_flags(), &id_to_feature_flags(3));
+ assert_eq!(route.paths[0][1].channel_features.le_flags(), &id_to_feature_flags(4));
+
+ assert_eq!(route.paths[0][2].pubkey, nodes[3]);
+ assert_eq!(route.paths[0][2].short_channel_id, 5);
+ assert_eq!(route.paths[0][2].fee_msat, 0);
+ assert_eq!(route.paths[0][2].cltv_expiry_delta, 2049);
+ assert_eq!(route.paths[0][2].node_features.le_flags(), &id_to_feature_flags(4));
+ assert_eq!(route.paths[0][2].channel_features.le_flags(), &Vec::<u8>::new());
+
+ assert_eq!(route.paths[0][3].pubkey, nodes[6]);
+ assert_eq!(route.paths[0][3].short_channel_id, 8);
+ assert_eq!(route.paths[0][3].fee_msat, 100);
+ assert_eq!(route.paths[0][3].cltv_expiry_delta, 42);
+ assert_eq!(route.paths[0][3].node_features.le_flags(), &Vec::<u8>::new()); // We dont pass flags in from invoices yet
+ assert_eq!(route.paths[0][3].channel_features.le_flags(), &Vec::<u8>::new()); // We can't learn any flags from invoices, sadly
+ }
+
+ fn last_hops_with_public_channel(nodes: &Vec<PublicKey>) -> Vec<RouteHint> {
+ let zero_fees = RoutingFees {
+ base_msat: 0,
+ proportional_millionths: 0,
+ };
+ vec![RouteHint(vec![RouteHintHop {
+ src_node_id: nodes[4].clone(),
+ short_channel_id: 11,
+ fees: zero_fees,
+ cltv_expiry_delta: (11 << 8) | 1,
+ htlc_minimum_msat: None,
+ htlc_maximum_msat: None,
+ }, RouteHintHop {
+ src_node_id: nodes[3].clone(),
+ short_channel_id: 8,
+ fees: zero_fees,
+ cltv_expiry_delta: (8 << 8) | 1,
+ htlc_minimum_msat: None,
+ htlc_maximum_msat: None,
+ }]), RouteHint(vec![RouteHintHop {
+ src_node_id: nodes[4].clone(),
+ short_channel_id: 9,
+ fees: RoutingFees {
+ base_msat: 1001,
+ proportional_millionths: 0,
+ },
+ cltv_expiry_delta: (9 << 8) | 1,
+ htlc_minimum_msat: None,
+ htlc_maximum_msat: None,
+ }]), RouteHint(vec![RouteHintHop {
+ src_node_id: nodes[5].clone(),
+ short_channel_id: 10,
+ fees: zero_fees,
+ cltv_expiry_delta: (10 << 8) | 1,
+ htlc_minimum_msat: None,
+ htlc_maximum_msat: None,
+ }])]
+ }
+
+ #[test]
+ fn last_hops_with_public_channel_test() {
+ let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+ let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
+ // This test shows that public routes can be present in the invoice
+ // which would be handled in the same manner.
+
+ let route = get_route(&our_id, &net_graph_msg_handler.network_graph.read().unwrap(), &nodes[6], None, None, &last_hops_with_public_channel(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger)).unwrap();
+ assert_eq!(route.paths[0].len(), 5);
+
+ assert_eq!(route.paths[0][0].pubkey, nodes[1]);
+ assert_eq!(route.paths[0][0].short_channel_id, 2);
+ assert_eq!(route.paths[0][0].fee_msat, 100);
+ assert_eq!(route.paths[0][0].cltv_expiry_delta, (4 << 8) | 1);
+ assert_eq!(route.paths[0][0].node_features.le_flags(), &id_to_feature_flags(2));
+ assert_eq!(route.paths[0][0].channel_features.le_flags(), &id_to_feature_flags(2));
+
+ assert_eq!(route.paths[0][1].pubkey, nodes[2]);
+ assert_eq!(route.paths[0][1].short_channel_id, 4);
+ assert_eq!(route.paths[0][1].fee_msat, 0);
+ assert_eq!(route.paths[0][1].cltv_expiry_delta, (6 << 8) | 1);
+ assert_eq!(route.paths[0][1].node_features.le_flags(), &id_to_feature_flags(3));
+ assert_eq!(route.paths[0][1].channel_features.le_flags(), &id_to_feature_flags(4));
+
+ assert_eq!(route.paths[0][2].pubkey, nodes[4]);
+ assert_eq!(route.paths[0][2].short_channel_id, 6);
+ assert_eq!(route.paths[0][2].fee_msat, 0);
+ assert_eq!(route.paths[0][2].cltv_expiry_delta, (11 << 8) | 1);
+ assert_eq!(route.paths[0][2].node_features.le_flags(), &id_to_feature_flags(5));
+ assert_eq!(route.paths[0][2].channel_features.le_flags(), &id_to_feature_flags(6));
+
+ assert_eq!(route.paths[0][3].pubkey, nodes[3]);
+ assert_eq!(route.paths[0][3].short_channel_id, 11);
+ assert_eq!(route.paths[0][3].fee_msat, 0);
+ assert_eq!(route.paths[0][3].cltv_expiry_delta, (8 << 8) | 1);
+ // If we have a peer in the node map, we'll use their features here since we don't have
+ // a way of figuring out their features from the invoice:
+ assert_eq!(route.paths[0][3].node_features.le_flags(), &id_to_feature_flags(4));
+ assert_eq!(route.paths[0][3].channel_features.le_flags(), &Vec::<u8>::new());
+
+ assert_eq!(route.paths[0][4].pubkey, nodes[6]);
+ assert_eq!(route.paths[0][4].short_channel_id, 8);
+ assert_eq!(route.paths[0][4].fee_msat, 100);
+ assert_eq!(route.paths[0][4].cltv_expiry_delta, 42);
+ assert_eq!(route.paths[0][4].node_features.le_flags(), &Vec::<u8>::new()); // We dont pass flags in from invoices yet
+ assert_eq!(route.paths[0][4].channel_features.le_flags(), &Vec::<u8>::new()); // We can't learn any flags from invoices, sadly
+ }
+
#[test]
fn our_chans_last_hop_connect_test() {
let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
use ln::{chan_utils, msgs};
use chain::keysinterface::{Sign, InMemorySigner, BaseSign};
-use io;
use prelude::*;
use core::cmp;
use sync::{Mutex, Arc};
+#[cfg(test)] use sync::MutexGuard;
use bitcoin::blockdata::transaction::{Transaction, SigHashType};
use bitcoin::util::bip143;
use bitcoin::secp256k1;
use bitcoin::secp256k1::key::{SecretKey, PublicKey};
use bitcoin::secp256k1::{Secp256k1, Signature};
-use util::ser::{Writeable, Writer, Readable};
+use util::ser::{Writeable, Writer};
use io::Error;
-use ln::msgs::DecodeError;
/// Initial value for revoked commitment downward counter
pub const INITIAL_REVOKED_COMMITMENT_NUMBER: u64 = 1 << 48;
/// - When signing, the holder transaction has not been revoked
/// - When revoking, the holder transaction has not been signed
/// - The holder commitment number is monotonic and without gaps
+/// - The revoked holder commitment number is monotonic and without gaps
+/// - There is at least one unrevoked holder transaction at all times
/// - The counterparty commitment number is monotonic and without gaps
/// - The pre-derived keys and pre-built transaction in CommitmentTransaction were correctly built
///
/// Eventually we will probably want to expose a variant of this which would essentially
/// be what you'd want to run on a hardware wallet.
///
+/// Note that counterparty signatures on the holder transaction are not checked, but it should
+/// be in a complete implementation.
+///
/// Note that before we do so we should ensure its serialization format has backwards- and
/// forwards-compatibility prefix/suffixes!
#[derive(Clone)]
pub struct EnforcingSigner {
pub inner: InMemorySigner,
- /// The last counterparty commitment number we signed, backwards counting
- pub last_commitment_number: Arc<Mutex<Option<u64>>>,
- /// The last holder commitment number we revoked, backwards counting
- pub revoked_commitment: Arc<Mutex<u64>>,
+ /// Channel state used for policy enforcement
+ pub state: Arc<Mutex<EnforcementState>>,
pub disable_revocation_policy_check: bool,
}
impl EnforcingSigner {
/// Construct an EnforcingSigner
pub fn new(inner: InMemorySigner) -> Self {
+ let state = Arc::new(Mutex::new(EnforcementState::new()));
Self {
inner,
- last_commitment_number: Arc::new(Mutex::new(None)),
- revoked_commitment: Arc::new(Mutex::new(INITIAL_REVOKED_COMMITMENT_NUMBER)),
+ state,
disable_revocation_policy_check: false
}
}
/// Construct an EnforcingSigner with externally managed storage
///
/// Since there are multiple copies of this struct for each channel, some coordination is needed
- /// so that all copies are aware of revocations. A pointer to this state is provided here, usually
- /// by an implementation of KeysInterface.
- pub fn new_with_revoked(inner: InMemorySigner, revoked_commitment: Arc<Mutex<u64>>, disable_revocation_policy_check: bool) -> Self {
+ /// so that all copies are aware of enforcement state. A pointer to this state is provided
+ /// here, usually by an implementation of KeysInterface.
+ pub fn new_with_revoked(inner: InMemorySigner, state: Arc<Mutex<EnforcementState>>, disable_revocation_policy_check: bool) -> Self {
Self {
inner,
- last_commitment_number: Arc::new(Mutex::new(None)),
- revoked_commitment,
+ state,
disable_revocation_policy_check
}
}
+
+ #[cfg(test)]
+ pub fn get_enforcement_state(&self) -> MutexGuard<EnforcementState> {
+ self.state.lock().unwrap()
+ }
}
impl BaseSign for EnforcingSigner {
fn release_commitment_secret(&self, idx: u64) -> [u8; 32] {
{
- let mut revoked = self.revoked_commitment.lock().unwrap();
- assert!(idx == *revoked || idx == *revoked - 1, "can only revoke the current or next unrevoked commitment - trying {}, revoked {}", idx, *revoked);
- *revoked = idx;
+ let mut state = self.state.lock().unwrap();
+ assert!(idx == state.last_holder_revoked_commitment || idx == state.last_holder_revoked_commitment - 1, "can only revoke the current or next unrevoked commitment - trying {}, last revoked {}", idx, state.last_holder_revoked_commitment);
+ assert!(idx > state.last_holder_commitment, "cannot revoke the last holder commitment - attempted to revoke {} last commitment {}", idx, state.last_holder_commitment);
+ state.last_holder_revoked_commitment = idx;
}
self.inner.release_commitment_secret(idx)
}
+ fn validate_holder_commitment(&self, holder_tx: &HolderCommitmentTransaction) -> Result<(), ()> {
+ let mut state = self.state.lock().unwrap();
+ let idx = holder_tx.commitment_number();
+ assert!(idx == state.last_holder_commitment || idx == state.last_holder_commitment - 1, "expecting to validate the current or next holder commitment - trying {}, current {}", idx, state.last_holder_commitment);
+ state.last_holder_commitment = idx;
+ Ok(())
+ }
+
fn pubkeys(&self) -> &ChannelPublicKeys { self.inner.pubkeys() }
fn channel_keys_id(&self) -> [u8; 32] { self.inner.channel_keys_id() }
self.verify_counterparty_commitment_tx(commitment_tx, secp_ctx);
{
- let mut last_commitment_number_guard = self.last_commitment_number.lock().unwrap();
+ let mut state = self.state.lock().unwrap();
let actual_commitment_number = commitment_tx.commitment_number();
- let last_commitment_number = last_commitment_number_guard.unwrap_or(actual_commitment_number);
+ let last_commitment_number = state.last_counterparty_commitment;
// These commitment numbers are backwards counting. We expect either the same as the previously encountered,
// or the next one.
assert!(last_commitment_number == actual_commitment_number || last_commitment_number - 1 == actual_commitment_number, "{} doesn't come after {}", actual_commitment_number, last_commitment_number);
- *last_commitment_number_guard = Some(cmp::min(last_commitment_number, actual_commitment_number))
+ // Ensure that the counterparty doesn't get more than two broadcastable commitments -
+ // the last and the one we are trying to sign
+ assert!(actual_commitment_number >= state.last_counterparty_revoked_commitment - 2, "cannot sign a commitment if second to last wasn't revoked - signing {} revoked {}", actual_commitment_number, state.last_counterparty_revoked_commitment);
+ state.last_counterparty_commitment = cmp::min(last_commitment_number, actual_commitment_number)
}
Ok(self.inner.sign_counterparty_commitment(commitment_tx, secp_ctx).unwrap())
}
+ fn validate_counterparty_revocation(&self, idx: u64, _secret: &SecretKey) -> Result<(), ()> {
+ let mut state = self.state.lock().unwrap();
+ assert!(idx == state.last_counterparty_revoked_commitment || idx == state.last_counterparty_revoked_commitment - 1, "expecting to validate the current or next counterparty revocation - trying {}, current {}", idx, state.last_counterparty_revoked_commitment);
+ state.last_counterparty_revoked_commitment = idx;
+ Ok(())
+ }
+
fn sign_holder_commitment_and_htlcs(&self, commitment_tx: &HolderCommitmentTransaction, secp_ctx: &Secp256k1<secp256k1::All>) -> Result<(Signature, Vec<Signature>), ()> {
let trusted_tx = self.verify_holder_commitment_tx(commitment_tx, secp_ctx);
let commitment_txid = trusted_tx.txid();
let holder_csv = self.inner.counterparty_selected_contest_delay();
- let revoked = self.revoked_commitment.lock().unwrap();
+ let state = self.state.lock().unwrap();
let commitment_number = trusted_tx.commitment_number();
- if *revoked - 1 != commitment_number && *revoked - 2 != commitment_number {
+ if state.last_holder_revoked_commitment - 1 != commitment_number && state.last_holder_revoked_commitment - 2 != commitment_number {
if !self.disable_revocation_policy_check {
panic!("can only sign the next two unrevoked commitment numbers, revoked={} vs requested={} for {}",
- *revoked, commitment_number, self.inner.commitment_seed[0])
+ state.last_holder_revoked_commitment, commitment_number, self.inner.commitment_seed[0])
}
}
impl Writeable for EnforcingSigner {
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), Error> {
+ // EnforcingSigner has two fields - `inner` ([`InMemorySigner`]) and `state`
+ // ([`EnforcementState`]). `inner` is serialized here and deserialized by
+ // [`KeysInterface::read_chan_signer`]. `state` is managed by [`KeysInterface`]
+ // and will be serialized as needed by the implementation of that trait.
self.inner.write(writer)?;
- let last = *self.last_commitment_number.lock().unwrap();
- last.write(writer)?;
Ok(())
}
}
-impl Readable for EnforcingSigner {
- fn read<R: io::Read>(reader: &mut R) -> Result<Self, DecodeError> {
- let inner = Readable::read(reader)?;
- let last_commitment_number = Readable::read(reader)?;
- Ok(EnforcingSigner {
- inner,
- last_commitment_number: Arc::new(Mutex::new(last_commitment_number)),
- revoked_commitment: Arc::new(Mutex::new(INITIAL_REVOKED_COMMITMENT_NUMBER)),
- disable_revocation_policy_check: false,
- })
- }
-}
-
impl EnforcingSigner {
fn verify_counterparty_commitment_tx<'a, T: secp256k1::Signing + secp256k1::Verification>(&self, commitment_tx: &'a CommitmentTransaction, secp_ctx: &Secp256k1<T>) -> TrustedCommitmentTransaction<'a> {
commitment_tx.verify(&self.inner.get_channel_parameters().as_counterparty_broadcastable(),
.expect("derived different per-tx keys or built transaction")
}
}
+
+/// The state used by [`EnforcingSigner`] in order to enforce policy checks
+///
+/// This structure is maintained by KeysInterface since we may have multiple copies of
+/// the signer and they must coordinate their state.
+#[derive(Clone)]
+pub struct EnforcementState {
+ /// The last counterparty commitment number we signed, backwards counting
+ pub last_counterparty_commitment: u64,
+ /// The last counterparty commitment they revoked, backwards counting
+ pub last_counterparty_revoked_commitment: u64,
+ /// The last holder commitment number we revoked, backwards counting
+ pub last_holder_revoked_commitment: u64,
+ /// The last validated holder commitment number, backwards counting
+ pub last_holder_commitment: u64,
+}
+
+impl EnforcementState {
+ /// Enforcement state for a new channel
+ pub fn new() -> Self {
+ EnforcementState {
+ last_counterparty_commitment: INITIAL_REVOKED_COMMITMENT_NUMBER,
+ last_counterparty_revoked_commitment: INITIAL_REVOKED_COMMITMENT_NUMBER,
+ last_holder_revoked_commitment: INITIAL_REVOKED_COMMITMENT_NUMBER,
+ last_holder_commitment: INITIAL_REVOKED_COMMITMENT_NUMBER,
+ }
+ }
+}
use ln::msgs;
use ln::msgs::OptionalField;
use ln::script::ShutdownScript;
-use util::enforcing_trait_impls::{EnforcingSigner, INITIAL_REVOKED_COMMITMENT_NUMBER};
+use util::enforcing_trait_impls::{EnforcingSigner, EnforcementState};
use util::events;
use util::logger::{Logger, Level, Record};
use util::ser::{Readable, ReadableArgs, Writer, Writeable};
fn get_channel_signer(&self, _inbound: bool, _channel_value_satoshis: u64) -> EnforcingSigner { unreachable!(); }
fn get_secure_random_bytes(&self) -> [u8; 32] { [0; 32] }
- fn read_chan_signer(&self, reader: &[u8]) -> Result<Self::Signer, msgs::DecodeError> {
- EnforcingSigner::read(&mut io::Cursor::new(reader))
+ fn read_chan_signer(&self, mut reader: &[u8]) -> Result<Self::Signer, msgs::DecodeError> {
+ let inner: InMemorySigner = Readable::read(&mut reader)?;
+ let state = Arc::new(Mutex::new(EnforcementState::new()));
+
+ Ok(EnforcingSigner::new_with_revoked(
+ inner,
+ state,
+ false
+ ))
}
fn sign_invoice(&self, _invoice_preimage: Vec<u8>) -> Result<RecoverableSignature, ()> { unreachable!(); }
}
pub override_session_priv: Mutex<Option<[u8; 32]>>,
pub override_channel_id_priv: Mutex<Option<[u8; 32]>>,
pub disable_revocation_policy_check: bool,
- revoked_commitments: Mutex<HashMap<[u8;32], Arc<Mutex<u64>>>>,
+ enforcement_states: Mutex<HashMap<[u8;32], Arc<Mutex<EnforcementState>>>>,
expectations: Mutex<Option<VecDeque<OnGetShutdownScriptpubkey>>>,
}
fn get_channel_signer(&self, inbound: bool, channel_value_satoshis: u64) -> EnforcingSigner {
let keys = self.backing.get_channel_signer(inbound, channel_value_satoshis);
- let revoked_commitment = self.make_revoked_commitment_cell(keys.commitment_seed);
- EnforcingSigner::new_with_revoked(keys, revoked_commitment, self.disable_revocation_policy_check)
+ let state = self.make_enforcement_state_cell(keys.commitment_seed);
+ EnforcingSigner::new_with_revoked(keys, state, self.disable_revocation_policy_check)
}
fn get_secure_random_bytes(&self) -> [u8; 32] {
let mut reader = io::Cursor::new(buffer);
let inner: InMemorySigner = Readable::read(&mut reader)?;
- let revoked_commitment = self.make_revoked_commitment_cell(inner.commitment_seed);
-
- let last_commitment_number = Readable::read(&mut reader)?;
+ let state = self.make_enforcement_state_cell(inner.commitment_seed);
- Ok(EnforcingSigner {
+ Ok(EnforcingSigner::new_with_revoked(
inner,
- last_commitment_number: Arc::new(Mutex::new(last_commitment_number)),
- revoked_commitment,
- disable_revocation_policy_check: self.disable_revocation_policy_check,
- })
+ state,
+ self.disable_revocation_policy_check
+ ))
}
fn sign_invoice(&self, invoice_preimage: Vec<u8>) -> Result<RecoverableSignature, ()> {
override_session_priv: Mutex::new(None),
override_channel_id_priv: Mutex::new(None),
disable_revocation_policy_check: false,
- revoked_commitments: Mutex::new(HashMap::new()),
+ enforcement_states: Mutex::new(HashMap::new()),
expectations: Mutex::new(None),
}
}
pub fn derive_channel_keys(&self, channel_value_satoshis: u64, id: &[u8; 32]) -> EnforcingSigner {
let keys = self.backing.derive_channel_keys(channel_value_satoshis, id);
- let revoked_commitment = self.make_revoked_commitment_cell(keys.commitment_seed);
- EnforcingSigner::new_with_revoked(keys, revoked_commitment, self.disable_revocation_policy_check)
+ let state = self.make_enforcement_state_cell(keys.commitment_seed);
+ EnforcingSigner::new_with_revoked(keys, state, self.disable_revocation_policy_check)
}
- fn make_revoked_commitment_cell(&self, commitment_seed: [u8; 32]) -> Arc<Mutex<u64>> {
- let mut revoked_commitments = self.revoked_commitments.lock().unwrap();
- if !revoked_commitments.contains_key(&commitment_seed) {
- revoked_commitments.insert(commitment_seed, Arc::new(Mutex::new(INITIAL_REVOKED_COMMITMENT_NUMBER)));
+ fn make_enforcement_state_cell(&self, commitment_seed: [u8; 32]) -> Arc<Mutex<EnforcementState>> {
+ let mut states = self.enforcement_states.lock().unwrap();
+ if !states.contains_key(&commitment_seed) {
+ let state = EnforcementState::new();
+ states.insert(commitment_seed, Arc::new(Mutex::new(state)));
}
- let cell = revoked_commitments.get(&commitment_seed).unwrap();
+ let cell = states.get(&commitment_seed).unwrap();
Arc::clone(cell)
}
}