WIP tokio 1 conversion
[dnsseed-rust] / src / bgp_client.rs
index 9038005a3d412aadc075bf26422c38ee32ed494b..17b6598936f7dc31e02fc696e000bd6b02eea4ce 100644 (file)
@@ -1,7 +1,7 @@
 use std::sync::{Arc, Mutex};
 use std::sync::atomic::{AtomicBool, Ordering};
 use std::cmp;
-use std::collections::HashMap;
+use std::collections::{HashMap, hash_map};
 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
 use std::time::{Duration, Instant};
 
@@ -11,25 +11,21 @@ use bgp_rs::Segment;
 use bgp_rs::Message;
 use bgp_rs::Reader;
 
-use tokio::prelude::*;
-use tokio::codec;
-use tokio::codec::Framed;
 use tokio::net::TcpStream;
-use tokio::timer::Delay;
-
-use futures::sync::mpsc;
+use tokio::time;
 
 use crate::printer::{Printer, Stat};
-use crate::timeout_stream::TimeoutStream;
 
 const PATH_SUFFIX_LEN: usize = 3;
 #[derive(Clone)]
-struct Route { // 32 bytes
+struct Route { // 32 bytes with a path id u32
        path_suffix: [u32; PATH_SUFFIX_LEN],
        path_len: u32,
        pref: u32,
        med: u32,
 }
+#[allow(dead_code)]
+const ROUTE_LEN: usize = 36 - std::mem::size_of::<(u32, Route)>();
 
 // To keep memory tight (and since we dont' need such close alignment), newtype the v4/v6 routing
 // table entries to make sure they are aligned to single bytes.
@@ -48,6 +44,10 @@ impl From<(Ipv4Addr, u8)> for V4Addr {
                }
        }
 }
+#[allow(dead_code)]
+const V4_ALIGN: usize = 1 - std::mem::align_of::<V4Addr>();
+#[allow(dead_code)]
+const V4_SIZE: usize = 5 - std::mem::size_of::<V4Addr>();
 
 #[repr(packed)]
 #[derive(PartialEq, Eq, Hash)]
@@ -63,17 +63,27 @@ impl From<(Ipv6Addr, u8)> for V6Addr {
                }
        }
 }
+#[allow(dead_code)]
+const V6_ALIGN: usize = 1 - std::mem::align_of::<V6Addr>();
+#[allow(dead_code)]
+const V6_SIZE: usize = 17 - std::mem::size_of::<V6Addr>();
 
 struct RoutingTable {
-       v4_table: HashMap<V4Addr, HashMap<u32, Route>>,
-       v6_table: HashMap<V6Addr, HashMap<u32, Route>>,
+       // We really want a HashMap for the values here, but they'll only ever contain a few entries,
+       // and Vecs are way more memory-effecient in that case.
+       v4_table: HashMap<V4Addr, Vec<(u32, Route)>>,
+       v6_table: HashMap<V6Addr, Vec<(u32, Route)>>,
+       max_paths: usize,
+       routes_with_max: usize,
 }
 
 impl RoutingTable {
        fn new() -> Self {
                Self {
-                       v4_table: HashMap::new(),
-                       v6_table: HashMap::new(),
+                       v4_table: HashMap::with_capacity(900_000),
+                       v6_table: HashMap::with_capacity(100_000),
+                       max_paths: 0,
+                       routes_with_max: 0,
                }
        }
 
@@ -83,9 +93,9 @@ impl RoutingTable {
                                //TODO: Optimize this (probably means making the tables btrees)!
                                let mut lookup = <$addrty>::from(($addr, $addr_bits));
                                for i in 0..$addr_bits {
-                                       if let Some(routes) = $table.get(&lookup).map(|hm| hm.values()) {
+                                       if let Some(routes) = $table.get(&lookup) {
                                                if routes.len() > 0 {
-                                                       return (lookup.pfxlen, routes.collect());
+                                                       return (lookup.pfxlen, routes.iter().map(|v| &v.1).collect());
                                                }
                                        }
                                        lookup.addr[lookup.addr.len() - (i/8) - 1] &= !(1u8 << (i % 8));
@@ -101,47 +111,92 @@ impl RoutingTable {
        }
 
        fn withdraw(&mut self, route: NLRIEncoding) {
+               macro_rules! remove {
+                       ($rt: expr, $v: expr, $id: expr) => { {
+                               match $rt.entry($v.into()) {
+                                       hash_map::Entry::Occupied(mut entry) => {
+                                               if entry.get().len() == self.max_paths {
+                                                       self.routes_with_max -= 1;
+                                                       if self.routes_with_max == 0 {
+                                                               self.max_paths = 0;
+                                                       }
+                                               }
+                                               entry.get_mut().retain(|e| e.0 != $id);
+                                               if entry.get_mut().is_empty() {
+                                                       entry.remove();
+                                               }
+                                       },
+                                       _ => {},
+                               }
+                       } }
+               }
                match route {
                        NLRIEncoding::IP(p) => {
                                let (ip, len) = <(IpAddr, u8)>::from(&p);
                                match ip {
-                                       IpAddr::V4(v4a) => self.v4_table.get_mut(&(v4a, len).into()).and_then(|hm| hm.remove(&0)),
-                                       IpAddr::V6(v6a) => self.v6_table.get_mut(&(v6a, len).into()).and_then(|hm| hm.remove(&0)),
+                                       IpAddr::V4(v4a) => remove!(self.v4_table, (v4a, len), 0),
+                                       IpAddr::V6(v6a) => remove!(self.v6_table, (v6a, len), 0),
                                }
                        },
                        NLRIEncoding::IP_WITH_PATH_ID((p, id)) => {
                                let (ip, len) = <(IpAddr, u8)>::from(&p);
                                match ip {
-                                       IpAddr::V4(v4a) => self.v4_table.get_mut(&(v4a, len).into()).and_then(|hm| hm.remove(&id)),
-                                       IpAddr::V6(v6a) => self.v6_table.get_mut(&(v6a, len).into()).and_then(|hm| hm.remove(&id)),
+                                       IpAddr::V4(v4a) => remove!(self.v4_table, (v4a, len), id),
+                                       IpAddr::V6(v6a) => remove!(self.v6_table, (v6a, len), id),
                                }
                        },
-                       NLRIEncoding::IP_MPLS(_) => None,
+                       NLRIEncoding::IP_MPLS(_) => (),
+                       NLRIEncoding::IP_MPLS_WITH_PATH_ID(_) => (),
+                       NLRIEncoding::IP_VPN_MPLS(_) => (),
+                       NLRIEncoding::L2VPN(_) => (),
                };
        }
 
        fn announce(&mut self, prefix: NLRIEncoding, route: Route) {
+               macro_rules! insert {
+                       ($rt: expr, $v: expr, $id: expr) => { {
+                               let old_max_paths = self.max_paths;
+                               let entry = $rt.entry($v.into()).or_insert_with(|| Vec::with_capacity(old_max_paths));
+                               let entry_had_max = entry.len() == self.max_paths;
+                               entry.retain(|e| e.0 != $id);
+                               if entry_had_max {
+                                       entry.reserve_exact(1);
+                               } else {
+                                       entry.reserve_exact(cmp::max(self.max_paths, entry.len() + 1) - entry.len());
+                               }
+                               entry.push(($id, route));
+                               if entry.len() > self.max_paths {
+                                       self.max_paths = entry.len();
+                                       self.routes_with_max = 1;
+                               } else if entry.len() == self.max_paths {
+                                       if !entry_had_max { self.routes_with_max += 1; }
+                               }
+                       } }
+               }
                match prefix {
                        NLRIEncoding::IP(p) => {
                                let (ip, len) = <(IpAddr, u8)>::from(&p);
                                match ip {
-                                       IpAddr::V4(v4a) => self.v4_table.entry((v4a, len).into()).or_insert(HashMap::new()).insert(0, route),
-                                       IpAddr::V6(v6a) => self.v6_table.entry((v6a, len).into()).or_insert(HashMap::new()).insert(0, route),
+                                       IpAddr::V4(v4a) => insert!(self.v4_table, (v4a, len), 0),
+                                       IpAddr::V6(v6a) => insert!(self.v6_table, (v6a, len), 0),
                                }
                        },
                        NLRIEncoding::IP_WITH_PATH_ID((p, id)) => {
                                let (ip, len) = <(IpAddr, u8)>::from(&p);
                                match ip {
-                                       IpAddr::V4(v4a) => self.v4_table.entry((v4a, len).into()).or_insert(HashMap::new()).insert(id, route),
-                                       IpAddr::V6(v6a) => self.v6_table.entry((v6a, len).into()).or_insert(HashMap::new()).insert(id, route),
+                                       IpAddr::V4(v4a) => insert!(self.v4_table, (v4a, len), id),
+                                       IpAddr::V6(v6a) => insert!(self.v6_table, (v6a, len), id),
                                }
                        },
-                       NLRIEncoding::IP_MPLS(_) => None,
+                       NLRIEncoding::IP_MPLS(_) => (),
+                       NLRIEncoding::IP_MPLS_WITH_PATH_ID(_) => (),
+                       NLRIEncoding::IP_VPN_MPLS(_) => (),
+                       NLRIEncoding::L2VPN(_) => (),
                };
        }
 }
 
-struct BytesCoder<'a>(&'a mut bytes::BytesMut);
+/*struct BytesCoder<'a>(&'a mut bytes::BytesMut);
 impl<'a> std::io::Write for BytesCoder<'a> {
        fn write(&mut self, b: &[u8]) -> Result<usize, std::io::Error> {
                self.0.extend_from_slice(&b);
@@ -164,8 +219,8 @@ impl<'a> std::io::Read for BytesDecoder<'a> {
        }
 }
 
-struct MsgCoder<'a>(&'a Printer);
-impl<'a> codec::Decoder for MsgCoder<'a> {
+struct MsgCoder(Option<Capabilities>);
+impl codec::Decoder for MsgCoder {
        type Item = Message;
        type Error = std::io::Error;
 
@@ -174,15 +229,16 @@ impl<'a> codec::Decoder for MsgCoder<'a> {
                        buf: bytes,
                        pos: 0
                };
-               match (Reader {
+               let def_cap = Default::default();
+               let mut reader = Reader {
                        stream: &mut decoder,
-                       capabilities: Capabilities {
-                               FOUR_OCTET_ASN_SUPPORT: true,
-                               EXTENDED_PATH_NLRI_SUPPORT: true,
-                       }
-               }).read() {
+                       capabilities: if let Some(cap) = &self.0 { cap } else { &def_cap },
+               };
+               match reader.read() {
                        Ok((_header, msg)) => {
                                decoder.buf.advance(decoder.pos);
+                               if let Message::Open(ref o) = &msg {
+                               }
                                Ok(Some(msg))
                        },
                        Err(e) => match e.kind() {
@@ -192,15 +248,15 @@ impl<'a> codec::Decoder for MsgCoder<'a> {
                }
        }
 }
-impl<'a> codec::Encoder for MsgCoder<'a> {
+impl codec::Encoder for MsgCoder {
        type Item = Message;
        type Error = std::io::Error;
 
        fn encode(&mut self, msg: Message, res: &mut bytes::BytesMut) -> Result<(), std::io::Error> {
-               msg.write(&mut BytesCoder(res))?;
+               msg.encode(&mut BytesCoder(res))?;
                Ok(())
        }
-}
+}*/
 
 pub struct BGPClient {
        routes: Mutex<RoutingTable>,
@@ -219,14 +275,35 @@ impl BGPClient {
                });
 
                let primary_route = path_vecs.pop().unwrap();
-               'asn_candidates: for asn in primary_route.path_suffix.iter().rev() {
-                       if *asn == 0 { continue 'asn_candidates; }
-                       for secondary_route in path_vecs.iter() {
-                               if !secondary_route.path_suffix.contains(asn) {
-                                       continue 'asn_candidates;
+               if path_vecs.len() > 3 {
+                       // If we have at least 3 paths, try to find the last unique ASN which doesn't show up in other paths
+                       // If we hit a T1 that is reasonably assumed to care about net neutrality, return the
+                       // previous ASN.
+                       let mut prev_asn = 0;
+                       'asn_candidates: for asn in primary_route.path_suffix.iter().rev() {
+                               if *asn == 0 { continue 'asn_candidates; }
+                               match *asn {
+                                       // Included: CenturyLink (L3), Cogent, Telia, NTT, GTT, Level3,
+                                       //           GBLX (L3), Zayo, TI Sparkle Seabone, HE, Telefonica
+                                       // Left out from Caida top-20: TATA, PCCW, Vodafone, RETN, Orange, Telstra,
+                                       //                             Singtel, Rostelecom, DTAG
+                                       209|174|1299|2914|3257|3356|3549|6461|6762|6939|12956 if prev_asn != 0 => return prev_asn,
+                                       _ => if path_vecs.iter().any(|route| !route.path_suffix.contains(asn)) {
+                                               if prev_asn != 0 { return prev_asn } else {
+                                                       // Multi-origin prefix, just give up and take the last AS in the
+                                                       // default path
+                                                       break 'asn_candidates;
+                                               }
+                                       } else {
+                                               // We only ever possibly return an ASN if it appears in all paths
+                                               prev_asn = *asn;
+                                       },
                                }
                        }
-                       return *asn;
+                       // All paths were the same, if the first ASN is non-0, return it.
+                       if prev_asn != 0 {
+                               return prev_asn;
+                       }
                }
 
                for asn in primary_route.path_suffix.iter().rev() {
@@ -296,89 +373,131 @@ impl BGPClient {
                } else { None }
        }
 
-       fn connect_given_client(addr: SocketAddr, timeout: Duration, printer: &'static Printer, client: Arc<BGPClient>) {
-               tokio::spawn(Delay::new(Instant::now() + timeout / 4).then(move |_| {
-                       let connect_timeout = Delay::new(Instant::now() + timeout.clone()).then(|_| {
-                               future::err(std::io::Error::new(std::io::ErrorKind::TimedOut, "timeout reached"))
-                       });
+       async fn handle_peer(open_msg: Message, stream: TcpStream, timeout: Duration, printer: &'static Printer, client: Arc<BGPClient>) -> Result<(), std::io::Error> {
+               let mut open_bytes = [0; 64];
+                       let len = {
+                               let mut write_cursor = std::io::Cursor::new(&mut open_bytes);
+                               open_msg.encode(&mut write_cursor);
+                               write_cursor.position()
+                       };
+                       stream.write_all(&open_bytes[..len]).await?;
+                       let mut cap = Default::default();
+
+                       let mut readpending = Vec::new();
+                       let mut readbuf = [0; 8192];
+                       let mut msg_timeout = time::sleep(timeout);
+                       'read_loop: loop {
+                               if client.shutdown.load(Ordering::Relaxed) {
+                                       return std::io::Error::new(std::io::ErrorKind::Other, "Shutting Down");
+                               }
+                               tokio::select! {
+                                       _ = msg_timeout => {
+                                               return Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "Keepalive expired"));
+                                       },
+                                       res = stream.read(&mut readbuf) => {
+                                               let mut msg_opt = None;
+                                               let bytecnt = res?;
+                                               if readpending.is_empty() {
+                                                       let mut cursor = std::io::Cursor::new(&readbuf[..bytecnt]);
+                                                       let mut reader = Reader { stream: &mut cursor, capabilities: &cap };
+                                                       match reader.read() {
+                                                               Ok((_header, newmsg)) => { readpending.append(&readbuf[cursor.position()..bytecnt]); newmsg = Some(msg_opt) },
+                                                               Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
+                                                                       readpending.append(&readbuf[..bytecnt]);
+                                                                       continue 'read_loop;
+                                                               },
+                                                               Err(e) => return Err(e),
+                                                       }
+                                               } else { readpending.append(&readbuf[..bytecnt]); }
+                                               loop {
+                                                       if msg_opt.is_none() {
+                                                               let mut cursor = std::io::Cursor::new(&readpending);
+                                                               let mut reader = Reader { stream: &mut cursor, capabilities: &cap };
+                                                               match reader.read() {
+                                                                       Ok((_header, newmsg)) => { newmsg = Some(msg_opt) },
+                                                                       Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {},
+                                                                       Err(e) => return Err(e),
+                                                               }
+                                                               readpending = readpending.split_off(cursor.position());
+                                                       }
+                                                       if let Some(bgp_msg) = msg_opt.take() {
+                                                               match bgp_msg {
+                                                                       Message::Open(o) => {
+                                                                               cap = Capabilities::from_parameters(o.parameters.clone());
+                                                                               client.routes.lock().unwrap().v4_table.clear();
+                                                                               client.routes.lock().unwrap().v6_table.clear();
+                                                                               printer.add_line("Connected to BGP route provider".to_string(), false);
+                                                                       },
+                                                                       Message::KeepAlive => {
+                                                                               msg_timeout = time::sleep(timeout);
+                                                                               //XXX: let _ = sender.try_send(Message::KeepAlive);
+                                                                       },
+                                                                       Message::Update(mut upd) => {
+                                                                               upd.normalize();
+                                                                               let mut route_table = client.routes.lock().unwrap();
+                                                                               for r in upd.withdrawn_routes {
+                                                                                       route_table.withdraw(r);
+                                                                               }
+                                                                               if let Some(path) = Self::map_attrs(upd.attributes) {
+                                                                                       for r in upd.announced_routes {
+                                                                                               route_table.announce(r, path.clone());
+                                                                                       }
+                                                                               }
+                                                                               printer.set_stat(Stat::V4RoutingTableSize(route_table.v4_table.len()));
+                                                                               printer.set_stat(Stat::V6RoutingTableSize(route_table.v6_table.len()));
+                                                                               printer.set_stat(Stat::RoutingTablePaths(route_table.max_paths));
+                                                                       },
+                                                                       _ => {}
+                                                               }
+                                                       } else { break; }
+                                               }
+                                       }
+                               };
+               }
+       }
+
+       fn connect_given_client(remote_asn: u32, addr: SocketAddr, timeout: Duration, printer: &'static Printer, client: Arc<BGPClient>) {
+               tokio::spawn(async move {
+                       time::sleep(timeout / 4).await;
+
                        let client_reconn = Arc::clone(&client);
-                       TcpStream::connect(&addr).select(connect_timeout)
-                               .or_else(move |_| {
-                                       Delay::new(Instant::now() + timeout / 2).then(|_| {
-                                               future::err(())
-                                       })
-                               }).and_then(move |stream| {
-                                       let (write, read) = Framed::new(stream.0, MsgCoder(printer)).split();
-                                       let (mut sender, receiver) = mpsc::channel(10); // We never really should send more than 10 messages unless they're dumb
-                                       tokio::spawn(write.sink_map_err(|_| { () }).send_all(receiver)
-                                               .then(|_| {
-                                                       future::err(())
-                                               }));
-                                       let _ = sender.try_send(Message::Open(Open {
+                       tokio::select! {
+                               _ = time::sleep(timeout) => {
+                                       time::sleep(timeout / 2).await;
+                               },
+                               mut stream = TcpStream::connect(&addr) => {
+                                       let peer_asn = if remote_asn > u16::max_value() as u32 { 23456 } else { remote_asn as u16 };
+                                       let open_msg = Message::Open(Open {
                                                version: 4,
-                                               peer_asn: 23456,
+                                               peer_asn,
                                                hold_timer: timeout.as_secs() as u16,
-                                               identifier: 0x453b1215, // 69.59.18.21
+                                               identifier: 0x453b1215, // 69.59.18.21. Note that you never actually need to change this.
                                                parameters: vec![OpenParameter::Capabilities(vec![
                                                        OpenCapability::MultiProtocol((AFI::IPV4, SAFI::Unicast)),
                                                        OpenCapability::MultiProtocol((AFI::IPV6, SAFI::Unicast)),
-                                                       OpenCapability::FourByteASN(397444),
+                                                       OpenCapability::FourByteASN(remote_asn),
                                                        OpenCapability::RouteRefresh,
                                                        OpenCapability::AddPath(vec![
                                                                (AFI::IPV4, SAFI::Unicast, AddPathDirection::ReceivePaths),
                                                                (AFI::IPV6, SAFI::Unicast, AddPathDirection::ReceivePaths)]),
-                                               ])]
-                                       }));
-                                       TimeoutStream::new_persistent(read, timeout).for_each(move |bgp_msg| {
-                                               if client.shutdown.load(Ordering::Relaxed) {
-                                                       return future::err(std::io::Error::new(std::io::ErrorKind::Other, "Shutting Down"));
-                                               }
-                                               match bgp_msg {
-                                                       Message::Open(_) => {
-                                                               client.routes.lock().unwrap().v4_table.clear();
-                                                               client.routes.lock().unwrap().v6_table.clear();
-                                                               printer.add_line("Connected to BGP route provider".to_string(), false);
-                                                       },
-                                                       Message::KeepAlive => {
-                                                               let _ = sender.try_send(Message::KeepAlive);
-                                                       },
-                                                       Message::Update(mut upd) => {
-                                                               upd.normalize();
-                                                               let mut route_table = client.routes.lock().unwrap();
-                                                               for r in upd.withdrawn_routes {
-                                                                       route_table.withdraw(r);
-                                                               }
-                                                               if let Some(path) = Self::map_attrs(upd.attributes) {
-                                                                       for r in upd.announced_routes {
-                                                                               route_table.announce(r, path.clone());
-                                                                       }
-                                                               }
-                                                               printer.set_stat(Stat::V4RoutingTableSize(route_table.v4_table.len()));
-                                                               printer.set_stat(Stat::V6RoutingTableSize(route_table.v6_table.len()));
-                                                       },
-                                                       _ => {}
-                                               }
-                                               future::ok(())
-                                       }).or_else(move |e| {
-                                               printer.add_line(format!("Got error from BGP stream: {:?}", e), true);
-                                               future::ok(())
-                                       })
-                               }).then(move |_| {
-                                       if !client_reconn.shutdown.load(Ordering::Relaxed) {
-                                               BGPClient::connect_given_client(addr, timeout, printer, client_reconn);
-                                       }
-                                       future::ok(())
-                               })
-                       })
-               );
+                                               ])],
+                                       });
+                                       let e = Self::handle_peer(open_msg, stream, printer, client);
+                                       printer.add_line(format!("Got error from BGP stream: {:?}", e), true);
+                               }
+                       };
+                       if !client_reconn.shutdown.load(Ordering::Relaxed) {
+                               BGPClient::connect_given_client(remote_asn, addr, timeout, printer, client_reconn);
+                       }
+               });
        }
 
-       pub fn new(addr: SocketAddr, timeout: Duration, printer: &'static Printer) -> Arc<BGPClient> {
+       pub fn new(remote_asn: u32, addr: SocketAddr, timeout: Duration, printer: &'static Printer) -> Arc<BGPClient> {
                let client = Arc::new(BGPClient {
                        routes: Mutex::new(RoutingTable::new()),
                        shutdown: AtomicBool::new(false),
                });
-               BGPClient::connect_given_client(addr, timeout, printer, Arc::clone(&client));
+               BGPClient::connect_given_client(remote_asn, addr, timeout, printer, Arc::clone(&client));
                client
        }
 }