Explorar o código

dtls13: support KeyUpdate messages

Marco Oliverio %!s(int64=2) %!d(string=hai) anos
pai
achega
dfc9873c0f
Modificáronse 3 ficheiros con 127 adicións e 8 borrados
  1. 48 0
      src/dtls13.c
  2. 78 8
      src/tls13.c
  3. 1 0
      wolfssl/internal.h

+ 48 - 0
src/dtls13.c

@@ -1551,6 +1551,9 @@ int Dtls13HandshakeSend(WOLFSSL* ssl, byte* message, word16 outputSize,
     maxFrag = wolfSSL_GetMaxFragSize(ssl, MAX_RECORD_SIZE);
     maxLen = length;
 
+    if (handshakeType == key_update)
+        ssl->dtls13WaitKeyUpdateAck = 1;
+
     if (maxLen < maxFrag) {
         ret = Dtls13SendOneFragmentRtx(ssl, handshakeType, outputSize, message,
             length, hashOutput);
@@ -2106,6 +2109,26 @@ static int Dtls13RtxIsTrackedByRn(const Dtls13RtxRecord* r, w64wrapper epoch,
     return 0;
 }
 
+static int Dtls13KeyUpdateAckReceived(WOLFSSL* ssl)
+{
+    int ret;
+    w64Increment(&ssl->dtls13Epoch);
+
+    /* Epoch wrapped up */
+    if (w64IsZero(ssl->dtls13Epoch))
+        return BAD_STATE_E;
+
+    ret = DeriveTls13Keys(ssl, update_traffic_key, ENCRYPT_SIDE_ONLY, 1);
+    if (ret != 0)
+        return ret;
+
+    ret = Dtls13NewEpoch(ssl, ssl->dtls13Epoch, ENCRYPT_SIDE_ONLY);
+    if (ret != 0)
+        return ret;
+
+    return Dtls13SetEpochKeys(ssl, ssl->dtls13Epoch, ENCRYPT_SIDE_ONLY);
+}
+
 #ifdef WOLFSSL_DEBUG_TLS
 static void Dtls13PrintRtxRecord(Dtls13RtxRecord* r)
 {
@@ -2200,12 +2223,27 @@ int Dtls13RtxTimeout(WOLFSSL* ssl)
     return Dtls13RtxSendBuffered(ssl);
 }
 
+static int Dtls13RtxHasKeyUpdateBuffered(WOLFSSL* ssl)
+{
+    Dtls13RtxRecord* r = ssl->dtls13Rtx.rtxRecords;
+
+    while (r != NULL) {
+        if (r->handshakeType == key_update)
+            return 1;
+
+        r = r->next;
+    }
+
+    return 0;
+}
+
 int DoDtls13Ack(WOLFSSL* ssl, const byte* input, word32 inputSize,
     word32* processedSize)
 {
     const byte* ackMessage;
     w64wrapper epoch, seq;
     word16 length;
+    int ret;
     int i;
 
     if (inputSize < OPAQUE16_LEN)
@@ -2234,6 +2272,16 @@ int DoDtls13Ack(WOLFSSL* ssl, const byte* input, word32 inputSize,
         ssl->options.serverState = SERVER_FINISHED_ACKED;
     }
 
+    if (ssl->dtls13WaitKeyUpdateAck) {
+        if (!Dtls13RtxHasKeyUpdateBuffered(ssl)) {
+            /* we removed the KeyUpdate message because it was ACKed */
+            ssl->dtls13WaitKeyUpdateAck = 0;
+            ret = Dtls13KeyUpdateAckReceived(ssl);
+            if (ret != 0)
+                return ret;
+        }
+    }
+
     *processedSize = length + OPAQUE16_LEN;
 
     /* After the handshake, not retransmitting here may incur in some extra time

+ 78 - 8
src/tls13.c

@@ -7896,6 +7896,11 @@ static int SendTls13KeyUpdate(WOLFSSL* ssl)
     WOLFSSL_START(WC_FUNC_KEY_UPDATE_SEND);
     WOLFSSL_ENTER("SendTls13KeyUpdate");
 
+#ifdef WOLFSSL_DTLS13
+    if (ssl->options.dtls)
+        i = Dtls13GetRlHeaderLength(1) + DTLS_HANDSHAKE_HEADER_SZ;
+#endif /* WOLFSSL_DTLS13 */
+
     outputSz = OPAQUE8_LEN + MAX_MSG_EXTRA;
     /* Check buffers are big enough and grow if needed. */
     if ((ret = CheckAvailableSize(ssl, outputSz)) != 0)
@@ -7906,6 +7911,11 @@ static int SendTls13KeyUpdate(WOLFSSL* ssl)
              ssl->buffers.outputBuffer.length;
     input = output + RECORD_HEADER_SZ;
 
+#ifdef WOLFSSL_DTLS13
+    if (ssl->options.dtls)
+        input = output + Dtls13GetRlHeaderLength(1);
+#endif /* WOLFSSL_DTLS13 */
+
     AddTls13Headers(output, OPAQUE8_LEN, key_update, ssl);
 
     /* If:
@@ -7918,6 +7928,15 @@ static int SendTls13KeyUpdate(WOLFSSL* ssl)
     /* Sent response, no longer need to respond. */
     ssl->keys.keyUpdateRespond = 0;
 
+#ifdef WOLFSSL_DTLS13
+    if (ssl->options.dtls) {
+        ret = Dtls13HandshakeSend(ssl, output, outputSz,
+            OPAQUE8_LEN + Dtls13GetRlHeaderLength(1) + DTLS_HANDSHAKE_HEADER_SZ,
+            key_update, 0);
+    }
+    else {
+#endif /* WOLFSSL_DTLS13 */
+
     /* This message is always encrypted. */
     sendSz = BuildTls13Message(ssl, output, outputSz, input,
                                headerSz + OPAQUE8_LEN, handshake, 0, 0, 0);
@@ -7935,15 +7954,26 @@ static int SendTls13KeyUpdate(WOLFSSL* ssl)
     ssl->buffers.outputBuffer.length += sendSz;
 
     ret = SendBuffered(ssl);
+
+
     if (ret != 0 && ret != WANT_WRITE)
         return ret;
+#ifdef WOLFSSL_DTLS13
+    }
+#endif /* WOLFSSL_DTLS13 */
+
+    /* In DTLS we must wait for the ack before setting up the new keys */
+    if (!ssl->options.dtls) {
+
+        /* Future traffic uses new encryption keys. */
+        if ((ret = DeriveTls13Keys(
+                       ssl, update_traffic_key, ENCRYPT_SIDE_ONLY, 1))
+            != 0)
+            return ret;
+        if ((ret = SetKeysSide(ssl, ENCRYPT_SIDE_ONLY)) != 0)
+            return ret;
+    }
 
-    /* Future traffic uses new encryption keys. */
-    if ((ret = DeriveTls13Keys(ssl, update_traffic_key, ENCRYPT_SIDE_ONLY, 1))
-                                                                           != 0)
-        return ret;
-    if ((ret = SetKeysSide(ssl, ENCRYPT_SIDE_ONLY)) != 0)
-        return ret;
 
     WOLFSSL_LEAVE("SendTls13KeyUpdate", ret);
     WOLFSSL_END(WC_FUNC_KEY_UPDATE_SEND);
@@ -8001,8 +8031,37 @@ static int DoTls13KeyUpdate(WOLFSSL* ssl, const byte* input, word32* inOutIdx,
     if ((ret = SetKeysSide(ssl, DECRYPT_SIDE_ONLY)) != 0)
         return ret;
 
-    if (ssl->keys.keyUpdateRespond)
+#ifdef WOLFSSL_DTLS13
+    if (ssl->options.dtls) {
+        w64Increment(&ssl->dtls13PeerEpoch);
+
+        ret = Dtls13NewEpoch(ssl, ssl->dtls13PeerEpoch, DECRYPT_SIDE_ONLY);
+        if (ret != 0)
+            return ret;
+
+        ret = Dtls13SetEpochKeys(ssl, ssl->dtls13PeerEpoch, DECRYPT_SIDE_ONLY);
+        if (ret != 0)
+            return ret;
+    }
+#endif /* WOLFSSL_DTLS13 */
+
+    if (ssl->keys.keyUpdateRespond) {
+
+#ifdef WOLFSSL_DTLS13
+        /* we already sent a keyUpdate (either in response to a previous
+           KeyUpdate or initiated by the application) and we are waiting for the
+           ack. We can't send a new KeyUpdate right away but to honor the RFC we
+           should send another KeyUpdate after the one in-flight is acked. We
+           don't do that as it looks redundant, it will make the code more
+           complex and I don't see a good use case for that. */
+        if (ssl->options.dtls && ssl->dtls13WaitKeyUpdateAck) {
+            ssl->keys.keyUpdateRespond = 0;
+            return 0;
+        }
+#endif /* WOLFSSL_DTLS13 */
+
         return SendTls13KeyUpdate(ssl);
+    }
 
     WOLFSSL_LEAVE("DoTls13KeyUpdate", ret);
     WOLFSSL_END(WC_FUNC_KEY_UPDATE_DO);
@@ -9029,7 +9088,7 @@ int DoTls13HandShakeMsgType(WOLFSSL* ssl, byte* input, word32* inOutIdx,
         break;
 
     case key_update:
-        WOLFSSL_MSG("processing finished");
+        WOLFSSL_MSG("processing key update");
         ret = DoTls13KeyUpdate(ssl, input, inOutIdx, size);
         break;
 
@@ -9894,6 +9953,17 @@ int wolfSSL_update_keys(WOLFSSL* ssl)
     if (ssl == NULL || !IsAtLeastTLSv1_3(ssl->version))
         return BAD_FUNC_ARG;
 
+#ifdef WOLFSSL_DTLS13
+    /* we are already waiting for the ack of a sent key update message. We can't
+       send another one before receiving its ack. Either wolfSSL_update_keys()
+       was invoked multiple times over a short period of time or we replied to a
+       KeyUpdate with update request. We'll just ignore sending this
+       KeyUpdate. */
+    /* TODO: add WOLFSSL_ERROR_ALREADY_IN_PROGRESS type of error here */
+    if (ssl->options.dtls && ssl->dtls13WaitKeyUpdateAck)
+            return WOLFSSL_SUCCESS;
+#endif /* WOLFSSL_DTLS13 */
+
     ret = SendTls13KeyUpdate(ssl);
     if (ret == WANT_WRITE)
         ret = WOLFSSL_ERROR_WANT_WRITE;

+ 1 - 0
wolfssl/internal.h

@@ -4635,6 +4635,7 @@ struct WOLFSSL {
     byte dtls13SendingFragments:1;
     byte dtls13SendingAckOrRtx:1;
     byte dtls13FastTimeout:1;
+    byte dtls13WaitKeyUpdateAck:1;
     word32 dtls13MessageLength;
     word32 dtls13FragOffset;
     byte dtls13FragHandshakeType;