diff --git a/libavutil/aes.c b/libavutil/aes.c index 5f31412149..52a250bc00 100644 --- a/libavutil/aes.c +++ b/libavutil/aes.c @@ -234,6 +234,7 @@ int av_aes_init(AVAES *a, const uint8_t *key, int key_bits, int decrypt) int KC = key_bits >> 5; int rounds = KC + 6; + a->rounds = rounds; a->crypt = decrypt ? aes_decrypt : aes_encrypt; if (ARCH_X86) ff_init_aes_x86(a, decrypt); @@ -243,8 +244,6 @@ int av_aes_init(AVAES *a, const uint8_t *key, int key_bits, int decrypt) if (key_bits != 128 && key_bits != 192 && key_bits != 256) return AVERROR(EINVAL); - a->rounds = rounds; - memcpy(tk, key, KC * 4); memcpy(a->round_key[0].u8, key, KC * 4); diff --git a/libavutil/x86/aes.asm b/libavutil/x86/aes.asm index 7084c46055..9a36991ca0 100644 --- a/libavutil/x86/aes.asm +++ b/libavutil/x86/aes.asm @@ -26,13 +26,12 @@ SECTION .text ; void ff_aes_decrypt(AVAES *a, uint8_t *dst, const uint8_t *src, ; int count, uint8_t *iv, int rounds) ;----------------------------------------------------------------------------- -%macro AES_CRYPT 1 -cglobal aes_%1rypt, 6,6,2 +%macro AES_CRYPT 2 +cglobal aes_%1rypt_%2, 5, 5, 2 test r3d, r3d je .ret shl r3d, 4 - add r5d, r5d - add r0, 0x60 + add r0, 0x70 add r2, r3 add r1, r3 neg r3 @@ -45,17 +44,15 @@ cglobal aes_%1rypt, 6,6,2 %ifidn %1, enc pxor m0, m1 %endif - pxor m0, [r0+8*r5-0x60] - cmp r5d, 24 - je .rounds12 - jl .rounds10 - aes%1 m0, [r0+0x70] + pxor m0, [r0+8*2*%2-0x70] +%if %2 > 10 +%if %2 > 12 aes%1 m0, [r0+0x60] -.rounds12: aes%1 m0, [r0+0x50] +%endif aes%1 m0, [r0+0x40] -.rounds10: aes%1 m0, [r0+0x30] +%endif aes%1 m0, [r0+0x20] aes%1 m0, [r0+0x10] aes%1 m0, [r0+0x00] @@ -64,7 +61,8 @@ cglobal aes_%1rypt, 6,6,2 aes%1 m0, [r0-0x30] aes%1 m0, [r0-0x40] aes%1 m0, [r0-0x50] - aes%1last m0, [r0-0x60] + aes%1 m0, [r0-0x60] + aes%1last m0, [r0-0x70] test r4, r4 je .noiv %ifidn %1, enc @@ -90,6 +88,10 @@ cglobal aes_%1rypt, 6,6,2 %if HAVE_AESNI_EXTERNAL INIT_XMM aesni -AES_CRYPT enc -AES_CRYPT dec +AES_CRYPT enc, 10 +AES_CRYPT enc, 12 +AES_CRYPT enc, 14 +AES_CRYPT dec, 10 +AES_CRYPT dec, 12 +AES_CRYPT dec, 14 %endif diff --git a/libavutil/x86/aes_init.c b/libavutil/x86/aes_init.c index 0ac8c20239..f825e0799c 100644 --- a/libavutil/x86/aes_init.c +++ b/libavutil/x86/aes_init.c @@ -22,15 +22,29 @@ #include "libavutil/aes_internal.h" #include "libavutil/x86/cpu.h" -void ff_aes_decrypt_aesni(AVAES *a, uint8_t *dst, const uint8_t *src, - int count, uint8_t *iv, int rounds); -void ff_aes_encrypt_aesni(AVAES *a, uint8_t *dst, const uint8_t *src, - int count, uint8_t *iv, int rounds); +void ff_aes_decrypt_10_aesni(AVAES *a, uint8_t *dst, const uint8_t *src, + int count, uint8_t *iv, int rounds); +void ff_aes_decrypt_12_aesni(AVAES *a, uint8_t *dst, const uint8_t *src, + int count, uint8_t *iv, int rounds); +void ff_aes_decrypt_14_aesni(AVAES *a, uint8_t *dst, const uint8_t *src, + int count, uint8_t *iv, int rounds); +void ff_aes_encrypt_10_aesni(AVAES *a, uint8_t *dst, const uint8_t *src, + int count, uint8_t *iv, int rounds); +void ff_aes_encrypt_12_aesni(AVAES *a, uint8_t *dst, const uint8_t *src, + int count, uint8_t *iv, int rounds); +void ff_aes_encrypt_14_aesni(AVAES *a, uint8_t *dst, const uint8_t *src, + int count, uint8_t *iv, int rounds); void ff_init_aes_x86(AVAES *a, int decrypt) { int cpu_flags = av_get_cpu_flags(); - if (EXTERNAL_AESNI(cpu_flags)) - a->crypt = decrypt ? ff_aes_decrypt_aesni : ff_aes_encrypt_aesni; + if (EXTERNAL_AESNI(cpu_flags)) { + if (a->rounds == 10) + a->crypt = decrypt ? ff_aes_decrypt_10_aesni : ff_aes_encrypt_10_aesni; + else if (a->rounds == 12) + a->crypt = decrypt ? ff_aes_decrypt_12_aesni : ff_aes_encrypt_12_aesni; + else if (a->rounds == 14) + a->crypt = decrypt ? ff_aes_decrypt_14_aesni : ff_aes_encrypt_14_aesni; + } }