Browse Source

Merge pull request #5609 from philljj/master

Fixes DTLS 1.3 client use-after-free error
David Garske 1 year ago
parent
commit
a36604079b
1 changed files with 31 additions and 29 deletions
  1. 31 29
      src/internal.c

+ 31 - 29
src/internal.c

@@ -10125,8 +10125,8 @@ int CheckAvailableSize(WOLFSSL *ssl, int size)
 
 #ifdef WOLFSSL_DTLS13
 static int GetInputData(WOLFSSL *ssl, word32 size);
-static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
-    word32* inOutIdx, RecordLayerHeader* rh, word16* size)
+static int GetDtls13RecordHeader(WOLFSSL* ssl, word32* inOutIdx,
+    RecordLayerHeader* rh, word16* size)
 {
 
     Dtls13UnifiedHdrInfo hdrInfo;
@@ -10140,7 +10140,7 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
     if (readSize < DTLS_UNIFIED_HEADER_MIN_SZ)
         return BUFFER_ERROR;
 
-    epochBits = *(input + *inOutIdx) & EE_MASK;
+    epochBits = *(ssl->buffers.inputBuffer.buffer + *inOutIdx) & EE_MASK;
     ret = Dtls13ReconstructEpochNumber(ssl, epochBits, &epochNumber);
     if (ret != 0)
         return ret;
@@ -10172,7 +10172,7 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
     }
 
     ret = Dtls13GetUnifiedHeaderSize(ssl,
-        *(input+*inOutIdx), &ssl->dtls13CurRlLength);
+        *(ssl->buffers.inputBuffer.buffer+*inOutIdx), &ssl->dtls13CurRlLength);
     if (ret != 0)
         return ret;
 
@@ -10185,7 +10185,8 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
             return ret;
     }
 
-    ret = Dtls13ParseUnifiedRecordLayer(ssl, input + *inOutIdx, (word16)readSize,
+    ret = Dtls13ParseUnifiedRecordLayer(ssl,
+        ssl->buffers.inputBuffer.buffer + *inOutIdx, (word16)readSize,
         &hdrInfo);
 
     if (ret != 0)
@@ -10212,7 +10213,8 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
                    ssl->keys.curSeq);
 #endif /* WOLFSSL_DEBUG_TLS */
 
-    XMEMCPY(ssl->dtls13CurRL, input + *inOutIdx, ssl->dtls13CurRlLength);
+    XMEMCPY(ssl->dtls13CurRL, ssl->buffers.inputBuffer.buffer + *inOutIdx,
+            ssl->dtls13CurRlLength);
     *inOutIdx += ssl->dtls13CurRlLength;
 
     return 0;
@@ -10221,14 +10223,14 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
 #endif /* WOLFSSL_DTLS13 */
 
 #ifdef WOLFSSL_DTLS
-static int GetDtlsRecordHeader(WOLFSSL* ssl, const byte* input,
-    word32* inOutIdx, RecordLayerHeader* rh, word16* size)
+static int GetDtlsRecordHeader(WOLFSSL* ssl, word32* inOutIdx,
+    RecordLayerHeader* rh, word16* size)
 {
 
 #ifdef HAVE_FUZZER
     if (ssl->fuzzerCb)
-        ssl->fuzzerCb(ssl, input + *inOutIdx, DTLS_RECORD_HEADER_SZ,
-                       FUZZ_HEAD, ssl->fuzzerCtx);
+        ssl->fuzzerCb(ssl, ssl->buffers.inputBuffer.buffer + *inOutIdx,
+                      DTLS_RECORD_HEADER_SZ, FUZZ_HEAD, ssl->fuzzerCtx);
 #endif
 
 #ifdef WOLFSSL_DTLS13
@@ -10237,11 +10239,11 @@ static int GetDtlsRecordHeader(WOLFSSL* ssl, const byte* input,
 
     read_size = ssl->buffers.inputBuffer.length - *inOutIdx;
 
-    if (Dtls13IsUnifiedHeader(*(input + *inOutIdx))) {
+    if (Dtls13IsUnifiedHeader(*(ssl->buffers.inputBuffer.buffer + *inOutIdx))) {
 
         /* version 1.3 already negotiated */
         if (ssl->options.tls1_3) {
-            ret = GetDtls13RecordHeader(ssl, input, inOutIdx, rh, size);
+            ret = GetDtls13RecordHeader(ssl, inOutIdx, rh, size);
             if (ret == 0 || ret != SEQUENCE_ERROR || ret != DTLS_CID_ERROR)
                 return ret;
         }
@@ -10269,9 +10271,10 @@ static int GetDtlsRecordHeader(WOLFSSL* ssl, const byte* input,
 #endif /* WOLFSSL_DTLS13 */
 
     /* type and version in same spot */
-    XMEMCPY(rh, input + *inOutIdx, ENUM_LEN + VERSION_SZ);
+    XMEMCPY(rh, ssl->buffers.inputBuffer.buffer + *inOutIdx,
+            ENUM_LEN + VERSION_SZ);
     *inOutIdx += ENUM_LEN + VERSION_SZ;
-    ato16(input + *inOutIdx, &ssl->keys.curEpoch);
+    ato16(ssl->buffers.inputBuffer.buffer + *inOutIdx, &ssl->keys.curEpoch);
 #ifdef WOLFSSL_DTLS13
     /* only non protected message can use the DTLSPlaintext record header */
     if (ssl->options.tls1_3 && ssl->keys.curEpoch != 0)
@@ -10285,14 +10288,14 @@ static int GetDtlsRecordHeader(WOLFSSL* ssl, const byte* input,
     *inOutIdx += OPAQUE16_LEN;
     if (ssl->options.haveMcast) {
     #ifdef WOLFSSL_MULTICAST
-        ssl->keys.curPeerId = input[*inOutIdx];
-        ssl->keys.curSeq_hi = input[*inOutIdx+1];
+        ssl->keys.curPeerId = ssl->buffers.inputBuffer.buffer[*inOutIdx];
+        ssl->keys.curSeq_hi = ssl->buffers.inputBuffer.buffer[*inOutIdx+1];
     #endif
     }
     else
-        ato16(input + *inOutIdx, &ssl->keys.curSeq_hi);
+        ato16(ssl->buffers.inputBuffer.buffer + *inOutIdx, &ssl->keys.curSeq_hi);
     *inOutIdx += OPAQUE16_LEN;
-    ato32(input + *inOutIdx, &ssl->keys.curSeq_lo);
+    ato32(ssl->buffers.inputBuffer.buffer + *inOutIdx, &ssl->keys.curSeq_lo);
     *inOutIdx += OPAQUE32_LEN;  /* advance past rest of seq */
 
 #ifdef WOLFSSL_DTLS13
@@ -10301,7 +10304,7 @@ static int GetDtlsRecordHeader(WOLFSSL* ssl, const byte* input,
     ssl->keys.curSeq = w64From32(ssl->keys.curSeq_hi, ssl->keys.curSeq_lo);
 #endif /* WOLFSSL_DTLS13 */
 
-    ato16(input + *inOutIdx, size);
+    ato16(ssl->buffers.inputBuffer.buffer + *inOutIdx, size);
     *inOutIdx += LENGTH_SZ;
 
     return 0;
@@ -10309,7 +10312,7 @@ static int GetDtlsRecordHeader(WOLFSSL* ssl, const byte* input,
 #endif /* WOLFSSL_DTLS */
 
 /* do all verify and sanity checks on record header */
-static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx,
+static int GetRecordHeader(WOLFSSL* ssl, word32* inOutIdx,
                            RecordLayerHeader* rh, word16 *size)
 {
     byte tls12minor;
@@ -10326,16 +10329,16 @@ static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx,
     if (!ssl->options.dtls) {
 #ifdef HAVE_FUZZER
         if (ssl->fuzzerCb)
-            ssl->fuzzerCb(ssl, input + *inOutIdx, RECORD_HEADER_SZ, FUZZ_HEAD,
-                    ssl->fuzzerCtx);
+            ssl->fuzzerCb(ssl, ssl->buffers.inputBuffer.buffer + *inOutIdx,
+                          RECORD_HEADER_SZ, FUZZ_HEAD, ssl->fuzzerCtx);
 #endif
-        XMEMCPY(rh, input + *inOutIdx, RECORD_HEADER_SZ);
+        XMEMCPY(rh, ssl->buffers.inputBuffer.buffer + *inOutIdx, RECORD_HEADER_SZ);
         *inOutIdx += RECORD_HEADER_SZ;
         ato16(rh->length, size);
     }
     else {
 #ifdef WOLFSSL_DTLS
-        ret = GetDtlsRecordHeader(ssl, input, inOutIdx, rh, size);
+        ret = GetDtlsRecordHeader(ssl, inOutIdx, rh, size);
         if (ret != 0)
             return ret;
 #endif
@@ -10385,7 +10388,7 @@ static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx,
         else if (ssl->options.dtls && !ssl->options.handShakeDone) {
             /* we may have lost the ServerHello and this is a unified record
                before version been negotiated */
-            if (Dtls13IsUnifiedHeader(*input)) {
+            if (Dtls13IsUnifiedHeader(*ssl->buffers.inputBuffer.buffer)) {
                 return SEQUENCE_ERROR;
             }
         }
@@ -10438,8 +10441,8 @@ static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx,
         case no_type:
         default:
 #ifdef OPENSSL_ALL
-            {
-                char *method = (char*)input + start;
+            if (!ssl->options.dtls) {
+                char *method = (char*)ssl->buffers.inputBuffer.buffer + start;
                 /* Attempt to identify if this is a plain HTTP request.
                  * No size checks because this function assumes at least
                  * RECORD_HEADER_SZ size of data has been read which is
@@ -19056,8 +19059,7 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr)
              * header, decrypting the numbers inside
              * DtlsParseUnifiedRecordLayer(). This violates the const attribute
              * of the buffer parameter of GetRecordHeader() used here. */
-            ret = GetRecordHeader(ssl, ssl->buffers.inputBuffer.buffer,
-                                       &ssl->buffers.inputBuffer.idx,
+            ret = GetRecordHeader(ssl, &ssl->buffers.inputBuffer.idx,
                                        &ssl->curRL, &ssl->curSize);
 
 #ifdef WOLFSSL_DTLS