Skip to content

Instantly share code, notes, and snippets.

@mfm24
Last active April 7, 2025 14:13
Show Gist options
  • Save mfm24/54185eb596813aa0fbebc6defa2df6d6 to your computer and use it in GitHub Desktop.
Save mfm24/54185eb596813aa0fbebc6defa2df6d6 to your computer and use it in GitHub Desktop.
SFB playing
import hashlib
import math
import random
from functools import lru_cache
class SFB:
drop_bucket_full = "drop_bucket_full"
drop_stochastic = "drop_stochastic"
drop_flow_blocked = "drop_flow_blocked"
def __init__(self, layers, buckets_per_layer, pdiff=0.1, hashfunc=lambda s: hashlib.sha512(str(s).encode()).digest()):
self.layers = layers
self.buckets_per_layer = buckets_per_layer
self.pdiff = pdiff
self.hash = hashfunc
self.bytes_per_layer = int((math.log2(buckets_per_layer - 1)) // 8 + 1)
self.reset()
def reset(self):
self.reset_buckets()
self.probs = [[0.0] * self.buckets_per_layer for __ in range(self.layers)]
def reset_buckets(self):
self.buckets = [[0] * self.buckets_per_layer for __ in range(self.layers)]
def add(self, hashable):
if self.update_probs(hashable, 16) is SFB.drop_bucket_full:
return SFB.drop_bucket_full
p = self.min_prob(hashable)
if p >= 1.0:
return SFB.drop_flow_blocked
elif p > 0.0:
# drop stochastically
if random.random() < p:
return SFB.drop_stochastic
for array, ps, i in self.buckets_probs_for_value(hashable):
array[i] += 1
def remove(self, hashable):
for array, ps, i in self.buckets_probs_for_value(hashable):
array[i] -= 1
def update_probs(self, v, threshold):
drop = False
for buckets, ps, i in self.buckets_probs_for_value(v):
if buckets[i] >= threshold:
drop = True
ps[i] += self.pdiff
# print(f"ps[i] = {ps[i]}")
if ps[i] > 1.0:
ps[i] = 1.0
elif buckets[i] == 0:
ps[i] -= self.pdiff
if ps[i] < 0.0:
ps[i] = 0.0
if drop:
return SFB.drop_bucket_full
@lru_cache
def buckets_probs_for_value(self, v):
h = self.hash(v)
assert len(h) >= self.bytes_per_layer * self.layers, f"Not enough bytes in hash!, need {self.bytes_per_layer} * {self.layers}, got {len(h)}"
return list(self.buckets_probs_for_hash(h, self.buckets, self.probs))
""" Yield bucket_array, prob_array, index for the given value """
def buckets_probs_for_hash(self, h, remaining_layer_buckets, remaining_layer_probs):
if not remaining_layer_buckets:
return
hash_start, rest = h[:self.bytes_per_layer], h[self.bytes_per_layer:]
yield remaining_layer_buckets[0], remaining_layer_probs[0], self.bytes_to_int(hash_start) % self.buckets_per_layer
yield from self.buckets_probs_for_hash(rest, remaining_layer_buckets[1:], remaining_layer_probs[1:])
def min_prob(self, v):
""" Return the min of all probs for this value """
return min(ps[i] for buckets, ps, i in self.buckets_probs_for_value(v))
def __repr__(self):
def _s():
yield "SFB:"
for i, l in enumerate(self.buckets):
yield f" Layer: {i} sum: {sum(l)}"
return "\n".join(_s()) + str(self.probs)
@staticmethod
def bytes_to_int(b):
return sum(256**i * b for i,b in enumerate(b))
class SFBBackedByList(SFB):
"""
SFB and a list. Has a take_oldest function to remove n items
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.back = []
def add(self, v):
ret = super().add(v)
if ret is None: # no errors adding
self.back.append(v)
return ret
def remove(self, v):
super().remove(v)
self.back.remove(v)
def take_oldest(self, n):
to_remove, new_back = self.back[:n], self.back[n:]
for x in to_remove:
self.remove(x)
self.back = new_back
def __len__(self):
return len(self.back)
def test_sfb_steady(layers, buckets_per_layer, steady_len, burst_size, loops=100, threshold=16):
# Create a SFB, populate steady_len, then add and remove burst size, keeping track of stats
s = SFBBackedByList(layers, buckets_per_layer)
for i in range(steady_len * 2):
if s.min(i) < threshold:
s.add(i)
if len(s) >= steady_len:
break
assert len(s) >= steady_len, "Unable to populate!"
ok, failed = 0, 0
for times in range(loops):
while len(s) < steady_len + burst_size:
i += 1
if s.min(i) < threshold:
s.add(i)
ok += 1
else:
failed += 1
s.take_oldest(burst_size)
print(f"Drop fraction = {failed/(failed + ok)}")
print(f"Drop fraction = {failed/(failed + ok)}")
# pkts = set() # mirrors sfb
def test_sfb(layers, buckets_per_layer, steady_len, steady_burst_size, hot_flows, loops=100):
# Create a SFB, populate steady_len, then add and remove burst size + hot_flows, keeping track of stats
# hot_flows should be a dict of val: count to be added each burst.
# We expect hot flows to be block, all others to pass
s = SFBBackedByList(layers, buckets_per_layer)
stats = dict(count=0, false_drops=0, false_allowed=0, pkt=0)
def populate():
for i in range(steady_len * 2):
s.add(i)
stats['pkt'] = i
if len(s) >= steady_len:
break
for val, counts in hot_flows.items():
for c in range(counts):
s.add(val)
assert len(s) >= steady_len, "Unable to populate!"
def run():
for __ in range(steady_burst_size):
i = stats['pkt'] + 1
stats['pkt'] += 1
stats['count'] += 1
if s.add(i) is not None:
stats['false_drops'] += 1
for val, counts in hot_flows.items():
for c in range(counts):
stats['count'] += 1
if s.add(val) is None:
stats['false_allowed'] += 1
s.take_oldest(len(s) - steady_len)
populate()
for times in range(loops):
run()
fraction_false = (stats['false_drops'] + stats['false_allowed']) / stats['count']
# print(f"Drop fraction = {fraction_false}")
print(f"Drop fraction = {fraction_false}, {stats}, {s}")
for val, counts in hot_flows.items():
print(f" hot {val}: {s.min_prob(val)} add: {s.add(val)}")
# pkts = set() # mirrors sfb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment