Created
April 10, 2014 15:45
-
-
Save doublec/10395713 to your computer and use it in GitHub Desktop.
ATS version of dtls1_process_heartbeat
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* A view for an array that contains: | |
byte = hbtype | |
ushort = payload length | |
byte[n] = bytes of length 'payload length' | |
byte[16]= padding | |
*) | |
dataview record_data_v (addr, int) = | |
| {l:addr} {n:nat | n > 16 + 2 + 1} make_record_data_v (l, n) of (ptr l, size_t n) | |
extern prfun free_record_data_v {l:addr} {n:nat} (pf: record_data_v (l, n)): void | |
fun get_record (s: SSLptr): [l:addr] [n:nat] (record_data_v (l, n) | ptr l, size_t n) = let | |
val len = get_record_length (s) | |
val data = get_record_data (s) | |
val () = assertloc (len > 16 + 2 + 1) | |
in | |
(make_record_data_v (data, len) | data, len) | |
end | |
(* These proof functions extract proofs out of the record_data_v | |
to allow access to the data stored in the record. The constants | |
for the size of the padding, payload buffer, etc are checked | |
within the proofs so that functions that manipulate memory | |
are checked that they remain within the correct bounds and | |
use the appropriate pointer values | |
*) | |
extern prfun extract_data_proof {l:addr} {n:nat} | |
(pf: record_data_v (l, n)): | |
(array_v (byte, l, n), | |
array_v (byte, l, n) -<lin,prf> record_data_v (l,n)) | |
extern prfun extract_hbtype_proof {l:addr} {n:nat} | |
(pf: record_data_v (l, n)): | |
(byte @ l, byte @ l -<lin,prf> record_data_v (l,n)) | |
extern prfun extract_payload_length_proof {l:addr} {n:nat} | |
(pf: record_data_v (l, n)): | |
(array_v (byte, l+1, 2), | |
array_v (byte, l+1, 2) -<lin,prf> record_data_v (l,n)) | |
extern prfun extract_payload_data_proof {l:addr} {n:nat} | |
(pf: record_data_v (l, n)): | |
(array_v (byte, l+1+2, n-16-2-1), | |
array_v (byte, l+1+2, n-16-2-1) -<lin,prf> record_data_v (l,n)) | |
extern prfun extract_padding_proof {l:addr} {n:nat} {n2:nat | n2 <= n - 16 - 2 - 1} | |
(pf: record_data_v (l, n), payload_length: size_t n2): | |
(array_v (byte, l + n2 + 1 + 2, 16), | |
array_v (byte, l + n2 + 1 + 2, 16) -<lin, prf> record_data_v (l, n)) | |
fun ats_dtls1_process_heartbeat(s: SSLptr): int = let | |
val padding = i2sz(PADDING) | |
val (pf_data | p_data, data_len) = get_record (s) | |
prval (pf, pff) = extract_hbtype_proof (pf_data) | |
val hbtype = $UN.cast2int (!p_data) | |
prval pf_data = pff (pf) | |
prval (pf, pff) = extract_payload_length_proof (pf_data) | |
val p = ptr_succ<byte> (p_data) | |
val payload_length = n2s (p) | |
prval pf_data = pff (pf) | |
val () = if (ptr_isnot_null (get_msg_callback (s))) then | |
call_msg_callback (get_msg_callback (s), | |
0, get_version (s), TLS1_RT_HEARTBEAT, | |
p_data, data_len, s, | |
get_msg_callback_arg (s)) | |
in | |
if hbtype = TLS1_HB_REQUEST then let | |
val () = assertloc (payload_length > 0) | |
val n = payload_length + padding + 1 + 2 | |
val (pf_buffer | p_buffer) = OPENSSL_malloc(n) | |
prval pf_response = make_record_data_v (p_buffer, n) | |
prval (pf, pff) = extract_hbtype_proof (pf_response) | |
val () = !p_buffer := cast2byte(TLS1_HB_RESPONSE) | |
prval pf_response = pff(pf) | |
prval (pf, pff) = extract_payload_length_proof (pf_response) | |
val p = add_ptr1_bsz (p_buffer, i2sz 1) | |
val () = s2n (pf | payload_length, p) | |
prval pf_response = pff(pf) | |
(* Won't compile without these assertions *) | |
val () = assertloc (data_len >= payload_length + padding + 1 + 2) | |
prval (pf_dst, pff_dst) = extract_payload_data_proof (pf_response) | |
prval (pf_src, pff_src) = extract_payload_data_proof (pf_data) | |
val () = safe_memcpy (pf_dst, pf_src | add_ptr1_bsz (p_buffer, i2sz 3), add_ptr1_bsz (p_data, i2sz 3), payload_length) | |
prval pf_response = pff_dst(pf_dst) | |
prval pf_data = pff_src(pf_src) | |
prval (pf, pff) = extract_padding_proof (pf_response, payload_length) | |
val () = RAND_pseudo_bytes (pf | add_ptr_bsz (p_buffer, payload_length + 1 + 2), padding) | |
prval pf_response = pff(pf) | |
prval (pf, pff) = extract_data_proof (pf_response) | |
val r = dtls1_write_bytes (pf | s, TLS1_RT_HEARTBEAT, p_buffer, n) | |
prval pf_response = pff(pf) | |
val () = if r >=0 && ptr_isnot_null (get_msg_callback (s)) then | |
call_msg_callback (get_msg_callback (s), | |
1, get_version (s), TLS1_RT_HEARTBEAT, | |
p_buffer, n, s, | |
get_msg_callback_arg (s)) | |
prval () = free_record_data_v (pf_data) | |
prval () = free_record_data_v (pf_response) | |
val () = OPENSSL_free (pf_buffer | p_buffer) | |
in | |
if r < 0 then r else 0 | |
end else if hbtype = TLS1_HB_RESPONSE then let | |
prval (pf, pff) = extract_payload_data_proof (pf_data) | |
val seq = n2s (add_ptr1_bsz (p_data, i2sz 3)) | |
prval pf_data = pff (pf) | |
prval () = free_record_data_v (pf_data) | |
in | |
if $UN.cast2int(payload_length) = 18 && $UN.cast2int(seq) = $UN.cast2int(get_tlsext_hb_seq (s)) then let | |
val () = dtls1_stop_timer (s) | |
val () = increment_tlsext_hb_seq (s) | |
val () = set_tlsext_hb_pending (s, $UN.cast2uint(0)) | |
in 0 | |
end else 0 | |
end else let | |
prval () = free_record_data_v (pf_data) | |
in | |
0 | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment