Skip to content

Commit 62e6231

Browse files
committed
Create a new enum for internal use of IP address
1 parent a74ad8c commit 62e6231

File tree

1 file changed

+74
-77
lines changed

1 file changed

+74
-77
lines changed

src/maxminddb/lib.rs

Lines changed: 74 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ pub struct Metadata {
7878
#[derive(Debug)]
7979
struct WithinNode {
8080
node: usize,
81-
ip_bytes: Vec<u8>,
81+
ip_int: IpInt,
8282
prefix_len: usize,
8383
}
8484

@@ -96,32 +96,66 @@ pub struct WithinItem<T> {
9696
pub info: T,
9797
}
9898

99+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100+
enum IpInt {
101+
V4(u32),
102+
V6(u128),
103+
}
104+
105+
impl IpInt {
106+
fn new(ip_addr: IpAddr) -> Self {
107+
match ip_addr {
108+
IpAddr::V4(v4) => IpInt::V4(v4.into()),
109+
IpAddr::V6(v6) => IpInt::V6(v6.into()),
110+
}
111+
}
112+
113+
fn get_bit(&self, index: usize) -> bool {
114+
match self {
115+
IpInt::V4(ip) => (ip >> (31 - index)) & 1 == 1,
116+
IpInt::V6(ip) => (ip >> (127 - index)) & 1 == 1,
117+
}
118+
}
119+
120+
fn bit_count(&self) -> usize {
121+
match self {
122+
IpInt::V4(_) => 32,
123+
IpInt::V6(_) => 128,
124+
}
125+
}
126+
127+
fn is_ipv4_in_ipv6(&self) -> bool {
128+
match self {
129+
IpInt::V4(_) => false,
130+
IpInt::V6(ip) => *ip <= 0xFFFFFFFF,
131+
}
132+
}
133+
}
134+
99135
impl<'de, T: Deserialize<'de>, S: AsRef<[u8]>> Iterator for Within<'de, T, S> {
100136
type Item = Result<WithinItem<T>, MaxMindDBError>;
101137

102138
fn next(&mut self) -> Option<Self::Item> {
103139
while let Some(current) = self.stack.pop() {
104-
let bit_count = current.ip_bytes.len() * 8;
140+
let bit_count = current.ip_int.bit_count();
105141

106142
// Skip networks that are aliases for the IPv4 network
107143
if self.reader.ipv4_start != 0
108144
&& current.node == self.reader.ipv4_start
109145
&& bit_count == 128
110-
&& current.ip_bytes[..12].iter().any(|&b| b != 0)
146+
&& !current.ip_int.is_ipv4_in_ipv6()
111147
{
112148
continue;
113149
}
114150

115151
match current.node.cmp(&self.node_count) {
116152
Ordering::Greater => {
117153
// This is a data node, emit it and we're done (until the following next call)
118-
let ip_net = match bytes_and_prefix_to_net(
119-
&current.ip_bytes,
120-
current.prefix_len as u8,
121-
) {
122-
Ok(ip_net) => ip_net,
123-
Err(e) => return Some(Err(e)),
124-
};
154+
let ip_net =
155+
match bytes_and_prefix_to_net(&current.ip_int, current.prefix_len as u8) {
156+
Ok(ip_net) => ip_net,
157+
Err(e) => return Some(Err(e)),
158+
};
125159
// TODO: should this block become a helper method on reader?
126160
let rec = match self.reader.resolve_data_pointer(current.node) {
127161
Ok(rec) => rec,
@@ -142,16 +176,23 @@ impl<'de, T: Deserialize<'de>, S: AsRef<[u8]>> Iterator for Within<'de, T, S> {
142176
Ordering::Less => {
143177
// In order traversal of our children
144178
// right/1-bit
145-
let mut right_ip_bytes = current.ip_bytes.clone();
146-
right_ip_bytes[current.prefix_len >> 3] |=
147-
1 << ((bit_count - current.prefix_len - 1) % 8);
179+
let mut right_ip_int = current.ip_int;
180+
181+
if current.prefix_len < bit_count {
182+
let bit = current.prefix_len;
183+
match &mut right_ip_int {
184+
IpInt::V4(ip) => *ip |= 1 << (31 - bit),
185+
IpInt::V6(ip) => *ip |= 1 << (127 - bit),
186+
};
187+
}
188+
148189
let node = match self.reader.read_node(current.node, 1) {
149190
Ok(node) => node,
150191
Err(e) => return Some(Err(e)),
151192
};
152193
self.stack.push(WithinNode {
153194
node,
154-
ip_bytes: right_ip_bytes,
195+
ip_int: right_ip_int,
155196
prefix_len: current.prefix_len + 1,
156197
});
157198
// left/0-bit
@@ -161,7 +202,7 @@ impl<'de, T: Deserialize<'de>, S: AsRef<[u8]>> Iterator for Within<'de, T, S> {
161202
};
162203
self.stack.push(WithinNode {
163204
node,
164-
ip_bytes: current.ip_bytes.clone(),
205+
ip_int: current.ip_int,
165206
prefix_len: current.prefix_len + 1,
166207
});
167208
}
@@ -283,8 +324,8 @@ impl<'de, S: AsRef<[u8]>> Reader<S> {
283324
where
284325
T: Deserialize<'de>,
285326
{
286-
let ip_bytes = ip_to_bytes(address);
287-
let (pointer, prefix_len) = self.find_address_in_tree(&ip_bytes)?;
327+
let ip_int = IpInt::new(address);
328+
let (pointer, prefix_len) = self.find_address_in_tree(&ip_int)?;
288329
if pointer == 0 {
289330
return Err(MaxMindDBError::AddressNotFoundError(
290331
"Address not found in database".to_owned(),
@@ -320,8 +361,8 @@ impl<'de, S: AsRef<[u8]>> Reader<S> {
320361
{
321362
let ip_address = cidr.network();
322363
let prefix_len = cidr.prefix() as usize;
323-
let ip_bytes = ip_to_bytes(ip_address);
324-
let bit_count = ip_bytes.len() * 8;
364+
let ip_int = IpInt::new(ip_address);
365+
let bit_count = ip_int.bit_count();
325366

326367
let mut node = self.start_node(bit_count);
327368
let node_count = self.metadata.node_count as usize;
@@ -331,8 +372,8 @@ impl<'de, S: AsRef<[u8]>> Reader<S> {
331372
// Traverse down the tree to the level that matches the cidr mark
332373
let mut i = 0_usize;
333374
while i < prefix_len {
334-
let bit = 1 & (ip_bytes[i >> 3] >> (7 - (i % 8))) as usize;
335-
node = self.read_node(node, bit)?;
375+
let bit = ip_int.get_bit(i);
376+
node = self.read_node(node, bit as usize)?;
336377
if node >= node_count {
337378
// We've hit a dead end before we exhausted our prefix
338379
break;
@@ -346,7 +387,7 @@ impl<'de, S: AsRef<[u8]>> Reader<S> {
346387
// traversed to as our to be processed stack.
347388
stack.push(WithinNode {
348389
node,
349-
ip_bytes,
390+
ip_int,
350391
prefix_len,
351392
});
352393
}
@@ -363,8 +404,8 @@ impl<'de, S: AsRef<[u8]>> Reader<S> {
363404
Ok(within)
364405
}
365406

366-
fn find_address_in_tree(&self, ip_address: &[u8]) -> Result<(usize, usize), MaxMindDBError> {
367-
let bit_count = ip_address.len() * 8;
407+
fn find_address_in_tree(&self, ip_int: &IpInt) -> Result<(usize, usize), MaxMindDBError> {
408+
let bit_count = ip_int.bit_count();
368409
let mut node = self.start_node(bit_count);
369410

370411
let node_count = self.metadata.node_count as usize;
@@ -375,8 +416,7 @@ impl<'de, S: AsRef<[u8]>> Reader<S> {
375416
prefix_len = i;
376417
break;
377418
}
378-
let bit = 1 & (ip_address[i >> 3] >> (7 - (i % 8)));
379-
419+
let bit = ip_int.get_bit(i);
380420
node = self.read_node(node, bit as usize)?;
381421
}
382422
match node_count {
@@ -468,61 +508,18 @@ fn to_usize(base: u8, bytes: &[u8]) -> usize {
468508
.iter()
469509
.fold(base as usize, |acc, &b| (acc << 8) | b as usize)
470510
}
471-
472-
fn ip_to_bytes(address: IpAddr) -> Vec<u8> {
473-
match address {
474-
IpAddr::V4(a) => a.octets().to_vec(),
475-
IpAddr::V6(a) => a.octets().to_vec(),
476-
}
477-
}
478-
479-
#[allow(clippy::many_single_char_names)]
480-
fn bytes_and_prefix_to_net(bytes: &[u8], prefix: u8) -> Result<IpNetwork, MaxMindDBError> {
481-
let (ip, pre) = match bytes.len() {
482-
4 => (
483-
IpAddr::V4(Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3])),
484-
prefix,
485-
),
486-
16 => {
487-
if bytes[0] == 0
488-
&& bytes[1] == 0
489-
&& bytes[2] == 0
490-
&& bytes[3] == 0
491-
&& bytes[4] == 0
492-
&& bytes[5] == 0
493-
&& bytes[6] == 0
494-
&& bytes[7] == 0
495-
&& bytes[8] == 0
496-
&& bytes[9] == 0
497-
&& bytes[10] == 0
498-
&& bytes[11] == 0
499-
{
500-
// It's actually v4, but in v6 form, convert would be nice if ipnetwork had this
501-
// logic.
502-
(
503-
IpAddr::V4(Ipv4Addr::new(bytes[12], bytes[13], bytes[14], bytes[15])),
504-
prefix - 96,
505-
)
511+
fn bytes_and_prefix_to_net(bytes: &IpInt, prefix: u8) -> Result<IpNetwork, MaxMindDBError> {
512+
let (ip, adj_prefix) = match bytes {
513+
IpInt::V4(ip) => (IpAddr::V4(Ipv4Addr::from(*ip)), prefix),
514+
IpInt::V6(ip) => {
515+
if bytes.is_ipv4_in_ipv6() {
516+
(IpAddr::V4(Ipv4Addr::from(*ip as u32)), prefix - 96)
506517
} else {
507-
let a = u16::from(bytes[0]) << 8 | u16::from(bytes[1]);
508-
let b = u16::from(bytes[2]) << 8 | u16::from(bytes[3]);
509-
let c = u16::from(bytes[4]) << 8 | u16::from(bytes[5]);
510-
let d = u16::from(bytes[6]) << 8 | u16::from(bytes[7]);
511-
let e = u16::from(bytes[8]) << 8 | u16::from(bytes[9]);
512-
let f = u16::from(bytes[10]) << 8 | u16::from(bytes[11]);
513-
let g = u16::from(bytes[12]) << 8 | u16::from(bytes[13]);
514-
let h = u16::from(bytes[14]) << 8 | u16::from(bytes[15]);
515-
(IpAddr::V6(Ipv6Addr::new(a, b, c, d, e, f, g, h)), prefix)
518+
(IpAddr::V6(Ipv6Addr::from(*ip)), prefix)
516519
}
517520
}
518-
// This should never happen
519-
_ => {
520-
return Err(MaxMindDBError::InvalidNetworkError(
521-
"invalid address".to_owned(),
522-
))
523-
}
524521
};
525-
IpNetwork::new(ip, pre).map_err(|e| MaxMindDBError::InvalidNetworkError(e.to_string()))
522+
IpNetwork::new(ip, adj_prefix).map_err(|e| MaxMindDBError::InvalidNetworkError(e.to_string()))
526523
}
527524

528525
fn find_metadata_start(buf: &[u8]) -> Result<usize, MaxMindDBError> {

0 commit comments

Comments
 (0)