Skip to content

Instantly share code, notes, and snippets.

@wolfspider
Created October 31, 2024 15:07
Show Gist options
  • Save wolfspider/c6486b30e7a74beb87188c61ffc5f9e8 to your computer and use it in GitHub Desktop.
Save wolfspider/c6486b30e7a74beb87188c61ffc5f9e8 to your computer and use it in GitHub Desktop.
Formal Methods Ring Buffers
#include <stdio.h>
#include <stdint.h>
typedef struct t__int32_t_s
{
int32_t *b;
uint32_t *first;
uint32_t *length;
uint32_t total_length;
}
t__int32_t;
uint32_t next(uint32_t i, uint32_t total_length)
{
if (i == total_length - 1U)
return 0U;
else
return i + 1U;
}
uint32_t prev(uint32_t i, uint32_t total_length)
{
if (i > 0U)
return i - 1U;
else
return total_length - 1U;
}
uint32_t one_past_last(uint32_t i, uint32_t length, uint32_t total_length)
{
if (length == total_length)
return i;
else if (i >= total_length - length)
return length - (total_length - i);
else
return i + length;
}
// Updated push to check for a full buffer using one_past_last
void push__int32_t(t__int32_t x, int32_t e)
{
// Calculate the one past last index
uint32_t one_past_last_index = one_past_last(*x.first, *x.length, x.total_length);
if (*x.length < x.total_length) { // Not full, proceed normally
uint32_t dest_slot = prev(*x.first, x.total_length);
x.b[dest_slot] = e;
*x.first = dest_slot;
*x.length = *x.length + 1U;
} else { // Buffer is full, overwrite the oldest element
x.b[one_past_last_index] = e;
*x.first = next(*x.first, x.total_length);
}
}
int32_t pop__int32_t(t__int32_t x)
{
int32_t e = x.b[*x.first];
*x.first = next(*x.first, x.total_length);
*x.length = *x.length - 1U;
return e;
}
int32_t main(void)
{
int32_t b[3U];
for (uint32_t _i = 0U; _i < 3U; ++_i)
b[_i] = (int32_t)1;
uint32_t buf0 = 0U;
uint32_t buf = 0U;
t__int32_t rb = { .b = b, .first = &buf0, .length = &buf, .total_length = 3U };
push__int32_t(rb, (int32_t)10);
push__int32_t(rb, (int32_t)20);
push__int32_t(rb, (int32_t)30);
push__int32_t(rb, (int32_t)40); // Overwrites oldest element
int32_t r = pop__int32_t(rb);
printf("out: %d\n", r);
return r;
}
@wolfspider
Copy link
Author

This ring buffer was updated to reflect the correct behavior of what we would want to see. The original happens to be:

uint32_t next(uint32_t i, uint32_t total_length)
{
  if (i == total_length - 1U)
    return 0U;
  else
    return i + 1U;
}

uint32_t prev(uint32_t i, uint32_t total_length)
{
  if (i > 0U)
    return i - 1U;
  else
    return total_length - 1U;
}

uint32_t one_past_last(uint32_t i, uint32_t length, uint32_t total_length)
{
  if (length == total_length)
    return i;
  else if (i >= total_length - length)
    return length - (total_length - i);
  else
    return i + length;
}

void push__int32_t(t__int32_t x, int32_t e)
{
  uint32_t dest_slot = prev(*x.first, x.total_length);
  x.b[dest_slot] = e;
  *x.first = dest_slot;
  *x.length = *x.length + 1U;
}

int32_t pop__int32_t(t__int32_t x)
{
  int32_t e = x.b[*x.first];
  *x.first = next(*x.first, x.total_length);
  *x.length = *x.length - 1U;
  return e;
}

int32_t main(void)
{
  int32_t b[32U];
  for (uint32_t _i = 0U; _i < 32U; ++_i)
    b[_i] = (int32_t)1;
  uint32_t buf0 = 0U;
  uint32_t buf = 0U;
  t__int32_t rb = { .b = b, .first = &buf0, .length = &buf, .total_length = 32U };
  push__int32_t(rb, (int32_t)0);
  int32_t r = pop__int32_t(rb);
  return r;
}

So we may be able to prove the behavior inherently or not that is the question.

@wolfspider
Copy link
Author

Well this is somewhat humorous in trying to be clever I attempted some new branching logic which fills the ring buffer in reverse but at least it is doing something when it is full preventing an overflow:

