Last active
April 7, 2025 14:13
-
-
Save mfm24/54185eb596813aa0fbebc6defa2df6d6 to your computer and use it in GitHub Desktop.
SFB playing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import hashlib | |
import math | |
import random | |
from functools import lru_cache | |
class SFB: | |
drop_bucket_full = "drop_bucket_full" | |
drop_stochastic = "drop_stochastic" | |
drop_flow_blocked = "drop_flow_blocked" | |
def __init__(self, layers, buckets_per_layer, pdiff=0.1, hashfunc=lambda s: hashlib.sha512(str(s).encode()).digest()): | |
self.layers = layers | |
self.buckets_per_layer = buckets_per_layer | |
self.pdiff = pdiff | |
self.hash = hashfunc | |
self.bytes_per_layer = int((math.log2(buckets_per_layer - 1)) // 8 + 1) | |
self.reset() | |
def reset(self): | |
self.reset_buckets() | |
self.probs = [[0.0] * self.buckets_per_layer for __ in range(self.layers)] | |
def reset_buckets(self): | |
self.buckets = [[0] * self.buckets_per_layer for __ in range(self.layers)] | |
def add(self, hashable): | |
if self.update_probs(hashable, 16) is SFB.drop_bucket_full: | |
return SFB.drop_bucket_full | |
p = self.min_prob(hashable) | |
if p >= 1.0: | |
return SFB.drop_flow_blocked | |
elif p > 0.0: | |
# drop stochastically | |
if random.random() < p: | |
return SFB.drop_stochastic | |
for array, ps, i in self.buckets_probs_for_value(hashable): | |
array[i] += 1 | |
def remove(self, hashable): | |
for array, ps, i in self.buckets_probs_for_value(hashable): | |
array[i] -= 1 | |
def update_probs(self, v, threshold): | |
drop = False | |
for buckets, ps, i in self.buckets_probs_for_value(v): | |
if buckets[i] >= threshold: | |
drop = True | |
ps[i] += self.pdiff | |
# print(f"ps[i] = {ps[i]}") | |
if ps[i] > 1.0: | |
ps[i] = 1.0 | |
elif buckets[i] == 0: | |
ps[i] -= self.pdiff | |
if ps[i] < 0.0: | |
ps[i] = 0.0 | |
if drop: | |
return SFB.drop_bucket_full | |
@lru_cache | |
def buckets_probs_for_value(self, v): | |
h = self.hash(v) | |
assert len(h) >= self.bytes_per_layer * self.layers, f"Not enough bytes in hash!, need {self.bytes_per_layer} * {self.layers}, got {len(h)}" | |
return list(self.buckets_probs_for_hash(h, self.buckets, self.probs)) | |
""" Yield bucket_array, prob_array, index for the given value """ | |
def buckets_probs_for_hash(self, h, remaining_layer_buckets, remaining_layer_probs): | |
if not remaining_layer_buckets: | |
return | |
hash_start, rest = h[:self.bytes_per_layer], h[self.bytes_per_layer:] | |
yield remaining_layer_buckets[0], remaining_layer_probs[0], self.bytes_to_int(hash_start) % self.buckets_per_layer | |
yield from self.buckets_probs_for_hash(rest, remaining_layer_buckets[1:], remaining_layer_probs[1:]) | |
def min_prob(self, v): | |
""" Return the min of all probs for this value """ | |
return min(ps[i] for buckets, ps, i in self.buckets_probs_for_value(v)) | |
def __repr__(self): | |
def _s(): | |
yield "SFB:" | |
for i, l in enumerate(self.buckets): | |
yield f" Layer: {i} sum: {sum(l)}" | |
return "\n".join(_s()) + str(self.probs) | |
@staticmethod | |
def bytes_to_int(b): | |
return sum(256**i * b for i,b in enumerate(b)) | |
class SFBBackedByList(SFB): | |
""" | |
SFB and a list. Has a take_oldest function to remove n items | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.back = [] | |
def add(self, v): | |
ret = super().add(v) | |
if ret is None: # no errors adding | |
self.back.append(v) | |
return ret | |
def remove(self, v): | |
super().remove(v) | |
self.back.remove(v) | |
def take_oldest(self, n): | |
to_remove, new_back = self.back[:n], self.back[n:] | |
for x in to_remove: | |
self.remove(x) | |
self.back = new_back | |
def __len__(self): | |
return len(self.back) | |
def test_sfb_steady(layers, buckets_per_layer, steady_len, burst_size, loops=100, threshold=16): | |
# Create a SFB, populate steady_len, then add and remove burst size, keeping track of stats | |
s = SFBBackedByList(layers, buckets_per_layer) | |
for i in range(steady_len * 2): | |
if s.min(i) < threshold: | |
s.add(i) | |
if len(s) >= steady_len: | |
break | |
assert len(s) >= steady_len, "Unable to populate!" | |
ok, failed = 0, 0 | |
for times in range(loops): | |
while len(s) < steady_len + burst_size: | |
i += 1 | |
if s.min(i) < threshold: | |
s.add(i) | |
ok += 1 | |
else: | |
failed += 1 | |
s.take_oldest(burst_size) | |
print(f"Drop fraction = {failed/(failed + ok)}") | |
print(f"Drop fraction = {failed/(failed + ok)}") | |
# pkts = set() # mirrors sfb | |
def test_sfb(layers, buckets_per_layer, steady_len, steady_burst_size, hot_flows, loops=100): | |
# Create a SFB, populate steady_len, then add and remove burst size + hot_flows, keeping track of stats | |
# hot_flows should be a dict of val: count to be added each burst. | |
# We expect hot flows to be block, all others to pass | |
s = SFBBackedByList(layers, buckets_per_layer) | |
stats = dict(count=0, false_drops=0, false_allowed=0, pkt=0) | |
def populate(): | |
for i in range(steady_len * 2): | |
s.add(i) | |
stats['pkt'] = i | |
if len(s) >= steady_len: | |
break | |
for val, counts in hot_flows.items(): | |
for c in range(counts): | |
s.add(val) | |
assert len(s) >= steady_len, "Unable to populate!" | |
def run(): | |
for __ in range(steady_burst_size): | |
i = stats['pkt'] + 1 | |
stats['pkt'] += 1 | |
stats['count'] += 1 | |
if s.add(i) is not None: | |
stats['false_drops'] += 1 | |
for val, counts in hot_flows.items(): | |
for c in range(counts): | |
stats['count'] += 1 | |
if s.add(val) is None: | |
stats['false_allowed'] += 1 | |
s.take_oldest(len(s) - steady_len) | |
populate() | |
for times in range(loops): | |
run() | |
fraction_false = (stats['false_drops'] + stats['false_allowed']) / stats['count'] | |
# print(f"Drop fraction = {fraction_false}") | |
print(f"Drop fraction = {fraction_false}, {stats}, {s}") | |
for val, counts in hot_flows.items(): | |
print(f" hot {val}: {s.min_prob(val)} add: {s.add(val)}") | |
# pkts = set() # mirrors sfb | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment