diff --git a/src/sm4/cipher.rs b/src/sm4/cipher.rs index fc59d29..6607365 100644 --- a/src/sm4/cipher.rs +++ b/src/sm4/cipher.rs @@ -73,6 +73,209 @@ fn combine_block(input: &[u32]) -> Sm4Result<[u8; 16]> { Ok(out) } +static C0F: [u64; 2] = [0x0F0F0F0F0F0F0F0F, 0x0F0F0F0F0F0F0F0F]; +static FLP: [u64; 2] = [0x0405060700010203, 0x0C0D0E0F08090A0B]; +static SHR: [u64; 2] = [0x0B0E0104070A0D00, 0x0306090C0F020508]; +static M1L: [u64; 2] = [0x9197E2E474720701, 0xC7C1B4B222245157]; +static M1H: [u64; 2] = [0xE240AB09EB49A200, 0xF052B91BF95BB012]; +static M2L: [u64; 2] = [0x5B67F2CEA19D0834, 0xEDD14478172BBE82]; +static M2H: [u64; 2] = [0xAE7201DD73AFDC00, 0x11CDBE62CC1063BF]; +static R08: [u64; 2] = [0x0605040702010003, 0x0E0D0C0F0A09080B]; +static R16: [u64; 2] = [0x0504070601000302, 0x0D0C0F0E09080B0A]; +static R24: [u64; 2] = [0x0407060500030201, 0x0C0F0E0D080B0A09]; + +#[cfg(all(any(target_arch = "x86", target_arch = "x86_64")))] +#[target_feature(enable = "sse")] +#[target_feature(enable = "sse2")] +#[target_feature(enable = "sse3")] +#[target_feature(enable = "aes")] +unsafe fn sm4_crypt_affine_ni(key: &[u32], sin: &[u8; 64], out: &mut [u8; 64], enc: i32) { + #[cfg(target_arch = "x86_64")] + use core::arch::x86_64::*; + + let c0f: __m128i = std::mem::transmute(C0F); + let flp: __m128i = std::mem::transmute(FLP); + let shr: __m128i = std::mem::transmute(SHR); + let m1l: __m128i = std::mem::transmute(M1L); + let m1h: __m128i = std::mem::transmute(M1H); + let m2l: __m128i = std::mem::transmute(M2L); + let m2h: __m128i = std::mem::transmute(M2H); + let r08: __m128i = std::mem::transmute(R08); + let r16: __m128i = std::mem::transmute(R16); + let r24: __m128i = std::mem::transmute(R24); + + let mut t0: __m128i; + let mut t1: __m128i; + let mut t2: __m128i; + let mut t3: __m128i; + + let p32: [i32; 16] = core::mem::transmute_copy(sin); + + t0 = _mm_set_epi32(p32[12], p32[8], p32[4], p32[0]); + t0 = _mm_shuffle_epi8(t0, flp); + + t1 = _mm_set_epi32(p32[13], p32[9], p32[5], p32[1]); + t1 = _mm_shuffle_epi8(t1, flp); + + t2 = _mm_set_epi32(p32[14], p32[10], p32[6], p32[2]); + t2 = _mm_shuffle_epi8(t2, flp); + + t3 = _mm_set_epi32(p32[15], p32[11], p32[7], p32[3]); + t3 = _mm_shuffle_epi8(t3, flp); + + let mut x: __m128i; + let mut y: __m128i; + let mut t4: __m128i; + + for i in 0..8 { + let k = if enc == 0 { i * 4 } else { 31 - i * 4 }; + let k1 = key[k]; + t4 = core::mem::transmute_copy(&[k1, k1, k1, k1]); + x = _mm_xor_si128(_mm_xor_si128(_mm_xor_si128(t1, t2), t3), t4); + + y = _mm_and_si128(x, c0f); + y = _mm_shuffle_epi8(m1l, y); + x = _mm_srli_epi64(x, 4); + x = _mm_and_si128(x, c0f); + x = _mm_xor_si128(_mm_shuffle_epi8(m1h, x), y); + x = _mm_shuffle_epi8(x, shr); + x = _mm_aesenclast_si128(x, c0f); + y = _mm_andnot_si128(x, c0f); + y = _mm_shuffle_epi8(m2l, y); + x = _mm_srli_epi64(x, 4); + x = _mm_and_si128(x, c0f); + x = _mm_xor_si128(_mm_shuffle_epi8(m2h, x), y); + y = _mm_xor_si128( + _mm_xor_si128(x, _mm_shuffle_epi8(x, r08)), + _mm_shuffle_epi8(x, r16), + ); + y = _mm_xor_si128(_mm_slli_epi32(y, 2), _mm_srli_epi32(y, 30)); + x = _mm_xor_si128(_mm_xor_si128(x, y), _mm_shuffle_epi8(x, r24)); + + t0 = _mm_xor_si128(t0, x); + + let k = if enc == 0 { i * 4 + 1 } else { 30 - i * 4 }; + let k2 = key[k]; + t4 = core::mem::transmute_copy(&[k2, k2, k2, k2]); + x = _mm_xor_si128(_mm_xor_si128(_mm_xor_si128(t0, t2), t3), t4); + + y = _mm_and_si128(x, c0f); + y = _mm_shuffle_epi8(m1l, y); + x = _mm_srli_epi64(x, 4); + x = _mm_and_si128(x, c0f); + x = _mm_xor_si128(_mm_shuffle_epi8(m1h, x), y); + x = _mm_shuffle_epi8(x, shr); + x = _mm_aesenclast_si128(x, c0f); + y = _mm_andnot_si128(x, c0f); + y = _mm_shuffle_epi8(m2l, y); + x = _mm_srli_epi64(x, 4); + x = _mm_and_si128(x, c0f); + x = _mm_xor_si128(_mm_shuffle_epi8(m2h, x), y); + y = _mm_xor_si128( + _mm_xor_si128(x, _mm_shuffle_epi8(x, r08)), + _mm_shuffle_epi8(x, r16), + ); + y = _mm_xor_si128(_mm_slli_epi32(y, 2), _mm_srli_epi32(y, 30)); + x = _mm_xor_si128(_mm_xor_si128(x, y), _mm_shuffle_epi8(x, r24)); + + t1 = _mm_xor_si128(t1, x); + + let k = if enc == 0 { i * 4 + 2 } else { 29 - i * 4 }; + let k3 = key[k]; + t4 = core::mem::transmute_copy(&[k3, k3, k3, k3]); + x = _mm_xor_si128(_mm_xor_si128(_mm_xor_si128(t0, t1), t3), t4); + + y = _mm_and_si128(x, c0f); + y = _mm_shuffle_epi8(m1l, y); + x = _mm_srli_epi64(x, 4); + x = _mm_and_si128(x, c0f); + x = _mm_xor_si128(_mm_shuffle_epi8(m1h, x), y); + x = _mm_shuffle_epi8(x, shr); + x = _mm_aesenclast_si128(x, c0f); + y = _mm_andnot_si128(x, c0f); + y = _mm_shuffle_epi8(m2l, y); + x = _mm_srli_epi64(x, 4); + x = _mm_and_si128(x, c0f); + x = _mm_xor_si128(_mm_shuffle_epi8(m2h, x), y); + y = _mm_xor_si128( + _mm_xor_si128(x, _mm_shuffle_epi8(x, r08)), + _mm_shuffle_epi8(x, r16), + ); + y = _mm_xor_si128(_mm_slli_epi32(y, 2), _mm_srli_epi32(y, 30)); + x = _mm_xor_si128(_mm_xor_si128(x, y), _mm_shuffle_epi8(x, r24)); + + t2 = _mm_xor_si128(t2, x); + + let k = if enc == 0 { i * 4 + 3 } else { 28 - i * 4 }; + let k4 = key[k]; + t4 = core::mem::transmute_copy(&[k4, k4, k4, k4]); + x = _mm_xor_si128(_mm_xor_si128(_mm_xor_si128(t0, t1), t2), t4); + + y = _mm_and_si128(x, c0f); + y = _mm_shuffle_epi8(m1l, y); + x = _mm_srli_epi64(x, 4); + x = _mm_and_si128(x, c0f); + x = _mm_xor_si128(_mm_shuffle_epi8(m1h, x), y); + x = _mm_shuffle_epi8(x, shr); + x = _mm_aesenclast_si128(x, c0f); + y = _mm_andnot_si128(x, c0f); + y = _mm_shuffle_epi8(m2l, y); + x = _mm_srli_epi64(x, 4); + x = _mm_and_si128(x, c0f); + x = _mm_xor_si128(_mm_shuffle_epi8(m2h, x), y); + y = _mm_xor_si128( + _mm_xor_si128(x, _mm_shuffle_epi8(x, r08)), + _mm_shuffle_epi8(x, r16), + ); + y = _mm_xor_si128(_mm_slli_epi32(y, 2), _mm_srli_epi32(y, 30)); + x = _mm_xor_si128(_mm_xor_si128(x, y), _mm_shuffle_epi8(x, r24)); + + t3 = _mm_xor_si128(t3, x); + } + + let mut res: [u32; 16] = [0; 16]; + let mut v: __m128i = _mm_set_epi64x(0x0, 0x0); + let v_prt: *mut __m128i = &mut v; + _mm_store_si128(v_prt, _mm_shuffle_epi8(t3, flp)); + let vr: [u32; 4] = core::mem::transmute_copy(&v); + res[0] = vr[0]; + res[4] = vr[1]; + res[8] = vr[2]; + res[12] = vr[3]; + + let mut v: __m128i = _mm_set_epi64x(0x0, 0x0); + let v_prt: *mut __m128i = &mut v; + _mm_store_si128(v_prt, _mm_shuffle_epi8(t2, flp)); + + let vr: [u32; 4] = core::mem::transmute_copy(&v); + res[1] = vr[0]; + res[5] = vr[1]; + res[9] = vr[2]; + res[13] = vr[3]; + + let mut v: __m128i = _mm_set_epi64x(0x0, 0x0); + let v_prt: *mut __m128i = &mut v; + _mm_store_si128(v_prt, _mm_shuffle_epi8(t1, flp)); + + let vr: [u32; 4] = core::mem::transmute_copy(&v); + res[2] = vr[0]; + res[6] = vr[1]; + res[10] = vr[2]; + res[14] = vr[3]; + + let mut v: __m128i = _mm_set_epi64x(0x0, 0x0); + let v_prt: *mut __m128i = &mut v; + _mm_store_si128(v_prt, _mm_shuffle_epi8(t0, flp)); + + let vr: [u32; 4] = core::mem::transmute_copy(&v); + res[3] = vr[0]; + res[7] = vr[1]; + res[11] = vr[2]; + res[15] = vr[3]; + + *out = core::mem::transmute_copy(&res); +} + fn tau_trans(input: u32) -> u32 { let input = split(input); let mut out: [u8; 4] = [0; 4]; @@ -180,6 +383,41 @@ impl Sm4Cipher { combine_block(&y) } + #[cfg(all(any(target_arch = "x86", target_arch = "x86_64")))] + pub fn encrypt_sm4ni(&self, block_in: &[u8; 64]) -> Result<[u8; 64], Sm4Error> { + let rk = &self.rk; + + let mut res: [u8; 64] = [0; 64]; + if is_x86_feature_detected!("sse") + && is_x86_feature_detected!("sse2") + && is_x86_feature_detected!("sse3") + && is_x86_feature_detected!("aes") + { + unsafe { sm4_crypt_affine_ni(rk, block_in, &mut res, 0) }; + } else { + for i in 0..4 { + let tmp_res = self.encrypt(&block_in[i * 16..i * 16 + 16])?; + for z in 0..16 { + res[i * 16 + z] = tmp_res[z]; + } + } + } + + Ok(res) + } + + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] + pub fn encrypt_sm4ni(&self, block_in: &[u8; 64]) -> Result<[u8; 64], Sm4Error> { + let mut res: [u8; 64] = [0; 64]; + for i in 0..4 { + let tmp_res = self.encrypt(&block_in[i * 16..i * 16 + 16])?; + for z in 0..16 { + res[i * 16 + z] = tmp_res[z]; + } + } + Ok(res) + } + pub fn decrypt(&self, block_in: &[u8]) -> Result<[u8; 16], Sm4Error> { let mut x: [u32; 4] = split_block(block_in)?; let rk = &self.rk; @@ -241,4 +479,34 @@ mod tests { assert_eq!(pt[i], data[i]); } } + + #[test] + fn enc_and_dec_sm4ni() { + let key: [u8; 16] = [ + 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, + 0x32, 0x10, + ]; + let cipher = Sm4Cipher::new(&key).unwrap(); + + let data: [u8; 64] = [ + 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, + 0x32, 0x10, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, + 0x76, 0x54, 0x32, 0x10, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, + 0xba, 0x98, 0x76, 0x54, 0x32, 0x10, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, + 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10, + ]; + let ct = cipher.encrypt_sm4ni(&data).unwrap(); + let standard_ct: [u8; 64] = [ + 0x68, 0x1e, 0xdf, 0x34, 0xd2, 0x06, 0x96, 0x5e, 0x86, 0xb3, 0xe9, 0x4f, 0x53, 0x6e, + 0x42, 0x46, 0x68, 0x1e, 0xdf, 0x34, 0xd2, 0x06, 0x96, 0x5e, 0x86, 0xb3, 0xe9, 0x4f, + 0x53, 0x6e, 0x42, 0x46, 0x68, 0x1e, 0xdf, 0x34, 0xd2, 0x06, 0x96, 0x5e, 0x86, 0xb3, + 0xe9, 0x4f, 0x53, 0x6e, 0x42, 0x46, 0x68, 0x1e, 0xdf, 0x34, 0xd2, 0x06, 0x96, 0x5e, + 0x86, 0xb3, 0xe9, 0x4f, 0x53, 0x6e, 0x42, 0x46, + ]; + + // Check the example cipher text + for i in 0..64 { + assert_eq!(standard_ct[i], ct[i]); + } + } } diff --git a/src/sm4/cipher_mode.rs b/src/sm4/cipher_mode.rs index c768d23..7a2b427 100644 --- a/src/sm4/cipher_mode.rs +++ b/src/sm4/cipher_mode.rs @@ -35,6 +35,14 @@ fn block_xor(a: &[u8], b: &[u8]) -> [u8; 16] { out } +fn block_xor_64(a: &[u8], b: &[u8]) -> [u8; 64] { + let mut out: [u8; 64] = [0; 64]; + for i in 0..64 { + out[i] = a[i] ^ b[i]; + } + out +} + fn block_add_one(a: &mut [u8]) { let mut carry = 1; @@ -161,27 +169,45 @@ impl Sm4CipherMode { } fn ctr_encrypt(&self, data: &[u8], iv: &[u8]) -> Result, Sm4Error> { - let block_num = data.len() / 16; - let tail_len = data.len() - block_num * 16; + let block_num = data.len() / 64; + let tail_len = data.len() - block_num * 64; let mut out: Vec = Vec::new(); let mut vec_buf: Vec = vec![0; 16]; vec_buf.clone_from_slice(iv); + // 先扩充到 64bit + let mut vec_buf_64: [u8; 64] = [0; 64]; + for z in 0..4 { + for i in 0..16 { + vec_buf_64[z * 16 + i] = vec_buf[i]; + } + block_add_one(&mut vec_buf[..]); + } + // Normal for i in 0..block_num { - let enc = self.cipher.encrypt(&vec_buf[..])?; - let ct = block_xor(&enc, &data[i * 16..i * 16 + 16]); + let enc = self.cipher.encrypt_sm4ni(&vec_buf_64)?; + let ct = block_xor_64(&enc, &data[i * 64..i * 64 + 64]); for i in ct.iter() { out.push(*i); } + + vec_buf[..16].copy_from_slice(&vec_buf_64[48..64]); + block_add_one(&mut vec_buf[..]); + for z in 0..4 { + for i in 0..16 { + vec_buf_64[z * 16 + i] = vec_buf[i]; + } + block_add_one(&mut vec_buf[..]); + } } // Last block - let enc = self.cipher.encrypt(&vec_buf[..])?; + let enc = self.cipher.encrypt_sm4ni(&vec_buf_64)?; for i in 0..tail_len { - let b = data[block_num * 16 + i] ^ enc[i]; + let b = data[block_num * 64 + i] ^ enc[i]; out.push(b); } Ok(out) @@ -322,6 +348,20 @@ mod tests { assert_eq!(lhs, rhs); } + #[test] + fn ctr_enc_long_test() { + let key = hex::decode("1234567890abcdef1234567890abcdef").unwrap(); + let iv = hex::decode("fedcba0987654321fedcba0987654321").unwrap(); + + let cipher_mode = Sm4CipherMode::new(&key, CipherMode::Ctr).unwrap(); + let msg = include_bytes!("example/textlong"); + let lhs = cipher_mode.encrypt(msg, &iv).unwrap(); + let lhs: &[u8] = lhs.as_ref(); + + let rhs: &[u8] = include_bytes!("example/text.sms4-ctr.long"); + assert_eq!(lhs, rhs); + } + #[test] fn cfb_enc_test() { let key = hex::decode("1234567890abcdef1234567890abcdef").unwrap(); diff --git a/src/sm4/example/text.sms4-ctr.long b/src/sm4/example/text.sms4-ctr.long new file mode 100644 index 0000000..b7f20ee Binary files /dev/null and b/src/sm4/example/text.sms4-ctr.long differ diff --git a/src/sm4/example/textlong b/src/sm4/example/textlong new file mode 100644 index 0000000..08a7c7d Binary files /dev/null and b/src/sm4/example/textlong differ