Merge pull request #2212 from wpaulino/off-by-one-locktime
[rust-lightning] / lightning / src / ln / onion_utils.rs
index 491188fcae13daf4cd17af099e129399197f5d47..b9f36bbd8dadf94592c9082a303ef378879a4819 100644 (file)
@@ -7,15 +7,15 @@
 // You may not use this file except in accordance with one or both of these
 // licenses.
 
-use crate::ln::{PaymentHash, PaymentPreimage, PaymentSecret};
-use crate::ln::channelmanager::HTLCSource;
+use crate::ln::{PaymentHash, PaymentPreimage};
+use crate::ln::channelmanager::{HTLCSource, RecipientOnionFields};
 use crate::ln::msgs;
 use crate::ln::wire::Encode;
 use crate::routing::gossip::NetworkUpdate;
 use crate::routing::router::RouteHop;
 use crate::util::chacha20::{ChaCha20, ChaChaReader};
 use crate::util::errors::{self, APIError};
-use crate::util::ser::{Readable, ReadableArgs, Writeable, LengthCalculatingWriter};
+use crate::util::ser::{Readable, ReadableArgs, Writeable, Writer, LengthCalculatingWriter};
 use crate::util::logger::Logger;
 
 use bitcoin::hashes::{Hash, HashEngine};
@@ -149,7 +149,7 @@ pub(super) fn construct_onion_keys<T: secp256k1::Signing>(secp_ctx: &Secp256k1<T
 }
 
 /// returns the hop data, as well as the first-hop value_msat and CLTV value we should send.
-pub(super) fn build_onion_payloads(path: &Vec<RouteHop>, total_msat: u64, payment_secret_option: &Option<PaymentSecret>, starting_htlc_offset: u32, keysend_preimage: &Option<PaymentPreimage>) -> Result<(Vec<msgs::OnionHopData>, u64, u32), APIError> {
+pub(super) fn build_onion_payloads(path: &Vec<RouteHop>, total_msat: u64, mut recipient_onion: RecipientOnionFields, starting_htlc_offset: u32, keysend_preimage: &Option<PaymentPreimage>) -> Result<(Vec<msgs::OnionHopData>, u64, u32), APIError> {
        let mut cur_value_msat = 0u64;
        let mut cur_cltv = starting_htlc_offset;
        let mut last_short_channel_id = 0;
@@ -164,12 +164,13 @@ pub(super) fn build_onion_payloads(path: &Vec<RouteHop>, total_msat: u64, paymen
                res.insert(0, msgs::OnionHopData {
                        format: if idx == 0 {
                                msgs::OnionHopDataFormat::FinalNode {
-                                       payment_data: if let &Some(ref payment_secret) = payment_secret_option {
+                                       payment_data: if let Some(secret) = recipient_onion.payment_secret.take() {
                                                Some(msgs::FinalOnionHopData {
-                                                       payment_secret: payment_secret.clone(),
+                                                       payment_secret: secret,
                                                        total_msat,
                                                })
                                        } else { None },
+                                       payment_metadata: recipient_onion.payment_metadata.take(),
                                        keysend_preimage: *keysend_preimage,
                                }
                        } else {
@@ -182,11 +183,11 @@ pub(super) fn build_onion_payloads(path: &Vec<RouteHop>, total_msat: u64, paymen
                });
                cur_value_msat += hop.fee_msat;
                if cur_value_msat >= 21000000 * 100000000 * 1000 {
-                       return Err(APIError::RouteError{err: "Channel fees overflowed?"});
+                       return Err(APIError::InvalidRoute{err: "Channel fees overflowed?".to_owned()});
                }
                cur_cltv += hop.cltv_expiry_delta as u32;
                if cur_cltv >= 500000000 {
-                       return Err(APIError::RouteError{err: "Channel CLTV overflowed?"});
+                       return Err(APIError::InvalidRoute{err: "Channel CLTV overflowed?".to_owned()});
                }
                last_short_channel_id = hop.short_channel_id;
        }
@@ -382,7 +383,7 @@ pub(super) fn build_failure_packet(shared_secret: &[u8], failure_type: u16, fail
        packet
 }
 
-#[inline]
+#[cfg(test)]
 pub(super) fn build_first_hop_failure_packet(shared_secret: &[u8], failure_type: u16, failure_data: &[u8]) -> msgs::OnionErrorPacket {
        let failure_packet = build_failure_packet(shared_secret, failure_type, failure_data);
        encrypt_failure_packet(shared_secret, &failure_packet.encode()[..])
@@ -592,29 +593,144 @@ pub(super) fn process_onion_failure<T: secp256k1::Signing, L: Deref>(secp_ctx: &
        } else { unreachable!(); }
 }
 
-/// An input used when decoding an onion packet.
-pub(crate) trait DecodeInput {
-       type Arg;
-       /// If Some, this is the input when checking the hmac of the onion packet.
-       fn payment_hash(&self) -> Option<&PaymentHash>;
-       /// Read argument when decrypting our hop payload.
-       fn read_arg(self) -> Self::Arg;
+#[derive(Clone)] // See Channel::revoke_and_ack for why, tl;dr: Rust bug
+pub(super) struct HTLCFailReason(HTLCFailReasonRepr);
+
+#[derive(Clone)] // See Channel::revoke_and_ack for why, tl;dr: Rust bug
+enum HTLCFailReasonRepr {
+       LightningError {
+               err: msgs::OnionErrorPacket,
+       },
+       Reason {
+               failure_code: u16,
+               data: Vec<u8>,
+       }
 }
 
-impl DecodeInput for PaymentHash {
-       type Arg = ();
-       fn payment_hash(&self) -> Option<&PaymentHash> {
-               Some(self)
+impl core::fmt::Debug for HTLCFailReason {
+       fn fmt(&self, f: &mut core::fmt::Formatter) -> Result<(), core::fmt::Error> {
+               match self.0 {
+                       HTLCFailReasonRepr::Reason { ref failure_code, .. } => {
+                               write!(f, "HTLC error code {}", failure_code)
+                       },
+                       HTLCFailReasonRepr::LightningError { .. } => {
+                               write!(f, "pre-built LightningError")
+                       }
+               }
        }
-       fn read_arg(self) -> Self::Arg { () }
 }
 
-impl DecodeInput for SharedSecret {
-       type Arg = SharedSecret;
-       fn payment_hash(&self) -> Option<&PaymentHash> {
-               None
+impl Writeable for HTLCFailReason {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), crate::io::Error> {
+               self.0.write(writer)
+       }
+}
+impl Readable for HTLCFailReason {
+       fn read<R: Read>(reader: &mut R) -> Result<Self, msgs::DecodeError> {
+               Ok(Self(Readable::read(reader)?))
+       }
+}
+
+impl_writeable_tlv_based_enum!(HTLCFailReasonRepr,
+       (0, LightningError) => {
+               (0, err, required),
+       },
+       (1, Reason) => {
+               (0, failure_code, required),
+               (2, data, vec_type),
+       },
+;);
+
+impl HTLCFailReason {
+       pub(super) fn reason(failure_code: u16, data: Vec<u8>) -> Self {
+               const BADONION: u16 = 0x8000;
+               const PERM: u16 = 0x4000;
+               const NODE: u16 = 0x2000;
+               const UPDATE: u16 = 0x1000;
+
+                    if failure_code == 1  | PERM { debug_assert!(data.is_empty()) }
+               else if failure_code == 2  | NODE { debug_assert!(data.is_empty()) }
+               else if failure_code == 2  | PERM | NODE { debug_assert!(data.is_empty()) }
+               else if failure_code == 3  | PERM | NODE { debug_assert!(data.is_empty()) }
+               else if failure_code == 4  | BADONION | PERM { debug_assert_eq!(data.len(), 32) }
+               else if failure_code == 5  | BADONION | PERM { debug_assert_eq!(data.len(), 32) }
+               else if failure_code == 6  | BADONION | PERM { debug_assert_eq!(data.len(), 32) }
+               else if failure_code == 7  | UPDATE {
+                       debug_assert_eq!(data.len() - 2, u16::from_be_bytes(data[0..2].try_into().unwrap()) as usize) }
+               else if failure_code == 8  | PERM { debug_assert!(data.is_empty()) }
+               else if failure_code == 9  | PERM { debug_assert!(data.is_empty()) }
+               else if failure_code == 10 | PERM { debug_assert!(data.is_empty()) }
+               else if failure_code == 11 | UPDATE {
+                       debug_assert_eq!(data.len() - 2 - 8, u16::from_be_bytes(data[8..10].try_into().unwrap()) as usize) }
+               else if failure_code == 12 | UPDATE {
+                       debug_assert_eq!(data.len() - 2 - 8, u16::from_be_bytes(data[8..10].try_into().unwrap()) as usize) }
+               else if failure_code == 13 | UPDATE {
+                       debug_assert_eq!(data.len() - 2 - 4, u16::from_be_bytes(data[4..6].try_into().unwrap()) as usize) }
+               else if failure_code == 14 | UPDATE {
+                       debug_assert_eq!(data.len() - 2, u16::from_be_bytes(data[0..2].try_into().unwrap()) as usize) }
+               else if failure_code == 15 | PERM { debug_assert_eq!(data.len(), 12) }
+               else if failure_code == 18 { debug_assert_eq!(data.len(), 4) }
+               else if failure_code == 19 { debug_assert_eq!(data.len(), 8) }
+               else if failure_code == 20 | UPDATE {
+                       debug_assert_eq!(data.len() - 2 - 2, u16::from_be_bytes(data[2..4].try_into().unwrap()) as usize) }
+               else if failure_code == 21 { debug_assert!(data.is_empty()) }
+               else if failure_code == 22 | PERM { debug_assert!(data.len() <= 11) }
+               else if failure_code == 23 { debug_assert!(data.is_empty()) }
+               else if failure_code & BADONION != 0 {
+                       // We set some bogus BADONION failure codes in test, so ignore unknown ones.
+               }
+               else { debug_assert!(false, "Unknown failure code: {}", failure_code) }
+
+               Self(HTLCFailReasonRepr::Reason { failure_code, data })
+       }
+
+       pub(super) fn from_failure_code(failure_code: u16) -> Self {
+               Self::reason(failure_code, Vec::new())
+       }
+
+       pub(super) fn from_msg(msg: &msgs::UpdateFailHTLC) -> Self {
+               Self(HTLCFailReasonRepr::LightningError { err: msg.reason.clone() })
+       }
+
+       pub(super) fn get_encrypted_failure_packet(&self, incoming_packet_shared_secret: &[u8; 32], phantom_shared_secret: &Option<[u8; 32]>)
+       -> msgs::OnionErrorPacket {
+               match self.0 {
+                       HTLCFailReasonRepr::Reason { ref failure_code, ref data } => {
+                               if let Some(phantom_ss) = phantom_shared_secret {
+                                       let phantom_packet = build_failure_packet(phantom_ss, *failure_code, &data[..]).encode();
+                                       let encrypted_phantom_packet = encrypt_failure_packet(phantom_ss, &phantom_packet);
+                                       encrypt_failure_packet(incoming_packet_shared_secret, &encrypted_phantom_packet.data[..])
+                               } else {
+                                       let packet = build_failure_packet(incoming_packet_shared_secret, *failure_code, &data[..]).encode();
+                                       encrypt_failure_packet(incoming_packet_shared_secret, &packet)
+                               }
+                       },
+                       HTLCFailReasonRepr::LightningError { ref err } => {
+                               encrypt_failure_packet(incoming_packet_shared_secret, &err.data)
+                       }
+               }
+       }
+
+       pub(super) fn decode_onion_failure<T: secp256k1::Signing, L: Deref>(
+               &self, secp_ctx: &Secp256k1<T>, logger: &L, htlc_source: &HTLCSource
+       ) -> (Option<NetworkUpdate>, Option<u64>, bool, Option<u16>, Option<Vec<u8>>)
+       where L::Target: Logger {
+               match self.0 {
+                       HTLCFailReasonRepr::LightningError { ref err } => {
+                               process_onion_failure(secp_ctx, logger, &htlc_source, err.data.clone())
+                       },
+                       HTLCFailReasonRepr::Reason { ref failure_code, ref data, .. } => {
+                               // we get a fail_malformed_htlc from the first hop
+                               // TODO: We'd like to generate a NetworkUpdate for temporary
+                               // failures here, but that would be insufficient as find_route
+                               // generally ignores its view of our own channels as we provide them via
+                               // ChannelDetails.
+                               if let &HTLCSource::OutboundRoute { ref path, .. } = htlc_source {
+                                       (None, Some(path.first().unwrap().short_channel_id), true, Some(*failure_code), Some(data.clone()))
+                               } else { unreachable!(); }
+                       }
+               }
        }
-       fn read_arg(self) -> Self::Arg { self }
 }
 
 /// Allows `decode_next_hop` to return the next hop packet bytes for either payments or onion
@@ -667,7 +783,7 @@ pub(crate) enum OnionDecodeErr {
 }
 
 pub(crate) fn decode_next_payment_hop(shared_secret: [u8; 32], hop_data: &[u8], hmac_bytes: [u8; 32], payment_hash: PaymentHash) -> Result<Hop, OnionDecodeErr> {
-       match decode_next_hop(shared_secret, hop_data, hmac_bytes, payment_hash) {
+       match decode_next_hop(shared_secret, hop_data, hmac_bytes, Some(payment_hash), ()) {
                Ok((next_hop_data, None)) => Ok(Hop::Receive(next_hop_data)),
                Ok((next_hop_data, Some((next_hop_hmac, FixedSizeOnionPacket(new_packet_bytes))))) => {
                        Ok(Hop::Forward {
@@ -680,12 +796,16 @@ pub(crate) fn decode_next_payment_hop(shared_secret: [u8; 32], hop_data: &[u8],
        }
 }
 
-pub(crate) fn decode_next_hop<D: DecodeInput, R: ReadableArgs<D::Arg>, N: NextPacketBytes>(shared_secret: [u8; 32], hop_data: &[u8], hmac_bytes: [u8; 32], decode_input: D) -> Result<(R, Option<([u8; 32], N)>), OnionDecodeErr> {
+pub(crate) fn decode_next_untagged_hop<T, R: ReadableArgs<T>, N: NextPacketBytes>(shared_secret: [u8; 32], hop_data: &[u8], hmac_bytes: [u8; 32], read_args: T) -> Result<(R, Option<([u8; 32], N)>), OnionDecodeErr> {
+       decode_next_hop(shared_secret, hop_data, hmac_bytes, None, read_args)
+}
+
+fn decode_next_hop<T, R: ReadableArgs<T>, N: NextPacketBytes>(shared_secret: [u8; 32], hop_data: &[u8], hmac_bytes: [u8; 32], payment_hash: Option<PaymentHash>, read_args: T) -> Result<(R, Option<([u8; 32], N)>), OnionDecodeErr> {
        let (rho, mu) = gen_rho_mu_from_shared_secret(&shared_secret);
        let mut hmac = HmacEngine::<Sha256>::new(&mu);
        hmac.input(hop_data);
-       if let Some(payment_hash) = decode_input.payment_hash() {
-               hmac.input(&payment_hash.0[..]);
+       if let Some(tag) = payment_hash {
+               hmac.input(&tag.0[..]);
        }
        if !fixed_time_eq(&Hmac::from_engine(hmac).into_inner(), &hmac_bytes) {
                return Err(OnionDecodeErr::Malformed {
@@ -696,7 +816,7 @@ pub(crate) fn decode_next_hop<D: DecodeInput, R: ReadableArgs<D::Arg>, N: NextPa
 
        let mut chacha = ChaCha20::new(&rho, &[0u8; 8]);
        let mut chacha_stream = ChaChaReader { chacha: &mut chacha, read: Cursor::new(&hop_data[..]) };
-       match R::read(&mut chacha_stream, decode_input.read_arg()) {
+       match R::read(&mut chacha_stream, read_args) {
                Err(err) => {
                        let error_code = match err {
                                msgs::DecodeError::UnknownVersion => 0x4000 | 1, // unknown realm byte