Add support for variable-length onion payload reads using TLV
[rust-lightning] / lightning / src / ln / features.rs
1 //! Lightning exposes sets of supported operations through "feature flags". This module includes
2 //! types to store those feature flags and query for specific flags.
3
4 use std::{cmp, fmt};
5 use std::result::Result;
6 use std::marker::PhantomData;
7
8 use ln::msgs::DecodeError;
9 use util::ser::{Readable, Writeable, Writer};
10
11 mod sealed { // You should just use the type aliases instead.
12         pub struct InitContext {}
13         pub struct NodeContext {}
14         pub struct ChannelContext {}
15
16         /// An internal trait capturing the various feature context types
17         pub trait Context {}
18         impl Context for InitContext {}
19         impl Context for NodeContext {}
20         impl Context for ChannelContext {}
21
22         pub trait DataLossProtect: Context {}
23         impl DataLossProtect for InitContext {}
24         impl DataLossProtect for NodeContext {}
25
26         pub trait InitialRoutingSync: Context {}
27         impl InitialRoutingSync for InitContext {}
28
29         pub trait UpfrontShutdownScript: Context {}
30         impl UpfrontShutdownScript for InitContext {}
31         impl UpfrontShutdownScript for NodeContext {}
32 }
33
34 /// Tracks the set of features which a node implements, templated by the context in which it
35 /// appears.
36 pub struct Features<T: sealed::Context> {
37         /// Note that, for convinience, flags is LITTLE endian (despite being big-endian on the wire)
38         flags: Vec<u8>,
39         mark: PhantomData<T>,
40 }
41
42 impl<T: sealed::Context> Clone for Features<T> {
43         fn clone(&self) -> Self {
44                 Self {
45                         flags: self.flags.clone(),
46                         mark: PhantomData,
47                 }
48         }
49 }
50 impl<T: sealed::Context> PartialEq for Features<T> {
51         fn eq(&self, o: &Self) -> bool {
52                 self.flags.eq(&o.flags)
53         }
54 }
55 impl<T: sealed::Context> fmt::Debug for Features<T> {
56         fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
57                 self.flags.fmt(fmt)
58         }
59 }
60
61 /// A feature message as it appears in an init message
62 pub type InitFeatures = Features<sealed::InitContext>;
63 /// A feature message as it appears in a node_announcement message
64 pub type NodeFeatures = Features<sealed::NodeContext>;
65 /// A feature message as it appears in a channel_announcement message
66 pub type ChannelFeatures = Features<sealed::ChannelContext>;
67
68 impl InitFeatures {
69         /// Create a Features with the features we support
70         pub fn supported() -> InitFeatures {
71                 InitFeatures {
72                         flags: vec![2 | 1 << 5],
73                         mark: PhantomData,
74                 }
75         }
76
77         /// Writes all features present up to, and including, 13.
78         pub(crate) fn write_up_to_13<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
79                 let len = cmp::min(2, self.flags.len());
80                 w.size_hint(len + 2);
81                 (len as u16).write(w)?;
82                 for i in (0..len).rev() {
83                         if i == 0 {
84                                 self.flags[i].write(w)?;
85                         } else {
86                                 // On byte 1, we want up-to-and-including-bit-13, 0-indexed, which is
87                                 // up-to-and-including-bit-5, 0-indexed, on this byte:
88                                 (self.flags[i] & 0b00_11_11_11).write(w)?;
89                         }
90                 }
91                 Ok(())
92         }
93
94         /// or's another InitFeatures into this one.
95         pub(crate) fn or(mut self, o: InitFeatures) -> InitFeatures {
96                 let total_feature_len = cmp::max(self.flags.len(), o.flags.len());
97                 self.flags.resize(total_feature_len, 0u8);
98                 for (byte, o_byte) in self.flags.iter_mut().zip(o.flags.iter()) {
99                         *byte |= *o_byte;
100                 }
101                 self
102         }
103 }
104
105 impl ChannelFeatures {
106         /// Create a Features with the features we support
107         #[cfg(not(feature = "fuzztarget"))]
108         pub(crate) fn supported() -> ChannelFeatures {
109                 ChannelFeatures {
110                         flags: Vec::new(),
111                         mark: PhantomData,
112                 }
113         }
114         #[cfg(feature = "fuzztarget")]
115         pub fn supported() -> ChannelFeatures {
116                 ChannelFeatures {
117                         flags: Vec::new(),
118                         mark: PhantomData,
119                 }
120         }
121
122         /// Takes the flags that we know how to interpret in an init-context features that are also
123         /// relevant in a channel-context features and creates a channel-context features from them.
124         pub(crate) fn with_known_relevant_init_flags(_init_ctx: &InitFeatures) -> Self {
125                 // There are currently no channel flags defined that we understand.
126                 Self { flags: Vec::new(), mark: PhantomData, }
127         }
128 }
129
130 impl NodeFeatures {
131         /// Create a Features with the features we support
132         #[cfg(not(feature = "fuzztarget"))]
133         pub(crate) fn supported() -> NodeFeatures {
134                 NodeFeatures {
135                         flags: vec![2 | 1 << 5],
136                         mark: PhantomData,
137                 }
138         }
139         #[cfg(feature = "fuzztarget")]
140         pub fn supported() -> NodeFeatures {
141                 NodeFeatures {
142                         flags: vec![2 | 1 << 5],
143                         mark: PhantomData,
144                 }
145         }
146
147         /// Takes the flags that we know how to interpret in an init-context features that are also
148         /// relevant in a node-context features and creates a node-context features from them.
149         pub(crate) fn with_known_relevant_init_flags(init_ctx: &InitFeatures) -> Self {
150                 let mut flags = Vec::new();
151                 if init_ctx.flags.len() > 0 {
152                         // Pull out data_loss_protect and upfront_shutdown_script (bits 0, 1, 4, and 5)
153                         flags.push(init_ctx.flags.last().unwrap() & 0b00110011);
154                 }
155                 Self { flags, mark: PhantomData, }
156         }
157 }
158
159 impl<T: sealed::Context> Features<T> {
160         /// Create a blank Features with no features set
161         pub fn empty() -> Features<T> {
162                 Features {
163                         flags: Vec::new(),
164                         mark: PhantomData,
165                 }
166         }
167
168         #[cfg(test)]
169         /// Create a Features given a set of flags, in LE.
170         pub fn from_le_bytes(flags: Vec<u8>) -> Features<T> {
171                 Features {
172                         flags,
173                         mark: PhantomData,
174                 }
175         }
176
177         #[cfg(test)]
178         /// Gets the underlying flags set, in LE.
179         pub fn le_flags(&self) -> &Vec<u8> {
180                 &self.flags
181         }
182
183         pub(crate) fn requires_unknown_bits(&self) -> bool {
184                 self.flags.iter().enumerate().any(|(idx, &byte)| {
185                         (match idx {
186                                 0 => (byte & 0b00010100),
187                                 1 => (byte & 0b01010100),
188                                 _ => (byte & 0b01010101),
189                         }) != 0
190                 })
191         }
192
193         pub(crate) fn supports_unknown_bits(&self) -> bool {
194                 self.flags.iter().enumerate().any(|(idx, &byte)| {
195                         (match idx {
196                                 0 => (byte & 0b11000100),
197                                 1 => (byte & 0b11111100),
198                                 _ => byte,
199                         }) != 0
200                 })
201         }
202
203         /// The number of bytes required to represent the feature flags present. This does not include
204         /// the length bytes which are included in the serialized form.
205         pub(crate) fn byte_count(&self) -> usize {
206                 self.flags.len()
207         }
208
209         #[cfg(test)]
210         pub(crate) fn set_require_unknown_bits(&mut self) {
211                 let newlen = cmp::max(2, self.flags.len());
212                 self.flags.resize(newlen, 0u8);
213                 self.flags[1] |= 0x40;
214         }
215
216         #[cfg(test)]
217         pub(crate) fn clear_require_unknown_bits(&mut self) {
218                 let newlen = cmp::max(2, self.flags.len());
219                 self.flags.resize(newlen, 0u8);
220                 self.flags[1] &= !0x40;
221                 if self.flags.len() == 2 && self.flags[1] == 0 {
222                         self.flags.resize(1, 0u8);
223                 }
224         }
225 }
226
227 impl<T: sealed::DataLossProtect> Features<T> {
228         pub(crate) fn supports_data_loss_protect(&self) -> bool {
229                 self.flags.len() > 0 && (self.flags[0] & 3) != 0
230         }
231 }
232
233 impl<T: sealed::UpfrontShutdownScript> Features<T> {
234         pub(crate) fn supports_upfront_shutdown_script(&self) -> bool {
235                 self.flags.len() > 0 && (self.flags[0] & (3 << 4)) != 0
236         }
237         #[cfg(test)]
238         pub(crate) fn unset_upfront_shutdown_script(&mut self) {
239                 self.flags[0] ^= 1 << 5;
240         }
241 }
242
243 impl<T: sealed::InitialRoutingSync> Features<T> {
244         pub(crate) fn initial_routing_sync(&self) -> bool {
245                 self.flags.len() > 0 && (self.flags[0] & (1 << 3)) != 0
246         }
247         pub(crate) fn set_initial_routing_sync(&mut self) {
248                 if self.flags.len() == 0 {
249                         self.flags.resize(1, 1 << 3);
250                 } else {
251                         self.flags[0] |= 1 << 3;
252                 }
253         }
254 }
255
256 impl<T: sealed::Context> Writeable for Features<T> {
257         fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
258                 w.size_hint(self.flags.len() + 2);
259                 (self.flags.len() as u16).write(w)?;
260                 for f in self.flags.iter().rev() { // Swap back to big-endian
261                         f.write(w)?;
262                 }
263                 Ok(())
264         }
265 }
266
267 impl<R: ::std::io::Read, T: sealed::Context> Readable<R> for Features<T> {
268         fn read(r: &mut R) -> Result<Self, DecodeError> {
269                 let mut flags: Vec<u8> = Readable::read(r)?;
270                 flags.reverse(); // Swap to little-endian
271                 Ok(Self {
272                         flags,
273                         mark: PhantomData,
274                 })
275         }
276 }