/// ``push`` is slightly more involved and crucially relies on the lemma above.
let push (#a: eqtype) (x: t a) (e: a): Stack unit
  (requires fun h ->
    well_formed h x /\ space_left h x )
  (ensures fun h0 _ h1 ->
    well_formed h1 x /\
    U32.(remaining_space h1 x =^ remaining_space h0 x -^ 1ul) /\
    M.(modifies (loc_union
      (loc_buffer x.length)
        (loc_union (loc_buffer x.first) (loc_buffer x.b))) h0 h1) /\
    as_list h1 x = e :: as_list h0 x)
=
  let dest_slot = prev !*x.first x.total_length in
  let h0 = ST.get () in
  //let o = one_past_last !*x.first !*x.length x.total_length in
  let open U32 in
  if !*x.length <^ x.total_length then
    x.b.(dest_slot) <- e;
    seq_update_unused_preserves_list (B.as_seq h0 x.b) dest_slot e
    (deref h0 x.first) (deref h0 x.length) x.total_length;
    x.first *= dest_slot;
    x.length *= U32.(!*x.length +^ 1ul)

let push_end (#a: eqtype) (x: t a) (e: a): Stack unit
(requires fun h ->
  well_formed h x /\ space_left h x )
(ensures fun h0 _ h1 ->
  well_formed h1 x)
=
  let h0 = ST.get () in
  let o = one_past_last !*x.first !*x.length x.total_length in
  let open U32 in
  if !*x.length >=^ x.total_length then
    x.b.(o) <- e;
    x.first *= next !*x.first x.total_length;
    x.length *= U32.(!*x.length)


let push_cont (#a: eqtype) (x: t a) (e: a): Stack unit
(requires fun h ->
 well_formed h x /\ space_left h x )
(ensures fun h0 _ h1 ->
 well_formed h1 x /\
 U32.(remaining_space h1 x =^ remaining_space h0 x -^ 1ul) /\
 M.(modifies (loc_union
   (loc_buffer x.length)
   (loc_union (loc_buffer x.first) (loc_buffer x.b))) h0 h1) /\
 as_list h1 x = e :: as_list h0 x)
=
let open U32 in
if !*x.length <^ x.total_length then
 push x e
else
 push_end x e

@wolfspider
Copy link
Author

wolfspider commented Nov 2, 2024

I think what happened again is I was looking at this so hard I started to go cross-eyed and didn't realize that the correct behavior was being exhibited! After cleaning the code up and adding some Lemmas back in we have this as the update:

let push (#a: eqtype) (x: t a) (e: a): Stack unit
  (requires fun h ->
    well_formed h x /\ space_left h x )
  (ensures fun h0 _ h1 ->
    well_formed h1 x /\
    U32.(remaining_space h1 x =^ remaining_space h0 x -^ 1ul) /\
    M.(modifies (loc_union
      (loc_buffer x.length)
        (loc_union (loc_buffer x.first) (loc_buffer x.b))) h0 h1) /\
    as_list h1 x = e :: as_list h0 x)
=
  let dest_slot = prev !*x.first x.total_length in
  let h0 = ST.get () in
  x.b.(dest_slot) <- e;
  seq_update_unused_preserves_list (B.as_seq h0 x.b) dest_slot e
    (deref h0 x.first) (deref h0 x.length) x.total_length;
  x.first *= dest_slot;
  x.length *= U32.(!*x.length +^ 1ul)

let push_end (#a: eqtype) (x: t a) (e: a): Stack unit
(requires fun h ->
  well_formed h x /\ space_left h x )
(ensures fun h0 _ h1 ->
  well_formed h1 x /\
    U32.(remaining_space h1 x =^ remaining_space h0 x) /\
    M.(modifies (loc_union
      (loc_buffer x.length)
        (loc_union (loc_buffer x.first) (loc_buffer x.b))) h0 h1))
=
  let h0 = ST.get () in
  let o = one_past_last !*x.first !*x.length x.total_length in
  x.b.(o) <- e;
  x.first *= next !*x.first x.total_length


let push_cont (#a: eqtype) (x: t a) (e: a): Stack unit
(requires fun h ->
 well_formed h x /\ space_left h x )
(ensures fun h0 _ h1 ->
 well_formed h1 x /\
 U32.(remaining_space h1 x =^ remaining_space h0 x -^ 1ul) /\
 M.(modifies (loc_union
   (loc_buffer x.length)
   (loc_union (loc_buffer x.first) (loc_buffer x.b))) h0 h1) /\
 as_list h1 x = e :: as_list h0 x)
=
let open U32 in
if !*x.length <^ x.total_length then
 push x e
else
 push_end x e

Which can be tested with the following C code:

#include <stdio.h> 
#include <stdint.h>

typedef struct t__int32_t_s
{
  int32_t *b;
  uint32_t *first;
  uint32_t *length;
  uint32_t total_length;
}
t__int32_t;

uint32_t next(uint32_t i, uint32_t total_length)
{
  if (i == total_length - 1U)
    return 0U;
  else
    return i + 1U;
}

uint32_t prev(uint32_t i, uint32_t total_length)
{
  if (i > 0U)
    return i - 1U;
  else
    return total_length - 1U;
}

uint32_t one_past_last(uint32_t i, uint32_t length, uint32_t total_length)
{
  if (length == total_length)
    return i;
  else if (i >= total_length - length)
    return length - (total_length - i);
  else
    return i + length;
}

void push__int32_t(t__int32_t x, int32_t e)
{
  uint32_t dest_slot = prev(*x.first, x.total_length);
  x.b[dest_slot] = e;
  *x.first = dest_slot;
  *x.length = *x.length + 1U;
}

void push_end__int32_t(t__int32_t x, int32_t e)
{
  uint32_t o = one_past_last(*x.first, *x.length, x.total_length);
  x.b[o] = e;
  *x.first = next(*x.first, x.total_length);
}

void push_cont__int32_t(t__int32_t x, int32_t e)
{
  if (*x.length < x.total_length)
    push__int32_t(x, e);
  else
    push_end__int32_t(x, e);
}

int32_t pop__int32_t(t__int32_t x)
{
  int32_t e = x.b[*x.first];
  *x.first = next(*x.first, x.total_length);
  *x.length = *x.length - 1U;
  return e;
}


int32_t main(void)
{
  int32_t b[3U];
  for (uint32_t _i = 0U; _i < 3U; ++_i)
    b[_i] = (int32_t)1;
  uint32_t buf0 = 0U;
  uint32_t buf = 0U;
  t__int32_t rb = { .b = b, .first = &buf0, .length = &buf, .total_length = 3U };
  push_cont__int32_t(rb, (int32_t)10);
  push_cont__int32_t(rb, (int32_t)20);
  push_cont__int32_t(rb, (int32_t)30);
  
  for (uint32_t _i = 0U; _i < 3U; ++_i)
    printf("%d\n", b[_i]);
  push_cont__int32_t(rb, (int32_t)40);
  printf("pushed 40\n");
  
  
  for (uint32_t _i = 0U; _i < 3U; ++_i)
    printf("%d\n", b[_i]);
  push_cont__int32_t(rb, (int32_t)40);
  printf("pushed 40\n");
  
  for (uint32_t _i = 0U; _i < 3U; ++_i)
    printf("%d\n", b[_i]);
  
  push_cont__int32_t(rb, (int32_t)40);
  printf("pushed 40\n");
  
  for (uint32_t _i = 0U; _i < 3U; ++_i)
    printf("%d\n", b[_i]);
    push_cont__int32_t(rb, (int32_t)40);
  
  //push_cont__int32_t(rb, (int32_t)40);
  int32_t r = pop__int32_t(rb);
  printf("out: %d\n", r);
  return r;
}

@wolfspider
Copy link
Author

The output looks something like this which does present the wraparound behavior:

30
20
10
pushed 40
40
20
10
pushed 40
40
40
10
pushed 40
40
40
40
out: 40

The interesting thing is working with the FStar plugin and Low* I did get warning messages about memory handling and ownership and so on. It did remind me of Rust which is incredible because it is not actually doing anything like Rust analyzer but reasoning about memory with Z3 instead. This allows me to keep all of that extra stuff out when things are performance critical. In the previous gists I have shown Rust does not just automatically run faster than C- maybe C++ most times but C is a different story altogether.

@wolfspider
Copy link
Author

The next issue is that we can prove filling the list in reverse more easily than filling it forward. If we can fill it forward then we have something which looks like most implementations. So far, filling the list forward and keeping it well formed is much more difficult in practice than I had originally expected. The wraparound behavior is there but in this case I cannot change all the lower level code to run things backwards and forward. It is more practical to decrement a pointer than increment it which is why the ring buffer implementation seems to do this as-is.

@wolfspider
Copy link
Author

Rather than reverse the whole thing I think this makes more sense- use it as is knowing that push__cont has the correct behavior after testing more:

Program returned: 40
30
20
10
pushed 40
40
20
10
pushed 40
40
40
10
pushed 40
40
40
40
pushed 50
50
40
40
out: 40

What this means is that values should be pushed into the ring initially at it's size- at this point it is guaranteed to be "safe" to continuously push and get the right behavior- which means that normally when filling it up with data we will push for the length of the ring and push__cont should then sync.

@wolfspider
Copy link
Author

wolfspider commented Nov 16, 2024

Putting it all together- what does a safe ringbuffer do? No matter where you push nothing is invalid. It is common to initialize a ring buffer with 0s and our provable code works by also taking care of this. The following comparison is in Python:

class RingBuffer:
    def __init__(self, total_length):
        self.b = [None] * total_length  # Buffer to hold data
        self.first = 0                  # Points to the first element
        self.length = 0                 # Current number of elements
        self.total_length = total_length

    def next(self, i):
        """Get the next index in a circular manner."""
        if (i == self.total_length - 1):
            return 0
        else:
            return i + 1

    def prev(self, i):
        """Get the previous index in a circular manner."""
        if (i > 0):
            return i - 1;
        else:
            return self.total_length - 1;

    def one_past_last(self):
        """Get the index one past the last element."""
        if self.length == self.total_length:
            return self.first
        elif self.first >= self.total_length - self.length:
            return self.length - (self.total_length - self.first)
        else:
            return self.first + self.length

    def push(self, e):
        """Push element `e` to the start of the ring buffer."""
        dest_slot = self.prev(self.first)
        self.b[dest_slot] = e
        self.first = dest_slot
        self.length = min(self.length + 1, self.total_length)

    def push_end(self, e):
        """Push element `e` to the end of the ring buffer."""
        o = self.one_past_last()
        self.b[o] = e
        self.first = self.next(self.first)

    def push_cont(self, e):
        """Push element `e` with wraparound."""
        if self.length < self.total_length:
            self.push(e)
        else:
            self.push_end(e)

    def pop(self):
        """Pop an element from the start of the ring buffer."""
        if self.length == 0:
            raise IndexError("Pop from empty buffer")
        e = self.b[self.first]
        self.first = self.next(self.first)
        self.length -= 1
        return e

    def __repr__(self):
        """String representation for debugging."""
        return f"RingBuffer({self.b}, first={self.first}, length={self.length})"
    
class SimpleRing:
    def __init__(self, num_slots):
        self.num_slots = num_slots
        self.head = 0
        self.tail = 0
        self.cur = 0
        self.buffer = [None] * num_slots

    def insert(self, value):
        """Insert a single value at the current position and advance the head."""
        # Calculate available space
        n = self.tail - self.cur
        if n < 0:  # Handle wraparound
            n += self.num_slots

        # Insert one element, respecting wraparound
        self.buffer[self.cur] = value
        self.cur += 1
        if self.cur >= self.num_slots:
            self.cur -= self.num_slots

        # Update head to match cur
        self.head = self.cur

    def __repr__(self):
        return (f"SimpleRing(buffer={self.buffer}, head={self.head}, "
                f"tail={self.tail}, cur={self.cur})")
    

# Initialize both implementations with the same size
sz = 5
formally_proven_buf = RingBuffer(sz)
simple_ring = SimpleRing(sz)

# Test data
test_values = [10, 20, 30, 40, 50, 60, 70]
init_values = [0, 0, 0, 0, 0]

# Insert elements into both buffers
print("Formally Proven Implementation:")
for i in range(sz):
    formally_proven_buf.push_cont(0)
for val in test_values:
    formally_proven_buf.push_cont(val)
    print(f"Inserted {val}: {formally_proven_buf}")

print(f"Pop: {formally_proven_buf.pop()}")
print(f"Pop: {formally_proven_buf.pop()}")
print(f"Pop: {formally_proven_buf.pop()}")
print(f"Pop: {formally_proven_buf.pop()}")
print(f"Pop: {formally_proven_buf.pop()}")

print("\nSimple Increment/Wraparound Implementation:")
for val in test_values:
    simple_ring.insert(val)
    print(f"Inserted {val}: {simple_ring}")

The results of this are:

Formally Proven Implementation:
Inserted 10: RingBuffer([10, 0, 0, 0, 0], first=1, length=5)
Inserted 20: RingBuffer([10, 20, 0, 0, 0], first=2, length=5)
Inserted 30: RingBuffer([10, 20, 30, 0, 0], first=3, length=5)
Inserted 40: RingBuffer([10, 20, 30, 40, 0], first=4, length=5)
Inserted 50: RingBuffer([10, 20, 30, 40, 50], first=0, length=5)
Inserted 60: RingBuffer([60, 20, 30, 40, 50], first=1, length=5)
Inserted 70: RingBuffer([60, 70, 30, 40, 50], first=2, length=5)
Pop: 30
Pop: 40
Pop: 50
Pop: 60
Pop: 70

Simple Increment/Wraparound Implementation:
Inserted 10: SimpleRing(buffer=[10, None, None, None, None], head=1, tail=0, cur=1)
Inserted 20: SimpleRing(buffer=[10, 20, None, None, None], head=2, tail=0, cur=2)
Inserted 30: SimpleRing(buffer=[10, 20, 30, None, None], head=3, tail=0, cur=3)
Inserted 40: SimpleRing(buffer=[10, 20, 30, 40, None], head=4, tail=0, cur=4)
Inserted 50: SimpleRing(buffer=[10, 20, 30, 40, 50], head=0, tail=0, cur=0)
Inserted 60: SimpleRing(buffer=[60, 20, 30, 40, 50], head=1, tail=0, cur=1)
Inserted 70: SimpleRing(buffer=[60, 70, 30, 40, 50], head=2, tail=0, cur=2)

@wolfspider
Copy link
Author

Interesting, nonetheless when putting the rubber to the pavement this is better actually for organizing the prospective API. What I've done is updated netmap python bindings elsewhere and have modeled the ringbuffer. This was done before implementing the full program with assistance from AI and the python black formatter (in preview). What real significance could this have?

import struct
import time
import select
import argparse
import netmap  # Assuming this is your netmap wrapper module


# Function to build a packet (like in your original code)
def build_packet():
    fmt = "!6s6sH" + "46s"  # Ethernet frame format
    return struct.pack(
        fmt,
        b"\xff" * 6,  # Destination MAC address
        b"\x00" * 6,  # Source MAC address
        0x0800,  # EtherType (IPv4)
        b"\x00" * 50,
    )  # Payload (50 bytes)


class RingBuffer:
    def __init__(self, txr, num_slots):
        """Initialize the RingBuffer with the transmit ring and buffer length."""
        self.txr = txr
        self.num_slots = num_slots
        self.cur = txr.cur  # Current position in the ring
        self.tail = txr.tail  # Tail pointer (the oldest slot available for use)
        self.head = txr.head
        self.cnt = 0
        self.batch = 256

    def front_load(self, pkt):
        """Pre-fill the buffer with packets ahead of time."""
        for i in range(self.num_slots):
            self.txr.slots[i].buf[0 : len(pkt)] = pkt  # Fill the buffer with the packet
            self.txr.slots[i].len = len(pkt)  # Set the length of the packet in the slot

    def space_left(self):
        """Calculate available space in the ring buffer."""
        n = (
            self.tail - self.cur
            if self.tail >= self.cur
            else self.num_slots - (self.cur - self.tail)
        )
        spcn = self.num_slots - n
        if spcn < 0:
            spcn += self.num_slots
        if spcn > self.batch:
            spcn = self.batch
        self.cur += spcn
        if self.cur >= self.num_slots:
            self.cur -= self.num_slots
        return spcn

    def push(self):
        """Push an element (packet) into the ring buffer."""
        self.cur = (self.cur + 1) % self.num_slots  # Update the current position
        self.txr.cur = self.txr.head = (
            self.cur
        )  # Update the txr pointers (cur and head)

    def pop(self):
        """Pop the next available slot from the ring buffer."""
        tl = self.tail
        self.tail = (
            self.tail + 1
        ) % self.num_slots  # Circular increment of the tail pointer
        self.txr.tail = self.tail  # Update the txr's tail pointer
        return tl  # Return the slot index that was popped

    def sync(self, nm):
        """Sync the Netmap transmit ring to transmit the packets."""
        nm.txsync()  # Perform a synchronization operation to send out the packets


def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(
        description="High-performance packet generator using netmap API",
        epilog="Press Ctrl-C to stop",
    )
    parser.add_argument(
        "-i", "--interface", help="Interface to register with netmap", default="vale0:0"
    )
    args = parser.parse_args()

    # Build the packet
    pkt = build_packet()

    print(f"Opening interface {args.interface}")

    # Open the netmap device and register the interface
    nm = netmap.Netmap()
    nm.open()
    nfd = nm.getfd()
    nm.if_name = args.interface
    nm.register()
    time.sleep(1)

    # Get the first transmit ring and fill in the buffers
    txr = nm.transmit_rings[0]
    num_slots = txr.num_slots

    # Initialize the RingBuffer
    ring_buffer = RingBuffer(txr, num_slots)
    # Pre-fill the ring buffer ahead of time with packets
    ring_buffer.front_load(pkt)

    print("Starting transmission, press Ctrl-C to stop")

    # Initialize variables
    cnt = 0  # Packet counter
    poller = select.poll()
    poller.register(nfd, select.POLLOUT)  # Monitor for available transmit slots
    t_start = time.time()
    try:
        while True:
            ready_list = poller.poll(2)  # Wait for available slots to transmit
            if len(ready_list) == 0:
                print("Timeout occurred")
                break

            # Check how many slots are available in the ring buffer
            n = ring_buffer.space_left()

            ring_buffer.push()
            # Sync the transmit ring to send the packets
            ring_buffer.sync(nm)

            cnt += n  # Update the packet counter
    except KeyboardInterrupt:
        pass

    # Calculate transmission rate
    t_end = time.time()
    rate = 0.001 * cnt / (t_end - t_start)
    unit = "K"
    if rate > 1000:
        rate /= 1000.0
        unit = "M"

    print(f"\nPackets sent: {cnt}, Avg rate {rate:6.3f} {unit}pps")

    # Close the netmap interface
    nm.close()


if __name__ == "__main__":
    main()

Performance has taken a hit but this will all be rewritten in a lower level language anyways so just a rough draft. The output:

python git:(master) ✗ python3 tx2.py
Opening interface vale0:0
Starting transmission, press Ctrl-C to stop
^C
Packets sent: 130473419, Avg rate 20.103 Mpps

On the receiving end:

Waiting for a packet to come
Received a packet with len 60
ffffffffffff000000000000080000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000

The original code does 30 Mpps but is completely procedural so I'm fine taking the hit:

try:
    cur = txr.cur
    while 1:
        ready_list = poller.poll(2)
        if len(ready_list) == 0:
            print("Timeout occurred")
            break;
        n = txr.tail - cur  # avail
        if n < 0:
            n += num_slots
        if n > batch:
            n = batch
        cur += n
        if cur >= num_slots:
            cur -= num_slots
        txr.cur = txr.head = cur # lazy update txr.cur and txr.head
        nm.txsync()
        cnt += n

So a little ho-hum in terms of making a massive improvement but an A+ for organization and generating new ideas to work with. This is talking to C and if you actually ran the code there are some python-isms which would make anyone wonder why would you feature that in the code!?!

@wolfspider
Copy link
Author

So anyhow, run it through the wash again and- now we have something better.

import struct
import time
import select
import argparse
import netmap
from array import array
from typing import Optional


def build_packet() -> bytes:
    """Build a packet with pre-calculated values for better performance."""
    # Pre-calculate the packet once and reuse
    fmt = "!6s6sH46s"
    return struct.pack(
        fmt,
        b"\xff" * 6,  # Destination MAC
        b"\x00" * 6,  # Source MAC
        0x0800,  # EtherType (IPv4)
        b"\x00" * 46,  # Payload
    )


class RingBuffer:
    __slots__ = (
        "txr",
        "num_slots",
        "cur",
        "tail",
        "head",
        "cnt",
        "batch",
        "_batch_mask",
    )

    def __init__(self, txr, num_slots: int):
        """Initialize the RingBuffer with optimized attributes."""
        self.txr = txr
        self.num_slots = num_slots
        self.cur = txr.cur
        self.tail = txr.tail
        self.head = txr.head
        self.cnt = 0
        # Make batch size a power of 2 for faster modulo operations
        self.batch = 256
        self._batch_mask = self.batch - 1

    def front_load(self, pkt: bytes) -> None:
        """Pre-fill the buffer using memoryview for efficient memory operations."""
        pkt_view = memoryview(pkt)
        pkt_len = len(pkt)

        # Pre-fill all slots at once
        for slot in self.txr.slots[: self.num_slots]:
            slot.buf[0:pkt_len] = pkt_view
            slot.len = pkt_len

    def space_left(self) -> int:
        """Calculate available space using bitwise operations for better performance."""
        n = (
            (self.tail - self.cur)
            if self.tail >= self.cur
            else self.num_slots - (self.cur - self.tail)
        )
        spcn = min(self.num_slots - n, self.batch)

        # Use bitwise AND for faster modulo
        self.cur = (self.cur + spcn) & (self.num_slots - 1)
        return spcn

    def push(self) -> None:
        """Push an element using bitwise operations."""
        # Use bitwise AND for faster modulo
        self.cur = (self.cur + 1) & (self.num_slots - 1)
        self.txr.cur = self.txr.head = self.cur

    def pop(self) -> int:
        """Pop an element using bitwise operations."""
        tl = self.tail
        # Use bitwise AND for faster modulo
        self.tail = (self.tail + 1) & (self.num_slots - 1)
        self.txr.tail = self.tail
        return tl

    def sync(self, nm: netmap.Netmap) -> None:
        """Sync the transmit ring."""
        nm.txsync()


def setup_netmap(interface: str) -> tuple[netmap.Netmap, int]:
    """Setup netmap interface with proper error handling."""
    nm = netmap.Netmap()
    try:
        nm.open()
        nm.if_name = interface
        nm.register()
        # Allow interface to initialize
        time.sleep(0.1)  # Reduced from 1s to 0.1s as that should be sufficient
        return nm, nm.getfd()
    except Exception as e:
        nm.close()
        raise RuntimeError(f"Failed to setup netmap interface: {e}")


def main():
    parser = argparse.ArgumentParser(
        description="High-performance packet generator using netmap API",
        epilog="Press Ctrl-C to stop",
    )
    parser.add_argument(
        "-i", "--interface", default="vale0:0", help="Interface to register with netmap"
    )
    parser.add_argument(
        "-b",
        "--batch-size",
        type=int,
        default=256,
        help="Batch size for packet transmission (power of 2)",
    )
    args = parser.parse_args()

    # Ensure batch size is a power of 2
    batch_size = args.batch_size
    if batch_size & (batch_size - 1) != 0:
        batch_size = 1 << (batch_size - 1).bit_length()
        print(f"Adjusting batch size to nearest power of 2: {batch_size}")

    pkt = build_packet()
    print(f"Opening interface {args.interface}")

    try:
        nm, nfd = setup_netmap(args.interface)
        txr = nm.transmit_rings[0]
        num_slots = txr.num_slots

        # Initialize and pre-fill ring buffer
        ring_buffer = RingBuffer(txr, num_slots)
        ring_buffer.batch = batch_size
        ring_buffer.front_load(pkt)

        print("Starting transmission, press Ctrl-C to stop")

        # Use an efficient polling mechanism
        poller = select.poll()
        poller.register(nfd, select.POLLOUT)

        cnt = 0
        t_start = time.monotonic()  # More precise than time.time()

        while True:
            if not poller.poll(2):
                print("Timeout occurred")
                break

            n = ring_buffer.space_left()
            ring_buffer.push()
            ring_buffer.sync(nm)
            cnt += n

    except KeyboardInterrupt:
        print("\nTransmission interrupted by user")
    except Exception as e:
        print(f"\nError during transmission: {e}")
    finally:
        t_end = time.monotonic()
        duration = t_end - t_start

        # Calculate rates
        rate = cnt / (duration * 1000)  # Convert to thousands
        unit = "K"
        if rate > 1000:
            rate /= 1000
            unit = "M"

        print(f"\nPackets sent: [{cnt:,}], Duration: {duration:.2f}s")
        print(f"Average rate: [{rate:,.3f}] {unit}pps")

        nm.close()


if __name__ == "__main__":
    main()

@wolfspider
Copy link
Author

wolfspider commented Nov 28, 2024

This still isn't ring buffer-y enough and after working with Netmap even more I think we may need to review some more examples to get the Python code into even better shape. The concern here is that in a real setting packets will be generated on the fly and even though we have our methods defined it is still just pushing the buffer through. I have a hazy idea about seeing something more like this in the examples somewhere so we will have to go searching through them to find something adequate.

@wolfspider
Copy link
Author

wolfspider commented Nov 28, 2024

Alright after going back through it we have something formally verified, speed is back up to where it was before, and packets are being generated ad-hoc.

➜  python git:(master) ✗ python3.10 tx3.py -i vale0:0
Opening interface vale0:0
Starting transmission, press Ctrl-C to stop
^C
Transmission interrupted by user

Packets sent: [159,631,103], Duration: 4.77s
Average rate: [33.465] Mpps

Comparison with the old code:

➜  python git:(master) ✗ python3.10 tx.py -i vale0:0
Opening interface vale0:0
Starting transmission, press Ctrl-C to stop
^C
Packets sent: 160381952, Avg rate 34.695 Mpps

This should be the final final product in terms of the Python code:

import struct
import time
import select
import argparse
import netmap


def build_packet() -> bytes:
    """Build a packet with pre-calculated values for better performance."""
    # Pre-calculate the packet once and reuse
    fmt = "!6s6sH46s"
    return struct.pack(
        fmt,
        b"\xff" * 6,  # Destination MAC
        b"\x00" * 6,  # Source MAC
        0x0800,  # EtherType (IPv4)
        b"\x00" * 46,  # Payload
    )


class RingBuffer:
    __slots__ = (
        "txr",
        "num_slots",
        "cur",
        "tail",
        "head",
        "cnt",
        "length",
        "first",
        "batch",
    )

    def __init__(self, txr, num_slots: int):
        """Initialize the RingBuffer with optimized attributes."""
        self.txr = txr
        self.num_slots = num_slots
        self.cur = txr.cur
        self.tail = txr.tail
        self.head = txr.head
        self.cnt = 0
        self.length = 0
        self.first = 0
        self.batch = 256

    def init(self, pkt: bytes) -> None:
        """
        Pre-fill the buffer by repeatedly calling `push_cont`.
        Stops when all slots are filled.
        """
        pkt_view = memoryview(pkt)

        # Call `fpush_cont` to fill the buffer until it is full
        while self.length < self.num_slots:
            self.push_cont(pkt_view)

    def next(self, i):
        """Get the next index in a circular manner."""
        if i == self.num_slots - 1:
            return 0
        else:
            return i + 1

    def prev(self, i):
        """Get the previous index in a circular manner."""
        if i > 0:
            return i - 1
        else:
            return self.num_slots - 1

    def one_past_last(self):
        """Get the index one past the last element."""
        if self.length == self.num_slots:
            return self.first
        elif self.first >= self.num_slots - self.length:
            return self.length - (self.num_slots - self.first)
        else:
            return self.first + self.length

    def space_left(self) -> int:
        """Calculate available space using bitwise operations for better performance."""
        if self.tail >= self.cur:
            n = self.tail - self.cur
        else:
            n = self.num_slots - (self.cur - self.tail)

        spcn = min(self.num_slots - n, self.batch)

        # Update self.cur to reflect reserved space
        self.cur += spcn
        if self.cur >= self.num_slots:
            self.cur -= self.num_slots

        return spcn

    def transmit(self) -> None:
        self.txr.cur = self.txr.head = self.cur

    def push(self, e):
        """Push an element to the start of the buffer."""
        dest_slot = self.prev(self.first)
        self.txr.slots[dest_slot].buf[: len(e)] = e
        self.txr.slots[dest_slot].len = len(e)
        self.first = dest_slot
        self.length = min(self.length + 1, self.num_slots)

    def push_end(self, e):
        """Push an element to the end of the buffer."""
        dest_slot = self.one_past_last()
        self.txr.slots[dest_slot].buf[: len(e)] = e
        self.txr.slots[dest_slot].len = len(e)
        self.first = self.next(self.first)

    def push_cont(self, e):
        """Push element `e` with wraparound."""
        if self.length < self.num_slots:
            self.push(e)
        else:
            self.push_end(e)

    def pop(self):
        """Pop an element from the start of the buffer."""
        if self.length == 0:
            raise IndexError("Pop from empty buffer")
        src_slot = self.txr.slots[self.first]
        pkt = bytes(src_slot.buf[: src_slot.len])
        self.first = self.next(self.first)
        self.length -= 1
        return pkt

    def sync(self, nm: netmap.Netmap) -> None:
        """Sync the transmit ring."""
        nm.txsync()


def setup_netmap(interface: str) -> tuple[netmap.Netmap, int]:
    """Setup netmap interface with proper error handling."""
    nm = netmap.Netmap()
    try:
        nm.open()
        nm.if_name = interface
        nm.register()
        # Allow interface to initialize
        time.sleep(0.1)  # Reduced from 1s to 0.1s as that should be sufficient
        return nm, nm.getfd()
    except Exception as e:
        nm.close()
        raise RuntimeError(f"Failed to setup netmap interface: {e}")


def main():
    parser = argparse.ArgumentParser(
        description="High-performance packet generator using netmap API",
        epilog="Press Ctrl-C to stop",
    )
    parser.add_argument(
        "-i", "--interface", default="vale0:0", help="Interface to register with netmap"
    )
    args = parser.parse_args()

    pkt = build_packet()
    print(f"Opening interface {args.interface}")

    try:
        nm, nfd = setup_netmap(args.interface)
        txr = nm.transmit_rings[0]
        num_slots = txr.num_slots

        # Initialize and pre-fill ring buffer
        ring_buffer = RingBuffer(txr, num_slots)
        ring_buffer.init(pkt)

        print("Starting transmission, press Ctrl-C to stop")

        # Use an efficient polling mechanism
        poller = select.poll()
        poller.register(nfd, select.POLLOUT)

        cnt = 0
        t_start = time.monotonic()  # More precise than time.time()

        while True:
            if not poller.poll(2):
                print("Timeout occurred")
                break

            n = ring_buffer.space_left()
            ring_buffer.transmit()
            ring_buffer.sync(nm)
            cnt += n

    except KeyboardInterrupt:
        print("\nTransmission interrupted by user")
    except Exception as e:
        print(f"\nError during transmission: {e}")
    finally:
        t_end = time.monotonic()
        duration = t_end - t_start

        # Calculate rates
        rate = cnt / (duration * 1000)  # Convert to thousands
        unit = "K"
        if rate > 1000:
            rate /= 1000
            unit = "M"

        print(f"\nPackets sent: [{cnt:,}], Duration: {duration:.2f}s")
        print(f"Average rate: [{rate:,.3f}] {unit}pps")

        nm.close()


if __name__ == "__main__":
    main()

@wolfspider
Copy link
Author

The benefit to this approach is that this is easier to update with just sending an array of arbitrary bytes where a higher bound may exist with the number of slots before sending the payload. That functionality could be easily dropped in here. Now we can finally move on to a part 2 remaking this at a lower level.

A couple of edits to the code also brought speed up to this:

python git:(master) ✗ python3.10 tx3.py -i vale0:0
Opening interface vale0:0
Starting transmission, press Ctrl-C to stop
^C
Transmission interrupted by user

Packets sent: [201,336,063], Duration: 5.84s
Average rate: [34.456] Mpps

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment