Browse Source

Move increment of dtls epoch to change cipher state function

Reviewed-by: Matt Caswell <matt@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/23212)
Frederik Wedel-Heinen 5 months ago
parent
commit
4897bd2022

+ 11 - 0
ssl/record/rec_layer_d1.c

@@ -679,3 +679,14 @@ void dtls1_increment_epoch(SSL_CONNECTION *s, int rw)
         s->rlayer.d->w_epoch++;
     }
 }
+
+uint16_t dtls1_get_epoch(SSL_CONNECTION *s, int rw) {
+    uint16_t epoch;
+
+    if (rw & SSL3_CC_READ)
+        epoch = s->rlayer.d->r_epoch;
+    else
+        epoch = s->rlayer.d->w_epoch;
+
+    return epoch;
+}

+ 2 - 2
ssl/record/rec_layer_s3.c

@@ -1322,7 +1322,7 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
             prev = s->rlayer.rrlnext;
             if (SSL_CONNECTION_IS_DTLS(s)
                     && level != OSSL_RECORD_PROTECTION_LEVEL_NONE)
-                epoch =  DTLS_RECORD_LAYER_get_r_epoch(&s->rlayer) + 1; /* new epoch */
+                epoch = dtls1_get_epoch(s, SSL3_CC_READ); /* new epoch */
 
 #ifndef OPENSSL_NO_DGRAM
             if (SSL_CONNECTION_IS_DTLS(s))
@@ -1339,7 +1339,7 @@ int ssl_set_new_record_layer(SSL_CONNECTION *s, int version,
         } else {
             if (SSL_CONNECTION_IS_DTLS(s)
                     && level != OSSL_RECORD_PROTECTION_LEVEL_NONE)
-                epoch =  DTLS_RECORD_LAYER_get_w_epoch(&s->rlayer) + 1; /* new epoch */
+                epoch = dtls1_get_epoch(s, SSL3_CC_WRITE); /* new epoch */
         }
 
         /*

+ 1 - 1
ssl/record/record.h

@@ -137,7 +137,6 @@ typedef struct record_layer_st {
 
 #define RECORD_LAYER_set_read_ahead(rl, ra)     ((rl)->read_ahead = (ra))
 #define RECORD_LAYER_get_read_ahead(rl)         ((rl)->read_ahead)
-#define DTLS_RECORD_LAYER_get_w_epoch(rl)       ((rl)->d->w_epoch)
 
 void RECORD_LAYER_init(RECORD_LAYER *rl, SSL_CONNECTION *s);
 void RECORD_LAYER_clear(RECORD_LAYER *rl);
@@ -163,6 +162,7 @@ __owur int dtls1_write_bytes(SSL_CONNECTION *s, uint8_t type, const void *buf,
 int do_dtls1_write(SSL_CONNECTION *s, uint8_t type, const unsigned char *buf,
                    size_t len, size_t *written);
 void dtls1_increment_epoch(SSL_CONNECTION *s, int rw);
+uint16_t dtls1_get_epoch(SSL_CONNECTION *s, int rw);
 int ssl_release_record(SSL_CONNECTION *s, TLS_RECORD *rr, size_t length);
 
 # define HANDLE_RLAYER_READ_RETURN(s, ret) \

+ 0 - 4
ssl/record/record_local.h

@@ -15,7 +15,3 @@
  *****************************************************************************/
 
 #define MAX_WARN_ALERT_COUNT    5
-
-/* Functions/macros provided by the RECORD_LAYER component */
-
-#define DTLS_RECORD_LAYER_get_r_epoch(rl)       ((rl)->d->r_epoch)

+ 8 - 12
ssl/statem/statem_clnt.c

@@ -871,20 +871,16 @@ WORK_STATE ossl_statem_client_post_work(SSL_CONNECTION *s, WORK_STATE wst)
             return WORK_ERROR;
         }
 
-        if (SSL_CONNECTION_IS_DTLS(s)) {
 #ifndef OPENSSL_NO_SCTP
-            if (s->hit) {
-                /*
-                 * Change to new shared key of SCTP-Auth, will be ignored if
-                 * no SCTP used.
-                 */
-                BIO_ctrl(SSL_get_wbio(ssl), BIO_CTRL_DGRAM_SCTP_NEXT_AUTH_KEY,
-                         0, NULL);
-            }
-#endif
-
-            dtls1_increment_epoch(s, SSL3_CC_WRITE);
+        if (SSL_CONNECTION_IS_DTLS(s) && s->hit) {
+            /*
+            * Change to new shared key of SCTP-Auth, will be ignored if
+            * no SCTP used.
+            */
+            BIO_ctrl(SSL_get_wbio(ssl), BIO_CTRL_DGRAM_SCTP_NEXT_AUTH_KEY,
+                     0, NULL);
         }
+#endif
         break;
 
     case TLS_ST_CW_FINISHED:

+ 0 - 2
ssl/statem/statem_lib.c

@@ -808,8 +808,6 @@ MSG_PROCESS_RETURN tls_process_change_cipher_spec(SSL_CONNECTION *s,
     }
 
     if (SSL_CONNECTION_IS_DTLS(s)) {
-        dtls1_increment_epoch(s, SSL3_CC_READ);
-
         if (s->version == DTLS1_BAD_VER)
             s->d1->handshake_read_seq++;
 

+ 0 - 3
ssl/statem/statem_srvr.c

@@ -994,9 +994,6 @@ WORK_STATE ossl_statem_server_post_work(SSL_CONNECTION *s, WORK_STATE wst)
             /* SSLfatal() already called */
             return WORK_ERROR;
         }
-
-        if (SSL_CONNECTION_IS_DTLS(s))
-            dtls1_increment_epoch(s, SSL3_CC_WRITE);
         break;
 
     case TLS_ST_SW_SRVR_DONE:

+ 3 - 0
ssl/t1_enc.c

@@ -228,6 +228,9 @@ int tls1_change_cipher_state(SSL_CONNECTION *s, int which)
         direction = OSSL_RECORD_DIRECTION_WRITE;
     }
 
+    if (SSL_CONNECTION_IS_DTLS(s))
+        dtls1_increment_epoch(s, which);
+
     if (!ssl_set_new_record_layer(s, s->version, direction,
                                     OSSL_RECORD_PROTECTION_LEVEL_APPLICATION,
                                     NULL, 0, key, cl, iv, (size_t)k, mac_secret,