diff --git a/src/maxminddb/decoder.rs b/src/maxminddb/decoder.rs index 84bbb541..041f712f 100644 --- a/src/maxminddb/decoder.rs +++ b/src/maxminddb/decoder.rs @@ -379,7 +379,7 @@ struct ArrayAccess<'a, 'de: 'a> { // `SeqAccess` is provided to the `Visitor` to give it the ability to iterate // through elements of the sequence. -impl<'de, 'a> SeqAccess<'de> for ArrayAccess<'a, 'de> { +impl<'de> SeqAccess<'de> for ArrayAccess<'_, 'de> { type Error = MaxMindDBError; fn next_element_seed(&mut self, seed: T) -> DecodeResult> @@ -404,7 +404,7 @@ struct MapAccessor<'a, 'de: 'a> { // `MapAccess` is provided to the `Visitor` to give it the ability to iterate // through entries of the map. -impl<'de, 'a> MapAccess<'de> for MapAccessor<'a, 'de> { +impl<'de> MapAccess<'de> for MapAccessor<'_, 'de> { type Error = MaxMindDBError; fn next_key_seed(&mut self, seed: K) -> DecodeResult> diff --git a/src/maxminddb/lib.rs b/src/maxminddb/lib.rs index b9705c05..d8de143a 100644 --- a/src/maxminddb/lib.rs +++ b/src/maxminddb/lib.rs @@ -78,7 +78,7 @@ pub struct Metadata { #[derive(Debug)] struct WithinNode { node: usize, - ip_bytes: Vec, + ip_int: IpInt, prefix_len: usize, } @@ -96,18 +96,54 @@ pub struct WithinItem { pub info: T, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum IpInt { + V4(u32), + V6(u128), +} + +impl IpInt { + fn new(ip_addr: IpAddr) -> Self { + match ip_addr { + IpAddr::V4(v4) => IpInt::V4(v4.into()), + IpAddr::V6(v6) => IpInt::V6(v6.into()), + } + } + + fn get_bit(&self, index: usize) -> bool { + match self { + IpInt::V4(ip) => (ip >> (31 - index)) & 1 == 1, + IpInt::V6(ip) => (ip >> (127 - index)) & 1 == 1, + } + } + + fn bit_count(&self) -> usize { + match self { + IpInt::V4(_) => 32, + IpInt::V6(_) => 128, + } + } + + fn is_ipv4_in_ipv6(&self) -> bool { + match self { + IpInt::V4(_) => false, + IpInt::V6(ip) => *ip <= 0xFFFFFFFF, + } + } +} + impl<'de, T: Deserialize<'de>, S: AsRef<[u8]>> Iterator for Within<'de, T, S> { type Item = Result, MaxMindDBError>; fn next(&mut self) -> Option { while let Some(current) = self.stack.pop() { - let bit_count = current.ip_bytes.len() * 8; + let bit_count = current.ip_int.bit_count(); // Skip networks that are aliases for the IPv4 network if self.reader.ipv4_start != 0 && current.node == self.reader.ipv4_start && bit_count == 128 - && current.ip_bytes[..12].iter().any(|&b| b != 0) + && !current.ip_int.is_ipv4_in_ipv6() { continue; } @@ -115,13 +151,11 @@ impl<'de, T: Deserialize<'de>, S: AsRef<[u8]>> Iterator for Within<'de, T, S> { match current.node.cmp(&self.node_count) { Ordering::Greater => { // This is a data node, emit it and we're done (until the following next call) - let ip_net = match bytes_and_prefix_to_net( - ¤t.ip_bytes, - current.prefix_len as u8, - ) { - Ok(ip_net) => ip_net, - Err(e) => return Some(Err(e)), - }; + let ip_net = + match bytes_and_prefix_to_net(¤t.ip_int, current.prefix_len as u8) { + Ok(ip_net) => ip_net, + Err(e) => return Some(Err(e)), + }; // TODO: should this block become a helper method on reader? let rec = match self.reader.resolve_data_pointer(current.node) { Ok(rec) => rec, @@ -142,16 +176,23 @@ impl<'de, T: Deserialize<'de>, S: AsRef<[u8]>> Iterator for Within<'de, T, S> { Ordering::Less => { // In order traversal of our children // right/1-bit - let mut right_ip_bytes = current.ip_bytes.clone(); - right_ip_bytes[current.prefix_len >> 3] |= - 1 << ((bit_count - current.prefix_len - 1) % 8); + let mut right_ip_int = current.ip_int; + + if current.prefix_len < bit_count { + let bit = current.prefix_len; + match &mut right_ip_int { + IpInt::V4(ip) => *ip |= 1 << (31 - bit), + IpInt::V6(ip) => *ip |= 1 << (127 - bit), + }; + } + let node = match self.reader.read_node(current.node, 1) { Ok(node) => node, Err(e) => return Some(Err(e)), }; self.stack.push(WithinNode { node, - ip_bytes: right_ip_bytes, + ip_int: right_ip_int, prefix_len: current.prefix_len + 1, }); // left/0-bit @@ -161,7 +202,7 @@ impl<'de, T: Deserialize<'de>, S: AsRef<[u8]>> Iterator for Within<'de, T, S> { }; self.stack.push(WithinNode { node, - ip_bytes: current.ip_bytes.clone(), + ip_int: current.ip_int, prefix_len: current.prefix_len + 1, }); } @@ -283,8 +324,8 @@ impl<'de, S: AsRef<[u8]>> Reader { where T: Deserialize<'de>, { - let ip_bytes = ip_to_bytes(address); - let (pointer, prefix_len) = self.find_address_in_tree(&ip_bytes)?; + let ip_int = IpInt::new(address); + let (pointer, prefix_len) = self.find_address_in_tree(&ip_int)?; if pointer == 0 { return Err(MaxMindDBError::AddressNotFoundError( "Address not found in database".to_owned(), @@ -314,14 +355,14 @@ impl<'de, S: AsRef<[u8]>> Reader { /// println!("ip_net={}, city={:?}", item.ip_net, item.info); /// } /// ``` - pub fn within(&'de self, cidr: IpNetwork) -> Result, MaxMindDBError> + pub fn within(&'de self, cidr: IpNetwork) -> Result, MaxMindDBError> where T: Deserialize<'de>, { let ip_address = cidr.network(); let prefix_len = cidr.prefix() as usize; - let ip_bytes = ip_to_bytes(ip_address); - let bit_count = ip_bytes.len() * 8; + let ip_int = IpInt::new(ip_address); + let bit_count = ip_int.bit_count(); let mut node = self.start_node(bit_count); let node_count = self.metadata.node_count as usize; @@ -331,8 +372,8 @@ impl<'de, S: AsRef<[u8]>> Reader { // Traverse down the tree to the level that matches the cidr mark let mut i = 0_usize; while i < prefix_len { - let bit = 1 & (ip_bytes[i >> 3] >> (7 - (i % 8))) as usize; - node = self.read_node(node, bit)?; + let bit = ip_int.get_bit(i); + node = self.read_node(node, bit as usize)?; if node >= node_count { // We've hit a dead end before we exhausted our prefix break; @@ -346,7 +387,7 @@ impl<'de, S: AsRef<[u8]>> Reader { // traversed to as our to be processed stack. stack.push(WithinNode { node, - ip_bytes, + ip_int, prefix_len, }); } @@ -363,8 +404,8 @@ impl<'de, S: AsRef<[u8]>> Reader { Ok(within) } - fn find_address_in_tree(&self, ip_address: &[u8]) -> Result<(usize, usize), MaxMindDBError> { - let bit_count = ip_address.len() * 8; + fn find_address_in_tree(&self, ip_int: &IpInt) -> Result<(usize, usize), MaxMindDBError> { + let bit_count = ip_int.bit_count(); let mut node = self.start_node(bit_count); let node_count = self.metadata.node_count as usize; @@ -375,8 +416,7 @@ impl<'de, S: AsRef<[u8]>> Reader { prefix_len = i; break; } - let bit = 1 & (ip_address[i >> 3] >> (7 - (i % 8))); - + let bit = ip_int.get_bit(i); node = self.read_node(node, bit as usize)?; } match node_count { @@ -469,60 +509,16 @@ fn to_usize(base: u8, bytes: &[u8]) -> usize { .fold(base as usize, |acc, &b| (acc << 8) | b as usize) } -fn ip_to_bytes(address: IpAddr) -> Vec { - match address { - IpAddr::V4(a) => a.octets().to_vec(), - IpAddr::V6(a) => a.octets().to_vec(), - } -} - -#[allow(clippy::many_single_char_names)] -fn bytes_and_prefix_to_net(bytes: &[u8], prefix: u8) -> Result { - let (ip, pre) = match bytes.len() { - 4 => ( - IpAddr::V4(Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3])), - prefix, - ), - 16 => { - if bytes[0] == 0 - && bytes[1] == 0 - && bytes[2] == 0 - && bytes[3] == 0 - && bytes[4] == 0 - && bytes[5] == 0 - && bytes[6] == 0 - && bytes[7] == 0 - && bytes[8] == 0 - && bytes[9] == 0 - && bytes[10] == 0 - && bytes[11] == 0 - { - // It's actually v4, but in v6 form, convert would be nice if ipnetwork had this - // logic. - ( - IpAddr::V4(Ipv4Addr::new(bytes[12], bytes[13], bytes[14], bytes[15])), - prefix - 96, - ) - } else { - let a = u16::from(bytes[0]) << 8 | u16::from(bytes[1]); - let b = u16::from(bytes[2]) << 8 | u16::from(bytes[3]); - let c = u16::from(bytes[4]) << 8 | u16::from(bytes[5]); - let d = u16::from(bytes[6]) << 8 | u16::from(bytes[7]); - let e = u16::from(bytes[8]) << 8 | u16::from(bytes[9]); - let f = u16::from(bytes[10]) << 8 | u16::from(bytes[11]); - let g = u16::from(bytes[12]) << 8 | u16::from(bytes[13]); - let h = u16::from(bytes[14]) << 8 | u16::from(bytes[15]); - (IpAddr::V6(Ipv6Addr::new(a, b, c, d, e, f, g, h)), prefix) - } - } - // This should never happen - _ => { - return Err(MaxMindDBError::InvalidNetworkError( - "invalid address".to_owned(), - )) +#[inline] +fn bytes_and_prefix_to_net(bytes: &IpInt, prefix: u8) -> Result { + let (ip, prefix) = match bytes { + IpInt::V4(ip) => (IpAddr::V4(Ipv4Addr::from(*ip)), prefix), + IpInt::V6(ip) if bytes.is_ipv4_in_ipv6() => { + (IpAddr::V4(Ipv4Addr::from(*ip as u32)), prefix - 96) } + IpInt::V6(ip) => (IpAddr::V6(Ipv6Addr::from(*ip)), prefix), }; - IpNetwork::new(ip, pre).map_err(|e| MaxMindDBError::InvalidNetworkError(e.to_string())) + IpNetwork::new(ip, prefix).map_err(|e| MaxMindDBError::InvalidNetworkError(e.to_string())) } fn find_metadata_start(buf: &[u8]) -> Result {