diff --git a/src/http_resolver.rs b/src/http_resolver.rs index ea02610..04d0fe2 100644 --- a/src/http_resolver.rs +++ b/src/http_resolver.rs @@ -29,7 +29,28 @@ const DOH_ENDPOINT: &'static str = "https://dns.google/dns-query?dns="; /// /// Note that using this may reveal our IP address to the recipient and information about who we're /// paying to Google (via `dns.google`). -pub struct HTTPHrnResolver; +#[derive(Debug, Clone)] +pub struct HTTPHrnResolver { + client: reqwest::Client, +} + +impl HTTPHrnResolver { + /// Create a new `HTTPHrnResolver` with a default `reqwest::Client`. + pub fn new() -> Self { + HTTPHrnResolver::default() + } + + /// Create a new `HTTPHrnResolver` with a custom `reqwest::Client`. + pub fn with_client(client: reqwest::Client) -> Self { + HTTPHrnResolver { client } + } +} + +impl Default for HTTPHrnResolver { + fn default() -> Self { + HTTPHrnResolver { client: reqwest::Client::new() } + } +} const B64_CHAR: [u8; 64] = [ b'A', b'B', b'C', b'D', b'E', b'F', b'G', b'H', b'I', b'J', b'K', b'L', b'M', b'N', b'O', b'P', @@ -107,11 +128,10 @@ impl HTTPHrnResolver { let mut pending_queries = vec![initial_query]; while let Some(query) = pending_queries.pop() { - let client = reqwest::Client::new(); - let request_url = query_to_url(query); - let req = client.get(request_url).header("accept", "application/dns-message").build(); - let resp = client.execute(req.map_err(|_| DNS_ERR)?).await.map_err(|_| DNS_ERR)?; + let req = + self.client.get(request_url).header("accept", "application/dns-message").build(); + let resp = self.client.execute(req.map_err(|_| DNS_ERR)?).await.map_err(|_| DNS_ERR)?; let body = resp.bytes().await.map_err(|_| DNS_ERR)?; let mut answer = QueryBuf::new_zeroed(0); @@ -136,8 +156,15 @@ impl HTTPHrnResolver { async fn resolve_lnurl_impl(&self, lnurl_url: &str) -> Result { let err = "Failed to fetch LN-Address initial well-known endpoint"; - let init: LNURLInitResponse = - reqwest::get(lnurl_url).await.map_err(|_| err)?.json().await.map_err(|_| err)?; + let init: LNURLInitResponse = self + .client + .get(lnurl_url) + .send() + .await + .map_err(|_| err)? + .json() + .await + .map_err(|_| err)?; if init.tag != "payRequest" { return Err("LNURL initial init_response had an incorrect tag value"); @@ -198,8 +225,15 @@ impl HrnResolver for HTTPHrnResolver { } else { write!(&mut callback, "?amount={}", amt.milli_sats()).expect("Write to String"); } - let callback_response: LNURLCallbackResponse = - reqwest::get(callback).await.map_err(|_| err)?.json().await.map_err(|_| err)?; + let callback_response: LNURLCallbackResponse = self + .client + .get(callback) + .send() + .await + .map_err(|_| err)? + .json() + .await + .map_err(|_| err)?; if !callback_response.routes.is_empty() { return Err("LNURL callback response contained a non-empty routes array"); @@ -257,7 +291,7 @@ mod tests { #[tokio::test] async fn test_dns_via_http_hrn_resolver() { - let resolver = HTTPHrnResolver; + let resolver = HTTPHrnResolver::default(); let instructions = PaymentInstructions::parse( "send.some@satsto.me", bitcoin::Network::Bitcoin, @@ -303,10 +337,11 @@ mod tests { #[tokio::test] async fn test_http_hrn_resolver() { + let resolver = HTTPHrnResolver::default(); let instructions = PaymentInstructions::parse( "lnurltest@bitcoin.ninja", bitcoin::Network::Bitcoin, - &HTTPHrnResolver, + &resolver, true, ) .await @@ -323,7 +358,7 @@ mod tests { assert_eq!(hrn.user(), "lnurltest"); assert_eq!(hrn.domain(), "bitcoin.ninja"); - instr.set_amount(Amount::from_sats(100_000).unwrap(), &HTTPHrnResolver).await.unwrap() + instr.set_amount(Amount::from_sats(100_000).unwrap(), &resolver).await.unwrap() } else { panic!(); }; @@ -348,11 +383,12 @@ mod tests { #[tokio::test] async fn test_http_lnurl_resolver() { + let resolver = HTTPHrnResolver::default(); let instructions = PaymentInstructions::parse( // lnurl encoding for lnurltest@bitcoin.ninja "lnurl1dp68gurn8ghj7cnfw33k76tw9ehxjmn2vyhjuam9d3kz66mwdamkutmvde6hymrs9akxuatjd36x2um5ahcq39", Network::Bitcoin, - &HTTPHrnResolver, + &resolver, true, ) .await @@ -365,7 +401,7 @@ mod tests { assert_eq!(instr.pop_callback(), None); assert!(instr.bip_353_dnssec_proof().is_none()); - instr.set_amount(Amount::from_sats(100_000).unwrap(), &HTTPHrnResolver).await.unwrap() + instr.set_amount(Amount::from_sats(100_000).unwrap(), &resolver).await.unwrap() } else { panic!(); }; diff --git a/src/lib.rs b/src/lib.rs index 2ea6c7c..f3fffe3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1351,14 +1351,9 @@ mod tests { #[cfg(feature = "http")] async fn test_lnurl(str: &str) { - let parsed = PaymentInstructions::parse( - str, - Network::Signet, - &http_resolver::HTTPHrnResolver, - false, - ) - .await - .unwrap(); + let resolver = http_resolver::HTTPHrnResolver::default(); + let parsed = + PaymentInstructions::parse(str, Network::Signet, &resolver, false).await.unwrap(); let parsed = match parsed { PaymentInstructions::ConfigurableAmount(parsed) => parsed,