@@ -16,11 +16,20 @@ pub const MAX_ADDRESS_COUNT: usize = 4;
16
16
pub const MAX_SERVER_COUNT : usize = 4 ;
17
17
18
18
const DNS_PORT : u16 = 53 ;
19
+ const MDNS_DNS_PORT : u16 = 5353 ;
19
20
const MAX_NAME_LEN : usize = 255 ;
20
21
const RETRANSMIT_DELAY : Duration = Duration :: from_millis ( 1_000 ) ;
21
22
const MAX_RETRANSMIT_DELAY : Duration = Duration :: from_millis ( 10_000 ) ;
22
23
const RETRANSMIT_TIMEOUT : Duration = Duration :: from_millis ( 10_000 ) ; // Should generally be 2-10 secs
23
24
25
+ #[ cfg( feature = "proto-ipv6" ) ]
26
+ const MDNS_IPV6_ADDR : IpAddress = IpAddress :: Ipv6 ( crate :: wire:: Ipv6Address ( [
27
+ 0xff , 0x02 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0xfb ,
28
+ ] ) ) ;
29
+
30
+ #[ cfg( feature = "proto-ipv4" ) ]
31
+ const MDNS_IPV4_ADDR : IpAddress = IpAddress :: Ipv4 ( crate :: wire:: Ipv4Address ( [ 224 , 0 , 0 , 251 ] ) ) ;
32
+
24
33
/// Error returned by [`Socket::start_query`]
25
34
#[ derive( Debug , PartialEq , Eq , Clone , Copy ) ]
26
35
#[ cfg_attr( feature = "defmt" , derive( defmt:: Format ) ) ]
@@ -81,6 +90,14 @@ struct PendingQuery {
81
90
delay : Duration ,
82
91
83
92
server_idx : usize ,
93
+ mdns : MulticastDns ,
94
+ }
95
+
96
+ #[ derive( Debug ) ]
97
+ pub enum MulticastDns {
98
+ Disabled ,
99
+ #[ cfg( feature = "socket-mdns" ) ]
100
+ Enabled ,
84
101
}
85
102
86
103
#[ derive( Debug ) ]
@@ -185,6 +202,7 @@ impl<'a> Socket<'a> {
185
202
& mut self ,
186
203
cx : & mut Context ,
187
204
name : & str ,
205
+ query_type : Type ,
188
206
) -> Result < QueryHandle , StartQueryError > {
189
207
let mut name = name. as_bytes ( ) ;
190
208
@@ -200,6 +218,13 @@ impl<'a> Socket<'a> {
200
218
201
219
let mut raw_name: Vec < u8 , MAX_NAME_LEN > = Vec :: new ( ) ;
202
220
221
+ let mut mdns = MulticastDns :: Disabled ;
222
+ #[ cfg( feature = "socket-mdns" ) ]
223
+ if name. split ( |& c| c == b'.' ) . last ( ) . unwrap ( ) == b"local" {
224
+ net_trace ! ( "Starting a mDNS query" ) ;
225
+ mdns = MulticastDns :: Enabled ;
226
+ }
227
+
203
228
for s in name. split ( |& c| c == b'.' ) {
204
229
if s. len ( ) > 63 {
205
230
net_trace ! ( "invalid name: too long label" ) ;
@@ -224,7 +249,7 @@ impl<'a> Socket<'a> {
224
249
. push ( 0x00 )
225
250
. map_err ( |_| StartQueryError :: NameTooLong ) ?;
226
251
227
- self . start_query_raw ( cx, & raw_name)
252
+ self . start_query_raw ( cx, & raw_name, query_type , mdns )
228
253
}
229
254
230
255
/// Start a query with a raw (wire-format) DNS name.
@@ -235,19 +260,22 @@ impl<'a> Socket<'a> {
235
260
& mut self ,
236
261
cx : & mut Context ,
237
262
raw_name : & [ u8 ] ,
263
+ query_type : Type ,
264
+ mdns : MulticastDns ,
238
265
) -> Result < QueryHandle , StartQueryError > {
239
266
let handle = self . find_free_query ( ) . ok_or ( StartQueryError :: NoFreeSlot ) ?;
240
267
241
268
self . queries [ handle. 0 ] = Some ( DnsQuery {
242
269
state : State :: Pending ( PendingQuery {
243
270
name : Vec :: from_slice ( raw_name) . map_err ( |_| StartQueryError :: NameTooLong ) ?,
244
- type_ : Type :: A ,
271
+ type_ : query_type ,
245
272
txid : cx. rand ( ) . rand_u16 ( ) ,
246
273
port : cx. rand ( ) . rand_source_port ( ) ,
247
274
delay : RETRANSMIT_DELAY ,
248
275
timeout_at : None ,
249
276
retransmit_at : Instant :: ZERO ,
250
277
server_idx : 0 ,
278
+ mdns,
251
279
} ) ,
252
280
#[ cfg( feature = "async" ) ]
253
281
waker : WakerRegistration :: new ( ) ,
@@ -313,11 +341,12 @@ impl<'a> Socket<'a> {
313
341
}
314
342
315
343
pub ( crate ) fn accepts ( & self , ip_repr : & IpRepr , udp_repr : & UdpRepr ) -> bool {
316
- udp_repr. src_port == DNS_PORT
344
+ ( udp_repr. src_port == DNS_PORT
317
345
&& self
318
346
. servers
319
347
. iter ( )
320
- . any ( |server| * server == ip_repr. src_addr ( ) )
348
+ . any ( |server| * server == ip_repr. src_addr ( ) ) )
349
+ || ( udp_repr. src_port == MDNS_DNS_PORT )
321
350
}
322
351
323
352
pub ( crate ) fn process (
@@ -482,6 +511,20 @@ impl<'a> Socket<'a> {
482
511
483
512
for q in self . queries . iter_mut ( ) . flatten ( ) {
484
513
if let State :: Pending ( pq) = & mut q. state {
514
+ // As per RFC 6762 any DNS query ending in .local. MUST be sent as mdns
515
+ // so we internally overwrite the servers for any of those queries
516
+ // in this function.
517
+ let servers = match pq. mdns {
518
+ #[ cfg( feature = "socket-mdns" ) ]
519
+ MulticastDns :: Enabled => & [
520
+ #[ cfg( feature = "proto-ipv6" ) ]
521
+ MDNS_IPV6_ADDR ,
522
+ #[ cfg( feature = "proto-ipv4" ) ]
523
+ MDNS_IPV4_ADDR ,
524
+ ] ,
525
+ MulticastDns :: Disabled => self . servers . as_slice ( ) ,
526
+ } ;
527
+
485
528
let timeout = if let Some ( timeout) = pq. timeout_at {
486
529
timeout
487
530
} else {
@@ -500,16 +543,15 @@ impl<'a> Socket<'a> {
500
543
// Try next server. We check below whether we've tried all servers.
501
544
pq. server_idx += 1 ;
502
545
}
503
-
504
546
// Check if we've run out of servers to try.
505
- if pq. server_idx >= self . servers . len ( ) {
547
+ if pq. server_idx >= servers. len ( ) {
506
548
net_trace ! ( "already tried all servers." ) ;
507
549
q. set_state ( State :: Failure ) ;
508
550
continue ;
509
551
}
510
552
511
553
// Check so the IP address is valid
512
- if self . servers [ pq. server_idx ] . is_unspecified ( ) {
554
+ if servers[ pq. server_idx ] . is_unspecified ( ) {
513
555
net_trace ! ( "invalid unspecified DNS server addr." ) ;
514
556
q. set_state ( State :: Failure ) ;
515
557
continue ;
@@ -526,20 +568,26 @@ impl<'a> Socket<'a> {
526
568
opcode : Opcode :: Query ,
527
569
question : Question {
528
570
name : & pq. name ,
529
- type_ : Type :: A ,
571
+ type_ : pq . type_ ,
530
572
} ,
531
573
} ;
532
574
533
575
let mut payload = [ 0u8 ; 512 ] ;
534
576
let payload = & mut payload[ ..repr. buffer_len ( ) ] ;
535
577
repr. emit ( & mut Packet :: new_unchecked ( payload) ) ;
536
578
579
+ let dst_port = match pq. mdns {
580
+ #[ cfg( feature = "socket-mdns" ) ]
581
+ MulticastDns :: Enabled => MDNS_DNS_PORT ,
582
+ MulticastDns :: Disabled => DNS_PORT ,
583
+ } ;
584
+
537
585
let udp_repr = UdpRepr {
538
586
src_port : pq. port ,
539
- dst_port : 53 ,
587
+ dst_port,
540
588
} ;
541
589
542
- let dst_addr = self . servers [ pq. server_idx ] ;
590
+ let dst_addr = servers[ pq. server_idx ] ;
543
591
let src_addr = cx. get_source_address ( dst_addr) . unwrap ( ) ; // TODO remove unwrap
544
592
let ip_repr = IpRepr :: new (
545
593
src_addr,
@@ -550,7 +598,7 @@ impl<'a> Socket<'a> {
550
598
) ;
551
599
552
600
net_trace ! (
553
- "sending {} octets to {:?}: {}" ,
601
+ "sending {} octets to {} from port {}" ,
554
602
payload. len( ) ,
555
603
ip_repr. dst_addr( ) ,
556
604
udp_repr. src_port
0 commit comments