Skip to content

Commit

Permalink
Adds check to ensure no switching between state machines (aws#3747)
Browse files Browse the repository at this point in the history
* state machine switch

* Fixing SAW proofs

Co-authored-by: Yan Peng <[email protected]>
  • Loading branch information
maddeleine and pennyannn authored Jan 18, 2023
1 parent 945b1b2 commit 0be885c
Show file tree
Hide file tree
Showing 17 changed files with 173 additions and 23 deletions.
25 changes: 14 additions & 11 deletions tests/saw/spec/handshake/handshake_io_lowlevel.saw
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ let conn_handshake_handshake_type pconn =
let conn_handshake_message_number pconn =
crucible_field (crucible_field pconn "handshake") "message_number";

//conn->handshake.state_machine
let conn_handshake_state_machine pconn =
crucible_field (crucible_field pconn "handshake") "state_machine";

//conn->handshake_params.our_chain_and_key
let conn_chain_and_key pconn =
crucible_field (crucible_field pconn "handshake_params") "our_chain_and_key";
Expand Down Expand Up @@ -143,6 +147,10 @@ let setup_connection_common chosen_psk_null = do {
crucible_points_to (conn_handshake_message_number pconn)
(crucible_term message_number);

state_machine <- crucible_fresh_var "state_machine" (llvm_int 32);
crucible_points_to (conn_handshake_state_machine pconn)
(crucible_term state_machine);

cork_val <- crucible_fresh_var "corked" (llvm_int 2);
crucible_ghost_value corked cork_val;

Expand Down Expand Up @@ -229,7 +237,8 @@ let setup_connection_common chosen_psk_null = do {
return (pconn, {{ {corked_io = corked_io
,mode = mode
,handshake = {message_number = message_number
,handshake_type = handshake_type}
,handshake_type = handshake_type
,state_machine = state_machine }
,corked = cork_val
,is_caching_enabled = False
,key_exchange_eph = eph_flag != zero
Expand Down Expand Up @@ -313,12 +322,8 @@ let s2n_connection_get_client_auth_type_spec = do{
// by conn_set_handshake_type (low-level model function)
let s2n_conn_set_handshake_type_spec chosen_psk_null = do {
(pconn, conn) <- setup_connection_common chosen_psk_null;
// we assume that the handshake struct denotes a valid handshake state
// (e.g. it will not index out of bounds in the state transition array
// "handshakes")
// conn.handshake is defined in s2n_handshake_io.cry.
// valid_handshake is defined in s2n_handshake_io.cry as well.
crucible_precond {{ valid_handshake conn.handshake }};
// We assume that the connection struct denotes a valid connection state
crucible_precond {{ valid_connection conn }};

// symbolically execute s2n_conn_set_handshake_type
crucible_execute_func [pconn];
Expand Down Expand Up @@ -346,10 +351,8 @@ let s2n_advance_message_spec = do {
// chosen_psk, so we arbitrarily set it to NULL here by using
// setup_connection. Using setup_psk_connection would work just as well.
(pconn, conn) <- setup_connection;
// we assume that the handshake struct denotes a valid handshake state
// (e.g. it will not index out of bounds in the state transition array
// "handshakes")
crucible_precond {{ valid_handshake conn.handshake }};
// We assume that the connection struct denotes a valid connection state
crucible_precond {{ valid_connection conn }};

// symbolically execute s2n_advance_message
crucible_execute_func [pconn];
Expand Down
3 changes: 2 additions & 1 deletion tests/saw/spec/handshake/rfc_handshake_tls13.cry
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,8 @@ testParameters = {
testConnection : connection
testConnection = {
handshake = {handshake_type = 0x00000000,
message_number = 0x00000000},
message_number = 0x00000000,
state_machine = 0x00000000},
mode = 0x00000000,
corked_io = False,
corked = zero,
Expand Down
64 changes: 57 additions & 7 deletions tests/saw/spec/handshake/s2n_handshake_io.cry
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,39 @@ module s2n_handshake_io where
// This function models the update of the s2n_connection struct by the
// s2n_conn_set_handshake_type function in s2n.
conn_set_handshake_type : connection -> connection
conn_set_handshake_type conn =
if IS_TLS13_HANDSHAKE conn
then conn_set_tls13_handshake_type conn
else conn_set_pre_tls13_handshake_type conn
conn_set_handshake_type conn = conn''
where conn'' = if IS_TLS13_HANDSHAKE conn'
then conn_set_tls13_handshake_type conn'
else conn_set_pre_tls13_handshake_type conn'
conn' = conn_choose_state_machine conn conn.actual_protocol_version

// This function models the state machine choosing by the
// s2n_conn_choose_state_machine function in s2n.
conn_choose_state_machine : connection -> [8] -> connection
conn_choose_state_machine conn protocol_version = conn'
where conn' = {handshake = handshake'
,mode = conn.mode
,corked_io = conn.corked_io
,corked = conn.corked
,is_caching_enabled = conn.is_caching_enabled
,resume_from_cache = conn.resume_from_cache
,server_can_send_ocsp = conn.server_can_send_ocsp
,key_exchange_eph = conn.key_exchange_eph
,client_auth_flag = conn.client_auth_flag
,actual_protocol_version = conn.actual_protocol_version
,no_client_cert = conn.no_client_cert
,early_data_state = conn.early_data_state
,chosen_psk_null = conn.chosen_psk_null
,quic_enabled = conn.quic_enabled
,npn_negotiated = conn.npn_negotiated
}
(handshake' : handshake) = {handshake_type = conn.handshake.handshake_type
,message_number = conn.handshake.message_number
,state_machine = state_machine'
}
state_machine' = if protocol_version == S2N_TLS13
then S2N_STATE_MACHINE_TLS13
else S2N_STATE_MACHINE_TLS12

// This function models the update of the s2n_connection struct by the
// s2n_conn_set_handshake_type function in s2n. This only models the
Expand All @@ -36,6 +65,7 @@ conn_set_pre_tls13_handshake_type conn = conn'
}
(handshake' : handshake) = {handshake_type = handshake_type'
,message_number = conn.handshake.message_number
,state_machine = conn.handshake.state_machine
}
handshake_type' = NEGOTIATED || full_handshake || with_npn ||
(if (full_handshake != 0) then
Expand Down Expand Up @@ -79,6 +109,7 @@ conn_set_tls13_handshake_type conn = conn'
}
(handshake' : handshake) = {handshake_type = handshake_type'
,message_number = conn.handshake.message_number
,state_machine = conn.handshake.state_machine
}
handshake_type' = (conn.handshake.handshake_type && (HELLO_RETRY_REQUEST || MIDDLEBOX_COMPAT || EARLY_CLIENT_CCS))
|| NEGOTIATED || full_handshake || with_early_data || client_auth || middlebox_compat
Expand Down Expand Up @@ -116,7 +147,8 @@ advance_message conn = conn2
,npn_negotiated = conn.npn_negotiated
}
(handshake2 : handshake) = { handshake_type = conn.handshake.handshake_type,
message_number = message_number2 }
message_number = message_number2,
state_machine = conn.handshake.state_machine }
cork2 = if (ACTIVE_STATE conn2).writer != (ACTIVE_STATE conn).writer
then if (ACTIVE_STATE conn2).writer == mode_writer conn2.mode
then s2n_cork conn else s2n_uncork conn
Expand Down Expand Up @@ -147,6 +179,7 @@ conn_set_handshake_no_client_cert conn = conn2
}
(handshake' : handshake) = {handshake_type = handshake_type'
,message_number = conn.handshake.message_number
,state_machine = conn.handshake.state_machine
}
handshake_type' = if conn.client_auth_flag
then conn.handshake.handshake_type || NO_CLIENT_CERT
Expand Down Expand Up @@ -227,6 +260,7 @@ type connection = {handshake : handshake

type handshake = {handshake_type : [32]
,message_number : [32]
,state_machine : [32]
}

type S2N_HANDSHAKES_COUNT = 256
Expand All @@ -235,7 +269,7 @@ type S2N_MAX_HANDSHAKE_LENGTH = 32
// functions model the corresponding macros in C

IS_TLS13_HANDSHAKE : connection -> Bit
IS_TLS13_HANDSHAKE conn = conn.actual_protocol_version == S2N_TLS13
IS_TLS13_HANDSHAKE conn = conn.handshake.state_machine == S2N_STATE_MACHINE_TLS13

ACTIVE_STATE_MACHINE : connection -> [TOTAL_HANDSHAKE_ACTIONS]handshake_action
ACTIVE_STATE_MACHINE conn = if IS_TLS13_HANDSHAKE conn then tls13_state_machine else state_machine # zero
Expand Down Expand Up @@ -942,8 +976,19 @@ valid_handshake hs = handshakefn != zero /\
((hs.message_number > 0) ==> (handshakefn@(hs.message_number + 1) != 0)) where
handshakefn = (handshakes_fn hs.handshake_type)

// A predicate to tell whether a given connection struct is valid
valid_connection : connection -> Bit
valid_connection conn = valid_handshake (conn.handshake)
valid_connection conn = valid_handshake (conn.handshake) /\
// Either conn.handshake.state_machine is S2N_STATE_MACHINE_INITIAL
((conn.handshake.state_machine == S2N_STATE_MACHINE_INITIAL) \/
// Or conn.actual_protocol_version is less than or equal to S2N_TLS12
// and conn.hanshake.state_machine has been set to S2N_STATE_MACHINE_TLS12
(conn.actual_protocol_version <= S2N_TLS12 /\
conn.handshake.state_machine == S2N_STATE_MACHINE_TLS12) \/
// Or conn.actual_protocol_version is greater or equal to S2N_TLS13
// and conn.hanshake.state_machine has been set to S2N_STATE_MACHINE_TLS13
(conn.actual_protocol_version >= S2N_TLS13 /\
conn.handshake.state_machine == S2N_STATE_MACHINE_TLS13))

// Tells if the connection struct is in a valid initial state
initial_connection : connection -> Bit
Expand Down Expand Up @@ -1070,3 +1115,8 @@ S2N_CERT_AUTH_REQUIRED = 1

//S2N early data states
S2N_EARLY_DATA_ACCEPTED = 3

// s2n_state_machine version
S2N_STATE_MACHINE_INITIAL = 0
S2N_STATE_MACHINE_TLS12 = 1
S2N_STATE_MACHINE_TLS13 = 2
4 changes: 2 additions & 2 deletions tests/saw/verify_imperative_cryptol_spec.saw
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
//
////////////////////////////////////////////////////////////////

import "spec/HMAC_iterative.cry";
import "spec/HMAC_properties.cry";
import "HMAC/spec/HMAC_iterative.cry";
import "HMAC/spec/HMAC_properties.cry";

let check n = do {
print (str_concat "Checking 'hmac_c_state_correct' for byte count " (show n));
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/s2n_early_data_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ int main(int argc, char **argv)
EXPECT_NOT_NULL(conn);
EXPECT_SUCCESS(s2n_connection_set_cipher_preferences(conn, "default_tls13"));
conn->secure->cipher_suite = &s2n_tls13_aes_256_gcm_sha384;
EXPECT_OK(s2n_conn_choose_state_machine(conn, S2N_TLS13));
conn->actual_protocol_version = S2N_TLS13;
conn->early_data_state = S2N_EARLY_DATA_REQUESTED;

Expand Down Expand Up @@ -477,9 +478,14 @@ int main(int argc, char **argv)
conn->early_data_state = S2N_EARLY_DATA_REQUESTED;

conn->actual_protocol_version = S2N_TLS12;
EXPECT_OK(s2n_conn_choose_state_machine(conn, S2N_TLS12));
EXPECT_FALSE(s2n_early_data_is_valid_for_connection(conn));

/* Reset state machine */
conn->handshake.state_machine = S2N_STATE_MACHINE_INITIAL;

conn->actual_protocol_version = S2N_TLS13;
EXPECT_OK(s2n_conn_choose_state_machine(conn, S2N_TLS13));
EXPECT_TRUE(s2n_early_data_is_valid_for_connection(conn));

EXPECT_SUCCESS(s2n_connection_free(conn));
Expand All @@ -495,6 +501,7 @@ int main(int argc, char **argv)
EXPECT_NOT_NULL(conn);
EXPECT_SUCCESS(s2n_connection_set_cipher_preferences(conn, "default_tls13"));
EXPECT_OK(s2n_append_test_chosen_psk_with_early_data(conn, nonzero_max_early_data, &s2n_tls13_aes_256_gcm_sha384));
EXPECT_OK(s2n_conn_choose_state_machine(conn, S2N_TLS13));
conn->actual_protocol_version = S2N_TLS13;
conn->early_data_state = S2N_EARLY_DATA_REQUESTED;

Expand All @@ -519,6 +526,7 @@ int main(int argc, char **argv)
EXPECT_OK(s2n_append_test_chosen_psk_with_early_data(conn, nonzero_max_early_data, &s2n_tls13_aes_256_gcm_sha384));
conn->secure->cipher_suite = &s2n_tls13_aes_256_gcm_sha384;
conn->actual_protocol_version = S2N_TLS13;
EXPECT_OK(s2n_conn_choose_state_machine(conn, S2N_TLS13));
conn->early_data_state = S2N_EARLY_DATA_REQUESTED;

const uint8_t empty_protocol[] = "";
Expand Down Expand Up @@ -683,6 +691,7 @@ int main(int argc, char **argv)
EXPECT_OK(s2n_append_test_chosen_psk_with_early_data(conn, nonzero_max_early_data, &s2n_tls13_aes_256_gcm_sha384));
EXPECT_SUCCESS(s2n_connection_set_early_data_expected(conn));
conn->secure->cipher_suite = &s2n_tls13_aes_256_gcm_sha384;
EXPECT_OK(s2n_conn_choose_state_machine(conn, S2N_TLS13));
conn->actual_protocol_version = S2N_TLS13;
conn->early_data_state = S2N_EARLY_DATA_REQUESTED;

Expand Down Expand Up @@ -1250,6 +1259,7 @@ int main(int argc, char **argv)
valid_connection->early_data_state = S2N_EARLY_DATA_ACCEPTED;
valid_connection->early_data_bytes = 0;
valid_connection->actual_protocol_version = S2N_TLS13;
EXPECT_OK(s2n_conn_choose_state_machine(valid_connection, S2N_TLS13));
valid_connection->handshake.handshake_type = NEGOTIATED | WITH_EARLY_DATA;
while (s2n_conn_get_current_message_type(valid_connection) != END_OF_EARLY_DATA) {
valid_connection->handshake.message_number++;
Expand Down
1 change: 1 addition & 0 deletions tests/unit/s2n_handshake_io_async_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ int main(int argc, char **argv)
s2n_blocked_status blocked = S2N_NOT_BLOCKED;
struct s2n_connection *conn = s2n_connection_new(S2N_CLIENT);
EXPECT_SUCCESS(s2n_connection_set_io_stuffers(NULL, &io_buffer, conn));
EXPECT_OK(s2n_conn_choose_state_machine(conn, S2N_TLS13));

/* Consistently blocks */
async_blocked = true;
Expand Down
1 change: 1 addition & 0 deletions tests/unit/s2n_quic_support_io_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct s2n_stuffer input_stuffer, output_stuffer;
static S2N_RESULT s2n_setup_conn(struct s2n_connection *conn)
{
conn->actual_protocol_version = S2N_TLS13;
EXPECT_OK(s2n_conn_choose_state_machine(conn, S2N_TLS13));

RESULT_GUARD_POSIX(s2n_stuffer_wipe(&input_stuffer));
RESULT_GUARD_POSIX(s2n_stuffer_wipe(&output_stuffer));
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/s2n_server_extensions_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ int main(int argc, char **argv)
EXPECT_NOT_NULL(conn);
EXPECT_SUCCESS(s2n_connection_allow_all_response_extensions(conn));
conn->actual_protocol_version = S2N_TLS13;
EXPECT_OK(s2n_conn_choose_state_machine(conn, S2N_TLS13));

struct s2n_stuffer *io_stuffer = &conn->handshake.io;

/* Setup required for PSK extension */
Expand Down Expand Up @@ -621,6 +623,7 @@ int main(int argc, char **argv)

server_conn->actual_protocol_version = S2N_TLS13;
server_conn->server_protocol_version = S2N_TLS13;
EXPECT_OK(s2n_conn_choose_state_machine(server_conn, S2N_TLS13));
server_conn->psk_params.chosen_psk = &empty_psk;
server_conn->psk_params.chosen_psk_wire_index = test_wire_index;

Expand All @@ -641,6 +644,7 @@ int main(int argc, char **argv)
EXPECT_NOT_NULL(client_conn = s2n_connection_new(S2N_CLIENT));
EXPECT_SUCCESS(s2n_connection_allow_all_response_extensions(client_conn));
client_conn->actual_protocol_version = S2N_TLS13;
EXPECT_OK(s2n_conn_choose_state_machine(client_conn, S2N_TLS13));

EXPECT_SUCCESS(s2n_connection_mark_extension_received(client_conn, s2n_server_key_share_extension.iana_value));

Expand Down
4 changes: 4 additions & 0 deletions tests/unit/s2n_server_hello_retry_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ int main(int argc, char **argv)
EXPECT_SUCCESS(s2n_ecc_evp_generate_ephemeral_key(&server_conn->kex_params.client_ecc_evp_params));

EXPECT_SUCCESS(s2n_set_connection_hello_retry_flags(server_conn));
EXPECT_OK(s2n_conn_choose_state_machine(server_conn, S2N_TLS13));

/* The client will need a key share extension to properly parse the hello */
/* Total extension size + size of each extension */
Expand Down Expand Up @@ -237,6 +238,7 @@ int main(int argc, char **argv)
conn->kex_params.client_ecc_evp_params.negotiated_curve = s2n_all_supported_curves_list[0];
EXPECT_SUCCESS(s2n_ecc_evp_generate_ephemeral_key(&conn->kex_params.client_ecc_evp_params));

EXPECT_OK(s2n_conn_choose_state_machine(conn, S2N_TLS13));
EXPECT_SUCCESS(s2n_set_connection_hello_retry_flags(conn));

EXPECT_TRUE(s2n_is_hello_retry_message(conn));
Expand Down Expand Up @@ -303,6 +305,7 @@ int main(int argc, char **argv)
server_conn->kex_params.server_ecc_evp_params.negotiated_curve = s2n_all_supported_curves_list[0];
server_conn->kex_params.client_ecc_evp_params.negotiated_curve = s2n_all_supported_curves_list[0];
EXPECT_SUCCESS(s2n_set_connection_hello_retry_flags(server_conn));
EXPECT_OK(s2n_conn_choose_state_machine(server_conn, S2N_TLS13));
EXPECT_SUCCESS(s2n_ecc_evp_generate_ephemeral_key(&server_conn->kex_params.client_ecc_evp_params));
EXPECT_SUCCESS(s2n_extensions_server_key_share_send(server_conn, extension_stuffer));

Expand All @@ -319,6 +322,7 @@ int main(int argc, char **argv)

/* Setup the handshake type and message number to simulate a condition where a HelloRetry should be sent */
EXPECT_SUCCESS(s2n_set_connection_hello_retry_flags(client_conn));
EXPECT_OK(s2n_conn_choose_state_machine(client_conn, S2N_TLS13));
EXPECT_SUCCESS(s2n_set_hello_retry_required(client_conn));

/* Parse the key share */
Expand Down
Loading

0 comments on commit 0be885c

Please sign in to comment.