Browse Source

linuxkm feature additions:

add build-time support for module signing using native Linux facility;

add support for alternative licenses using WOLFSSL_LICENSE macro;

improve load-time kernel log messages;

add support for sp-math-all asm/AVX2 acceleration;

add error-checking and return in SAVE_VECTOR_REGISTERS();

implement support for x86 accelerated crypto from interrupt handlers, gated on WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED:

  * wolfcrypt_irq_fpu_states
  * am_in_hard_interrupt_handler()
  * allocate_wolfcrypt_irq_fpu_states()
  * free_wolfcrypt_irq_fpu_states()
  * save_vector_registers_x86()
  * restore_vector_registers_x86()

add WOLFSSL_LINUXKM_SIMD, WOLFSSL_LINUXKM_SIMD_X86, and WOLFSSL_LINUXKM_SIMD_ARM macros for more readable gating.
Daniel Pouzzner 2 years ago
parent
commit
83e0e19e03

+ 12 - 5
linuxkm/Kbuild

@@ -17,7 +17,7 @@
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335, USA
-#/
+#
 
 SHELL=/bin/bash
 
@@ -48,6 +48,7 @@ $(obj)/linuxkm/module_exports.o: $(WOLFSSL_OBJ_TARGETS)
 # this mechanism only works in kernel 5.x+ (fallback to hardcoded value)
 hostprogs := linuxkm/get_thread_size
 always-y := $(hostprogs)
+always := $(hostprogs)
 HOST_EXTRACFLAGS += $(NOSTDINC_FLAGS) $(LINUXINCLUDE) $(KBUILD_CFLAGS) -static
 
 # this rule is needed to get build to succeed in 4.x (get_thread_size still doesn't get built)
@@ -91,14 +92,20 @@ endif
 
 asflags-y := $(WOLFSSL_ASFLAGS) $(ASFLAGS_FPUSIMD_DISABLE)
 
-# vectorized algorithms equipped with {SAVE,RESTORE}_VECTOR_REGISTERS()
-# can be safely included here:
+# vectorized algorithms protected by {SAVE,RESTORE}_VECTOR_REGISTERS() can be
+# safely included here, though many of these are not yet kernel-compatible:
 $(obj)/wolfcrypt/src/aes_asm.o: asflags-y = $(WOLFSSL_ASFLAGS) $(ASFLAGS_FPU_DISABLE_SIMD_ENABLE)
 $(obj)/wolfcrypt/src/aes_gcm_asm.o: asflags-y = $(WOLFSSL_ASFLAGS) $(ASFLAGS_FPU_DISABLE_SIMD_ENABLE)
+$(obj)/wolfcrypt/src/sha256_asm.o: asflags-y = $(WOLFSSL_ASFLAGS) $(ASFLAGS_FPU_DISABLE_SIMD_ENABLE)
+$(obj)/wolfcrypt/src/sp_x86_64_asm.o: asflags-y = $(WOLFSSL_ASFLAGS) $(ASFLAGS_FPU_DISABLE_SIMD_ENABLE)
+$(obj)/wolfcrypt/src/sha512_asm.o: asflags-y = $(WOLFSSL_ASFLAGS) $(ASFLAGS_FPU_DISABLE_SIMD_ENABLE)
+$(obj)/wolfcrypt/src/poly1305_asm.o: asflags-y = $(WOLFSSL_ASFLAGS) $(ASFLAGS_FPU_DISABLE_SIMD_ENABLE)
+$(obj)/wolfcrypt/src/chacha_asm.o: asflags-y = $(WOLFSSL_ASFLAGS) $(ASFLAGS_FPU_DISABLE_SIMD_ENABLE)
 
-# these _asms are kernel-compatible, but they still irritate objtool:
+# these _asms are known kernel-compatible, but they still irritate objtool:
 $(obj)/wolfcrypt/src/aes_asm.o: OBJECT_FILES_NON_STANDARD := y
 $(obj)/wolfcrypt/src/aes_gcm_asm.o: OBJECT_FILES_NON_STANDARD := y
+$(obj)/wolfcrypt/src/sp_x86_64_asm.o: OBJECT_FILES_NON_STANDARD := y
 
 ifeq "$(ENABLED_LINUXKM_PIE)" "yes"
 
@@ -158,7 +165,7 @@ $(src)/linuxkm/module_exports.c: $(src)/linuxkm/module_exports.c.template $(WOLF
 	@cp $< $@
 	@readelf --symbols --wide $(WOLFSSL_OBJ_TARGETS) |				\
 		awk '/^ *[0-9]+: / {							\
-		  if ($$8 !~ /^(wc_|wolf)/){next;}					\
+		  if ($$8 !~ /^(wc_|wolf|WOLF|TLSX_)/){next;}				\
 		  if (($$4 == "FUNC") && ($$5 == "GLOBAL") && ($$6 == "DEFAULT")) {	\
 		    print "EXPORT_SYMBOL_NS(" $$8 ", WOLFSSL);";    		  	\
 		  }									\

+ 36 - 1
linuxkm/Makefile

@@ -21,7 +21,7 @@
 
 SHELL=/bin/bash
 
-all: libwolfssl.ko
+all: libwolfssl.ko libwolfssl.ko.signed
 
 .PHONY: libwolfssl.ko
 
@@ -61,7 +61,42 @@ libwolfssl.ko:
 	@if test -z "$(src_libwolfssl_la_OBJECTS)"; then echo '$$src_libwolfssl_la_OBJECTS is unset.' >&2; exit 1; fi
 	@mkdir -p linuxkm src wolfcrypt/src wolfcrypt/test
 	@if test ! -h $(SRC_TOP)/Kbuild; then ln -s $(MODULE_TOP)/Kbuild $(SRC_TOP)/Kbuild; fi
+ifeq "$(ENABLED_LINUXKM_PIE)" "yes"
+	+$(MAKE) -C $(KERNEL_ROOT) M=$(MODULE_TOP) src=$(SRC_TOP) CC_FLAGS_FTRACE=
+else
 	+$(MAKE) -C $(KERNEL_ROOT) M=$(MODULE_TOP) src=$(SRC_TOP)
+endif
+
+libwolfssl.ko.signed: libwolfssl.ko
+	@cd '$(KERNEL_ROOT)' || exit $$?;							\
+	while read configline; do								\
+		case "$$configline" in								\
+		CONFIG_MODULE_SIG_KEY=*)							\
+				CONFIG_MODULE_SIG_KEY="$${configline#CONFIG_MODULE_SIG_KEY=}"	\
+			;;									\
+		CONFIG_MODULE_SIG_HASH=*)							\
+				CONFIG_MODULE_SIG_HASH="$${configline#CONFIG_MODULE_SIG_HASH=}"	\
+			;;									\
+		esac;										\
+	done < .config || exit $$?;								\
+	if [[ -n "$${CONFIG_MODULE_SIG_KEY}" && -n "$${CONFIG_MODULE_SIG_HASH}" &&		\
+			( ! -f '$(MODULE_TOP)/$@' ||						\
+			'$(MODULE_TOP)/$<' -nt '$(MODULE_TOP)/$@' ) ]]; then			\
+		CONFIG_MODULE_SIG_KEY="$${CONFIG_MODULE_SIG_KEY#\"}";				\
+		CONFIG_MODULE_SIG_KEY="$${CONFIG_MODULE_SIG_KEY%\"}";				\
+		CONFIG_MODULE_SIG_HASH="$${CONFIG_MODULE_SIG_HASH#\"}";				\
+		CONFIG_MODULE_SIG_HASH="$${CONFIG_MODULE_SIG_HASH%\"}";				\
+		cp -p '$(MODULE_TOP)/$<' '$(MODULE_TOP)/$@' || exit $$?;			\
+		./scripts/sign-file "$${CONFIG_MODULE_SIG_HASH}"				\
+				    "$${CONFIG_MODULE_SIG_KEY}"					\
+				    "$${CONFIG_MODULE_SIG_KEY/%.pem/.x509}"			\
+				    '$(MODULE_TOP)/$@' ||					\
+					$(RM) -f '$(MODULE_TOP)/$@' || exit $$?;		\
+		if [[ "$(quiet)" != "silent_" ]]; then						\
+			echo "  Module $@ signed by $${CONFIG_MODULE_SIG_KEY}.";		\
+		fi										\
+	fi
+
 
 .PHONY: install modules_install
 install modules_install:

+ 72 - 19
linuxkm/module_hooks.c

@@ -19,6 +19,10 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335, USA
  */
 
+#ifndef WOLFSSL_LICENSE
+#define WOLFSSL_LICENSE "GPL v2"
+#endif
+
 #define FIPS_NO_WRAPPERS
 
 #ifdef HAVE_CONFIG_H
