From d4af94b39cba9fde969243bc3e00dc8d21681170 Mon Sep 17 00:00:00 2001 From: benthecarman Date: Sun, 29 Jun 2025 19:59:26 -0500 Subject: [PATCH] Allow for custom HTTPHrnResolvers This allows for giving a custom reqwest client when creating a HTTPHrnResolver. This is useful for adding custom headers, proxying connections, etc. This also has the added benefit of using the same reqwest client across every call so it'll be better and reusing tls connections. --- src/http_resolver.rs | 64 ++++++++++++++++++++++++++++++++++---------- src/lib.rs | 11 +++----- 2 files changed, 53 insertions(+), 22 deletions(-) diff --git a/src/http_resolver.rs b/src/http_resolver.rs index ea02610..04d0fe2 100644 --- a/src/http_resolver.rs +++ b/src/http_resolver.rs @@ -29,7 +29,28 @@ const DOH_ENDPOINT: &'static str = "https://dns.google/dns-query?dns="; /// /// Note that using this may reveal our IP address to the recipient and information about who we're /// paying to Google (via `dns.google`). -pub struct HTTPHrnResolver; +#[derive(Debug, Clone)] +pub struct HTTPHrnResolver { + client: reqwest::Client, +} + +impl HTTPHrnResolver { + /// Create a new `HTTPHrnResolver` with a default `reqwest::Client`. + pub fn new() -> Self { + HTTPHrnResolver::default() + } + + /// Create a new `HTTPHrnResolver` with a custom `reqwest::Client`. + pub fn with_client(client: reqwest::Client) -> Self { + HTTPHrnResolver { client } + } +} + +impl Default for HTTPHrnResolver { + fn default() -> Self { + HTTPHrnResolver { client: reqwest::Client::new() } + } +} const B64_CHAR: [u8; 64] = [ b'A', b'B', b'C', b'D', b'E', b'F', b'G', b'H', b'I', b'J', b'K', b'L', b'M', b'N', b'O', b'P', @@ -107,11 +128,10 @@ impl HTTPHrnResolver { let mut pending_queries = vec![initial_query]; while let Some(query) = pending_queries.pop() { - let client = reqwest::Client::new(); - let request_url = query_to_url(query); - let req = client.get(request_url).header("accept", "application/dns-message").build(); - let resp = client.execute(req.map_err(|_| DNS_ERR)?).await.map_err(|_| DNS_ERR)?; + let req = + self.client.get(request_url).header("accept", "application/dns-message").build(); + let resp = self.client.execute(req.map_err(|_| DNS_ERR)?).await.map_err(|_| DNS_ERR)?; let body = resp.bytes().await.map_err(|_| DNS_ERR)?; let mut answer = QueryBuf::new_zeroed(0); @@ -136,8 +156,15 @@ impl HTTPHrnResolver { async fn resolve_lnurl_impl(&self, lnurl_url: &str) -> Result { let err = "Failed to fetch LN-Address initial well-known endpoint"; - let init: LNURLInitResponse = - reqwest::get(lnurl_url).await.map_err(|_| err)?.json().await.map_err(|_| err)?; + let init: LNURLInitResponse = self + .client + .get(lnurl_url) + .send() + .await + .map_err(|_| err)? + .json() + .await + .map_err(|_| err)?; if init.tag != "payRequest" { return Err("LNURL initial init_response had an incorrect tag value"); @@ -198,8 +225,15 @@ impl HrnResolver for HTTPHrnResolver { } else { write!(&mut callback, "?amount={}", amt.milli_sats()).expect("Write to String"); } - let callback_response: LNURLCallbackResponse = - reqwest::get(callback).await.map_err(|_| err)?.json().await.map_err(|_| err)?; + let callback_response: LNURLCallbackResponse = self + .client + .get(callback) + .send() + .await + .map_err(|_| err)? + .json() + .await + .map_err(|_| err)?; if !callback_response.routes.is_empty() { return Err("LNURL callback response contained a non-empty routes array"); @@ -257,7 +291,7 @@ mod tests { #[tokio::test] async fn test_dns_via_http_hrn_resolver() { - let resolver = HTTPHrnResolver; + let resolver = HTTPHrnResolver::default(); let instructions = PaymentInstructions::parse( "send.some@satsto.me", bitcoin::Network::Bitcoin, @@ -303,10 +337,11 @@ mod tests { #[tokio::test] async fn test_http_hrn_resolver() { + let resolver = HTTPHrnResolver::default(); let instructions = PaymentInstructions::parse( "lnurltest@bitcoin.ninja", bitcoin::Network::Bitcoin, - &HTTPHrnResolver, + &resolver, true, ) .await @@ -323,7 +358,7 @@ mod tests { assert_eq!(hrn.user(), "lnurltest"); assert_eq!(hrn.domain(), "bitcoin.ninja"); - instr.set_amount(Amount::from_sats(100_000).unwrap(), &HTTPHrnResolver).await.unwrap() + instr.set_amount(Amount::from_sats(100_000).unwrap(), &resolver).await.unwrap() } else { panic!(); }; @@ -348,11 +383,12 @@ mod tests { #[tokio::test] async fn test_http_lnurl_resolver() { + let resolver = HTTPHrnResolver::default(); let instructions = PaymentInstructions::parse( // lnurl encoding for lnurltest@bitcoin.ninja "lnurl1dp68gurn8ghj7cnfw33k76tw9ehxjmn2vyhjuam9d3kz66mwdamkutmvde6hymrs9akxuatjd36x2um5ahcq39", Network::Bitcoin, - &HTTPHrnResolver, + &resolver, true, ) .await @@ -365,7 +401,7 @@ mod tests { assert_eq!(instr.pop_callback(), None); assert!(instr.bip_353_dnssec_proof().is_none()); - instr.set_amount(Amount::from_sats(100_000).unwrap(), &HTTPHrnResolver).await.unwrap() + instr.set_amount(Amount::from_sats(100_000).unwrap(), &resolver).await.unwrap() } else { panic!(); }; diff --git a/src/lib.rs b/src/lib.rs index 2ea6c7c..f3fffe3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1351,14 +1351,9 @@ mod tests { #[cfg(feature = "http")] async fn test_lnurl(str: &str) { - let parsed = PaymentInstructions::parse( - str, - Network::Signet, - &http_resolver::HTTPHrnResolver, - false, - ) - .await - .unwrap(); + let resolver = http_resolver::HTTPHrnResolver::default(); + let parsed = + PaymentInstructions::parse(str, Network::Signet, &resolver, false).await.unwrap(); let parsed = match parsed { PaymentInstructions::ConfigurableAmount(parsed) => parsed,