Skip to content

Commit 6eca9d5

Browse files
alexlaparainliu
authored andcommitted
optimizations
1 parent 4070845 commit 6eca9d5

File tree

8 files changed

+52
-78
lines changed

8 files changed

+52
-78
lines changed

srtp/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ rtcp = { version = "0.9.0", path = "../rtcp" }
2727
byteorder = "1"
2828
bytes = "1"
2929
thiserror = "1.0"
30-
hmac = { version = "0.12.1", features = ["std", "reset"] }
30+
hmac = { version = "0.12.1", features = ["std"] }
3131
sha1 = "0.10.5"
3232
ctr = "0.8.0"
3333
aes = "0.7.5"

srtp/benches/srtp_bench.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use criterion::{criterion_group, criterion_main, Criterion};
33
use util::Marshal;
44
use webrtc_srtp::{context::Context, protection_profile::ProtectionProfile};
55

6-
fn benchmark_buffer(c: &mut Criterion) {
6+
fn benchmark_encrypt_rtp_aes_128_cm_hmac_sha1(c: &mut Criterion) {
77
let mut ctx = Context::new(
88
&vec![
99
96, 180, 31, 4, 119, 137, 128, 252, 75, 194, 252, 44, 63, 56, 61, 55,
@@ -16,20 +16,18 @@ fn benchmark_buffer(c: &mut Criterion) {
1616
.unwrap();
1717

1818
let mut pld = BytesMut::new();
19-
20-
for i in 0..1000 {
19+
for i in 0..1200 {
2120
pld.extend_from_slice(&[i as u8]);
2221
}
2322

24-
let mut count = 1;
25-
2623
c.bench_function("Benchmark context ", |b| {
24+
let mut seq = 1;
2725
b.iter_batched(
2826
|| {
2927
let pkt = rtp::packet::Packet {
3028
header: rtp::header::Header {
31-
sequence_number: count,
32-
timestamp: count.into(),
29+
sequence_number: seq,
30+
timestamp: seq.into(),
3331
extension_profile: 48862,
3432
marker: true,
3533
padding: false,
@@ -39,7 +37,7 @@ fn benchmark_buffer(c: &mut Criterion) {
3937
},
4038
payload: pld.clone().into(),
4139
};
42-
count += 1;
40+
seq += 1;
4341
pkt.marshal().unwrap()
4442
},
4543
|pkt_raw| {
@@ -50,5 +48,5 @@ fn benchmark_buffer(c: &mut Criterion) {
5048
});
5149
}
5250

53-
criterion_group!(benches, benchmark_buffer);
51+
criterion_group!(benches, benchmark_encrypt_rtp_aes_128_cm_hmac_sha1);
5452
criterion_main!(benches);

srtp/src/cipher/cipher_aead_aes_gcm.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,19 @@ impl Cipher for CipherAeadAesGcm {
3333
roc: u32,
3434
) -> Result<Bytes> {
3535
// Grow the given buffer to fit the output.
36+
let header_len = header.marshal_size();
3637
let mut writer =
37-
BytesMut::with_capacity(header.marshal_size() + payload.len() + self.auth_tag_len());
38+
BytesMut::with_capacity(payload.len() + self.auth_tag_len());
3839

39-
let data = header.marshal()?;
40-
writer.extend(data);
40+
// Copy header unencrypted.
41+
writer.extend_from_slice(&payload[..header_len]);
4142

4243
let nonce = self.rtp_initialization_vector(header, roc);
4344

4445
let encrypted = self.srtp_cipher.encrypt(
4546
Nonce::from_slice(&nonce),
4647
Payload {
47-
msg: payload,
48+
msg: &payload[header_len..],
4849
aad: &writer,
4950
},
5051
)?;

srtp/src/cipher/cipher_aes_cm_hmac_sha1.rs

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -121,19 +121,15 @@ impl CipherAesCmHmacSha1 {
121121
/// - Authenticated portion of the packet is everything BEFORE MKI
122122
/// - k_a is the session message authentication key
123123
/// - n_tag is the bit-length of the output authentication tag
124-
fn generate_srtp_auth_tag(&mut self, buf: &[u8], roc: u32) -> Vec<u8> {
125-
self.srtp_session_auth.reset();
124+
fn generate_srtp_auth_tag(&self, buf: &[u8], roc: u32) -> [u8; 20] {
125+
let mut signer = self.srtp_session_auth.clone();
126126

127-
self.srtp_session_auth.update(buf);
127+
signer.update(buf);
128128

129129
// For SRTP only, we need to hash the rollover counter as well.
130-
self.srtp_session_auth.update(&roc.to_be_bytes());
130+
signer.update(&roc.to_be_bytes());
131131

132-
let result = self.srtp_session_auth.clone().finalize();
133-
let code_bytes = result.into_bytes();
134-
135-
// Truncate the hash to the first AUTH_TAG_SIZE bytes.
136-
code_bytes[0..self.auth_tag_len()].to_vec()
132+
signer.finalize().into_bytes().into()
137133
}
138134

139135
/// https://tools.ietf.org/html/rfc3711#section-4.2
@@ -147,13 +143,12 @@ impl CipherAesCmHmacSha1 {
147143
/// - Authenticated portion of the packet is everything BEFORE MKI
148144
/// - k_a is the session message authentication key
149145
/// - n_tag is the bit-length of the output authentication tag
150-
fn generate_srtcp_auth_tag(&mut self, buf: &[u8]) -> Vec<u8> {
151-
self.srtcp_session_auth.reset();
146+
fn generate_srtcp_auth_tag(&self, buf: &[u8]) -> Vec<u8> {
147+
let mut signer = self.srtcp_session_auth.clone();
152148

153-
self.srtcp_session_auth.update(buf);
149+
signer.update(buf);
154150

155-
let result = self.srtcp_session_auth.clone().finalize();
156-
let code_bytes = result.into_bytes();
151+
let code_bytes = signer.finalize().into_bytes();
157152

158153
// Truncate the hash to the first AUTH_TAG_SIZE bytes.
159154
code_bytes[0..self.auth_tag_len()].to_vec()
@@ -179,26 +174,26 @@ impl Cipher for CipherAesCmHmacSha1 {
179174
) -> Result<Bytes> {
180175
let header_len = header.marshal_size();
181176
let mut writer =
182-
BytesMut::with_capacity(header_len + payload.len() + self.auth_tag_len());
177+
Vec::with_capacity(payload.len() + self.auth_tag_len());
183178

184179
// Copy the header unencrypted.
185-
writer.extend(header.marshal());
180+
writer.extend_from_slice(&payload[..header_len]);
181+
186182
// Encrypt the payload
187183
let nonce = generate_counter(
188184
header.sequence_number,
189185
roc,
190186
header.ssrc,
191187
&self.srtp_session_salt,
192188
);
193-
194-
writer.put_bytes(0, payload.len());
189+
writer.resize(payload.len(), 0);
195190
self.ctx.encrypt_init(None, None, Some(&nonce)).unwrap();
196-
let count = self.ctx.cipher_update(&payload, Some(&mut writer[header_len..])).unwrap();
191+
let count = self.ctx.cipher_update(&payload[header_len..], Some(&mut writer[header_len..])).unwrap();
197192
self.ctx.cipher_final(&mut writer[count..]).unwrap();
198193

199-
// Generate the auth tag.
200-
let auth_tag = self.generate_srtp_auth_tag(&writer, roc);
201-
writer.extend(auth_tag);
194+
// Generate and write the auth tag.
195+
let auth_tag = &self.generate_srtp_auth_tag(&writer, roc)[..self.auth_tag_len()];
196+
writer.extend_from_slice(auth_tag);
202197

203198
Ok(Bytes::from(writer))
204199
}
@@ -220,11 +215,11 @@ impl Cipher for CipherAesCmHmacSha1 {
220215
let cipher_text = &encrypted[..encrypted.len() - self.auth_tag_len()];
221216

222217
// Generate the auth tag we expect to see from the ciphertext.
223-
let expected_tag = self.generate_srtp_auth_tag(cipher_text, roc);
218+
let expected_tag = &self.generate_srtp_auth_tag(cipher_text, roc)[..self.auth_tag_len()];
224219

225220
// See if the auth tag actually matches.
226221
// We use a constant time comparison to prevent timing attacks.
227-
if actual_tag.ct_eq(&expected_tag).unwrap_u8() != 1 {
222+
if actual_tag.ct_eq(expected_tag).unwrap_u8() != 1 {
228223
return Err(Error::RtpFailedToVerifyAuthTag);
229224
}
230225

srtp/src/context/mod.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,14 +170,13 @@ impl Context {
170170
self.srtp_ssrc_states.entry(ssrc).or_insert(s)
171171
}
172172

173-
fn get_srtcp_ssrc_state(&mut self, ssrc: u32) -> Option<&mut SrtcpSsrcState> {
173+
fn get_srtcp_ssrc_state(&mut self, ssrc: u32) -> &mut SrtcpSsrcState {
174174
let s = SrtcpSsrcState {
175175
ssrc,
176176
replay_detector: Some((self.new_srtcp_replay_detector)()),
177177
..Default::default()
178178
};
179-
self.srtcp_ssrc_states.entry(ssrc).or_insert(s);
180-
self.srtcp_ssrc_states.get_mut(&ssrc)
179+
self.srtcp_ssrc_states.entry(ssrc).or_insert(s)
181180
}
182181

183182
/// roc returns SRTP rollover counter value of specified SSRC.
@@ -197,8 +196,6 @@ impl Context {
197196

198197
/// set_index sets SRTCP index value of specified SSRC.
199198
fn set_index(&mut self, ssrc: u32, index: usize) {
200-
if let Some(s) = self.get_srtcp_ssrc_state(ssrc) {
201-
s.srtcp_index = index;
202-
}
199+
self.get_srtcp_ssrc_state(ssrc).srtcp_index = index;
203200
}
204201
}

srtp/src/context/srtcp.rs

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,16 @@ impl Context {
1313
let index = self.cipher.get_rtcp_index(encrypted);
1414
let ssrc = u32::from_be_bytes([encrypted[4], encrypted[5], encrypted[6], encrypted[7]]);
1515

16-
{
17-
if let Some(state) = self.get_srtcp_ssrc_state(ssrc) {
18-
if let Some(replay_detector) = &mut state.replay_detector {
19-
if !replay_detector.check(index as u64) {
20-
return Err(Error::SrtcpSsrcDuplicated(ssrc, index));
21-
}
22-
}
23-
} else {
24-
return Err(Error::SsrcMissingFromSrtcp(ssrc));
16+
if let Some(replay_detector) = &mut self.get_srtcp_ssrc_state(ssrc).replay_detector {
17+
if !replay_detector.check(index as u64) {
18+
return Err(Error::SrtcpSsrcDuplicated(ssrc, index));
2519
}
2620
}
2721

2822
let dst = self.cipher.decrypt_rtcp(encrypted, index, ssrc)?;
2923

30-
{
31-
if let Some(state) = self.get_srtcp_ssrc_state(ssrc) {
32-
if let Some(replay_detector) = &mut state.replay_detector {
33-
replay_detector.accept();
34-
}
35-
}
24+
if let Some(replay_detector) = &mut self.get_srtcp_ssrc_state(ssrc).replay_detector {
25+
replay_detector.accept();
3626
}
3727

3828
Ok(dst)
@@ -46,18 +36,14 @@ impl Context {
4636

4737
let ssrc = u32::from_be_bytes([decrypted[4], decrypted[5], decrypted[6], decrypted[7]]);
4838

49-
let index;
50-
{
51-
if let Some(state) = self.get_srtcp_ssrc_state(ssrc) {
52-
state.srtcp_index += 1;
53-
if state.srtcp_index > MAX_SRTCP_INDEX {
54-
state.srtcp_index = 0;
55-
}
56-
index = state.srtcp_index;
57-
} else {
58-
return Err(Error::SsrcMissingFromSrtcp(ssrc));
39+
let index = {
40+
let state = self.get_srtcp_ssrc_state(ssrc);
41+
state.srtcp_index += 1;
42+
if state.srtcp_index > MAX_SRTCP_INDEX {
43+
state.srtcp_index = 0;
5944
}
60-
}
45+
state.srtcp_index
46+
};
6147

6248
self.cipher.encrypt_rtcp(decrypted, index, ssrc)
6349
}

srtp/src/context/srtp.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ impl Context {
1010
encrypted: &[u8],
1111
header: &rtp::header::Header,
1212
) -> Result<Bytes> {
13-
let roc;
14-
{
13+
let roc = {
1514
let state = self.get_srtp_ssrc_state(header.ssrc);
1615
if let Some(replay_detector) = &mut state.replay_detector {
1716
if !replay_detector.check(header.sequence_number as u64) {
@@ -22,8 +21,8 @@ impl Context {
2221
}
2322
}
2423

25-
roc = state.next_rollover_count(header.sequence_number);
26-
}
24+
state.next_rollover_count(header.sequence_number)
25+
};
2726

2827
let dst = self.cipher.decrypt_rtp(encrypted, header, roc)?;
2928
{
@@ -46,14 +45,14 @@ impl Context {
4645

4746
pub fn encrypt_rtp_with_header(
4847
&mut self,
49-
plaintext: &[u8],
48+
payload: &[u8],
5049
header: &rtp::header::Header,
5150
) -> Result<Bytes> {
5251
let roc = self.get_srtp_ssrc_state(header.ssrc).next_rollover_count(header.sequence_number);
5352

5453
let dst = self
5554
.cipher
56-
.encrypt_rtp(&plaintext[header.marshal_size()..], header, roc)?;
55+
.encrypt_rtp(&payload, header, roc)?;
5756

5857
self.get_srtp_ssrc_state(header.ssrc).update_rollover_count(header.sequence_number);
5958

srtp/src/error.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ pub enum Error {
4949
SrtpSaltLength(usize, usize),
5050
#[error("SyntaxError: {0}")]
5151
ExtMapParse(String),
52-
#[error("ssrc {0} not exist in srtp_ssrc_state")]
53-
SsrcMissingFromSrtp(u32),
5452
#[error("srtp ssrc={0} index={1}: duplicated")]
5553
SrtpSsrcDuplicated(u32, u16),
5654
#[error("srtcp ssrc={0} index={1}: duplicated")]

0 commit comments

Comments
 (0)