Skip to content

Instantly share code, notes, and snippets.

@doublec
Created April 10, 2014 15:45
Show Gist options
  • Save doublec/10395713 to your computer and use it in GitHub Desktop.
Save doublec/10395713 to your computer and use it in GitHub Desktop.
ATS version of dtls1_process_heartbeat
(* A view for an array that contains:
byte = hbtype
ushort = payload length
byte[n] = bytes of length 'payload length'
byte[16]= padding
*)
dataview record_data_v (addr, int) =
| {l:addr} {n:nat | n > 16 + 2 + 1} make_record_data_v (l, n) of (ptr l, size_t n)
extern prfun free_record_data_v {l:addr} {n:nat} (pf: record_data_v (l, n)): void
fun get_record (s: SSLptr): [l:addr] [n:nat] (record_data_v (l, n) | ptr l, size_t n) = let
val len = get_record_length (s)
val data = get_record_data (s)
val () = assertloc (len > 16 + 2 + 1)
in
(make_record_data_v (data, len) | data, len)
end
(* These proof functions extract proofs out of the record_data_v
to allow access to the data stored in the record. The constants
for the size of the padding, payload buffer, etc are checked
within the proofs so that functions that manipulate memory
are checked that they remain within the correct bounds and
use the appropriate pointer values
*)
extern prfun extract_data_proof {l:addr} {n:nat}
(pf: record_data_v (l, n)):
(array_v (byte, l, n),
array_v (byte, l, n) -<lin,prf> record_data_v (l,n))
extern prfun extract_hbtype_proof {l:addr} {n:nat}
(pf: record_data_v (l, n)):
(byte @ l, byte @ l -<lin,prf> record_data_v (l,n))
extern prfun extract_payload_length_proof {l:addr} {n:nat}
(pf: record_data_v (l, n)):
(array_v (byte, l+1, 2),
array_v (byte, l+1, 2) -<lin,prf> record_data_v (l,n))
extern prfun extract_payload_data_proof {l:addr} {n:nat}
(pf: record_data_v (l, n)):
(array_v (byte, l+1+2, n-16-2-1),
array_v (byte, l+1+2, n-16-2-1) -<lin,prf> record_data_v (l,n))
extern prfun extract_padding_proof {l:addr} {n:nat} {n2:nat | n2 <= n - 16 - 2 - 1}
(pf: record_data_v (l, n), payload_length: size_t n2):
(array_v (byte, l + n2 + 1 + 2, 16),
array_v (byte, l + n2 + 1 + 2, 16) -<lin, prf> record_data_v (l, n))
fun ats_dtls1_process_heartbeat(s: SSLptr): int = let
val padding = i2sz(PADDING)
val (pf_data | p_data, data_len) = get_record (s)
prval (pf, pff) = extract_hbtype_proof (pf_data)
val hbtype = $UN.cast2int (!p_data)
prval pf_data = pff (pf)
prval (pf, pff) = extract_payload_length_proof (pf_data)
val p = ptr_succ<byte> (p_data)
val payload_length = n2s (p)
prval pf_data = pff (pf)
val () = if (ptr_isnot_null (get_msg_callback (s))) then
call_msg_callback (get_msg_callback (s),
0, get_version (s), TLS1_RT_HEARTBEAT,
p_data, data_len, s,
get_msg_callback_arg (s))
in
if hbtype = TLS1_HB_REQUEST then let
val () = assertloc (payload_length > 0)
val n = payload_length + padding + 1 + 2
val (pf_buffer | p_buffer) = OPENSSL_malloc(n)
prval pf_response = make_record_data_v (p_buffer, n)
prval (pf, pff) = extract_hbtype_proof (pf_response)
val () = !p_buffer := cast2byte(TLS1_HB_RESPONSE)
prval pf_response = pff(pf)
prval (pf, pff) = extract_payload_length_proof (pf_response)
val p = add_ptr1_bsz (p_buffer, i2sz 1)
val () = s2n (pf | payload_length, p)
prval pf_response = pff(pf)
(* Won't compile without these assertions *)
val () = assertloc (data_len >= payload_length + padding + 1 + 2)
prval (pf_dst, pff_dst) = extract_payload_data_proof (pf_response)
prval (pf_src, pff_src) = extract_payload_data_proof (pf_data)
val () = safe_memcpy (pf_dst, pf_src | add_ptr1_bsz (p_buffer, i2sz 3), add_ptr1_bsz (p_data, i2sz 3), payload_length)
prval pf_response = pff_dst(pf_dst)
prval pf_data = pff_src(pf_src)
prval (pf, pff) = extract_padding_proof (pf_response, payload_length)
val () = RAND_pseudo_bytes (pf | add_ptr_bsz (p_buffer, payload_length + 1 + 2), padding)
prval pf_response = pff(pf)
prval (pf, pff) = extract_data_proof (pf_response)
val r = dtls1_write_bytes (pf | s, TLS1_RT_HEARTBEAT, p_buffer, n)
prval pf_response = pff(pf)
val () = if r >=0 && ptr_isnot_null (get_msg_callback (s)) then
call_msg_callback (get_msg_callback (s),
1, get_version (s), TLS1_RT_HEARTBEAT,
p_buffer, n, s,
get_msg_callback_arg (s))
prval () = free_record_data_v (pf_data)
prval () = free_record_data_v (pf_response)
val () = OPENSSL_free (pf_buffer | p_buffer)
in
if r < 0 then r else 0
end else if hbtype = TLS1_HB_RESPONSE then let
prval (pf, pff) = extract_payload_data_proof (pf_data)
val seq = n2s (add_ptr1_bsz (p_data, i2sz 3))
prval pf_data = pff (pf)
prval () = free_record_data_v (pf_data)
in
if $UN.cast2int(payload_length) = 18 && $UN.cast2int(seq) = $UN.cast2int(get_tlsext_hb_seq (s)) then let
val () = dtls1_stop_timer (s)
val () = increment_tlsext_hb_seq (s)
val () = set_tlsext_hb_pending (s, $UN.cast2uint(0))
in 0
end else 0
end else let
prval () = free_record_data_v (pf_data)
in
0
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment