Browse Source

dtls13: change encryption keys dynamically based on the epoch

In DTLSv1.3, because of retransmission and reordering, we may need to encrypt or
decrypt records with older keys. As an example, if the server finished message
is lost, the server will need to retransmit that message using handshake traffic
keys, even if he already used the traffic0 ones (as, for example, to send
NewSessionTicket just after the finished message).

This commit implements a way to save the key bound to a DTLS epoch and setting
the right key/epoch when needed.
Marco Oliverio 2 years ago
parent
commit
2696c3cdd3
3 changed files with 282 additions and 0 deletions
  1. 224 0
      src/dtls13.c
  2. 9 0
      src/internal.c
  3. 49 0
      wolfssl/internal.h

+ 224 - 0
src/dtls13.c

@@ -159,6 +159,230 @@ static int Dtls13InitChaChaCipher(RecordNumberCiphers* c, byte* key,
 }
 #endif /* HAVE_CHACHA */
 
+struct Dtls13Epoch* Dtls13GetEpoch(WOLFSSL* ssl, w64wrapper epochNumber)
+{
+    Dtls13Epoch* e;
+    int i;
+
+    for (i = 0; i < DTLS13_EPOCH_SIZE; ++i) {
+        e = &ssl->dtls13Epochs[i];
+        if (w64Equal(e->epochNumber, epochNumber) && e->isValid)
+            return e;
+    }
+
+    return NULL;
+}
+
+static void Dtls13EpochCopyKeys(WOLFSSL* ssl, Dtls13Epoch* e, Keys* k, int side)
+{
+    byte clientWrite, serverWrite;
+    byte enc, dec;
+
+    WOLFSSL_ENTER("Dtls13SetEpochKeys");
+
+    clientWrite = serverWrite = 0;
+    enc = dec = 0;
+    switch (side) {
+
+    case ENCRYPT_SIDE_ONLY:
+        if (ssl->options.side == WOLFSSL_CLIENT_END)
+            clientWrite = 1;
+        if (ssl->options.side == WOLFSSL_SERVER_END)
+            serverWrite = 1;
+        enc = 1;
+        break;
+
+    case DECRYPT_SIDE_ONLY:
+        if (ssl->options.side == WOLFSSL_CLIENT_END)
+            serverWrite = 1;
+        if (ssl->options.side == WOLFSSL_SERVER_END)
+            clientWrite = 1;
+        dec = 1;
+        break;
+
+    case ENCRYPT_AND_DECRYPT_SIDE:
+        clientWrite = serverWrite = 1;
+        enc = dec = 1;
+        break;
+    }
+
+    if (clientWrite) {
+        XMEMCPY(e->client_write_key, k->client_write_key,
+            sizeof(e->client_write_key));
+
+        XMEMCPY(e->client_write_IV, k->client_write_IV,
+            sizeof(e->client_write_IV));
+
+        XMEMCPY(e->client_sn_key, k->client_sn_key, sizeof(e->client_sn_key));
+    }
+
+    if (serverWrite) {
+        XMEMCPY(e->server_write_key, k->server_write_key,
+            sizeof(e->server_write_key));
+        XMEMCPY(e->server_write_IV, k->server_write_IV,
+            sizeof(e->server_write_IV));
+        XMEMCPY(e->server_sn_key, k->server_sn_key, sizeof(e->server_sn_key));
+    }
+
+    if (enc)
+        XMEMCPY(e->aead_enc_imp_IV, k->aead_enc_imp_IV,
+            sizeof(e->aead_enc_imp_IV));
+
+    if (dec)
+        XMEMCPY(e->aead_dec_imp_IV, k->aead_dec_imp_IV,
+            sizeof(e->aead_dec_imp_IV));
+}
+
+static Dtls13Epoch* Dtls13NewEpochSlot(WOLFSSL* ssl)
+{
+    Dtls13Epoch *e, *oldest = NULL;
+    w64wrapper oldestNumber;
+    int i;
+
+    /* FIXME: add max function */
+    oldestNumber = w64From32((word32)-1, (word32)-1);
+    oldest = NULL;
+
+    for (i = 0; i < DTLS13_EPOCH_SIZE; ++i) {
+        e = &ssl->dtls13Epochs[i];
+        if (!e->isValid)
+            return e;
+
+        if (!w64Equal(e->epochNumber, ssl->dtls13Epoch) &&
+            !w64Equal(e->epochNumber, ssl->dtls13PeerEpoch) &&
+            w64LT(e->epochNumber, oldestNumber))
+            oldest = e;
+    }
+
+    if (oldest == NULL)
+        return NULL;
+
+    e = oldest;
+
+#ifdef WOLFSSL_DEBUG_TLS
+    WOLFSSL_MSG_EX("Delete epoch: %d", e->epochNumber);
+#endif /* WOLFSSL_DEBUG_TLS */
+
+    XMEMSET(e, 0, sizeof(*e));
+
+    return e;
+}
+
+int Dtls13NewEpoch(WOLFSSL* ssl, w64wrapper epochNumber, int side)
+{
+    Dtls13Epoch* e;
+
+#ifdef WOLFSSL_DEBUG_TLS
+    WOLFSSL_MSG_EX("New epoch: %d", w64GetLow32(epochNumber));
+#endif /* WOLFSSL_DEBUG_TLS */
+
+    e = Dtls13GetEpoch(ssl, epochNumber);
+    if (e == NULL) {
+        e = Dtls13NewEpochSlot(ssl);
+        if (e == NULL)
+            return BAD_STATE_E;
+    }
+
+    Dtls13EpochCopyKeys(ssl, e, &ssl->keys, side);
+
+    if (!e->isValid) {
+        /* fresh epoch, initialize fields */
+        e->epochNumber = epochNumber;
+        e->isValid = 1;
+        e->side = side;
+    }
+    else if (e->side != side) {
+        /* epoch used for the other side already. update side */
+        e->side = ENCRYPT_AND_DECRYPT_SIDE;
+    }
+
+    return 0;
+}
+
+int Dtls13SetEpochKeys(WOLFSSL* ssl, w64wrapper epochNumber,
+    enum encrypt_side side)
+{
+    byte clientWrite, serverWrite;
+    Dtls13Epoch* e;
+    byte enc, dec;
+
+    WOLFSSL_ENTER("Dtls13SetEpochKeys");
+
+    clientWrite = serverWrite = 0;
+    enc = dec = 0;
+    switch (side) {
+
+    case ENCRYPT_SIDE_ONLY:
+        if (ssl->options.side == WOLFSSL_CLIENT_END)
+            clientWrite = 1;
+        if (ssl->options.side == WOLFSSL_SERVER_END)
+            serverWrite = 1;
+        enc = 1;
+        break;
+
+    case DECRYPT_SIDE_ONLY:
+        if (ssl->options.side == WOLFSSL_CLIENT_END)
+            serverWrite = 1;
+        if (ssl->options.side == WOLFSSL_SERVER_END)
+            clientWrite = 1;
+        dec = 1;
+        break;
+
+    case ENCRYPT_AND_DECRYPT_SIDE:
+        clientWrite = serverWrite = 1;
+        enc = dec = 1;
+        break;
+    }
+
+    e = Dtls13GetEpoch(ssl, epochNumber);
+    /* we don't have the requested key */
+    if (e == NULL)
+        return BAD_STATE_E;
+
+    if (e->side != ENCRYPT_AND_DECRYPT_SIDE && e->side != side)
+        return BAD_STATE_E;
+
+    if (enc)
+        ssl->dtls13EncryptEpoch = e;
+    if (dec)
+        ssl->dtls13DecryptEpoch = e;
+
+    /* epoch 0 has no key to copy */
+    if (w64IsZero(epochNumber))
+        return 0;
+
+    if (clientWrite) {
+        XMEMCPY(ssl->keys.client_write_key, e->client_write_key,
+            sizeof(ssl->keys.client_write_key));
+
+        XMEMCPY(ssl->keys.client_write_IV, e->client_write_IV,
+            sizeof(ssl->keys.client_write_IV));
+
+        XMEMCPY(ssl->keys.client_sn_key, e->client_sn_key,
+            sizeof(ssl->keys.client_sn_key));
+    }
+
+    if (serverWrite) {
+        XMEMCPY(ssl->keys.server_write_key, e->server_write_key,
+            sizeof(ssl->keys.server_write_key));
+
+        XMEMCPY(ssl->keys.server_write_IV, e->server_write_IV,
+            sizeof(ssl->keys.server_write_IV));
+
+        XMEMCPY(ssl->keys.server_sn_key, e->server_sn_key,
+            sizeof(ssl->keys.server_sn_key));
+    }
+
+    if (enc)
+        XMEMCPY(ssl->keys.aead_enc_imp_IV, e->aead_enc_imp_IV,
+            sizeof(ssl->keys.aead_enc_imp_IV));
+    if (dec)
+        XMEMCPY(ssl->keys.aead_dec_imp_IV, e->aead_dec_imp_IV,
+            sizeof(ssl->keys.aead_dec_imp_IV));
+
+    return SetKeysSide(ssl, side);
+}
+
 int Dtls13SetRecordNumberKeys(WOLFSSL* ssl, enum encrypt_side side)
 {
     RecordNumberCiphers* enc = NULL;

+ 9 - 0
src/internal.c

@@ -6762,6 +6762,15 @@ int InitSSL(WOLFSSL* ssl, WOLFSSL_CTX* ctx, int writeDup)
     }
 #endif /* HAVE_SECURE_RENEGOTIATION */
 
+
+#ifdef WOLFSSL_DTLS13
+    /* setup 0 (un-protected) epoch */
+    ssl->dtls13Epochs[0].isValid = 1;
+    ssl->dtls13Epochs[0].side = ENCRYPT_AND_DECRYPT_SIDE;
+    ssl->dtls13EncryptEpoch = &ssl->dtls13Epochs[0];
+    ssl->dtls13DecryptEpoch = &ssl->dtls13Epochs[0];
+#endif /* WOLFSSL_DTLS13 */
+
     return 0;
 }
 

+ 49 - 0
wolfssl/internal.h

@@ -4317,6 +4317,43 @@ typedef enum EarlyDataState {
 } EarlyDataState;
 #endif
 
+#ifdef WOLFSSL_DTLS13
+
+enum  {
+    DTLS13_EPOCH_EARLYDATA = 1,
+    DTLS13_EPOCH_HANDSHAKE = 2,
+    DTLS13_EPOCH_TRAFFIC0 = 3
+};
+
+typedef struct Dtls13Epoch {
+    w64wrapper epochNumber;
+
+    w64wrapper nextSeqNumber;
+    w64wrapper nextPeerSeqNumber;
+
+    word32 window[WOLFSSL_DTLS_WINDOW_WORDS];
+
+    /* key material for the epoch */
+    byte client_write_key[MAX_SYM_KEY_SIZE];
+    byte server_write_key[MAX_SYM_KEY_SIZE];
+    byte client_write_IV[MAX_WRITE_IV_SZ];
+    byte server_write_IV[MAX_WRITE_IV_SZ];
+
+    byte aead_exp_IV[AEAD_MAX_EXP_SZ];
+    byte aead_enc_imp_IV[AEAD_MAX_IMP_SZ];
+    byte aead_dec_imp_IV[AEAD_MAX_IMP_SZ];
+
+    byte client_sn_key[MAX_SYM_KEY_SIZE];
+    byte server_sn_key[MAX_SYM_KEY_SIZE];
+
+    byte isValid;
+    byte side;
+} Dtls13Epoch;
+
+#define DTLS13_EPOCH_SIZE 3
+
+#endif /* WOLFSSL_DTLS13 */
+
 /* wolfSSL ssl type */
 struct WOLFSSL {
     WOLFSSL_CTX*    ctx;
@@ -4514,6 +4551,12 @@ struct WOLFSSL {
 #ifdef WOLFSSL_DTLS13
     RecordNumberCiphers dtlsRecordNumberEncrypt;
     RecordNumberCiphers dtlsRecordNumberDecrypt;
+    Dtls13Epoch dtls13Epochs[DTLS13_EPOCH_SIZE];
+    Dtls13Epoch *dtls13EncryptEpoch;
+    Dtls13Epoch *dtls13DecryptEpoch;
+    w64wrapper dtls13Epoch;
+    w64wrapper dtls13PeerEpoch;
+
 #endif /* WOLFSSL_DTLS13 */
 
 #endif /* WOLFSSL_DTLS */
@@ -5260,6 +5303,12 @@ WOLFSSL_LOCAL word32 nid2oid(int nid, int grp);
 
 #ifdef WOLFSSL_DTLS13
 
+WOLFSSL_LOCAL struct Dtls13Epoch* Dtls13GetEpoch(WOLFSSL* ssl,
+    w64wrapper epochNumber);
+WOLFSSL_LOCAL int Dtls13NewEpoch(WOLFSSL* ssl, w64wrapper epochNumber,
+    int side);
+WOLFSSL_LOCAL int Dtls13SetEpochKeys(WOLFSSL* ssl, w64wrapper epochNumber,
+    enum encrypt_side side);
 WOLFSSL_LOCAL int Dtls13DeriveSnKeys(WOLFSSL* ssl, int provision);
 WOLFSSL_LOCAL int Dtls13SetRecordNumberKeys(WOLFSSL* ssl,
     enum encrypt_side side);