]> git.bitcoin.ninja Git - dnssec-prover/commitdiff
Print info about which test failed when a test panics
authorMatt Corallo <git@bluematt.me>
Fri, 26 Jul 2024 14:52:13 +0000 (14:52 +0000)
committerMatt Corallo <git@bluematt.me>
Fri, 26 Jul 2024 14:52:13 +0000 (14:52 +0000)
src/test/crypto.rs

index 31787cf912d379e5ad2c5c898d2b52886e3e4915..194fcd8b248a50ebe7c92092fb896933354f9ca7 100644 (file)
@@ -61,7 +61,7 @@ fn decode_asn(sig: &str, int_len: usize) -> Result<Vec<u8>, ()> {
 }
 
 fn test_ecdsa<
-       Validate: Fn(&[u8], &[u8], &[u8]) -> Result<(), ()>,
+       Validate: Fn(&[u8], &[u8], &[u8]) -> Result<(), ()> + Clone,
        Hash: Fn(&[u8]) -> HashResult,
 >(v: Value, int_len: usize, validate_fn: Validate, hash_fn: Hash) {
        for (group_idx, group) in v["testGroups"].as_array().unwrap().into_iter().enumerate() {
@@ -71,7 +71,7 @@ fn test_ecdsa<
                for test in group["tests"].as_array().unwrap() {
                        let msg = Vec::from_hex(test["msg"].as_str().unwrap()).unwrap();
 
-                       let result = match test["result"].as_str().unwrap() {
+                       let expected_result = match test["result"].as_str().unwrap() {
                                "valid" => Ok(()),
                                "invalid" => Err(()),
                                r => panic!("Unknown result type {}", r),
@@ -79,13 +79,20 @@ fn test_ecdsa<
 
                        let sig = decode_asn(test["sig"].as_str().unwrap(), int_len);
                        if sig.is_err() {
-                               assert_eq!(result, Err(()));
+                               assert_eq!(expected_result, Err(()));
                                continue;
                        }
 
                        let hash = hash_fn(&msg);
-                       assert_eq!(result, validate_fn(&pk, &sig.unwrap(), hash.as_ref()),
-                               "Failed test case group {}, test id {}, comment {}", group_idx, test["tcId"], test["comment"]);
+                       let pk = &pk[..];
+                       let validate_fn = validate_fn.clone();
+                       let panicked = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
+                               let result = validate_fn(&pk, &sig.unwrap(), hash.as_ref());
+                               assert_eq!(result, expected_result);
+                       }));
+                       if panicked.is_err() {
+                               panic!("Test case group {}, test id {}, comment {} panicked!", group_idx, test["tcId"], test["comment"]);
+                       }
                }
        }
 }