Unset upfront_shutdown_script using bit clearing
[rust-lightning] / lightning / src / ln / features.rs
1 //! Lightning exposes sets of supported operations through "feature flags". This module includes
2 //! types to store those feature flags and query for specific flags.
3
4 use std::{cmp, fmt};
5 use std::result::Result;
6 use std::marker::PhantomData;
7
8 use ln::msgs::DecodeError;
9 use util::ser::{Readable, Writeable, Writer};
10
11 mod sealed { // You should just use the type aliases instead.
12         pub struct InitContext {}
13         pub struct NodeContext {}
14         pub struct ChannelContext {}
15
16         /// An internal trait capturing the various feature context types
17         pub trait Context {}
18         impl Context for InitContext {}
19         impl Context for NodeContext {}
20         impl Context for ChannelContext {}
21
22         pub trait DataLossProtect: Context {}
23         impl DataLossProtect for InitContext {}
24         impl DataLossProtect for NodeContext {}
25
26         pub trait InitialRoutingSync: Context {}
27         impl InitialRoutingSync for InitContext {}
28
29         pub trait UpfrontShutdownScript: Context {}
30         impl UpfrontShutdownScript for InitContext {}
31         impl UpfrontShutdownScript for NodeContext {}
32
33         pub trait VariableLengthOnion: Context {}
34         impl VariableLengthOnion for InitContext {}
35         impl VariableLengthOnion for NodeContext {}
36
37         pub trait PaymentSecret: Context {}
38         impl PaymentSecret for InitContext {}
39         impl PaymentSecret for NodeContext {}
40
41         pub trait BasicMPP: Context {}
42         impl BasicMPP for InitContext {}
43         impl BasicMPP for NodeContext {}
44 }
45
46 /// Tracks the set of features which a node implements, templated by the context in which it
47 /// appears.
48 pub struct Features<T: sealed::Context> {
49         /// Note that, for convinience, flags is LITTLE endian (despite being big-endian on the wire)
50         flags: Vec<u8>,
51         mark: PhantomData<T>,
52 }
53
54 impl<T: sealed::Context> Clone for Features<T> {
55         fn clone(&self) -> Self {
56                 Self {
57                         flags: self.flags.clone(),
58                         mark: PhantomData,
59                 }
60         }
61 }
62 impl<T: sealed::Context> PartialEq for Features<T> {
63         fn eq(&self, o: &Self) -> bool {
64                 self.flags.eq(&o.flags)
65         }
66 }
67 impl<T: sealed::Context> fmt::Debug for Features<T> {
68         fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
69                 self.flags.fmt(fmt)
70         }
71 }
72
73 /// A feature message as it appears in an init message
74 pub type InitFeatures = Features<sealed::InitContext>;
75 /// A feature message as it appears in a node_announcement message
76 pub type NodeFeatures = Features<sealed::NodeContext>;
77 /// A feature message as it appears in a channel_announcement message
78 pub type ChannelFeatures = Features<sealed::ChannelContext>;
79
80 impl InitFeatures {
81         /// Create a Features with the features we support
82         pub fn supported() -> InitFeatures {
83                 InitFeatures {
84                         flags: vec![2 | 1 << 5, 1 << (9-8) | 1 << (15 - 8), 1 << (17 - 8*2)],
85                         mark: PhantomData,
86                 }
87         }
88
89         /// Writes all features present up to, and including, 13.
90         pub(crate) fn write_up_to_13<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
91                 let len = cmp::min(2, self.flags.len());
92                 w.size_hint(len + 2);
93                 (len as u16).write(w)?;
94                 for i in (0..len).rev() {
95                         if i == 0 {
96                                 self.flags[i].write(w)?;
97                         } else {
98                                 // On byte 1, we want up-to-and-including-bit-13, 0-indexed, which is
99                                 // up-to-and-including-bit-5, 0-indexed, on this byte:
100                                 (self.flags[i] & 0b00_11_11_11).write(w)?;
101                         }
102                 }
103                 Ok(())
104         }
105
106         /// or's another InitFeatures into this one.
107         pub(crate) fn or(mut self, o: InitFeatures) -> InitFeatures {
108                 let total_feature_len = cmp::max(self.flags.len(), o.flags.len());
109                 self.flags.resize(total_feature_len, 0u8);
110                 for (byte, o_byte) in self.flags.iter_mut().zip(o.flags.iter()) {
111                         *byte |= *o_byte;
112                 }
113                 self
114         }
115 }
116
117 impl ChannelFeatures {
118         /// Create a Features with the features we support
119         #[cfg(not(feature = "fuzztarget"))]
120         pub(crate) fn supported() -> ChannelFeatures {
121                 ChannelFeatures {
122                         flags: Vec::new(),
123                         mark: PhantomData,
124                 }
125         }
126         #[cfg(feature = "fuzztarget")]
127         pub fn supported() -> ChannelFeatures {
128                 ChannelFeatures {
129                         flags: Vec::new(),
130                         mark: PhantomData,
131                 }
132         }
133
134         /// Takes the flags that we know how to interpret in an init-context features that are also
135         /// relevant in a channel-context features and creates a channel-context features from them.
136         pub(crate) fn with_known_relevant_init_flags(_init_ctx: &InitFeatures) -> Self {
137                 // There are currently no channel flags defined that we understand.
138                 Self { flags: Vec::new(), mark: PhantomData, }
139         }
140 }
141
142 impl NodeFeatures {
143         /// Create a Features with the features we support
144         #[cfg(not(feature = "fuzztarget"))]
145         pub(crate) fn supported() -> NodeFeatures {
146                 NodeFeatures {
147                         flags: vec![2 | 1 << 5, 1 << (9 - 8) | 1 << (15 - 8), 1 << (17 - 8*2)],
148                         mark: PhantomData,
149                 }
150         }
151         #[cfg(feature = "fuzztarget")]
152         pub fn supported() -> NodeFeatures {
153                 NodeFeatures {
154                         flags: vec![2 | 1 << 5, 1 << (9 - 8) | 1 << (15 - 8), 1 << (17 - 8*2)],
155                         mark: PhantomData,
156                 }
157         }
158
159         /// Takes the flags that we know how to interpret in an init-context features that are also
160         /// relevant in a node-context features and creates a node-context features from them.
161         /// Be sure to blank out features that are unknown to us.
162         pub(crate) fn with_known_relevant_init_flags(init_ctx: &InitFeatures) -> Self {
163                 let mut flags = Vec::new();
164                 for (i, feature_byte)in init_ctx.flags.iter().enumerate() {
165                         match i {
166                                 // Blank out initial_routing_sync (feature bits 2/3), gossip_queries (6/7),
167                                 // gossip_queries_ex (10/11), option_static_remotekey (12/13), and
168                                 // option_support_large_channel (16/17)
169                                 0 => flags.push(feature_byte & 0b00110011),
170                                 1 => flags.push(feature_byte & 0b11000011),
171                                 2 => flags.push(feature_byte & 0b00000011),
172                                 _ => (),
173                         }
174                 }
175                 Self { flags, mark: PhantomData, }
176         }
177 }
178
179 impl<T: sealed::Context> Features<T> {
180         /// Create a blank Features with no features set
181         pub fn empty() -> Features<T> {
182                 Features {
183                         flags: Vec::new(),
184                         mark: PhantomData,
185                 }
186         }
187
188         #[cfg(test)]
189         /// Create a Features given a set of flags, in LE.
190         pub fn from_le_bytes(flags: Vec<u8>) -> Features<T> {
191                 Features {
192                         flags,
193                         mark: PhantomData,
194                 }
195         }
196
197         #[cfg(test)]
198         /// Gets the underlying flags set, in LE.
199         pub fn le_flags(&self) -> &Vec<u8> {
200                 &self.flags
201         }
202
203         pub(crate) fn requires_unknown_bits(&self) -> bool {
204                 self.flags.iter().enumerate().any(|(idx, &byte)| {
205                         (match idx {
206                                 // Unknown bits are even bits which we don't understand, we list ones which we do
207                                 // here:
208                                 // unknown, upfront_shutdown_script, unknown (actually initial_routing_sync, but it
209                                 // is only valid as an optional feature), and data_loss_protect:
210                                 0 => (byte & 0b01000100),
211                                 // payment_secret, unknown, unknown, var_onion_optin:
212                                 1 => (byte & 0b00010100),
213                                 // unknown, unknown, unknown, basic_mpp:
214                                 2 => (byte & 0b01010100),
215                                 // fallback, all even bits set:
216                                 _ => (byte & 0b01010101),
217                         }) != 0
218                 })
219         }
220
221         pub(crate) fn supports_unknown_bits(&self) -> bool {
222                 self.flags.iter().enumerate().any(|(idx, &byte)| {
223                         (match idx {
224                                 // unknown, upfront_shutdown_script, initial_routing_sync (is only valid as an
225                                 // optional feature), and data_loss_protect:
226                                 0 => (byte & 0b11000100),
227                                 // payment_secret, unknown, unknown, var_onion_optin:
228                                 1 => (byte & 0b00111100),
229                                 // unknown, unknown, unknown, basic_mpp:
230                                 2 => (byte & 0b11111100),
231                                 _ => byte,
232                         }) != 0
233                 })
234         }
235
236         /// The number of bytes required to represent the feature flags present. This does not include
237         /// the length bytes which are included in the serialized form.
238         pub(crate) fn byte_count(&self) -> usize {
239                 self.flags.len()
240         }
241
242         #[cfg(test)]
243         pub(crate) fn set_require_unknown_bits(&mut self) {
244                 let newlen = cmp::max(3, self.flags.len());
245                 self.flags.resize(newlen, 0u8);
246                 self.flags[2] |= 0x40;
247         }
248
249         #[cfg(test)]
250         pub(crate) fn clear_require_unknown_bits(&mut self) {
251                 let newlen = cmp::max(3, self.flags.len());
252                 self.flags.resize(newlen, 0u8);
253                 self.flags[2] &= !0x40;
254                 if self.flags.len() == 3 && self.flags[2] == 0 {
255                         self.flags.resize(2, 0u8);
256                 }
257                 if self.flags.len() == 2 && self.flags[1] == 0 {
258                         self.flags.resize(1, 0u8);
259                 }
260         }
261 }
262
263 impl<T: sealed::DataLossProtect> Features<T> {
264         pub(crate) fn supports_data_loss_protect(&self) -> bool {
265                 self.flags.len() > 0 && (self.flags[0] & 3) != 0
266         }
267 }
268
269 impl<T: sealed::UpfrontShutdownScript> Features<T> {
270         pub(crate) fn supports_upfront_shutdown_script(&self) -> bool {
271                 self.flags.len() > 0 && (self.flags[0] & (3 << 4)) != 0
272         }
273         #[cfg(test)]
274         pub(crate) fn unset_upfront_shutdown_script(&mut self) {
275                 self.flags[0] &= !(1 << 5);
276         }
277 }
278
279 impl<T: sealed::VariableLengthOnion> Features<T> {
280         pub(crate) fn supports_variable_length_onion(&self) -> bool {
281                 self.flags.len() > 1 && (self.flags[1] & 3) != 0
282         }
283 }
284
285 impl<T: sealed::InitialRoutingSync> Features<T> {
286         pub(crate) fn initial_routing_sync(&self) -> bool {
287                 self.flags.len() > 0 && (self.flags[0] & (1 << 3)) != 0
288         }
289         pub(crate) fn set_initial_routing_sync(&mut self) {
290                 if self.flags.len() == 0 {
291                         self.flags.resize(1, 1 << 3);
292                 } else {
293                         self.flags[0] |= 1 << 3;
294                 }
295         }
296 }
297
298 impl<T: sealed::PaymentSecret> Features<T> {
299         #[allow(dead_code)]
300         // Note that we never need to test this since what really matters is the invoice - iff the
301         // invoice provides a payment_secret, we assume that we can use it (ie that the recipient
302         // supports payment_secret).
303         pub(crate) fn supports_payment_secret(&self) -> bool {
304                 self.flags.len() > 1 && (self.flags[1] & (3 << (14-8))) != 0
305         }
306 }
307
308 impl<T: sealed::BasicMPP> Features<T> {
309         // We currently never test for this since we don't actually *generate* multipath routes.
310         #[allow(dead_code)]
311         pub(crate) fn supports_basic_mpp(&self) -> bool {
312                 self.flags.len() > 2 && (self.flags[2] & (3 << (16-8*2))) != 0
313         }
314 }
315
316 impl<T: sealed::Context> Writeable for Features<T> {
317         fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
318                 w.size_hint(self.flags.len() + 2);
319                 (self.flags.len() as u16).write(w)?;
320                 for f in self.flags.iter().rev() { // Swap back to big-endian
321                         f.write(w)?;
322                 }
323                 Ok(())
324         }
325 }
326
327 impl<T: sealed::Context> Readable for Features<T> {
328         fn read<R: ::std::io::Read>(r: &mut R) -> Result<Self, DecodeError> {
329                 let mut flags: Vec<u8> = Readable::read(r)?;
330                 flags.reverse(); // Swap to little-endian
331                 Ok(Self {
332                         flags,
333                         mark: PhantomData,
334                 })
335         }
336 }
337
338 #[cfg(test)]
339 mod tests {
340         use super::{ChannelFeatures, InitFeatures, NodeFeatures, Features};
341
342         #[test]
343         fn sanity_test_our_features() {
344                 assert!(!ChannelFeatures::supported().requires_unknown_bits());
345                 assert!(!ChannelFeatures::supported().supports_unknown_bits());
346                 assert!(!InitFeatures::supported().requires_unknown_bits());
347                 assert!(!InitFeatures::supported().supports_unknown_bits());
348                 assert!(!NodeFeatures::supported().requires_unknown_bits());
349                 assert!(!NodeFeatures::supported().supports_unknown_bits());
350
351                 assert!(InitFeatures::supported().supports_upfront_shutdown_script());
352                 assert!(NodeFeatures::supported().supports_upfront_shutdown_script());
353
354                 assert!(InitFeatures::supported().supports_data_loss_protect());
355                 assert!(NodeFeatures::supported().supports_data_loss_protect());
356
357                 assert!(InitFeatures::supported().supports_variable_length_onion());
358                 assert!(NodeFeatures::supported().supports_variable_length_onion());
359
360                 assert!(InitFeatures::supported().supports_payment_secret());
361                 assert!(NodeFeatures::supported().supports_payment_secret());
362
363                 assert!(InitFeatures::supported().supports_basic_mpp());
364                 assert!(NodeFeatures::supported().supports_basic_mpp());
365
366                 let mut init_features = InitFeatures::supported();
367                 init_features.set_initial_routing_sync();
368                 assert!(!init_features.requires_unknown_bits());
369                 assert!(!init_features.supports_unknown_bits());
370         }
371
372         #[test]
373         fn sanity_test_unkown_bits_testing() {
374                 let mut features = ChannelFeatures::supported();
375                 features.set_require_unknown_bits();
376                 assert!(features.requires_unknown_bits());
377                 features.clear_require_unknown_bits();
378                 assert!(!features.requires_unknown_bits());
379         }
380
381         #[test]
382         fn test_node_with_known_relevant_init_flags() {
383                 // Create an InitFeatures with initial_routing_sync supported.
384                 let mut init_features = InitFeatures::supported();
385                 init_features.set_initial_routing_sync();
386
387                 // Attempt to pull out non-node-context feature flags from these InitFeatures.
388                 let res = NodeFeatures::with_known_relevant_init_flags(&init_features);
389
390                 {
391                         // Check that the flags are as expected: optional_data_loss_protect,
392                         // option_upfront_shutdown_script, var_onion_optin, payment_secret, and
393                         // basic_mpp.
394                         assert_eq!(res.flags.len(), 3);
395                         assert_eq!(res.flags[0], 0b00100010);
396                         assert_eq!(res.flags[1], 0b10000010);
397                         assert_eq!(res.flags[2], 0b00000010);
398                 }
399
400                 // Check that the initial_routing_sync feature was correctly blanked out.
401                 let new_features: InitFeatures = Features::from_le_bytes(res.flags);
402                 assert!(!new_features.initial_routing_sync());
403         }
404 }