Add/announce features for payment_secret and basic_mpp
[rust-lightning] / lightning / src / ln / features.rs
1 //! Lightning exposes sets of supported operations through "feature flags". This module includes
2 //! types to store those feature flags and query for specific flags.
3
4 use std::{cmp, fmt};
5 use std::result::Result;
6 use std::marker::PhantomData;
7
8 use ln::msgs::DecodeError;
9 use util::ser::{Readable, Writeable, Writer};
10
11 mod sealed { // You should just use the type aliases instead.
12         pub struct InitContext {}
13         pub struct NodeContext {}
14         pub struct ChannelContext {}
15
16         /// An internal trait capturing the various feature context types
17         pub trait Context {}
18         impl Context for InitContext {}
19         impl Context for NodeContext {}
20         impl Context for ChannelContext {}
21
22         pub trait DataLossProtect: Context {}
23         impl DataLossProtect for InitContext {}
24         impl DataLossProtect for NodeContext {}
25
26         pub trait InitialRoutingSync: Context {}
27         impl InitialRoutingSync for InitContext {}
28
29         pub trait UpfrontShutdownScript: Context {}
30         impl UpfrontShutdownScript for InitContext {}
31         impl UpfrontShutdownScript for NodeContext {}
32
33         pub trait VariableLengthOnion: Context {}
34         impl VariableLengthOnion for InitContext {}
35         impl VariableLengthOnion for NodeContext {}
36
37         pub trait PaymentSecret: Context {}
38         impl PaymentSecret for InitContext {}
39         impl PaymentSecret for NodeContext {}
40
41         pub trait BasicMPP: Context {}
42         impl BasicMPP for InitContext {}
43         impl BasicMPP for NodeContext {}
44 }
45
46 /// Tracks the set of features which a node implements, templated by the context in which it
47 /// appears.
48 pub struct Features<T: sealed::Context> {
49         /// Note that, for convinience, flags is LITTLE endian (despite being big-endian on the wire)
50         flags: Vec<u8>,
51         mark: PhantomData<T>,
52 }
53
54 impl<T: sealed::Context> Clone for Features<T> {
55         fn clone(&self) -> Self {
56                 Self {
57                         flags: self.flags.clone(),
58                         mark: PhantomData,
59                 }
60         }
61 }
62 impl<T: sealed::Context> PartialEq for Features<T> {
63         fn eq(&self, o: &Self) -> bool {
64                 self.flags.eq(&o.flags)
65         }
66 }
67 impl<T: sealed::Context> fmt::Debug for Features<T> {
68         fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
69                 self.flags.fmt(fmt)
70         }
71 }
72
73 /// A feature message as it appears in an init message
74 pub type InitFeatures = Features<sealed::InitContext>;
75 /// A feature message as it appears in a node_announcement message
76 pub type NodeFeatures = Features<sealed::NodeContext>;
77 /// A feature message as it appears in a channel_announcement message
78 pub type ChannelFeatures = Features<sealed::ChannelContext>;
79
80 impl InitFeatures {
81         /// Create a Features with the features we support
82         pub fn supported() -> InitFeatures {
83                 InitFeatures {
84                         flags: vec![2 | 1 << 5, 1 << (9-8) | 1 << (15 - 8), 1 << (17 - 8*2)],
85                         mark: PhantomData,
86                 }
87         }
88
89         /// Writes all features present up to, and including, 13.
90         pub(crate) fn write_up_to_13<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
91                 let len = cmp::min(2, self.flags.len());
92                 w.size_hint(len + 2);
93                 (len as u16).write(w)?;
94                 for i in (0..len).rev() {
95                         if i == 0 {
96                                 self.flags[i].write(w)?;
97                         } else {
98                                 // On byte 1, we want up-to-and-including-bit-13, 0-indexed, which is
99                                 // up-to-and-including-bit-5, 0-indexed, on this byte:
100                                 (self.flags[i] & 0b00_11_11_11).write(w)?;
101                         }
102                 }
103                 Ok(())
104         }
105
106         /// or's another InitFeatures into this one.
107         pub(crate) fn or(mut self, o: InitFeatures) -> InitFeatures {
108                 let total_feature_len = cmp::max(self.flags.len(), o.flags.len());
109                 self.flags.resize(total_feature_len, 0u8);
110                 for (byte, o_byte) in self.flags.iter_mut().zip(o.flags.iter()) {
111                         *byte |= *o_byte;
112                 }
113                 self
114         }
115 }
116
117 impl ChannelFeatures {
118         /// Create a Features with the features we support
119         #[cfg(not(feature = "fuzztarget"))]
120         pub(crate) fn supported() -> ChannelFeatures {
121                 ChannelFeatures {
122                         flags: Vec::new(),
123                         mark: PhantomData,
124                 }
125         }
126         #[cfg(feature = "fuzztarget")]
127         pub fn supported() -> ChannelFeatures {
128                 ChannelFeatures {
129                         flags: Vec::new(),
130                         mark: PhantomData,
131                 }
132         }
133
134         /// Takes the flags that we know how to interpret in an init-context features that are also
135         /// relevant in a channel-context features and creates a channel-context features from them.
136         pub(crate) fn with_known_relevant_init_flags(_init_ctx: &InitFeatures) -> Self {
137                 // There are currently no channel flags defined that we understand.
138                 Self { flags: Vec::new(), mark: PhantomData, }
139         }
140 }
141
142 impl NodeFeatures {
143         /// Create a Features with the features we support
144         #[cfg(not(feature = "fuzztarget"))]
145         pub(crate) fn supported() -> NodeFeatures {
146                 NodeFeatures {
147                         flags: vec![2 | 1 << 5, 1 << (9 - 8) | 1 << (15 - 8), 1 << (17 - 8*2)],
148                         mark: PhantomData,
149                 }
150         }
151         #[cfg(feature = "fuzztarget")]
152         pub fn supported() -> NodeFeatures {
153                 NodeFeatures {
154                         flags: vec![2 | 1 << 5, 1 << (9 - 8) | 1 << (15 - 8), 1 << (17 - 8*2)],
155                         mark: PhantomData,
156                 }
157         }
158
159         /// Takes the flags that we know how to interpret in an init-context features that are also
160         /// relevant in a node-context features and creates a node-context features from them.
161         /// Be sure to blank out features that are unknown to us.
162         pub(crate) fn with_known_relevant_init_flags(init_ctx: &InitFeatures) -> Self {
163                 let mut flags = Vec::new();
164                 for (i, feature_byte)in init_ctx.flags.iter().enumerate() {
165                         match i {
166                                 // Blank out initial_routing_sync (feature bits 2/3), gossip_queries (6/7),
167                                 // gossip_queries_ex (10/11), option_static_remotekey (12/13), and
168                                 // payment_secret (14/15)
169                                 0 => flags.push(feature_byte & 0b00110011),
170                                 1 => flags.push(feature_byte & 0b00000011),
171                                 _ => (),
172                         }
173                 }
174                 Self { flags, mark: PhantomData, }
175         }
176 }
177
178 impl<T: sealed::Context> Features<T> {
179         /// Create a blank Features with no features set
180         pub fn empty() -> Features<T> {
181                 Features {
182                         flags: Vec::new(),
183                         mark: PhantomData,
184                 }
185         }
186
187         #[cfg(test)]
188         /// Create a Features given a set of flags, in LE.
189         pub fn from_le_bytes(flags: Vec<u8>) -> Features<T> {
190                 Features {
191                         flags,
192                         mark: PhantomData,
193                 }
194         }
195
196         #[cfg(test)]
197         /// Gets the underlying flags set, in LE.
198         pub fn le_flags(&self) -> &Vec<u8> {
199                 &self.flags
200         }
201
202         pub(crate) fn requires_unknown_bits(&self) -> bool {
203                 self.flags.iter().enumerate().any(|(idx, &byte)| {
204                         (match idx {
205                                 // Unknown bits are even bits which we don't understand, we list ones which we do
206                                 // here:
207                                 // unknown, upfront_shutdown_script, unknown (actually initial_routing_sync, but it
208                                 // is only valid as an optional feature), and data_loss_protect:
209                                 0 => (byte & 0b01000100),
210                                 // payment_secret, unknown, unknown, var_onion_optin:
211                                 1 => (byte & 0b00010100),
212                                 // unknown, unknown, unknown, basic_mpp:
213                                 2 => (byte & 0b01010100),
214                                 // fallback, all even bits set:
215                                 _ => (byte & 0b01010101),
216                         }) != 0
217                 })
218         }
219
220         pub(crate) fn supports_unknown_bits(&self) -> bool {
221                 self.flags.iter().enumerate().any(|(idx, &byte)| {
222                         (match idx {
223                                 // unknown, upfront_shutdown_script, initial_routing_sync (is only valid as an
224                                 // optional feature), and data_loss_protect:
225                                 0 => (byte & 0b11000100),
226                                 // payment_secret, unknown, unknown, var_onion_optin:
227                                 1 => (byte & 0b00111100),
228                                 // unknown, unknown, unknown, basic_mpp:
229                                 2 => (byte & 0b11111100),
230                                 _ => byte,
231                         }) != 0
232                 })
233         }
234
235         /// The number of bytes required to represent the feature flags present. This does not include
236         /// the length bytes which are included in the serialized form.
237         pub(crate) fn byte_count(&self) -> usize {
238                 self.flags.len()
239         }
240
241         #[cfg(test)]
242         pub(crate) fn set_require_unknown_bits(&mut self) {
243                 let newlen = cmp::max(3, self.flags.len());
244                 self.flags.resize(newlen, 0u8);
245                 self.flags[2] |= 0x40;
246         }
247
248         #[cfg(test)]
249         pub(crate) fn clear_require_unknown_bits(&mut self) {
250                 let newlen = cmp::max(3, self.flags.len());
251                 self.flags.resize(newlen, 0u8);
252                 self.flags[2] &= !0x40;
253                 if self.flags.len() == 3 && self.flags[2] == 0 {
254                         self.flags.resize(2, 0u8);
255                 }
256                 if self.flags.len() == 2 && self.flags[1] == 0 {
257                         self.flags.resize(1, 0u8);
258                 }
259         }
260 }
261
262 impl<T: sealed::DataLossProtect> Features<T> {
263         pub(crate) fn supports_data_loss_protect(&self) -> bool {
264                 self.flags.len() > 0 && (self.flags[0] & 3) != 0
265         }
266 }
267
268 impl<T: sealed::UpfrontShutdownScript> Features<T> {
269         pub(crate) fn supports_upfront_shutdown_script(&self) -> bool {
270                 self.flags.len() > 0 && (self.flags[0] & (3 << 4)) != 0
271         }
272         #[cfg(test)]
273         pub(crate) fn unset_upfront_shutdown_script(&mut self) {
274                 self.flags[0] ^= 1 << 5;
275         }
276 }
277
278 impl<T: sealed::VariableLengthOnion> Features<T> {
279         pub(crate) fn supports_variable_length_onion(&self) -> bool {
280                 self.flags.len() > 1 && (self.flags[1] & 3) != 0
281         }
282 }
283
284 impl<T: sealed::InitialRoutingSync> Features<T> {
285         pub(crate) fn initial_routing_sync(&self) -> bool {
286                 self.flags.len() > 0 && (self.flags[0] & (1 << 3)) != 0
287         }
288         pub(crate) fn set_initial_routing_sync(&mut self) {
289                 if self.flags.len() == 0 {
290                         self.flags.resize(1, 1 << 3);
291                 } else {
292                         self.flags[0] |= 1 << 3;
293                 }
294         }
295 }
296
297 impl<T: sealed::PaymentSecret> Features<T> {
298         #[allow(dead_code)]
299         // Note that we never need to test this since what really matters is the invoice - iff the
300         // invoice provides a payment_secret, we assume that we can use it (ie that the recipient
301         // supports payment_secret).
302         pub(crate) fn payment_secret(&self) -> bool {
303                 self.flags.len() > 1 && (self.flags[1] & (3 << (14-8))) != 0
304         }
305 }
306
307 impl<T: sealed::BasicMPP> Features<T> {
308         // We currently never test for this since we don't actually *generate* multipath routes.
309         #[allow(dead_code)]
310         pub(crate) fn basic_mpp(&self) -> bool {
311                 self.flags.len() > 2 && (self.flags[2] & (3 << (16-8*2))) != 0
312         }
313 }
314
315 impl<T: sealed::Context> Writeable for Features<T> {
316         fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
317                 w.size_hint(self.flags.len() + 2);
318                 (self.flags.len() as u16).write(w)?;
319                 for f in self.flags.iter().rev() { // Swap back to big-endian
320                         f.write(w)?;
321                 }
322                 Ok(())
323         }
324 }
325
326 impl<T: sealed::Context> Readable for Features<T> {
327         fn read<R: ::std::io::Read>(r: &mut R) -> Result<Self, DecodeError> {
328                 let mut flags: Vec<u8> = Readable::read(r)?;
329                 flags.reverse(); // Swap to little-endian
330                 Ok(Self {
331                         flags,
332                         mark: PhantomData,
333                 })
334         }
335 }
336
337 #[cfg(test)]
338 mod tests {
339         use super::{ChannelFeatures, InitFeatures, NodeFeatures, Features};
340
341         #[test]
342         fn sanity_test_our_features() {
343                 assert!(!ChannelFeatures::supported().requires_unknown_bits());
344                 assert!(!ChannelFeatures::supported().supports_unknown_bits());
345                 assert!(!InitFeatures::supported().requires_unknown_bits());
346                 assert!(!InitFeatures::supported().supports_unknown_bits());
347                 assert!(!NodeFeatures::supported().requires_unknown_bits());
348                 assert!(!NodeFeatures::supported().supports_unknown_bits());
349
350                 assert!(InitFeatures::supported().supports_upfront_shutdown_script());
351                 assert!(NodeFeatures::supported().supports_upfront_shutdown_script());
352
353                 assert!(InitFeatures::supported().supports_data_loss_protect());
354                 assert!(NodeFeatures::supported().supports_data_loss_protect());
355
356                 assert!(InitFeatures::supported().supports_variable_length_onion());
357                 assert!(NodeFeatures::supported().supports_variable_length_onion());
358
359                 let mut init_features = InitFeatures::supported();
360                 init_features.set_initial_routing_sync();
361                 assert!(!init_features.requires_unknown_bits());
362                 assert!(!init_features.supports_unknown_bits());
363         }
364
365         #[test]
366         fn sanity_test_unkown_bits_testing() {
367                 let mut features = ChannelFeatures::supported();
368                 features.set_require_unknown_bits();
369                 assert!(features.requires_unknown_bits());
370                 features.clear_require_unknown_bits();
371                 assert!(!features.requires_unknown_bits());
372         }
373
374         #[test]
375         fn test_node_with_known_relevant_init_flags() {
376                 // Create an InitFeatures with initial_routing_sync supported.
377                 let mut init_features = InitFeatures::supported();
378                 init_features.set_initial_routing_sync();
379
380                 // Attempt to pull out non-node-context feature flags from these InitFeatures.
381                 let res = NodeFeatures::with_known_relevant_init_flags(&init_features);
382
383                 {
384                         // Check that the flags are as expected: optional_data_loss_protect,
385                         // option_upfront_shutdown_script, and var_onion_optin set.
386                         assert_eq!(res.flags[0], 0b00100010);
387                         assert_eq!(res.flags[1], 0b00000010);
388                         assert_eq!(res.flags.len(), 2);
389                 }
390
391                 // Check that the initial_routing_sync feature was correctly blanked out.
392                 let new_features: InitFeatures = Features::from_le_bytes(res.flags);
393                 assert!(!new_features.initial_routing_sync());
394         }
395 }