Make HTLCDescriptor writeable
[rust-lightning] / lightning / src / onion_message / packet.rs
index 8b677e7bb611ea307067493bc2252685e7632636..a76371b39eb9fa12769e2144fb6cdd02067b70f7 100644 (file)
@@ -264,8 +264,9 @@ ReadableArgs<(SharedSecret, &H, &L)> for Payload<<H as CustomOnionMessageHandler
 }
 
 /// When reading a packet off the wire, we don't know a priori whether the packet is to be forwarded
-/// or received. Thus we read a ControlTlvs rather than reading a ForwardControlTlvs or
-/// ReceiveControlTlvs directly.
+/// or received. Thus we read a `ControlTlvs` rather than reading a [`ForwardTlvs`] or
+/// [`ReceiveTlvs`] directly. Also useful on the encoding side to keep forward and receive TLVs in
+/// the same iterator.
 pub(crate) enum ControlTlvs {
        /// This onion message is intended to be forwarded.
        Forward(ForwardTlvs),
@@ -274,19 +275,16 @@ pub(crate) enum ControlTlvs {
 }
 
 impl Readable for ControlTlvs {
-       fn read<R: Read>(mut r: &mut R) -> Result<Self, DecodeError> {
-               let mut _padding: Option<Padding> = None;
-               let mut _short_channel_id: Option<u64> = None;
-               let mut next_node_id: Option<PublicKey> = None;
-               let mut path_id: Option<[u8; 32]> = None;
-               let mut next_blinding_override: Option<PublicKey> = None;
-               decode_tlv_stream!(&mut r, {
+       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               _init_and_read_tlv_stream!(r, {
                        (1, _padding, option),
                        (2, _short_channel_id, option),
                        (4, next_node_id, option),
                        (6, path_id, option),
                        (8, next_blinding_override, option),
                });
+               let _padding: Option<Padding> = _padding;
+               let _short_channel_id: Option<u64> = _short_channel_id;
 
                let valid_fwd_fmt  = next_node_id.is_some() && path_id.is_none();
                let valid_recv_fmt = next_node_id.is_none() && next_blinding_override.is_none();
@@ -307,3 +305,12 @@ impl Readable for ControlTlvs {
                Ok(payload_fmt)
        }
 }
+
+impl Writeable for ControlTlvs {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               match self {
+                       Self::Forward(tlvs) => tlvs.write(w),
+                       Self::Receive(tlvs) => tlvs.write(w),
+               }
+       }
+}