@@ -41,13 +45,13 @@ static int libwolfssl_cleanup(void) {
 #ifdef WOLFCRYPT_ONLY
     ret = wolfCrypt_Cleanup();
     if (ret != 0)
-        pr_err("wolfCrypt_Cleanup() failed: %s", wc_GetErrorString(ret));
+        pr_err("wolfCrypt_Cleanup() failed: %s\n", wc_GetErrorString(ret));
     else
         pr_info("wolfCrypt " LIBWOLFSSL_VERSION_STRING " cleanup complete.\n");
 #else
     ret = wolfSSL_Cleanup();
     if (ret != WOLFSSL_SUCCESS)
-        pr_err("wolfSSL_Cleanup() failed: %s", wc_GetErrorString(ret));
+        pr_err("wolfSSL_Cleanup() failed: %s\n", wc_GetErrorString(ret));
     else
         pr_info("wolfSSL " LIBWOLFSSL_VERSION_STRING " cleanup complete.\n");
 #endif
@@ -89,8 +93,8 @@ static void lkmFipsCb(int ok, int err, const char* hash)
     if ((! ok) || (err != 0))
         pr_err("libwolfssl FIPS error: %s\n", wc_GetErrorString(err));
     if (err == IN_CORE_FIPS_E) {
-        pr_err("In-core integrity hash check failure.\n");
-        pr_err("Update verifyCore[] in fips_test.c with new hash \"%s\" and rebuild.\n",
+        pr_err("In-core integrity hash check failure.\n"
+               "Update verifyCore[] in fips_test.c with new hash \"%s\" and rebuild.\n",
                hash ? hash : "<null>");
     }
 }
@@ -104,6 +108,13 @@ static int wolfssl_init(void)
 {
     int ret;
 
+#ifdef CONFIG_MODULE_SIG
+    if (THIS_MODULE->sig_ok == false) {
+        pr_err("wolfSSL module load aborted -- bad or missing module signature with CONFIG_MODULE_SIG kernel.\n");
+        return -ECANCELED;
+    }
+#endif
+
 #ifdef USE_WOLFSSL_LINUXKM_PIE_REDIRECT_TABLE
     ret = set_up_wolfssl_linuxkm_pie_redirect_table();
     if (ret < 0)
@@ -165,13 +176,13 @@ static int wolfssl_init(void)
 #ifdef HAVE_FIPS
     ret = wolfCrypt_SetCb_fips(lkmFipsCb);
     if (ret != 0) {
-        pr_err("wolfCrypt_SetCb_fips() failed: %s", wc_GetErrorString(ret));
+        pr_err("wolfCrypt_SetCb_fips() failed: %s\n", wc_GetErrorString(ret));
         return -ECANCELED;
     }
     fipsEntry();
     ret = wolfCrypt_GetStatus_fips();
     if (ret != 0) {
-        pr_err("wolfCrypt_GetStatus_fips() failed: %s", wc_GetErrorString(ret));
+        pr_err("wolfCrypt_GetStatus_fips() failed: %s\n", wc_GetErrorString(ret));
         if (ret == IN_CORE_FIPS_E) {
             const char *newhash = wolfCrypt_GetCoreHash_fips();
             pr_err("Update verifyCore[] in fips_test.c with new hash \"%s\" and rebuild.\n",
@@ -198,13 +209,13 @@ static int wolfssl_init(void)
 #ifdef WOLFCRYPT_ONLY
     ret = wolfCrypt_Init();
     if (ret != 0) {
-        pr_err("wolfCrypt_Init() failed: %s", wc_GetErrorString(ret));
+        pr_err("wolfCrypt_Init() failed: %s\n", wc_GetErrorString(ret));
         return -ECANCELED;
     }
 #else
     ret = wolfSSL_Init();
     if (ret != WOLFSSL_SUCCESS) {
-        pr_err("wolfSSL_Init() failed: %s", wc_GetErrorString(ret));
+        pr_err("wolfSSL_Init() failed: %s\n", wc_GetErrorString(ret));
         return -ECANCELED;
     }
 #endif
@@ -212,7 +223,7 @@ static int wolfssl_init(void)
 #ifndef NO_CRYPT_TEST
     ret = wolfcrypt_test(NULL);
     if (ret < 0) {
-        pr_err("wolfcrypt self-test failed with return code %d.", ret);
+        pr_err("wolfcrypt self-test failed with return code %d.\n", ret);
         (void)libwolfssl_cleanup();
         msleep(10);
         return -ECANCELED;
@@ -221,12 +232,20 @@ static int wolfssl_init(void)
 #endif
 
 #ifdef WOLFCRYPT_ONLY
-    pr_info("wolfCrypt " LIBWOLFSSL_VERSION_STRING " loaded. See https://www.wolfssl.com/ for information.\n");
+    pr_info("wolfCrypt " LIBWOLFSSL_VERSION_STRING " loaded"
+#ifdef CONFIG_MODULE_SIG
+            " with valid module signature"
+#endif
+            ".\nSee https://www.wolfssl.com/ for more information.\n"
+            "wolfCrypt Copyright (C) 2006-present wolfSSL Inc.  Licensed under " WOLFSSL_LICENSE ".\n");
 #else
-    pr_info("wolfSSL " LIBWOLFSSL_VERSION_STRING " loaded. See https://www.wolfssl.com/ for information.\n");
+    pr_info("wolfSSL " LIBWOLFSSL_VERSION_STRING " loaded"
+#ifdef CONFIG_MODULE_SIG
+            " with valid module signature"
+#endif
+            ".\nSee https://www.wolfssl.com/ for more information.\n"
+            "wolfSSL Copyright (C) 2006-present wolfSSL Inc.  Licensed under " WOLFSSL_LICENSE ".\n");
 #endif
-
-    pr_info("Copyright (C) 2006-2020 wolfSSL Inc. All Rights Reserved.\n");
 
     return 0;
 }
@@ -240,18 +259,38 @@ static void wolfssl_exit(void)
 #endif
 {
     (void)libwolfssl_cleanup();
+
     return;
 }
 
 module_exit(wolfssl_exit);
 
-MODULE_LICENSE("GPL v2");
+MODULE_LICENSE(WOLFSSL_LICENSE);
 MODULE_AUTHOR("https://www.wolfssl.com/");
 MODULE_DESCRIPTION("libwolfssl cryptographic and protocol facilities");
 MODULE_VERSION(LIBWOLFSSL_VERSION_STRING);
 
 #ifdef USE_WOLFSSL_LINUXKM_PIE_REDIRECT_TABLE
 
+/* get_current() is an inline or macro, depending on the target -- sidestep the whole issue with a wrapper func. */
+static struct task_struct *my_get_current_thread(void) {
+    return get_current();
+}
+
+/* ditto for preempt_count(). */
+static int my_preempt_count(void) {
+    return preempt_count();
+}
+
+#if defined(WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED) && (LINUX_VERSION_CODE < KERNEL_VERSION(5, 14, 0))
+static int my_copy_fpregs_to_fpstate(struct fpu *fpu) {
+    return copy_fpregs_to_fpstate(fpu);
+}
+static void my_copy_kernel_to_fpregs(union fpregs_state *fpstate) {
+    copy_kernel_to_fpregs(fpstate);
+}
+#endif
+
 static int set_up_wolfssl_linuxkm_pie_redirect_table(void) {
     memset(
         &wolfssl_linuxkm_pie_redirect_table,
@@ -310,12 +349,14 @@ static int set_up_wolfssl_linuxkm_pie_redirect_table(void) {
         kmalloc_order_trace;
 
     wolfssl_linuxkm_pie_redirect_table.get_random_bytes = get_random_bytes;
-    wolfssl_linuxkm_pie_redirect_table.ktime_get_real_seconds =
-        ktime_get_real_seconds;
-    wolfssl_linuxkm_pie_redirect_table.ktime_get_with_offset =
-        ktime_get_with_offset;
+    wolfssl_linuxkm_pie_redirect_table.ktime_get_coarse_real_ts64 =
+        ktime_get_coarse_real_ts64;
+
+    wolfssl_linuxkm_pie_redirect_table.get_current = my_get_current_thread;
+    wolfssl_linuxkm_pie_redirect_table.preempt_count = my_preempt_count;
 
-#if defined(WOLFSSL_AESNI) || defined(USE_INTEL_SPEEDUP)
+#ifdef WOLFSSL_LINUXKM_SIMD_X86
+    wolfssl_linuxkm_pie_redirect_table.irq_fpu_usable = irq_fpu_usable;
     #ifdef kernel_fpu_begin
     wolfssl_linuxkm_pie_redirect_table.kernel_fpu_begin_mask =
         kernel_fpu_begin_mask;
@@ -324,6 +365,18 @@ static int set_up_wolfssl_linuxkm_pie_redirect_table(void) {
         kernel_fpu_begin;
     #endif
     wolfssl_linuxkm_pie_redirect_table.kernel_fpu_end = kernel_fpu_end;
+    #ifdef WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED
+        #if LINUX_VERSION_CODE < KERNEL_VERSION(5, 14, 0)
+            wolfssl_linuxkm_pie_redirect_table.copy_fpregs_to_fpstate = my_copy_fpregs_to_fpstate;
+            wolfssl_linuxkm_pie_redirect_table.copy_kernel_to_fpregs = my_copy_kernel_to_fpregs;
+        #else
+            wolfssl_linuxkm_pie_redirect_table.save_fpregs_to_fpstate = save_fpregs_to_fpstate;
+            wolfssl_linuxkm_pie_redirect_table.__restore_fpregs_from_fpstate = __restore_fpregs_from_fpstate;
+            wolfssl_linuxkm_pie_redirect_table.xfeatures_mask_all = &xfeatures_mask_all;
+        #endif
+        wolfssl_linuxkm_pie_redirect_table.cpu_number = &cpu_number;
+        wolfssl_linuxkm_pie_redirect_table.nr_cpu_ids = &nr_cpu_ids;
+#endif /* WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED */
 #endif
 
     wolfssl_linuxkm_pie_redirect_table.__mutex_init = __mutex_init;

+ 178 - 40
wolfcrypt/src/aes.c

@@ -805,7 +805,12 @@ block cipher mechanism that uses n-bit binary string parameter key with 128-bits
             nr = temp_key->rounds;
             aes->rounds = nr;
 
-            SAVE_VECTOR_REGISTERS();
+            if (SAVE_VECTOR_REGISTERS() != 0) {
+#ifdef WOLFSSL_SMALL_STACK
+                XFREE(temp_key, aes->heap, DYNAMIC_TYPE_AES);
+#endif
+                return BAD_STATE_E;
+            }
 
             Key_Schedule[nr] = Temp_Key_Schedule[0];
             Key_Schedule[nr-1] = _mm_aesimc_si128(Temp_Key_Schedule[1]);
@@ -1738,10 +1743,8 @@ static void wc_AesEncrypt(Aes* aes, const byte* inBlock, byte* outBlock)
             tmp_align = tmp + (AESNI_ALIGN - ((wc_ptr_t)tmp % AESNI_ALIGN));
 
             XMEMCPY(tmp_align, inBlock, AES_BLOCK_SIZE);
-            SAVE_VECTOR_REGISTERS();
             AES_ECB_encrypt(tmp_align, tmp_align, AES_BLOCK_SIZE,
                     (byte*)aes->key, aes->rounds);
-            RESTORE_VECTOR_REGISTERS();
             XMEMCPY(outBlock, tmp_align, AES_BLOCK_SIZE);
             XFREE(tmp, aes->heap, DYNAMIC_TYPE_TMP_BUFFER);
             return;
@@ -1751,10 +1754,8 @@ static void wc_AesEncrypt(Aes* aes, const byte* inBlock, byte* outBlock)
         #endif
         }
 
-        SAVE_VECTOR_REGISTERS();
         AES_ECB_encrypt(inBlock, outBlock, AES_BLOCK_SIZE, (byte*)aes->key,
                         aes->rounds);
-        RESTORE_VECTOR_REGISTERS();
 
         return;
     }
@@ -2089,10 +2090,8 @@ static void wc_AesDecrypt(Aes* aes, const byte* inBlock, byte* outBlock)
         /* if input and output same will overwrite input iv */
         if ((const byte*)aes->tmp != inBlock)
             XMEMCPY(aes->tmp, inBlock, AES_BLOCK_SIZE);
-        SAVE_VECTOR_REGISTERS();
         AES_ECB_decrypt(inBlock, outBlock, AES_BLOCK_SIZE, (byte*)aes->key,
                         aes->rounds);
-        RESTORE_VECTOR_REGISTERS();
         return;
     }
     else {
@@ -3060,6 +3059,62 @@ int wc_AesSetIV(Aes* aes, const byte* iv)
     #elif defined(WOLFSSL_DEVCRYPTO_AES)
         /* implemented in wolfcrypt/src/port/devcrypt/devcrypto_aes.c */
 
+    #elif defined(WOLFSSL_LINUXKM)
+
+        #ifdef WOLFSSL_AESNI
+
+        __must_check int wc_AesEncryptDirect(Aes* aes, byte* out, const byte* in)
+        {
+            if (haveAESNI && aes->use_aesni) {
+                if (SAVE_VECTOR_REGISTERS() != 0)
+                    return BAD_STATE_E;
+            }
+            wc_AesEncrypt(aes, in, out);
+            if (haveAESNI && aes->use_aesni)
+                RESTORE_VECTOR_REGISTERS();
+            return 0;
+        }
+        /* vector reg save/restore is explicit in all below calls to
+         * wc_Aes{En,De}cryptDirect(), so bypass the public version with a
+         * macro.
+         */
+        #define wc_AesEncryptDirect(aes, out, in) wc_AesEncrypt(aes, in, out)
+        #ifdef HAVE_AES_DECRYPT
+        /* Allow direct access to one block decrypt */
+        __must_check int wc_AesDecryptDirect(Aes* aes, byte* out, const byte* in)
+        {
+            if (haveAESNI && aes->use_aesni) {
+                if (SAVE_VECTOR_REGISTERS() != 0)
+                    return BAD_STATE_E;
+            }
+            wc_AesDecrypt(aes, in, out);
+            if (haveAESNI && aes->use_aesni)
+                RESTORE_VECTOR_REGISTERS();
+            return 0;
+        }
+        #define wc_AesDecryptDirect(aes, out, in) wc_AesDecrypt(aes, in, out)
+        #endif /* HAVE_AES_DECRYPT */
+
+        #else /* !WOLFSSL_AESNI */
+
+        __must_check int wc_AesEncryptDirect(Aes* aes, byte* out, const byte* in)
+        {
+            wc_AesEncrypt(aes, in, out);
+            return 0;
+        }
+        #define wc_AesEncryptDirect(aes, out, in) wc_AesEncrypt(aes, in, out)
+        #ifdef HAVE_AES_DECRYPT
+        /* Allow direct access to one block decrypt */
+        __must_check int wc_AesDecryptDirect(Aes* aes, byte* out, const byte* in)
+        {
+            wc_AesDecrypt(aes, in, out);
+            return 0;
+        }
+        #define wc_AesDecryptDirect(aes, out, in) wc_AesDecrypt(aes, in, out)
+        #endif /* HAVE_AES_DECRYPT */
+
+        #endif /* WOLFSSL_AESNI */
+
     #else
         /* Allow direct access to one block encrypt */
         void wc_AesEncryptDirect(Aes* aes, byte* out, const byte* in)
@@ -3834,7 +3889,10 @@ int wc_AesSetIV(Aes* aes, const byte* iv)
 
                 tmp_align = tmp + (AESNI_ALIGN - ((wc_ptr_t)tmp % AESNI_ALIGN));
                 XMEMCPY(tmp_align, in, sz);
-                SAVE_VECTOR_REGISTERS();
+                if (SAVE_VECTOR_REGISTERS() != 0) {
+                    XFREE(tmp, aes->heap, DYNAMIC_TYPE_TMP_BUFFER);
+                    return BAD_STATE_E;
+                }
                 AES_CBC_encrypt(tmp_align, tmp_align, (byte*)aes->reg, sz,
                                                   (byte*)aes->key, aes->rounds);
                 RESTORE_VECTOR_REGISTERS();
@@ -3850,7 +3908,8 @@ int wc_AesSetIV(Aes* aes, const byte* iv)
             #endif
             }
 
-            SAVE_VECTOR_REGISTERS();
+            if (SAVE_VECTOR_REGISTERS() != 0)
+                return BAD_STATE_E;
             AES_CBC_encrypt(in, out, (byte*)aes->reg, sz, (byte*)aes->key,
                             aes->rounds);
             RESTORE_VECTOR_REGISTERS();
@@ -3947,7 +4006,8 @@ int wc_AesSetIV(Aes* aes, const byte* iv)
 
             /* if input and output same will overwrite input iv */
             XMEMCPY(aes->tmp, in + sz - AES_BLOCK_SIZE, AES_BLOCK_SIZE);
-            SAVE_VECTOR_REGISTERS();
+            if (SAVE_VECTOR_REGISTERS() != 0)
+                return BAD_STATE_E;
             #if defined(WOLFSSL_AESNI_BY4)
             AES_CBC_decrypt_by4(in, out, (byte*)aes->reg, sz, (byte*)aes->key,
                             aes->rounds);
@@ -7519,7 +7579,8 @@ int wc_AesGcmEncrypt(Aes* aes, byte* out, const byte* in, word32 sz,
 #ifdef WOLFSSL_AESNI
     #ifdef HAVE_INTEL_AVX2
     if (IS_INTEL_AVX2(intel_flags)) {
-        SAVE_VECTOR_REGISTERS();
+        if (SAVE_VECTOR_REGISTERS() != 0)
+            return BAD_STATE_E;
         AES_GCM_encrypt_avx2(in, out, authIn, iv, authTag, sz, authInSz, ivSz,
                                  authTagSz, (const byte*)aes->key, aes->rounds);
         RESTORE_VECTOR_REGISTERS();
@@ -7529,7 +7590,8 @@ int wc_AesGcmEncrypt(Aes* aes, byte* out, const byte* in, word32 sz,
     #endif
     #ifdef HAVE_INTEL_AVX1
     if (IS_INTEL_AVX1(intel_flags)) {
-        SAVE_VECTOR_REGISTERS();
+        if (SAVE_VECTOR_REGISTERS() != 0)
+            return BAD_STATE_E;
         AES_GCM_encrypt_avx1(in, out, authIn, iv, authTag, sz, authInSz, ivSz,
                                  authTagSz, (const byte*)aes->key, aes->rounds);
         RESTORE_VECTOR_REGISTERS();
@@ -8041,7 +8103,8 @@ int wc_AesGcmDecrypt(Aes* aes, byte* out, const byte* in, word32 sz,
 #ifdef WOLFSSL_AESNI
     #ifdef HAVE_INTEL_AVX2
     if (IS_INTEL_AVX2(intel_flags)) {
-        SAVE_VECTOR_REGISTERS();
+        if (SAVE_VECTOR_REGISTERS() != 0)
+            return BAD_STATE_E;
         AES_GCM_decrypt_avx2(in, out, authIn, iv, authTag, sz, authInSz, ivSz,
                                  authTagSz, (byte*)aes->key, aes->rounds, &res);
         RESTORE_VECTOR_REGISTERS();
@@ -8053,7 +8116,8 @@ int wc_AesGcmDecrypt(Aes* aes, byte* out, const byte* in, word32 sz,
     #endif
     #ifdef HAVE_INTEL_AVX1
     if (IS_INTEL_AVX1(intel_flags)) {
-        SAVE_VECTOR_REGISTERS();
+        if (SAVE_VECTOR_REGISTERS() != 0)
+            return BAD_STATE_E;
         AES_GCM_decrypt_avx1(in, out, authIn, iv, authTag, sz, authInSz, ivSz,
                                  authTagSz, (byte*)aes->key, aes->rounds, &res);
         RESTORE_VECTOR_REGISTERS();
@@ -8296,7 +8360,7 @@ extern void AES_GCM_encrypt_final_aesni(unsigned char* tag,
  * @param [in]      iv    IV/nonce buffer.
  * @param [in]      ivSz  Length of IV/nonce data.
  */
-static void AesGcmInit_aesni(Aes* aes, const byte* iv, word32 ivSz)
+static int AesGcmInit_aesni(Aes* aes, const byte* iv, word32 ivSz)
 {
     /* Reset state fields. */
     aes->aSz = 0;
@@ -8309,7 +8373,8 @@ static void AesGcmInit_aesni(Aes* aes, const byte* iv, word32 ivSz)
 
 #ifdef HAVE_INTEL_AVX2
     if (IS_INTEL_AVX2(intel_flags)) {
-        SAVE_VECTOR_REGISTERS();
+        if (SAVE_VECTOR_REGISTERS() != 0)
+            return BAD_STATE_E;
         AES_GCM_init_avx2((byte*)aes->key, aes->rounds, iv, ivSz, aes->H,
                           AES_COUNTER(aes), AES_INITCTR(aes));
         RESTORE_VECTOR_REGISTERS();
@@ -8318,7 +8383,8 @@ static void AesGcmInit_aesni(Aes* aes, const byte* iv, word32 ivSz)
 #endif
 #ifdef HAVE_INTEL_AVX1
     if (IS_INTEL_AVX1(intel_flags)) {
-        SAVE_VECTOR_REGISTERS();
+        if (SAVE_VECTOR_REGISTERS() != 0)
+            return BAD_STATE_E;
         AES_GCM_init_avx1((byte*)aes->key, aes->rounds, iv, ivSz, aes->H,
                           AES_COUNTER(aes), AES_INITCTR(aes));
         RESTORE_VECTOR_REGISTERS();
@@ -8326,11 +8392,13 @@ static void AesGcmInit_aesni(Aes* aes, const byte* iv, word32 ivSz)
     else
 #endif
     {
-        SAVE_VECTOR_REGISTERS();
+        if (SAVE_VECTOR_REGISTERS() != 0)
+            return BAD_STATE_E;
         AES_GCM_init_aesni((byte*)aes->key, aes->rounds, iv, ivSz, aes->H,
                            AES_COUNTER(aes), AES_INITCTR(aes));
         RESTORE_VECTOR_REGISTERS();
     }
+    return 0;
 }
 
 /* Update the AES GCM for encryption with authentication data.
@@ -8458,13 +8526,14 @@ static void AesGcmAadUpdate_aesni(Aes* aes, const byte* a, word32 aSz, int endA)
  * @param [in]      a    Buffer holding authentication data.
  * @param [in]      aSz  Length of authentication data in bytes.
  */
-static void AesGcmEncryptUpdate_aesni(Aes* aes, byte* c, const byte* p,
+static int AesGcmEncryptUpdate_aesni(Aes* aes, byte* c, const byte* p,
     word32 cSz, const byte* a, word32 aSz)
 {
     word32 blocks;
     int partial;
 
-    SAVE_VECTOR_REGISTERS();
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
     /* Hash in A, the Authentication Data */
     AesGcmAadUpdate_aesni(aes, a, aSz, (cSz > 0) && (c != NULL));
 
@@ -8573,6 +8642,7 @@ static void AesGcmEncryptUpdate_aesni(Aes* aes, byte* c, const byte* p,
         }
     }
     RESTORE_VECTOR_REGISTERS();
+    return 0;
 }
 
 /* Finalize the AES GCM for encryption and calculate the authentication tag.
@@ -8584,12 +8654,13 @@ static void AesGcmEncryptUpdate_aesni(Aes* aes, byte* c, const byte* p,
  * @param [in]      authTagSz  Length of authentication tag in bytes.
  * @return  0 on success.
  */
-static void AesGcmEncryptFinal_aesni(Aes* aes, byte* authTag, word32 authTagSz)
+static int AesGcmEncryptFinal_aesni(Aes* aes, byte* authTag, word32 authTagSz)
 {
     /* AAD block incomplete when > 0 */
     byte over = aes->aOver;
 
-    SAVE_VECTOR_REGISTERS();
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
     if (aes->cOver > 0) {
         /* Cipher text block incomplete. */
         over = aes->cOver;
@@ -8635,6 +8706,7 @@ static void AesGcmEncryptFinal_aesni(Aes* aes, byte* authTag, word32 authTagSz)
             aes->aSz, aes->H, AES_INITCTR(aes));
     }
     RESTORE_VECTOR_REGISTERS();
+    return 0;
 }
 
 #if defined(HAVE_AES_DECRYPT) || defined(HAVE_AESGCM_DECRYPT)
@@ -8680,13 +8752,14 @@ extern void AES_GCM_decrypt_final_aesni(unsigned char* tag,
  * @param [in]      a    Buffer holding authentication data.
  * @param [in]      aSz  Length of authentication data in bytes.
  */
-static void AesGcmDecryptUpdate_aesni(Aes* aes, byte* p, const byte* c,
+static int AesGcmDecryptUpdate_aesni(Aes* aes, byte* p, const byte* c,
     word32 cSz, const byte* a, word32 aSz)
 {
     word32 blocks;
     int partial;
 
-    SAVE_VECTOR_REGISTERS();
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
     /* Hash in A, the Authentication Data */
     AesGcmAadUpdate_aesni(aes, a, aSz, (cSz > 0) && (c != NULL));
 
@@ -8797,6 +8870,7 @@ static void AesGcmDecryptUpdate_aesni(Aes* aes, byte* p, const byte* c,
         }
     }
     RESTORE_VECTOR_REGISTERS();
+    return 0;
 }
 
 /* Finalize the AES GCM for decryption and check the authentication tag.
@@ -8819,7 +8893,8 @@ static int AesGcmDecryptFinal_aesni(Aes* aes, const byte* authTag,
     byte over = aes->aOver;
     byte *lastBlock = AES_LASTGBLOCK(aes);
 
-    SAVE_VECTOR_REGISTERS();
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
     if (aes->cOver > 0) {
         /* Cipher text block incomplete. */
         over = aes->cOver;
@@ -8940,7 +9015,7 @@ int wc_AesGcmInit(Aes* aes, const byte* key, word32 len, const byte* iv,
                 || IS_INTEL_AVX1(intel_flags)
             #endif
                 ) {
-                AesGcmInit_aesni(aes, iv, ivSz);
+                ret = AesGcmInit_aesni(aes, iv, ivSz);
             }
             else
         #endif
@@ -9052,7 +9127,7 @@ int wc_AesGcmEncryptUpdate(Aes* aes, byte* out, const byte* in, word32 sz,
             || IS_INTEL_AVX1(intel_flags)
         #endif
             ) {
-            AesGcmEncryptUpdate_aesni(aes, out, in, sz, authIn, authInSz);
+            ret = AesGcmEncryptUpdate_aesni(aes, out, in, sz, authIn, authInSz);
         }
         else
     #endif
@@ -9818,13 +9893,13 @@ int wc_AesCcmEncrypt(Aes* aes, byte* out, const byte* in, word32 inSz,
     B[15] = 1;
 #ifdef WOLFSSL_AESNI
     if (haveAESNI && aes->use_aesni) {
+        if (SAVE_VECTOR_REGISTERS() != 0)
+            return BAD_STATE_E;
         while (inSz >= AES_BLOCK_SIZE * 4) {
             AesCcmCtrIncSet4(B, lenSz);
 
-            SAVE_VECTOR_REGISTERS();
             AES_ECB_encrypt(B, A, AES_BLOCK_SIZE * 4, (byte*)aes->key,
                             aes->rounds);
-            RESTORE_VECTOR_REGISTERS();
 
             xorbuf(A, in, AES_BLOCK_SIZE * 4);
             XMEMCPY(out, A, AES_BLOCK_SIZE * 4);
@@ -9835,6 +9910,7 @@ int wc_AesCcmEncrypt(Aes* aes, byte* out, const byte* in, word32 inSz,
 
             AesCcmCtrInc4(B, lenSz);
         }
+        RESTORE_VECTOR_REGISTERS();
     }
 #endif
     while (inSz >= AES_BLOCK_SIZE) {
@@ -9903,13 +9979,13 @@ int  wc_AesCcmDecrypt(Aes* aes, byte* out, const byte* in, word32 inSz,
 
 #ifdef WOLFSSL_AESNI
     if (haveAESNI && aes->use_aesni) {
+        if (SAVE_VECTOR_REGISTERS() != 0)
+            return BAD_STATE_E;
         while (oSz >= AES_BLOCK_SIZE * 4) {
             AesCcmCtrIncSet4(B, lenSz);
 
-            SAVE_VECTOR_REGISTERS();
             AES_ECB_encrypt(B, A, AES_BLOCK_SIZE * 4, (byte*)aes->key,
                             aes->rounds);
-            RESTORE_VECTOR_REGISTERS();
 
             xorbuf(A, in, AES_BLOCK_SIZE * 4);
             XMEMCPY(o, A, AES_BLOCK_SIZE * 4);
@@ -9920,6 +9996,7 @@ int  wc_AesCcmDecrypt(Aes* aes, byte* out, const byte* in, word32 inSz,
 
             AesCcmCtrInc4(B, lenSz);
         }
+        RESTORE_VECTOR_REGISTERS();
     }
 #endif
     while (oSz >= AES_BLOCK_SIZE) {
@@ -10274,12 +10351,10 @@ int wc_AesEcbDecrypt(Aes* aes, byte* out, const byte* in, word32 sz)
 #else
 
 /* Software AES - ECB */
-int wc_AesEcbEncrypt(Aes* aes, byte* out, const byte* in, word32 sz)
+static int _AesEcbEncrypt(Aes* aes, byte* out, const byte* in, word32 sz)
 {
     word32 blocks = sz / AES_BLOCK_SIZE;
 
-    if ((in == NULL) || (out == NULL) || (aes == NULL))
-      return BAD_FUNC_ARG;
 #ifdef WOLFSSL_IMXRT_DCP
     if (aes->keylen == 16)
         return DCPAesEcbEncrypt(aes, out, in, sz);
@@ -10293,13 +10368,10 @@ int wc_AesEcbEncrypt(Aes* aes, byte* out, const byte* in, word32 sz)
     return 0;
 }
 
-
-int wc_AesEcbDecrypt(Aes* aes, byte* out, const byte* in, word32 sz)
+static int _AesEcbDecrypt(Aes* aes, byte* out, const byte* in, word32 sz)
 {
     word32 blocks = sz / AES_BLOCK_SIZE;
 
-    if ((in == NULL) || (out == NULL) || (aes == NULL))
-      return BAD_FUNC_ARG;
 #ifdef WOLFSSL_IMXRT_DCP
     if (aes->keylen == 16)
         return DCPAesEcbDecrypt(aes, out, in, sz);
@@ -10312,6 +10384,36 @@ int wc_AesEcbDecrypt(Aes* aes, byte* out, const byte* in, word32 sz)
     }
     return 0;
 }
+
+int wc_AesEcbEncrypt(Aes* aes, byte* out, const byte* in, word32 sz)
+{
+    int ret;
+
+    if ((in == NULL) || (out == NULL) || (aes == NULL))
+      return BAD_FUNC_ARG;
+
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
+    ret = _AesEcbEncrypt(aes, out, in, sz);
+    RESTORE_VECTOR_REGISTERS();
+
+    return ret;
+}
+
+int wc_AesEcbDecrypt(Aes* aes, byte* out, const byte* in, word32 sz)
+{
+    int ret;
+
+    if ((in == NULL) || (out == NULL) || (aes == NULL))
+      return BAD_FUNC_ARG;
+
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
+    ret = _AesEcbDecrypt(aes, out, in, sz);
+    RESTORE_VECTOR_REGISTERS();
+
+    return ret;
+}
 #endif
 #endif /* HAVE_AES_ECB */
 
@@ -10360,6 +10462,9 @@ static int wc_AesFeedbackEncrypt(Aes* aes, byte* out, const byte* in,
         sz--;
     }
 
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
+
     while (sz >= AES_BLOCK_SIZE) {
         /* Using aes->tmp here for inline case i.e. in=out */
         wc_AesEncryptDirect(aes, (byte*)aes->tmp, (byte*)aes->reg);
@@ -10406,6 +10511,7 @@ static int wc_AesFeedbackEncrypt(Aes* aes, byte* out, const byte* in,
             aes->left--;
         }
     }
+    RESTORE_VECTOR_REGISTERS();
 
     return 0;
 }
@@ -10448,6 +10554,9 @@ static int wc_AesFeedbackDecrypt(Aes* aes, byte* out, const byte* in, word32 sz,
         sz--;
     }
 
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
+
     while (sz > AES_BLOCK_SIZE) {
         /* Using aes->tmp here for inline case i.e. in=out */
         wc_AesEncryptDirect(aes, (byte*)aes->tmp, (byte*)aes->reg);
@@ -10491,6 +10600,7 @@ static int wc_AesFeedbackDecrypt(Aes* aes, byte* out, const byte* in, word32 sz,
             aes->left--;
         }
     }
+    RESTORE_VECTOR_REGISTERS();
 
     return 0;
 }
@@ -10572,6 +10682,9 @@ static int wc_AesFeedbackCFB8(Aes* aes, byte* out, const byte* in,
         return 0;
     }
 
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
+
     while (sz > 0) {
         wc_AesEncryptDirect(aes, (byte*)aes->tmp, (byte*)aes->reg);
         if (dir == AES_DECRYPTION) {
@@ -10600,6 +10713,8 @@ static int wc_AesFeedbackCFB8(Aes* aes, byte* out, const byte* in,
         sz  -= 1;
     }
 
+    RESTORE_VECTOR_REGISTERS();
+
     return 0;
 }
 
@@ -10621,6 +10736,9 @@ static int wc_AesFeedbackCFB1(Aes* aes, byte* out, const byte* in,
         return 0;
     }
 
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
+
     while (sz > 0) {
         wc_AesEncryptDirect(aes, (byte*)aes->tmp, (byte*)aes->reg);
         if (dir == AES_DECRYPTION) {
@@ -10667,6 +10785,7 @@ static int wc_AesFeedbackCFB1(Aes* aes, byte* out, const byte* in,
     if (bit > 0 && bit < 7) {
         out[0] = cur;
     }
+    RESTORE_VECTOR_REGISTERS();
 
     return 0;
 }
@@ -10843,6 +10962,9 @@ int wc_AesKeyWrap_ex(Aes *aes, const byte* in, word32 inSz, byte* out,
         XMEMCPY(tmp, iv, KEYWRAP_BLOCK_SIZE);
     }
 
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
+
     for (j = 0; j <= 5; j++) {
         for (i = 1; i <= inSz / KEYWRAP_BLOCK_SIZE; i++) {
             /* load R[i] */
@@ -10860,6 +10982,7 @@ int wc_AesKeyWrap_ex(Aes *aes, const byte* in, word32 inSz, byte* out,
         }
         r = out + KEYWRAP_BLOCK_SIZE;
     }
+    RESTORE_VECTOR_REGISTERS();
 
     /* C[0] = A */
     XMEMCPY(out, tmp, KEYWRAP_BLOCK_SIZE);
@@ -10944,6 +11067,9 @@ int wc_AesKeyUnWrap_ex(Aes *aes, const byte* in, word32 inSz, byte* out,
     XMEMCPY(out, in + KEYWRAP_BLOCK_SIZE, inSz - KEYWRAP_BLOCK_SIZE);
     XMEMSET(t, 0, sizeof(t));
 
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
+
     /* initialize counter to 6n */
     n = (inSz - 1) / KEYWRAP_BLOCK_SIZE;
     InitKeyWrapCounter(t, 6 * n);
@@ -10964,6 +11090,7 @@ int wc_AesKeyUnWrap_ex(Aes *aes, const byte* in, word32 inSz, byte* out,
             XMEMCPY(r, tmp + KEYWRAP_BLOCK_SIZE, KEYWRAP_BLOCK_SIZE);
         }
     }
+    RESTORE_VECTOR_REGISTERS();
 
     /* verify IV */
     if (XMEMCMP(tmp, expIv, KEYWRAP_BLOCK_SIZE) != 0)
@@ -11178,10 +11305,10 @@ static int _AesXtsHelper(Aes* aes, byte* out, const byte* in, word32 sz, int dir
 
     xorbuf(out, in, totalSz);
     if (dir == AES_ENCRYPTION) {
-        return wc_AesEcbEncrypt(aes, out, out, totalSz);
+        return _AesEcbEncrypt(aes, out, out, totalSz);
     }
     else {
-        return wc_AesEcbDecrypt(aes, out, out, totalSz);
+        return _AesEcbDecrypt(aes, out, out, totalSz);
     }
 }
 #endif /* HAVE_AES_ECB */
@@ -11224,6 +11351,9 @@ int wc_AesXtsEncrypt(XtsAes* xaes, byte* out, const byte* in, word32 sz,
         XMEMSET(tmp, 0, AES_BLOCK_SIZE); /* set to 0's in case of improper AES
                                           * key setup passed to encrypt direct*/
 
+        if (SAVE_VECTOR_REGISTERS() != 0)
+            return BAD_STATE_E;
+
         wc_AesEncryptDirect(tweak, tmp, i);
 
     #ifdef HAVE_AES_ECB
@@ -11231,6 +11361,7 @@ int wc_AesXtsEncrypt(XtsAes* xaes, byte* out, const byte* in, word32 sz,
         if (in != out) { /* can not handle inline */
             XMEMCPY(out, tmp, AES_BLOCK_SIZE);
             if ((ret = _AesXtsHelper(aes, out, in, sz, AES_ENCRYPTION)) != 0) {
+                RESTORE_VECTOR_REGISTERS();
                 return ret;
             }
         }
@@ -11285,6 +11416,7 @@ int wc_AesXtsEncrypt(XtsAes* xaes, byte* out, const byte* in, word32 sz,
             wc_AesEncryptDirect(aes, out - AES_BLOCK_SIZE, buf);
             xorbuf(out - AES_BLOCK_SIZE, tmp, AES_BLOCK_SIZE);
         }
+        RESTORE_VECTOR_REGISTERS();
     }
     else {
         WOLFSSL_MSG("Plain text input too small for encryption");
@@ -11335,6 +11467,9 @@ int wc_AesXtsDecrypt(XtsAes* xaes, byte* out, const byte* in, word32 sz,
         XMEMSET(tmp, 0, AES_BLOCK_SIZE); /* set to 0's in case of improper AES
                                           * key setup passed to decrypt direct*/
 
+        if (SAVE_VECTOR_REGISTERS() != 0)
+            return BAD_STATE_E;
+
         wc_AesEncryptDirect(tweak, tmp, i);
 
         /* if Stealing then break out of loop one block early to handle special
@@ -11348,6 +11483,7 @@ int wc_AesXtsDecrypt(XtsAes* xaes, byte* out, const byte* in, word32 sz,
         if (in != out) { /* can not handle inline */
             XMEMCPY(out, tmp, AES_BLOCK_SIZE);
             if ((ret = _AesXtsHelper(aes, out, in, sz, AES_DECRYPTION)) != 0) {
+                RESTORE_VECTOR_REGISTERS();
                 return ret;
             }
         }
@@ -11416,6 +11552,7 @@ int wc_AesXtsDecrypt(XtsAes* xaes, byte* out, const byte* in, word32 sz,
             /* Make buffer with end of cipher text | last */
             XMEMCPY(buf, tmp2, AES_BLOCK_SIZE);
             if (sz >= AES_BLOCK_SIZE) { /* extra sanity check before copy */
+                RESTORE_VECTOR_REGISTERS();
                 return BUFFER_E;
             }
             XMEMCPY(buf, in,   sz);
@@ -11426,6 +11563,7 @@ int wc_AesXtsDecrypt(XtsAes* xaes, byte* out, const byte* in, word32 sz,
             xorbuf(tmp2, tmp, AES_BLOCK_SIZE);
             XMEMCPY(out - AES_BLOCK_SIZE, tmp2, AES_BLOCK_SIZE);
         }
+        RESTORE_VECTOR_REGISTERS();
     }
     else {
         WOLFSSL_MSG("Plain text input too small for encryption");

+ 4 - 2
wolfcrypt/src/chacha.c

@@ -418,14 +418,16 @@ int wc_Chacha_Process(ChaCha* ctx, byte* output, const byte* input,
 
     #ifdef HAVE_INTEL_AVX2
     if (IS_INTEL_AVX2(cpuidFlags)) {
-        SAVE_VECTOR_REGISTERS();
+        if (SAVE_VECTOR_REGISTERS() != 0)
+            return BAD_STATE_E;
         chacha_encrypt_avx2(ctx, input, output, msglen);
         RESTORE_VECTOR_REGISTERS();
         return 0;
     }
     #endif
     if (IS_INTEL_AVX1(cpuidFlags)) {
-        SAVE_VECTOR_REGISTERS();
+        if (SAVE_VECTOR_REGISTERS() != 0)
+            return BAD_STATE_E;
         chacha_encrypt_avx1(ctx, input, output, msglen);
         RESTORE_VECTOR_REGISTERS();
         return 0;

+ 36 - 10
wolfcrypt/src/cmac.c

@@ -119,10 +119,19 @@ int wc_InitCmac_ex(Cmac* cmac, const byte* key, word32 keySz,
         byte l[AES_BLOCK_SIZE];
 
         XMEMSET(l, 0, AES_BLOCK_SIZE);
-        wc_AesEncryptDirect(&cmac->aes, l, l);
-        ShiftAndXorRb(cmac->k1, l);
-        ShiftAndXorRb(cmac->k2, cmac->k1);
-        ForceZero(l, AES_BLOCK_SIZE);
+#ifdef WOLFSSL_LINUXKM
+        ret =
+#endif
+            wc_AesEncryptDirect(&cmac->aes, l, l);
+#ifdef WOLFSSL_LINUXKM
+        if (ret == 0) {
+#endif
+            ShiftAndXorRb(cmac->k1, l);
+            ShiftAndXorRb(cmac->k2, cmac->k1);
+            ForceZero(l, AES_BLOCK_SIZE);
+#ifdef WOLFSSL_LINUXKM
+        }
+#endif
     }
     return ret;
 }
@@ -172,10 +181,19 @@ int wc_CmacUpdate(Cmac* cmac, const byte* in, word32 inSz)
             if (cmac->totalSz != 0) {
                 xorbuf(cmac->buffer, cmac->digest, AES_BLOCK_SIZE);
             }
-            wc_AesEncryptDirect(&cmac->aes, cmac->digest, cmac->buffer);
-            cmac->totalSz += AES_BLOCK_SIZE;
-            cmac->bufferSz = 0;
+#ifdef WOLFSSL_LINUXKM
+            ret =
+#endif
+                wc_AesEncryptDirect(&cmac->aes, cmac->digest, cmac->buffer);
+#ifdef WOLFSSL_LINUXKM
+            if (ret == 0) {
+#endif
+                cmac->totalSz += AES_BLOCK_SIZE;
+                cmac->bufferSz = 0;
+            }
+#ifdef WOLFSSL_LINUXKM
         }
+#endif
     }
 
     return ret;
@@ -221,9 +239,17 @@ int wc_CmacFinal(Cmac* cmac, byte* out, word32* outSz)
     }
     xorbuf(cmac->buffer, cmac->digest, AES_BLOCK_SIZE);
     xorbuf(cmac->buffer, subKey, AES_BLOCK_SIZE);
-    wc_AesEncryptDirect(&cmac->aes, cmac->digest, cmac->buffer);
-
-    XMEMCPY(out, cmac->digest, *outSz);
+#ifdef WOLFSSL_LINUXKM
+    ret =
+#endif
+        wc_AesEncryptDirect(&cmac->aes, cmac->digest, cmac->buffer);
+#ifdef WOLFSSL_LINUXKM
+    if (ret == 0) {
+#endif
+        XMEMCPY(out, cmac->digest, *outSz);
+#ifdef WOLFSSL_LINUXKM
+    }
+#endif
 
     ForceZero(cmac, sizeof(Cmac));
 

+ 9 - 18
wolfcrypt/src/curve25519.c

@@ -128,15 +128,12 @@ int wc_curve25519_make_pub(int public_size, byte* pub, int private_size,
 #else
     fe_init();
 
-    #if defined(USE_INTEL_SPEEDUP) || defined(WOLFSSL_ARMASM)
-        SAVE_VECTOR_REGISTERS();
-    #endif
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
 
     ret = curve25519(pub, priv, kCurve25519BasePoint);
 
-    #if defined(USE_INTEL_SPEEDUP) || defined(WOLFSSL_ARMASM)
-        RESTORE_VECTOR_REGISTERS();
-    #endif
+    RESTORE_VECTOR_REGISTERS();
 #endif
 
     return ret;
@@ -174,15 +171,12 @@ int wc_curve25519_generic(int public_size, byte* pub,
 
     fe_init();
 
-    #if defined(USE_INTEL_SPEEDUP) || defined(WOLFSSL_ARMASM)
-        SAVE_VECTOR_REGISTERS();
-    #endif
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
 
     ret = curve25519(pub, priv, basepoint);
 
-    #if defined(USE_INTEL_SPEEDUP) || defined(WOLFSSL_ARMASM)
-        RESTORE_VECTOR_REGISTERS();
-    #endif
+    RESTORE_VECTOR_REGISTERS();
 
     return ret;
 #endif /* FREESCALE_LTC_ECC */
@@ -295,15 +289,12 @@ int wc_curve25519_shared_secret_ex(curve25519_key* private_key,
         ret = nxp_ltc_curve25519(&o, private_key->k, &public_key->p,
                                  kLTC_Curve25519);
     #else
-        #if defined(USE_INTEL_SPEEDUP) || defined(WOLFSSL_ARMASM)
-            SAVE_VECTOR_REGISTERS();
-        #endif
+        if (SAVE_VECTOR_REGISTERS() != 0)
+            return BAD_STATE_E;
 
         ret = curve25519(o.point, private_key->k, public_key->p.point);
 
-        #if defined(USE_INTEL_SPEEDUP) || defined(WOLFSSL_ARMASM)
-            RESTORE_VECTOR_REGISTERS();
-        #endif
+        RESTORE_VECTOR_REGISTERS();
     #endif
     if (ret != 0) {
         ForceZero(&o, sizeof(o));

+ 138 - 0
wolfcrypt/src/memory.c

@@ -1133,3 +1133,141 @@ void __attribute__((no_instrument_function))
 }
 #endif
 
+#ifdef WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED
+union fpregs_state **wolfcrypt_irq_fpu_states = NULL;
+#endif
+
+#if defined(WOLFSSL_LINUXKM_SIMD_X86) && defined(WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED)
+
+    static __must_check inline int am_in_hard_interrupt_handler(void) {
+        return (preempt_count() & (NMI_MASK | HARDIRQ_MASK)) != 0;
+    }
+
+    __must_check int allocate_wolfcrypt_irq_fpu_states(void) {
+        wolfcrypt_irq_fpu_states = (union fpregs_state **)kzalloc(nr_cpu_ids * sizeof(struct fpu_state *), GFP_KERNEL);
+        if (! wolfcrypt_irq_fpu_states) {
+            pr_err("warning, allocation of %lu bytes for wolfcrypt_irq_fpu_states failed.\n", nr_cpu_ids * sizeof(struct fpu_state *));
+            return MEMORY_E;
+        }
+        {
+            unsigned int i;
+            for (i=0; i<nr_cpu_ids; ++i) {
+                _Static_assert(sizeof(union fpregs_state) <= PAGE_SIZE, "union fpregs_state is larger than expected.");
+                wolfcrypt_irq_fpu_states[i] = (union fpregs_state *)kzalloc(PAGE_SIZE /* sizeof(union fpregs_state) */, GFP_KERNEL);
+                if (! wolfcrypt_irq_fpu_states[i])
+                    break;
+                /* double-check that the allocation is 64-byte-aligned as needed for xsave. */
+                if ((unsigned long)wolfcrypt_irq_fpu_states[i] & 63UL) {
+                    pr_err("warning, allocation for wolfcrypt_irq_fpu_states was not properly aligned (%px).\n", wolfcrypt_irq_fpu_states[i]);
+                    kfree(wolfcrypt_irq_fpu_states[i]);
+                    wolfcrypt_irq_fpu_states[i] = 0;
+                    break;
+                }
+            }
+            if (i < nr_cpu_ids) {
+                pr_err("warning, only %u/%u allocations succeeded for wolfcrypt_irq_fpu_states.\n", i, nr_cpu_ids);
+                return MEMORY_E;
+            }
+        }
+        return 0;
+    }
+
+    void free_wolfcrypt_irq_fpu_states(void) {
+        if (wolfcrypt_irq_fpu_states) {
+            unsigned int i;
+            for (i=0; i<nr_cpu_ids; ++i) {
+                if (wolfcrypt_irq_fpu_states[i])
+                    kfree(wolfcrypt_irq_fpu_states[i]);
+            }
+            kfree(wolfcrypt_irq_fpu_states);
+            wolfcrypt_irq_fpu_states = 0;
+        }
+    }
+
+    __must_check int save_vector_registers_x86(void) {
+        preempt_disable();
+        if (! irq_fpu_usable()) {
+            if (am_in_hard_interrupt_handler()) {
+                int processor_id;
+
+                if (! wolfcrypt_irq_fpu_states) {
+                    static int warned_on_null_wolfcrypt_irq_fpu_states = 0;
+                    preempt_enable();
+                    if (! warned_on_null_wolfcrypt_irq_fpu_states) {
+                        warned_on_null_wolfcrypt_irq_fpu_states = 1;
+                        pr_err("save_vector_registers_x86 with null wolfcrypt_irq_fpu_states.\n");
+                    }
+                    return EFAULT;
+                }
+
+                processor_id = __smp_processor_id();
+
+                if (! wolfcrypt_irq_fpu_states[processor_id]) {
+                    static int warned_on_null_wolfcrypt_irq_fpu_states_processor_id = -1;
+                    preempt_enable();
+                    if (warned_on_null_wolfcrypt_irq_fpu_states_processor_id < processor_id) {
+                        warned_on_null_wolfcrypt_irq_fpu_states_processor_id = processor_id;
+                        pr_err("save_vector_registers_x86 for cpu id %d with null wolfcrypt_irq_fpu_states[id].\n", processor_id);
+                    }
+                    return EFAULT;
+                }
+
+                /* check for nested interrupts -- doesn't exist on x86, but make
+                 * sure, in case something changes.
+                 *
+                 * (see https://stackoverflow.com/questions/23324084/nested-interrupt-handling-in-arm)
+                 */
+                if (((char *)wolfcrypt_irq_fpu_states[processor_id])[PAGE_SIZE-1] != 0) {
+                    preempt_enable();
+                    pr_err("save_vector_registers_x86 called recursively for cpu id %d.\n", processor_id);
+                    return EPERM;
+                }
+
+                /* note, fpregs_lock() is not needed here, because
+                 * interrupts/preemptions are already disabled here.
+                 */
+                {
+                    /* save_fpregs_to_fpstate() only accesses fpu->state, which has
+                     * stringent alignment requirements (64 byte cache line), but takes
+                     * a pointer to the parent struct.  work around this.
+                     */
+                    struct fpu *fake_fpu_pointer = (struct fpu *)(((char *)wolfcrypt_irq_fpu_states[processor_id]) - offsetof(struct fpu, state));
+                #if LINUX_VERSION_CODE < KERNEL_VERSION(5, 14, 0)
+                    copy_fpregs_to_fpstate(fake_fpu_pointer);
+                #else
+                    save_fpregs_to_fpstate(fake_fpu_pointer);
+                #endif
+                }
+                ((char *)wolfcrypt_irq_fpu_states[processor_id])[PAGE_SIZE-1] = 1; /* mark the slot as used. */
+                /* note, not preempt_enable()ing, mirroring kernel_fpu_begin() semantics. */
+                return 0;
+            }
+            preempt_enable();
+            return EPERM;
+        } else {
+            kernel_fpu_begin();
+            preempt_enable(); /* kernel_fpu_begin() does its own preempt_disable().  decrement ours. */
+            return 0;
+        }
+    }
+    void restore_vector_registers_x86(void) {
+        if (am_in_hard_interrupt_handler()) {
+            int processor_id = __smp_processor_id();
+            if (((char *)wolfcrypt_irq_fpu_states[processor_id])[PAGE_SIZE-1]) {
+            #if LINUX_VERSION_CODE < KERNEL_VERSION(5, 14, 0)
+                copy_kernel_to_fpregs(wolfcrypt_irq_fpu_states[processor_id]);
+            #else
+                __restore_fpregs_from_fpstate(wolfcrypt_irq_fpu_states[processor_id], xfeatures_mask_all);
+            #endif
+                ((char *)wolfcrypt_irq_fpu_states[processor_id])[PAGE_SIZE-1] = 0;
+                preempt_enable();
+                return;
+            } else {
+                pr_err("restore_vector_registers_x86 called for cpu id %d without saved context.\n", processor_id);
+                preempt_enable(); /* just in case */
+                return;
+            }
+        }
+        kernel_fpu_end();
+    }
+#endif /* WOLFSSL_LINUXKM_SIMD_X86 && WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED */

+ 26 - 12
wolfcrypt/src/poly1305.c

@@ -262,14 +262,16 @@ static WC_INLINE void u32tole64(const word32 inLe32, byte outLe64[8])
 This local function operates on a message with a given number of bytes
 with a given ctx pointer to a Poly1305 structure.
 */
-static void poly1305_blocks(Poly1305* ctx, const unsigned char *m,
+static int poly1305_blocks(Poly1305* ctx, const unsigned char *m,
                      size_t bytes)
 {
 #ifdef USE_INTEL_SPEEDUP
     /* AVX2 is handled in wc_Poly1305Update. */
-    SAVE_VECTOR_REGISTERS();
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
     poly1305_blocks_avx(ctx, m, bytes);
     RESTORE_VECTOR_REGISTERS();
+    return 0;
 #elif defined(POLY130564)
     const word64 hibit = (ctx->finished) ? 0 : ((word64)1 << 40); /* 1 << 128 */
     word64 r0,r1,r2;
@@ -320,6 +322,8 @@ static void poly1305_blocks(Poly1305* ctx, const unsigned char *m,
     ctx->h[1] = h1;
     ctx->h[2] = h2;
 
+    return 0;
+
 #else /* if not 64 bit then use 32 bit */
     const word32 hibit = (ctx->finished) ? 0 : ((word32)1 << 24); /* 1 << 128 */
     word32 r0,r1,r2,r3,r4;
@@ -385,6 +389,8 @@ static void poly1305_blocks(Poly1305* ctx, const unsigned char *m,
     ctx->h[3] = h3;
     ctx->h[4] = h4;
 
+    return 0;
+
 #endif /* end of 64 bit cpu blocks or 32 bit cpu */
 }
 
@@ -392,15 +398,17 @@ static void poly1305_blocks(Poly1305* ctx, const unsigned char *m,
 This local function is used for the last call when a message with a given
 number of bytes is less than the block size.
 */
-static void poly1305_block(Poly1305* ctx, const unsigned char *m)
+static int poly1305_block(Poly1305* ctx, const unsigned char *m)
 {
 #ifdef USE_INTEL_SPEEDUP
     /* No call to poly1305_block when AVX2, AVX2 does 4 blocks at a time. */
-    SAVE_VECTOR_REGISTERS();
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
     poly1305_block_avx(ctx, m);
     RESTORE_VECTOR_REGISTERS();
+    return 0;
 #else
-    poly1305_blocks(ctx, m, POLY1305_BLOCK_SIZE);
+    return poly1305_blocks(ctx, m, POLY1305_BLOCK_SIZE);
 #endif
 }
 #endif /* !defined(WOLFSSL_ARMASM) || !defined(__aarch64__) */
@@ -434,7 +442,8 @@ int wc_Poly1305SetKey(Poly1305* ctx, const byte* key, word32 keySz)
         intel_flags = cpuid_get_flags();
         cpu_flags_set = 1;
     }
-    SAVE_VECTOR_REGISTERS();
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
     #ifdef HAVE_INTEL_AVX2
     if (IS_INTEL_AVX2(intel_flags))
         poly1305_setkey_avx2(ctx, key);
@@ -516,7 +525,8 @@ int wc_Poly1305Final(Poly1305* ctx, byte* mac)
         return BAD_FUNC_ARG;
 
 #ifdef USE_INTEL_SPEEDUP
-    SAVE_VECTOR_REGISTERS();
+    if (SAVE_VECTOR_REGISTERS() != 0)
+        return BAD_STATE_E;
     #ifdef HAVE_INTEL_AVX2
     if (IS_INTEL_AVX2(intel_flags))
         poly1305_final_avx2(ctx, mac);
@@ -704,7 +714,12 @@ int wc_Poly1305Update(Poly1305* ctx, const byte* m, word32 bytes)
 #ifdef USE_INTEL_SPEEDUP
     #ifdef HAVE_INTEL_AVX2
     if (IS_INTEL_AVX2(intel_flags)) {
+
+        if (SAVE_VECTOR_REGISTERS() != 0)
+            return BAD_STATE_E;
+
         /* handle leftover */
+
         if (ctx->leftover) {
             size_t want = sizeof(ctx->buffer) - ctx->leftover;
             if (want > bytes)
@@ -718,15 +733,11 @@ int wc_Poly1305Update(Poly1305* ctx, const byte* m, word32 bytes)
             if (ctx->leftover < sizeof(ctx->buffer))
                 return 0;
 
-            SAVE_VECTOR_REGISTERS();
             if (!ctx->started)
                 poly1305_calc_powers_avx2(ctx);
             poly1305_blocks_avx2(ctx, ctx->buffer, sizeof(ctx->buffer));
             ctx->leftover = 0;
         }
-        else {
-            SAVE_VECTOR_REGISTERS();
-        }
 
         /* process full blocks */
         if (bytes >= sizeof(ctx->buffer)) {
@@ -769,8 +780,11 @@ int wc_Poly1305Update(Poly1305* ctx, const byte* m, word32 bytes)
 
         /* process full blocks */
         if (bytes >= POLY1305_BLOCK_SIZE) {
+            int ret;
             size_t want = (bytes & ~(POLY1305_BLOCK_SIZE - 1));
-            poly1305_blocks(ctx, m, want);
+            ret = poly1305_blocks(ctx, m, want);
+            if (ret != 0)
+                return ret;
             m += want;
             bytes -= (word32)want;
         }

+ 8 - 4
wolfcrypt/src/sha256.c

@@ -319,8 +319,10 @@ static int InitSha256(wc_Sha256* sha256)
 
     static WC_INLINE int inline_XTRANSFORM(wc_Sha256* S, const byte* D) {
         int ret;
-        if (Transform_Sha256_is_vectorized)
-            SAVE_VECTOR_REGISTERS();
+        if (Transform_Sha256_is_vectorized) {
+            if (SAVE_VECTOR_REGISTERS() != 0)
+                return BAD_STATE_E;
+        }
         ret = (*Transform_Sha256_p)(S, D);
         if (Transform_Sha256_is_vectorized)
             RESTORE_VECTOR_REGISTERS();
@@ -330,8 +332,10 @@ static int InitSha256(wc_Sha256* sha256)
 
     static WC_INLINE int inline_XTRANSFORM_LEN(wc_Sha256* S, const byte* D, word32 L) {
         int ret;
-        if (Transform_Sha256_is_vectorized)
-            SAVE_VECTOR_REGISTERS();
+        if (Transform_Sha256_is_vectorized) {
+            if (SAVE_VECTOR_REGISTERS() != 0)
+                return BAD_STATE_E;
+        }
         ret = (*Transform_Sha256_Len_p)(S, D, L);
         if (Transform_Sha256_is_vectorized)
             RESTORE_VECTOR_REGISTERS();

+ 8 - 4
wolfcrypt/src/sha512.c

@@ -448,8 +448,10 @@ static int InitSha512_256(wc_Sha512* sha512)
 
     static WC_INLINE int Transform_Sha512(wc_Sha512 *sha512) {
         int ret;
-        if (Transform_Sha512_is_vectorized)
-            SAVE_VECTOR_REGISTERS();
+        if (Transform_Sha512_is_vectorized) {
+            if (SAVE_VECTOR_REGISTERS() != 0)
+                return BAD_STATE_E;
+        }
         ret = (*Transform_Sha512_p)(sha512);
         if (Transform_Sha512_is_vectorized)
             RESTORE_VECTOR_REGISTERS();
@@ -457,8 +459,10 @@ static int InitSha512_256(wc_Sha512* sha512)
     }
     static WC_INLINE int Transform_Sha512_Len(wc_Sha512 *sha512, word32 len) {
         int ret;
-        if (Transform_Sha512_is_vectorized)
-            SAVE_VECTOR_REGISTERS();
+        if (Transform_Sha512_is_vectorized) {
+            if (SAVE_VECTOR_REGISTERS() != 0)
+                return BAD_STATE_E;
+        }
         ret = (*Transform_Sha512_Len_p)(sha512, len);
         if (Transform_Sha512_is_vectorized)
             RESTORE_VECTOR_REGISTERS();

+ 13 - 0
wolfcrypt/src/wc_port.c

@@ -157,6 +157,15 @@ int wolfCrypt_Init(void)
         }
     #endif
 
+    #if defined(WOLFSSL_LINUXKM_SIMD_X86) \
+        && defined(WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED)
+        ret = allocate_wolfcrypt_irq_fpu_states();
+        if (ret != 0) {
+            WOLFSSL_MSG("allocate_wolfcrypt_irq_fpu_states failed");
+            return ret;
+        }
+    #endif
+
     #if WOLFSSL_CRYPT_HW_MUTEX
         /* If crypto hardware mutex protection is enabled, then initialize it */
         ret = wolfSSL_CryptHwMutexInit();
@@ -356,6 +365,10 @@ int wolfCrypt_Cleanup(void)
         rpcmem_deinit();
         wolfSSL_CleanupHandle();
     #endif
+    #if defined(WOLFSSL_LINUXKM_SIMD_X86) \
+        && defined(WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED)
+        free_wolfcrypt_irq_fpu_states();
+    #endif
     }
 
     return ret;

+ 10 - 0
wolfcrypt/test/test.c

@@ -9084,7 +9084,12 @@ WOLFSSL_TEST_SUBROUTINE int aes_test(void)
         ret = wc_AesSetKey(enc, niKey, sizeof(niKey), cipher, AES_ENCRYPTION);
         if (ret != 0)
             ERROR_OUT(-5943, out);
+#ifdef WOLFSSL_LINUXKM
+        if (wc_AesEncryptDirect(enc, cipher, niPlain) != 0)
+            ERROR_OUT(-5950, out);
+#else
         wc_AesEncryptDirect(enc, cipher, niPlain);
+#endif
         if (XMEMCMP(cipher, niCipher, AES_BLOCK_SIZE) != 0)
             ERROR_OUT(-5944, out);
 
@@ -9092,7 +9097,12 @@ WOLFSSL_TEST_SUBROUTINE int aes_test(void)
         ret = wc_AesSetKey(dec, niKey, sizeof(niKey), plain, AES_DECRYPTION);
         if (ret != 0)
             ERROR_OUT(-5945, out);
+#ifdef WOLFSSL_LINUXKM
+        if (wc_AesDecryptDirect(dec, plain, niCipher) != 0)
+            ERROR_OUT(-5951, out);
+#else
         wc_AesDecryptDirect(dec, plain, niCipher);
+#endif
         if (XMEMCMP(plain, niPlain, AES_BLOCK_SIZE) != 0)
             ERROR_OUT(-5946, out);
     }

+ 5 - 0
wolfssl/wolfcrypt/aes.h

@@ -351,8 +351,13 @@ WOLFSSL_API int wc_AesEcbDecrypt(Aes* aes, byte* out,
 #endif
 /* AES-DIRECT */
 #if defined(WOLFSSL_AES_DIRECT)
+#ifdef WOLFSSL_LINUXKM
+ WOLFSSL_API __must_check int wc_AesEncryptDirect(Aes* aes, byte* out, const byte* in);
+ WOLFSSL_API __must_check int wc_AesDecryptDirect(Aes* aes, byte* out, const byte* in);
+#else
  WOLFSSL_API void wc_AesEncryptDirect(Aes* aes, byte* out, const byte* in);
  WOLFSSL_API void wc_AesDecryptDirect(Aes* aes, byte* out, const byte* in);
+#endif
  WOLFSSL_API int  wc_AesSetKeyDirect(Aes* aes, const byte* key, word32 len,
                                 const byte* iv, int dir);
 #endif

+ 113 - 14
wolfssl/wolfcrypt/wc_port.h

@@ -122,29 +122,45 @@
     #endif
     #include <linux/net.h>
     #include <linux/slab.h>
-    #if defined(WOLFSSL_AESNI) || defined(USE_INTEL_SPEEDUP)
+    #if defined(WOLFSSL_AESNI) || defined(USE_INTEL_SPEEDUP) || defined(WOLFSSL_SP_X86_64_ASM)
+        #ifndef CONFIG_X86
+            #error X86 SIMD extensions requested, but CONFIG_X86 is not set.
+        #endif
+        #define WOLFSSL_LINUXKM_SIMD
+        #define WOLFSSL_LINUXKM_SIMD_X86
         #if LINUX_VERSION_CODE < KERNEL_VERSION(4, 0, 0)
             #include <asm/i387.h>
         #else
             #include <asm/simd.h>
         #endif
+        #include <asm/fpu/internal.h>
         #ifndef SAVE_VECTOR_REGISTERS
-            #define SAVE_VECTOR_REGISTERS() kernel_fpu_begin()
+            #define SAVE_VECTOR_REGISTERS() save_vector_registers_x86()
         #endif
         #ifndef RESTORE_VECTOR_REGISTERS
-            #define RESTORE_VECTOR_REGISTERS() kernel_fpu_end()
+            #define RESTORE_VECTOR_REGISTERS() restore_vector_registers_x86()
+        #endif
+    #elif defined(WOLFSSL_ARMASM) || defined(WOLFSSL_SP_ARM32_ASM) || \
+          defined(WOLFSSL_SP_ARM64_ASM) || defined(WOLFSSL_SP_ARM_THUMB_ASM) ||\
+          defined(WOLFSSL_SP_ARM_CORTEX_M_ASM)
+        #if !defined(CONFIG_ARM) && !defined(CONFIG_ARM64)
+            #error ARM SIMD extensions requested, but CONFIG_ARM* is not set.
         #endif
-    #elif defined(WOLFSSL_ARMASM)
+        #define WOLFSSL_LINUXKM_SIMD
+        #define WOLFSSL_LINUXKM_SIMD_ARM
         #include <asm/fpsimd.h>
         #ifndef SAVE_VECTOR_REGISTERS
-            #define SAVE_VECTOR_REGISTERS() ({ preempt_disable(); fpsimd_preserve_current_state(); })
+            #define SAVE_VECTOR_REGISTERS() save_vector_registers_arm()
         #endif
         #ifndef RESTORE_VECTOR_REGISTERS
-            #define RESTORE_VECTOR_REGISTERS() ({ fpsimd_restore_current_state(); preempt_enable(); })
+            #define RESTORE_VECTOR_REGISTERS() restore_vector_registers_arm()
         #endif
     #else
+        #ifndef WOLFSSL_NO_ASM
+            #define WOLFSSL_NO_ASM
+        #endif
         #ifndef SAVE_VECTOR_REGISTERS
-            #define SAVE_VECTOR_REGISTERS() ({})
+            #define SAVE_VECTOR_REGISTERS() 0
         #endif
         #ifndef RESTORE_VECTOR_REGISTERS
             #define RESTORE_VECTOR_REGISTERS() ({})
@@ -247,10 +263,13 @@
         typeof(kmalloc_order_trace) *kmalloc_order_trace;
 
         typeof(get_random_bytes) *get_random_bytes;
-        typeof(ktime_get_real_seconds) *ktime_get_real_seconds;
-        typeof(ktime_get_with_offset) *ktime_get_with_offset;
+        typeof(ktime_get_coarse_real_ts64) *ktime_get_coarse_real_ts64;
+
+        struct task_struct *(*get_current)(void);
+        int (*preempt_count)(void);
 
-        #if defined(WOLFSSL_AESNI) || defined(USE_INTEL_SPEEDUP)
+        #ifdef WOLFSSL_LINUXKM_SIMD_X86
+        typeof(irq_fpu_usable) *irq_fpu_usable;
         /* kernel_fpu_begin() replaced by kernel_fpu_begin_mask() in commit e4512289,
          * released in kernel 5.11, backported to 5.4.93
          */
@@ -260,7 +279,21 @@
             typeof(kernel_fpu_begin) *kernel_fpu_begin;
         #endif
         typeof(kernel_fpu_end) *kernel_fpu_end;
+
+        #ifdef WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED
+        #if LINUX_VERSION_CODE < KERNEL_VERSION(5, 14, 0)
+            typeof(copy_fpregs_to_fpstate) *copy_fpregs_to_fpstate;
+            typeof(copy_kernel_to_fpregs) *copy_kernel_to_fpregs;
+        #else
+            typeof(save_fpregs_to_fpstate) *save_fpregs_to_fpstate;
+            typeof(__restore_fpregs_from_fpstate) *__restore_fpregs_from_fpstate;
+            typeof(xfeatures_mask_all) *xfeatures_mask_all;
         #endif
+        typeof(cpu_number) *cpu_number;
+        typeof(nr_cpu_ids) *nr_cpu_ids;
+        #endif /* WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED */
+
+        #endif /* WOLFSSL_LINUXKM_SIMD_X86 */
 
         typeof(__mutex_init) *__mutex_init;
         typeof(mutex_lock) *mutex_lock;
@@ -326,6 +359,7 @@
     #define kfree (wolfssl_linuxkm_get_pie_redirect_table()->kfree)
     #define ksize (wolfssl_linuxkm_get_pie_redirect_table()->ksize)
     #define krealloc (wolfssl_linuxkm_get_pie_redirect_table()->krealloc)
+    #define kzalloc(size, flags) kmalloc(size, (flags) | __GFP_ZERO)
     #ifdef HAVE_KVMALLOC
         #define kvmalloc_node (wolfssl_linuxkm_get_pie_redirect_table()->kvmalloc_node)
         #define kvfree (wolfssl_linuxkm_get_pie_redirect_table()->kvfree)
@@ -335,16 +369,33 @@
     #define kmalloc_order_trace (wolfssl_linuxkm_get_pie_redirect_table()->kmalloc_order_trace)
 
     #define get_random_bytes (wolfssl_linuxkm_get_pie_redirect_table()->get_random_bytes)
-    #define ktime_get_real_seconds (wolfssl_linuxkm_get_pie_redirect_table()->ktime_get_real_seconds)
-    #define ktime_get_with_offset (wolfssl_linuxkm_get_pie_redirect_table()->ktime_get_with_offset)
+    #define ktime_get_coarse_real_ts64 (wolfssl_linuxkm_get_pie_redirect_table()->ktime_get_coarse_real_ts64)
+
+    #undef get_current
+    #define get_current (wolfssl_linuxkm_get_pie_redirect_table()->get_current)
+    #undef preempt_count
+    #define preempt_count (wolfssl_linuxkm_get_pie_redirect_table()->preempt_count)
 
-    #if defined(WOLFSSL_AESNI) || defined(USE_INTEL_SPEEDUP)
+    #ifdef WOLFSSL_LINUXKM_SIMD_X86
+        #define irq_fpu_usable (wolfssl_linuxkm_get_pie_redirect_table()->irq_fpu_usable)
         #ifdef kernel_fpu_begin
             #define kernel_fpu_begin_mask (wolfssl_linuxkm_get_pie_redirect_table()->kernel_fpu_begin_mask)
         #else
             #define kernel_fpu_begin (wolfssl_linuxkm_get_pie_redirect_table()->kernel_fpu_begin)
         #endif
         #define kernel_fpu_end (wolfssl_linuxkm_get_pie_redirect_table()->kernel_fpu_end)
+        #ifdef WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED
+            #if LINUX_VERSION_CODE < KERNEL_VERSION(5, 14, 0)
+                #define copy_fpregs_to_fpstate (wolfssl_linuxkm_get_pie_redirect_table()->copy_fpregs_to_fpstate)
+                #define copy_kernel_to_fpregs (wolfssl_linuxkm_get_pie_redirect_table()->copy_kernel_to_fpregs)
+            #else
+                #define save_fpregs_to_fpstate (wolfssl_linuxkm_get_pie_redirect_table()->save_fpregs_to_fpstate)
+                #define __restore_fpregs_from_fpstate (wolfssl_linuxkm_get_pie_redirect_table()->__restore_fpregs_from_fpstate)
+                #define xfeatures_mask_all (*(wolfssl_linuxkm_get_pie_redirect_table()->xfeatures_mask_all))
+            #endif
+            #define cpu_number (*(wolfssl_linuxkm_get_pie_redirect_table()->cpu_number))
+            #define nr_cpu_ids (*(wolfssl_linuxkm_get_pie_redirect_table()->nr_cpu_ids))
+        #endif /* WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED */
     #endif
 
     #define __mutex_init (wolfssl_linuxkm_get_pie_redirect_table()->__mutex_init)
@@ -371,6 +422,54 @@
 
     #endif /* USE_WOLFSSL_LINUXKM_PIE_REDIRECT_TABLE */
 
+#ifdef WOLFSSL_LINUXKM_SIMD
+
+#ifdef WOLFSSL_LINUXKM_SIMD_X86
+
+#ifdef WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED
+    extern __must_check int allocate_wolfcrypt_irq_fpu_states(void);
+    extern void free_wolfcrypt_irq_fpu_states(void);
+    extern __must_check int save_vector_registers_x86(void);
+    extern void restore_vector_registers_x86(void);
+#else /* !WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED */
+    static __must_check inline int save_vector_registers_x86(void) {
+        preempt_disable();
+        if (! irq_fpu_usable()) {
+            preempt_enable();
+            return EPERM;
+        } else {
+            kernel_fpu_begin();
+            preempt_enable(); /* kernel_fpu_begin() does its own preempt_disable().  decrement ours. */
+            return 0;
+        }
+    }
+    static inline void restore_vector_registers_x86(void) {
+        kernel_fpu_end();
+    }
+#endif /* !WOLFSSL_LINUXKM_SIMD_X86_IRQ_ALLOWED */
+
+#elif defined(CONFIG_ARM) || defined(CONFIG_ARM64)
+
+    static __must_check inline int save_vector_registers_arm(void) {
+        preempt_disable();
+        if (! may_use_simd()) {
+            preempt_enable();
+            return EPERM;
+        } else {
+            fpsimd_preserve_current_state();
+            return 0;
+        }
+    }
+    static inline void restore_vector_registers_arm(void) {
+        fpsimd_restore_current_state();
+        preempt_enable();
+    }
+
+#endif
+
+#endif /* WOLFSSL_LINUXKM_SIMD */
+
+
     /* Linux headers define these using C expressions, but we need
      * them to be evaluable by the preprocessor, for use in sp_int.h.
      */
@@ -484,7 +583,7 @@
 #else /* ! WOLFSSL_LINUXKM */
 
     #ifndef SAVE_VECTOR_REGISTERS
-        #define SAVE_VECTOR_REGISTERS() do{}while(0)
+        #define SAVE_VECTOR_REGISTERS() 0
     #endif
     #ifndef RESTORE_VECTOR_REGISTERS
         #define RESTORE_VECTOR_REGISTERS() do{}while(0)