Switch to streaming queries
authorMatt Corallo <git@bluematt.me>
Sun, 16 Jul 2023 00:37:05 +0000 (00:37 +0000)
committerMatt Corallo <git@bluematt.me>
Sun, 16 Jul 2023 05:58:32 +0000 (05:58 +0000)
In order to use streaming queries we have to use `tokio-postgres`'s
`query_raw` command, rather than `query`. This should reduce our
memory footprint from 10+GB to well under one.

src/lookup.rs

index 0534c1c8a7e82922bc240eab0be214f06d87eec8..04e46bd0d8ec9f1292445605085e659c5dc3c95b 100644 (file)
@@ -10,6 +10,8 @@ use lightning::util::ser::Readable;
 use tokio_postgres::{Client, Connection, NoTls, Socket};
 use tokio_postgres::tls::NoTlsStream;
 
+use futures::StreamExt;
+
 use crate::{config, TestLogger};
 use crate::serialization::MutatedProperties;
 
@@ -88,9 +90,11 @@ pub(super) async fn fetch_channel_announcements(delta_set: &mut DeltaSet, networ
 
        println!("Obtaining corresponding database entries");
        // get all the channel announcements that are currently in the network graph
-       let announcement_rows = client.query("SELECT announcement_signed, seen FROM channel_announcements WHERE short_channel_id = any($1) ORDER BY short_channel_id ASC", &[&channel_ids]).await.unwrap();
+       let announcement_rows = client.query_raw("SELECT announcement_signed, seen FROM channel_announcements WHERE short_channel_id = any($1) ORDER BY short_channel_id ASC", [&channel_ids]).await.unwrap();
+       let mut pinned_rows = Box::pin(announcement_rows);
 
-       for current_announcement_row in announcement_rows {
+       while let Some(row_res) = pinned_rows.next().await {
+               let current_announcement_row = row_res.unwrap();
                let blob: Vec<u8> = current_announcement_row.get("announcement_signed");
                let mut readable = Cursor::new(blob);
                let unsigned_announcement = ChannelAnnouncement::read(&mut readable).unwrap().contents;
@@ -117,7 +121,9 @@ pub(super) async fn fetch_channel_announcements(delta_set: &mut DeltaSet, networ
 
                // here is where the channels whose first update in either direction occurred after
                // `last_seen_timestamp` are added to the selection
-               let newer_oldest_directional_updates = client.query("
+               let params: [&(dyn tokio_postgres::types::ToSql + Sync); 2] =
+                       [&channel_ids, &last_sync_timestamp_object];
+               let newer_oldest_directional_updates = client.query_raw("
                        SELECT * FROM (
                                SELECT DISTINCT ON (short_channel_id) *
                                FROM (
@@ -129,9 +135,12 @@ pub(super) async fn fetch_channel_announcements(delta_set: &mut DeltaSet, networ
                                ORDER BY short_channel_id ASC, seen DESC
                        ) AS distinct_chans
                        WHERE distinct_chans.seen >= $2
-                       ", &[&channel_ids, &last_sync_timestamp_object]).await.unwrap();
+                       ", params).await.unwrap();
+               let mut pinned_updates = Box::pin(newer_oldest_directional_updates);
+
+               while let Some(row_res) = pinned_updates.next().await {
+                       let current_row = row_res.unwrap();
 
-               for current_row in newer_oldest_directional_updates {
                        let scid: i64 = current_row.get("short_channel_id");
                        let current_seen_timestamp_object: SystemTime = current_row.get("seen");
                        let current_seen_timestamp: u32 = current_seen_timestamp_object.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() as u32;
@@ -153,7 +162,9 @@ pub(super) async fn fetch_channel_announcements(delta_set: &mut DeltaSet, networ
                let reminder_threshold_timestamp = SystemTime::now().checked_sub(config::CHANNEL_REMINDER_AGE).unwrap();
                let read_only_graph = network_graph.read_only();
 
-               let older_latest_directional_updates = client.query("
+               let params: [&(dyn tokio_postgres::types::ToSql + Sync); 2] =
+                       [&channel_ids, &reminder_threshold_timestamp];
+               let older_latest_directional_updates = client.query_raw("
                        SELECT short_channel_id FROM (
                                SELECT DISTINCT ON (short_channel_id) *
                                FROM (
@@ -165,9 +176,11 @@ pub(super) async fn fetch_channel_announcements(delta_set: &mut DeltaSet, networ
                                ORDER BY short_channel_id ASC, seen ASC
                        ) AS distinct_chans
                        WHERE distinct_chans.seen <= $2
-                       ", &[&channel_ids, &reminder_threshold_timestamp]).await.unwrap();
+                       ", params).await.unwrap();
+               let mut pinned_updates = Box::pin(older_latest_directional_updates);
 
-               for current_row in older_latest_directional_updates {
+               while let Some(row_res) = pinned_updates.next().await {
+                       let current_row = row_res.unwrap();
                        let scid: i64 = current_row.get("short_channel_id");
 
                        // annotate this channel as requiring that reminders be sent to the client
@@ -208,7 +221,7 @@ pub(super) async fn fetch_channel_updates(delta_set: &mut DeltaSet, client: &Cli
        // get the latest channel update in each direction prior to last_sync_timestamp, provided
        // there was an update in either direction that happened after the last sync (to avoid
        // collecting too many reference updates)
-       let reference_rows = client.query("
+       let reference_rows = client.query_raw("
                SELECT id, direction, blob_signed FROM channel_updates
                WHERE id IN (
                        SELECT DISTINCT ON (short_channel_id, direction) id
@@ -220,14 +233,17 @@ pub(super) async fn fetch_channel_updates(delta_set: &mut DeltaSet, client: &Cli
                        FROM channel_updates
                        WHERE seen >= $1
                )
-               ", &[&last_sync_timestamp_object]).await.unwrap();
+               ", [last_sync_timestamp_object]).await.unwrap();
+       let mut pinned_rows = Box::pin(reference_rows);
 
-       println!("Fetched reference rows ({}): {:?}", reference_rows.len(), start.elapsed());
+       println!("Fetched reference rows in {:?}", start.elapsed());
 
-       let mut last_seen_update_ids: Vec<i32> = Vec::with_capacity(reference_rows.len());
+       let mut last_seen_update_ids: Vec<i32> = Vec::new();
        let mut non_intermediate_ids: HashSet<i32> = HashSet::new();
+       let mut reference_row_count = 0;
 
-       for current_reference in reference_rows {
+       while let Some(row_res) = pinned_rows.next().await {
+               let current_reference = row_res.unwrap();
                let update_id: i32 = current_reference.get("id");
                last_seen_update_ids.push(update_id);
                non_intermediate_ids.insert(update_id);
@@ -245,27 +261,31 @@ pub(super) async fn fetch_channel_updates(delta_set: &mut DeltaSet, client: &Cli
                        (*current_channel_delta).updates.1.get_or_insert(DirectedUpdateDelta::default())
                };
                update_delta.last_update_before_seen = Some(unsigned_channel_update);
+               reference_row_count += 1;
        }
 
-       println!("Processed reference rows (delta size: {}): {:?}", delta_set.len(), start.elapsed());
+       println!("Processed {} reference rows (delta size: {}) in {:?}",
+               reference_row_count, delta_set.len(), start.elapsed());
 
        // get all the intermediate channel updates
        // (to calculate the set of mutated fields for snapshotting, where intermediate updates may
        // have been omitted)
 
-       let intermediate_updates = client.query("
+       let intermediate_updates = client.query_raw("
                SELECT id, direction, blob_signed, seen
                FROM channel_updates
                WHERE seen >= $1
-               ", &[&last_sync_timestamp_object]).await.unwrap();
-       println!("Fetched intermediate rows ({}): {:?}", intermediate_updates.len(), start.elapsed());
+               ", [last_sync_timestamp_object]).await.unwrap();
+       let mut pinned_updates = Box::pin(intermediate_updates);
+       println!("Fetched intermediate rows in {:?}", start.elapsed());
 
        let mut previous_scid = u64::MAX;
        let mut previously_seen_directions = (false, false);
 
        // let mut previously_seen_directions = (false, false);
        let mut intermediate_update_count = 0;
-       for intermediate_update in intermediate_updates {
+       while let Some(row_res) = pinned_updates.next().await {
+               let intermediate_update = row_res.unwrap();
                let update_id: i32 = intermediate_update.get("id");
                if non_intermediate_ids.contains(&update_id) {
                        continue;