Simplify + clarify random-bytes-fetching from KeysInterface
[rust-lightning] / lightning / src / util / test_utils.rs
index d684e9c07b5fce335d893e8b93ee5cdb698b7d7a..559277d2a68c674948479d329861e2471d719c81 100644 (file)
@@ -350,7 +350,7 @@ impl Logger for TestLogger {
 
 pub struct TestKeysInterface {
        backing: keysinterface::KeysManager,
-       pub override_session_priv: Mutex<Option<SecretKey>>,
+       pub override_session_priv: Mutex<Option<[u8; 32]>>,
        pub override_channel_id_priv: Mutex<Option<[u8; 32]>>,
 }
 
@@ -364,18 +364,19 @@ impl keysinterface::KeysInterface for TestKeysInterface {
                EnforcingChannelKeys::new(self.backing.get_channel_keys(inbound, channel_value_satoshis))
        }
 
-       fn get_onion_rand(&self) -> (SecretKey, [u8; 32]) {
-               match *self.override_session_priv.lock().unwrap() {
-                       Some(key) => (key.clone(), [0; 32]),
-                       None => self.backing.get_onion_rand()
+       fn get_secure_random_bytes(&self) -> [u8; 32] {
+               let override_channel_id = self.override_channel_id_priv.lock().unwrap();
+               let override_session_key = self.override_session_priv.lock().unwrap();
+               if override_channel_id.is_some() && override_session_key.is_some() {
+                       panic!("We don't know which override key to use!");
                }
-       }
-
-       fn get_channel_id(&self) -> [u8; 32] {
-               match *self.override_channel_id_priv.lock().unwrap() {
-                       Some(key) => key.clone(),
-                       None => self.backing.get_channel_id()
+               if let Some(key) = &*override_channel_id {
+                       return *key;
+               }
+               if let Some(key) = &*override_session_key {
+                       return *key;
                }
+               self.backing.get_secure_random_bytes()
        }
 }