Skip to content

Commit a9634dd

Browse files
committed
Create a new IpBytes enum
1 parent a74ad8c commit a9634dd

File tree

1 file changed

+111
-73
lines changed

1 file changed

+111
-73
lines changed

src/maxminddb/lib.rs

Lines changed: 111 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ pub struct Metadata {
7878
#[derive(Debug)]
7979
struct WithinNode {
8080
node: usize,
81-
ip_bytes: Vec<u8>,
81+
ip_bytes: IpBytes,
8282
prefix_len: usize,
8383
}
8484

@@ -96,6 +96,85 @@ pub struct WithinItem<T> {
9696
pub info: T,
9797
}
9898

99+
#[derive(Debug, Clone, Copy)]
100+
enum IpBytes {
101+
V4([u8; 4]),
102+
V6([u8; 16]),
103+
}
104+
105+
impl IpBytes {
106+
fn new(ip_addr: IpAddr) -> Self {
107+
match ip_addr {
108+
IpAddr::V4(v4) => IpBytes::V4(v4.octets()),
109+
IpAddr::V6(v6) => IpBytes::V6(v6.octets()),
110+
}
111+
}
112+
113+
fn len(&self) -> usize {
114+
match self {
115+
IpBytes::V4(_) => 4,
116+
IpBytes::V6(_) => 16,
117+
}
118+
}
119+
120+
fn get_byte(&self, index: usize) -> Option<u8> {
121+
match self {
122+
IpBytes::V4(bytes) => bytes.get(index).copied(),
123+
IpBytes::V6(bytes) => bytes.get(index).copied(),
124+
}
125+
}
126+
fn set_byte(&mut self, index: usize, value: u8) -> Result<(), MaxMindDBError> {
127+
match self {
128+
IpBytes::V4(bytes) => {
129+
if index < 4 {
130+
bytes[index] = value;
131+
Ok(())
132+
} else {
133+
Err(MaxMindDBError::InvalidNetworkError(
134+
"Index out of range for Ipv4".to_string(),
135+
))
136+
}
137+
}
138+
139+
IpBytes::V6(bytes) => {
140+
if index < 16 {
141+
bytes[index] = value;
142+
Ok(())
143+
} else {
144+
Err(MaxMindDBError::InvalidNetworkError(
145+
"Index out of range for Ipv6".to_string(),
146+
))
147+
}
148+
}
149+
}
150+
}
151+
152+
fn is_ipv4_in_ipv6(&self) -> bool {
153+
match self{
154+
IpBytes::V6(bytes) => bytes[..12] == [0; 12],
155+
_ => false
156+
}
157+
}
158+
159+
fn to_ip_addr(&self) -> IpAddr {
160+
match self {
161+
IpBytes::V4(bytes) => IpAddr::V4(Ipv4Addr::from(*bytes)),
162+
IpBytes::V6(bytes) => {
163+
if self.is_ipv4_in_ipv6() {
164+
IpAddr::V4(Ipv4Addr::new(
165+
bytes[12],
166+
bytes[13],
167+
bytes[14],
168+
bytes[15],
169+
))
170+
} else {
171+
IpAddr::V6(Ipv6Addr::from(*bytes))
172+
}
173+
}
174+
}
175+
}
176+
}
177+
99178
impl<'de, T: Deserialize<'de>, S: AsRef<[u8]>> Iterator for Within<'de, T, S> {
100179
type Item = Result<WithinItem<T>, MaxMindDBError>;
101180

@@ -107,21 +186,20 @@ impl<'de, T: Deserialize<'de>, S: AsRef<[u8]>> Iterator for Within<'de, T, S> {
107186
if self.reader.ipv4_start != 0
108187
&& current.node == self.reader.ipv4_start
109188
&& bit_count == 128
110-
&& current.ip_bytes[..12].iter().any(|&b| b != 0)
189+
&& !current.ip_bytes.is_ipv4_in_ipv6()
111190
{
112191
continue;
113192
}
114193

115194
match current.node.cmp(&self.node_count) {
116195
Ordering::Greater => {
117196
// This is a data node, emit it and we're done (until the following next call)
118-
let ip_net = match bytes_and_prefix_to_net(
119-
&current.ip_bytes,
120-
current.prefix_len as u8,
121-
) {
122-
Ok(ip_net) => ip_net,
123-
Err(e) => return Some(Err(e)),
124-
};
197+
let ip_net =
198+
match bytes_and_prefix_to_net(&current.ip_bytes, current.prefix_len as u8)
199+
{
200+
Ok(ip_net) => ip_net,
201+
Err(e) => return Some(Err(e)),
202+
};
125203
// TODO: should this block become a helper method on reader?
126204
let rec = match self.reader.resolve_data_pointer(current.node) {
127205
Ok(rec) => rec,
@@ -142,9 +220,14 @@ impl<'de, T: Deserialize<'de>, S: AsRef<[u8]>> Iterator for Within<'de, T, S> {
142220
Ordering::Less => {
143221
// In order traversal of our children
144222
// right/1-bit
145-
let mut right_ip_bytes = current.ip_bytes.clone();
146-
right_ip_bytes[current.prefix_len >> 3] |=
147-
1 << ((bit_count - current.prefix_len - 1) % 8);
223+
let mut right_ip_bytes = current.ip_bytes;
224+
if let Some(mut byte) =
225+
right_ip_bytes.get_byte(current.prefix_len >> 3)
226+
{
227+
byte |= 1 << ((bit_count - current.prefix_len - 1) % 8);
228+
let _ = right_ip_bytes.set_byte(current.prefix_len >> 3, byte); //safe due to bound check
229+
}
230+
148231
let node = match self.reader.read_node(current.node, 1) {
149232
Ok(node) => node,
150233
Err(e) => return Some(Err(e)),
@@ -161,7 +244,7 @@ impl<'de, T: Deserialize<'de>, S: AsRef<[u8]>> Iterator for Within<'de, T, S> {
161244
};
162245
self.stack.push(WithinNode {
163246
node,
164-
ip_bytes: current.ip_bytes.clone(),
247+
ip_bytes: current.ip_bytes,
165248
prefix_len: current.prefix_len + 1,
166249
});
167250
}
@@ -283,7 +366,7 @@ impl<'de, S: AsRef<[u8]>> Reader<S> {
283366
where
284367
T: Deserialize<'de>,
285368
{
286-
let ip_bytes = ip_to_bytes(address);
369+
let ip_bytes = IpBytes::new(address);
287370
let (pointer, prefix_len) = self.find_address_in_tree(&ip_bytes)?;
288371
if pointer == 0 {
289372
return Err(MaxMindDBError::AddressNotFoundError(
@@ -320,7 +403,7 @@ impl<'de, S: AsRef<[u8]>> Reader<S> {
320403
{
321404
let ip_address = cidr.network();
322405
let prefix_len = cidr.prefix() as usize;
323-
let ip_bytes = ip_to_bytes(ip_address);
406+
let ip_bytes = IpBytes::new(ip_address);
324407
let bit_count = ip_bytes.len() * 8;
325408

326409
let mut node = self.start_node(bit_count);
@@ -331,7 +414,8 @@ impl<'de, S: AsRef<[u8]>> Reader<S> {
331414
// Traverse down the tree to the level that matches the cidr mark
332415
let mut i = 0_usize;
333416
while i < prefix_len {
334-
let bit = 1 & (ip_bytes[i >> 3] >> (7 - (i % 8))) as usize;
417+
let byte = ip_bytes.get_byte(i >> 3).unwrap(); //Safe due to the previous bound check
418+
let bit = 1 & (byte >> (7 - (i % 8))) as usize;
335419
node = self.read_node(node, bit)?;
336420
if node >= node_count {
337421
// We've hit a dead end before we exhausted our prefix
@@ -363,8 +447,8 @@ impl<'de, S: AsRef<[u8]>> Reader<S> {
363447
Ok(within)
364448
}
365449

366-
fn find_address_in_tree(&self, ip_address: &[u8]) -> Result<(usize, usize), MaxMindDBError> {
367-
let bit_count = ip_address.len() * 8;
450+
fn find_address_in_tree(&self, ip_bytes: &IpBytes) -> Result<(usize, usize), MaxMindDBError> {
451+
let bit_count = ip_bytes.len() * 8;
368452
let mut node = self.start_node(bit_count);
369453

370454
let node_count = self.metadata.node_count as usize;
@@ -375,7 +459,8 @@ impl<'de, S: AsRef<[u8]>> Reader<S> {
375459
prefix_len = i;
376460
break;
377461
}
378-
let bit = 1 & (ip_address[i >> 3] >> (7 - (i % 8)));
462+
let byte = ip_bytes.get_byte(i >> 3).unwrap(); //Safe due to bound check
463+
let bit = 1 & (byte >> (7 - (i % 8)));
379464

380465
node = self.read_node(node, bit as usize)?;
381466
}
@@ -468,61 +553,14 @@ fn to_usize(base: u8, bytes: &[u8]) -> usize {
468553
.iter()
469554
.fold(base as usize, |acc, &b| (acc << 8) | b as usize)
470555
}
471-
472-
fn ip_to_bytes(address: IpAddr) -> Vec<u8> {
473-
match address {
474-
IpAddr::V4(a) => a.octets().to_vec(),
475-
IpAddr::V6(a) => a.octets().to_vec(),
476-
}
477-
}
478-
479-
#[allow(clippy::many_single_char_names)]
480-
fn bytes_and_prefix_to_net(bytes: &[u8], prefix: u8) -> Result<IpNetwork, MaxMindDBError> {
481-
let (ip, pre) = match bytes.len() {
482-
4 => (
483-
IpAddr::V4(Ipv4Addr::new(bytes[0], bytes[1], bytes[2], bytes[3])),
484-
prefix,
485-
),
486-
16 => {
487-
if bytes[0] == 0
488-
&& bytes[1] == 0
489-
&& bytes[2] == 0
490-
&& bytes[3] == 0
491-
&& bytes[4] == 0
492-
&& bytes[5] == 0
493-
&& bytes[6] == 0
494-
&& bytes[7] == 0
495-
&& bytes[8] == 0
496-
&& bytes[9] == 0
497-
&& bytes[10] == 0
498-
&& bytes[11] == 0
499-
{
500-
// It's actually v4, but in v6 form, convert would be nice if ipnetwork had this
501-
// logic.
502-
(
503-
IpAddr::V4(Ipv4Addr::new(bytes[12], bytes[13], bytes[14], bytes[15])),
504-
prefix - 96,
505-
)
506-
} else {
507-
let a = u16::from(bytes[0]) << 8 | u16::from(bytes[1]);
508-
let b = u16::from(bytes[2]) << 8 | u16::from(bytes[3]);
509-
let c = u16::from(bytes[4]) << 8 | u16::from(bytes[5]);
510-
let d = u16::from(bytes[6]) << 8 | u16::from(bytes[7]);
511-
let e = u16::from(bytes[8]) << 8 | u16::from(bytes[9]);
512-
let f = u16::from(bytes[10]) << 8 | u16::from(bytes[11]);
513-
let g = u16::from(bytes[12]) << 8 | u16::from(bytes[13]);
514-
let h = u16::from(bytes[14]) << 8 | u16::from(bytes[15]);
515-
(IpAddr::V6(Ipv6Addr::new(a, b, c, d, e, f, g, h)), prefix)
516-
}
517-
}
518-
// This should never happen
519-
_ => {
520-
return Err(MaxMindDBError::InvalidNetworkError(
521-
"invalid address".to_owned(),
522-
))
523-
}
556+
fn bytes_and_prefix_to_net(bytes: &IpBytes, prefix: u8) -> Result<IpNetwork, MaxMindDBError> {
557+
let ip = bytes.to_ip_addr();
558+
let adjusted_prefix = match (ip, bytes, prefix) {
559+
(IpAddr::V4(_), IpBytes::V6(_), p) => p - 96,
560+
(_, _, p) => p
524561
};
525-
IpNetwork::new(ip, pre).map_err(|e| MaxMindDBError::InvalidNetworkError(e.to_string()))
562+
IpNetwork::new(ip, adjusted_prefix)
563+
.map_err(|e| MaxMindDBError::InvalidNetworkError(e.to_string()))
526564
}
527565

528566
fn find_metadata_start(buf: &[u8]) -> Result<usize, MaxMindDBError> {

0 commit comments

Comments
 (0)