Skip to content

Instantly share code, notes, and snippets.

@raphaelsc
Created August 30, 2024 16:33
Show Gist options
  • Save raphaelsc/295f419f4a07b74561bb45e3e39f727c to your computer and use it in GitHub Desktop.
Save raphaelsc/295f419f4a07b74561bb45e3e39f727c to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import argparse
import os
from collections import defaultdict
import struct
import sys
import re
import uuid
import time
class Stream:
size = {
'c': 1, # char
'b': 1, # signed char (int8)
'B': 1, # unsigned char (uint8)
'?': 1, # bool
'h': 2, # short (int16)
'H': 2, # unsigned short (uint16)
'i': 4, # int (int32)
'I': 4, # unsigned int (uint32)
'l': 4, # long (int32)
'l': 4, # unsigned long (int32)
'q': 8, # long long (int64)
'Q': 8, # unsigned long long (uint64)
'f': 4, # float
'd': 8, # double
}
def __init__(self, data, offset=0):
self.data = data
self.offset = offset
def skip(self, n):
self.offset += n
def read(self, typ):
try:
(val,) = struct.unpack_from('>{}'.format(typ), self.data, self.offset)
except Exception as e:
raise ValueError('Failed to read type `{}\' from stream at offset {}: {}'.format(typ, e, self.offset))
self.offset += self.size[typ]
return val
def bool(self):
return self.read('?')
def int8(self):
return self.read('b')
def uint8(self):
return self.read('B')
def int16(self):
return self.read('h')
def uint16(self):
return self.read('H')
def int32(self):
return self.read('i')
def uint32(self):
return self.read('I')
def int64(self):
return self.read('q')
def uint64(self):
return self.read('Q')
def float(self):
return self.read('f')
def double(self):
return self.read('d')
def bytes(self, len_type):
len = len_type(self)
val = self.data[self.offset:self.offset + len]
self.offset += len
return val
def bytes16(self):
return self.bytes(Stream.uint16)
def bytes32(self):
return self.bytes(Stream.uint32)
def string(self, len_type):
buf = self.bytes(len_type)
try:
return buf.decode('utf-8')
except UnicodeDecodeError:
# FIXME why are some strings unintelligible?
# FIXME Remove this when we finally transition to Python3
if sys.version_info[0] == 2:
return 'INVALID(size={}, bytes={})'.format(len(buf), ''.join(map(lambda x: '{:02x}'.format(ord(x)), buf)))
else:
return 'INVALID(size={}, bytes={})'.format(len(buf), ''.join(map(lambda x: '{:02x}'.format(x), buf)))
def string16(self):
return self.string(Stream.uint16)
def string32(self):
return self.string(Stream.uint32)
def map16(self, keytype=string16, valuetype=string16):
return {self.keytype(): self.valuetype() for _ in range(self.int16())}
def map32(self, keytype=string16, valuetype=string16):
return {keytype(self): valuetype(self) for _ in range(self.int32())}
def array32(self, valuetype):
return [valuetype(self) for _ in range(self.int32())]
def tuple(self, *member_types):
return (mt(self) for mt in member_types)
def struct(self, *members):
return {member_name: member_type(self) for member_name, member_type in members}
def set_of_tagged_union(self, tag_type, *members):
members_by_keys = {k: (n, t) for k, n, t in members}
value = {}
for _ in range(tag_type(self)):
key = tag_type(self)
size = self.uint32()
if key in members_by_keys:
name, typ = members_by_keys[key]
value[name] = typ(self)
#TODO: check we haven't read more than size
else:
self.skip(size)
return value
def enum32(self, *values):
d = {v: n for v, n in values}
return d[self.uint32()]
@staticmethod
def instantiate(template_type, *args):
def instanciated_type(stream):
return template_type(stream, *args)
return instanciated_type
def parse(stream, schema):
return {name: typ(stream) for name, typ in schema}
def scylla_parse(data, sstable_format):
disk_token_bound = Stream.instantiate(
Stream.struct,
('exclusive', Stream.uint8),
('token', Stream.string16),
)
disk_token_range = Stream.instantiate(
Stream.struct,
('left', disk_token_bound),
('right', disk_token_bound),
)
sharding_metadata = Stream.instantiate(
Stream.struct,
('token_ranges', Stream.instantiate(Stream.array32, disk_token_range)),
)
sstable_enabled_features = Stream.instantiate(
Stream.struct,
('enabled_features', Stream.uint64),
)
extension_attributes = Stream.instantiate(
Stream.map32, Stream.string32, Stream.string32,
)
UUID = Stream.instantiate(
Stream.struct,
('msb', Stream.uint64),
('lsb', Stream.uint64),
)
run_identifier = Stream.instantiate(
Stream.struct,
('id', UUID),
)
large_data_type = Stream.instantiate(
Stream.enum32,
(1, "partition_size"),
(2, "row_size"),
(3, "cell_size"),
(4, "rows_in_partition"),
)
large_data_stats_entry = Stream.instantiate(
Stream.struct,
('max_value', Stream.uint64),
('threshold', Stream.uint64),
('above_threshold', Stream.uint32),
)
large_data_stats = Stream.instantiate(
Stream.map32, large_data_type, large_data_stats_entry,
)
scylla_component_data = Stream.instantiate(
Stream.set_of_tagged_union,
Stream.uint32,
(1, "sharding", sharding_metadata),
(2, "features", sstable_enabled_features),
(3, "extension_attributes", extension_attributes),
(4, "run_identifier", run_identifier),
(5, "large_data_stats", large_data_stats),
(6, "sstable_origin", Stream.string32),
)
schema = (
('data', scylla_component_data),
)
return parse(Stream(data), schema)
METADATA_TYPE_TO_NAME = {
0: "Validation",
1: "Compaction",
2: "Stats",
3: "Serialization",
}
def read_validation(stream, fmt):
return parse(
stream,
(
('partitioner', Stream.string16),
('filter_chance', Stream.double),
)
)
def read_compaction(stream, fmt):
ka_la_schema = (
('ancestors', Stream.instantiate(Stream.array32, Stream.uint32)),
('cardinality', Stream.instantiate(Stream.array32, Stream.uint8)),
)
mc_schema = (
('cardinality', Stream.instantiate(Stream.array32, Stream.uint8)),
)
if re.match('m[cde]', fmt):
return parse(stream, mc_schema)
else:
return parse(stream, ka_la_schema)
def read_stats(stream, fmt):
replay_position = Stream.instantiate(
Stream.struct,
('id', Stream.uint64),
('pos', Stream.uint32),
)
estimated_histogram = Stream.instantiate(
Stream.array32,
Stream.instantiate(
Stream.struct,
('offset', Stream.uint64),
('bucket', Stream.uint64),
),
)
streaming_histogram = Stream.instantiate(
Stream.struct,
('max_bin_size', Stream.uint32),
('elements', Stream.instantiate(
Stream.array32,
Stream.instantiate(
Stream.struct,
('key', Stream.double),
('value', Stream.uint64),
),
)),
)
commitlog_interval = Stream.instantiate(
Stream.tuple,
replay_position,
replay_position,
)
ka_la_schema = (
('estimated_partition_size', estimated_histogram),
('estimated_cells_count', estimated_histogram),
('position', replay_position),
('min_timestamp', Stream.int64),
('max_timestamp', Stream.int64),
('max_local_deletion_time', Stream.int32),
('compression_ratio', Stream.double),
('estimated_tombstone_drop_time', streaming_histogram),
('sstable_level', Stream.uint32),
('repaired_at', Stream.uint64),
('min_column_names', Stream.instantiate(Stream.array32, Stream.string16)),
('max_column_names', Stream.instantiate(Stream.array32, Stream.string16)),
('has_legacy_counter_shards', Stream.bool),
)
mc_schema = (
('estimated_partition_size', estimated_histogram),
('estimated_cells_count', estimated_histogram),
('position', replay_position),
('min_timestamp', Stream.int64),
('max_timestamp', Stream.int64),
('min_local_deletion_time', Stream.int32),
('max_local_deletion_time', Stream.int32),
('min_ttl', Stream.int32),
('max_ttl', Stream.int32),
('compression_ratio', Stream.double),
('estimated_tombstone_drop_time', streaming_histogram),
('sstable_level', Stream.uint32),
('repaired_at', Stream.uint64),
('min_column_names', Stream.instantiate(Stream.array32, Stream.string16)),
('max_column_names', Stream.instantiate(Stream.array32, Stream.string16)),
('has_legacy_counter_shards', Stream.bool),
('columns_count', Stream.int64),
('rows_count', Stream.int64),
('commitlog_lower_bound', replay_position),
('commitlog_intervals', Stream.instantiate(Stream.array32, commitlog_interval)),
)
if re.match('m[cde]', fmt):
return parse(stream, mc_schema)
else:
return parse(stream, ka_la_schema)
def read_serialization(stream, fmt):
# TODO (those vints are scary)
return {}
READ_METADATA = {
0: read_validation,
1: read_compaction,
2: read_stats,
3: read_serialization,
}
def stats_parse(data, sstable_format):
def read_metadata_offset(stream):
return (Stream.uint32, Stream.uint32)
offsets = Stream(data).array32(
Stream.instantiate(
Stream.tuple, Stream.uint32, Stream.uint32,
)
)
return {METADATA_TYPE_TO_NAME[typ]: READ_METADATA[typ](Stream(data, offset), sstable_format)
for typ, offset in offsets}
def sizeof_fmt(num, suffix='B'):
for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']:
if abs(num) < 1024.0:
return "%3.1f%s%s" % (num, unit, suffix)
num /= 1024.0
return "%.1f%s%s" % (num, 'Yi', suffix)
class sstable:
filename = ""
size = 0
partitions = 0
rows = 0
tombstones = 0
expired = 0
def __init__(self, f, s, p, r, t, e):
self.filename = f
self.size = s
self.partitions = p
self.rows = r
self.tombstones = t
self.expired = e
def describe(self):
return "{ %s: size: %s, partitions: %d, rows: %d, tombstones: %d (out of which %d are expired) }" % (self.filename, sizeof_fmt(self.size), self.partitions, self.rows, self.tombstones, self.expired)
def get_run_size(run):
run_size = 0
for sst in run[1]:
run_size += sst.size
return run_size
def get_run_id(scylla_file, sst_format):
with open(scylla_file, 'rb') as f:
data = f.read()
metadata = scylla_parse(data, sst_format)
try:
run_id_struct = metadata['data']['run_identifier']['id']
return str(run_id_struct['msb']) + str(run_id_struct['lsb'])
except Exception as e:
# generate random run id if none is found in an old Scylla.db
return str(uuid.uuid4())
def get_data_stats(stats_file, sst_format, gc_before):
with open(stats_file, 'rb') as f:
data = f.read()
metadata = stats_parse(data, sst_format)
level=metadata['Stats']['sstable_level']
partitions = 0
elements = metadata['Stats']['estimated_partition_size']
for element in elements:
partitions += element['bucket']
rows = 0
try:
rows = metadata['Stats']['rows_count']
except Exception as e:
rows = 0
tombstones = expired = 0
try:
elements = metadata['Stats']['estimated_tombstone_drop_time']['elements']
#print("{}: {}".format(stats_file, elements))
for element in elements:
value = element['value']
tombstones += value
# tombstone is expired if its deletion_time < gc_before
if element['key'] < gc_before:
expired += value
except Exception as e:
tombstones = expired = 0
return level,partitions,rows,tombstones,expired
class per_shard_info:
shard_id = 0
runs_to_sstables = []
size = 0
partitions = 0
rows = 0
tombstones = 0
expired = 0
def __init__(self, shard_id):
self.shard_id = shard_id
self.runs_to_sstables = defaultdict(set)
def add_sstable(self, level, sst):
self.runs_to_sstables[level].add(sst)
def dump(self):
print("--- SHARD #{} ---".format(self.shard_id))
for level,sstables in sorted(self.runs_to_sstables.items(), key=get_run_size, reverse=True):
run_size = 0
run_partitions = 0
run_rows = 0
run_tombstones = 0
run_expired = 0
sst_descriptions = ""
for sst in sstables:
run_size += sst.size
run_partitions += sst.partitions
run_rows += sst.rows
run_tombstones += sst.tombstones
run_expired += sst.expired
sst_descriptions += "\n\t" + sst.describe()
print("[Level %d: size: %s, partitions: %d rows: %d, tombstones: %d (out of which %d are expired) %s\n]" % (level, sizeof_fmt(run_size), run_partitions, run_rows, run_tombstones, run_expired, sst_descriptions))
self.size += run_size
self.partitions += run_partitions
self.rows += run_rows
self.tombstones += run_tombstones
self.expired += run_expired
def summary(self):
estimated_droppable_pctg = 0
if self.rows > 0:
estimated_droppable_pctg = float(self.expired) / float(self.rows) * 100.0
print("--- SHARD #{} ---".format(self.shard_id))
print("size: %s, partitions: %d rows: %d, tombstones: %d (out of which %d are expired), estimated droppable tombstone: %f%%" % (sizeof_fmt(self.size), self.partitions, self.rows, self.tombstones, self.expired, estimated_droppable_pctg))
def shard_of(token, shards=int):
sharding_ignore_msb_bits=12
token = (token << sharding_ignore_msb_bits) & 0xffffffffffffffff #64 bits
res = (token * shards) >> 64
res = (res & 0xffffffff) # 32 bits
assert(res < shards)
return res
# constants below belong to class murmur3_token
INT64_MAX = int(2 ** 63 - 1)
INT64_MIN = -INT64_MAX - 1
INT64_OVF_OFFSET = INT64_MAX + 1
INT64_OVF_DIV = 2 * INT64_OVF_OFFSET
class murmur3_token:
@staticmethod
def body_and_tail(data):
l = len(data)
nblocks = l // 16
tail = l % 16
if nblocks:
# we use '<', specifying little-endian byte order for data bigger than
# a byte so behavior is the same on little- and big-endian platforms
return struct.unpack_from('<' + ('qq' * nblocks), data), struct.unpack_from('b' * tail, data, -tail), l
else:
return tuple(), struct.unpack_from('b' * tail, data, -tail), l
@staticmethod
def rotl64(x, r):
# note: not a general-purpose function because it leaves the high-order bits intact
# suitable for this use case without wasting cycles
mask = 2 ** r - 1
rotated = (x << r) | ((x >> 64 - r) & mask)
return rotated
@staticmethod
def fmix(k):
# masking off the 31s bits that would be leftover after >> 33 a 64-bit number
k ^= (k >> 33) & 0x7fffffff
k *= 0xff51afd7ed558ccd
k ^= (k >> 33) & 0x7fffffff
k *= 0xc4ceb9fe1a85ec53
k ^= (k >> 33) & 0x7fffffff
return k
@staticmethod
def truncate_int64(x):
if not INT64_MIN <= x <= INT64_MAX:
x = (x + INT64_OVF_OFFSET) % INT64_OVF_DIV - INT64_OVF_OFFSET
return x
@staticmethod
def get_token(data):
h1 = h2 = 0
c1 = -8663945395140668459 # 0x87c37b91114253d5
c2 = 0x4cf5ad432745937f
body, tail, total_len = murmur3_token.body_and_tail(data)
# body
for i in range(0, len(body), 2):
k1 = body[i]
k2 = body[i + 1]
k1 *= c1
k1 = murmur3_token.rotl64(k1, 31)
k1 *= c2
h1 ^= k1
h1 = murmur3_token.rotl64(h1, 27)
h1 += h2
h1 = h1 * 5 + 0x52dce729
k2 *= c2
k2 = murmur3_token.rotl64(k2, 33)
k2 *= c1
h2 ^= k2
h2 = murmur3_token.rotl64(h2, 31)
h2 += h1
h2 = h2 * 5 + 0x38495ab5
# tail
k1 = k2 = 0
len_tail = len(tail)
if len_tail > 8:
for i in range(len_tail - 1, 7, -1):
k2 ^= tail[i] << (i - 8) * 8
k2 *= c2
k2 = murmur3_token.rotl64(k2, 33)
k2 *= c1
h2 ^= k2
if len_tail:
for i in range(min(7, len_tail - 1), -1, -1):
k1 ^= tail[i] << i * 8
k1 *= c1
k1 = murmur3_token.rotl64(k1, 31)
k1 *= c2
h1 ^= k1
# finalization
h1 ^= total_len
h2 ^= total_len
h1 += h2
h2 += h1
h1 = murmur3_token.fmix(h1)
h2 = murmur3_token.fmix(h2)
h1 += h2
return murmur3_token.truncate_int64(h1)
def get_sstable_tokens(summary_filename):
try:
with open(summary_filename, 'rb') as f:
minIndexInterval = struct.unpack('>I', f.read(4))[0]
offsetCount = struct.unpack('>I', f.read(4))[0]
offheapSize = struct.unpack('>q', f.read(8))[0]
samplingLevel = struct.unpack('>I', f.read(4))[0]
fullSamplingSummarySize = struct.unpack('>I', f.read(4))[0]
f.read(offsetCount * 4)
f.read(offheapSize - offsetCount * 4);
firstSize = struct.unpack('>I', f.read(4))[0]
firstString = '>' + str(firstSize) + 's'
first = struct.unpack(firstString, f.read(firstSize))[0]
lastSize = struct.unpack('>I', f.read(4))[0]
lastString = '>' + str(lastSize) + 's'
last = struct.unpack(lastString, f.read(lastSize))[0]
except Exception as e:
print("get_sstable_tokens: {}".format(e))
sys.exit(1)
return murmur3_token.get_token(first), murmur3_token.get_token(last)
def is_uuid(v):
pattern = re.compile("([0-9a-z]{4})_([0-9a-z]{4})_([0-9a-z]{5})([0-9a-z]{13})")
return pattern.match(v)
def main():
if len(sys.argv) != 3 and len(sys.argv) != 4:
print("usage: {} /path/to/table shards [gc_grace_seconds]".format(sys.argv[0]))
exit(1)
directory=sys.argv[1]
shards=sys.argv[2]
gc_grace_seconds = 3600*24*7
if len(sys.argv) == 4:
gc_grace_seconds = int(sys.argv[3])
gc_before = time.time() - gc_grace_seconds
print("gc_grace_seconds = {}, gc_before = {}".format(gc_grace_seconds, gc_before))
per_shard_info_set = dict()
for shard_id in range(0, int(shards)):
per_shard_info_set[shard_id] = per_shard_info(shard_id)
for filename in os.listdir(directory):
if not filename.endswith("Scylla.db"):
continue
sst_format = ""
if filename.startswith("me-"):
sst_format = "me"
elif filename.startswith("mc-"):
sst_format = "mc"
elif filename.startswith("md-"):
sst_format = "md"
elif filename.count("ka-"):
sst_format = "ka"
elif filename.count("la-"):
sst_format = "la"
else:
print("unable to find sst format in {}", filename)
exit(1)
scylla_file = os.path.join(directory, filename)
if not os.path.exists(scylla_file):
continue
data_file = scylla_file.replace("Scylla.db", "Data.db")
if not os.path.exists(data_file):
continue
size = os.stat(data_file).st_size
stats_file = scylla_file.replace("Scylla.db", "Statistics.db")
level,partitions,rows,tombstones,expired = get_data_stats(stats_file, sst_format, gc_before)
summary_file = scylla_file.replace("Scylla.db", "Summary.db")
first_token, last_token = get_sstable_tokens(summary_file)
sst = sstable(filename, size, partitions, rows, tombstones, expired)
sst_generation = filename.split('-')[1]
shard_id = 0
if sst_generation.isdigit():
shard_id = int(sst_generation) % int(shards)
elif is_uuid(sst_generation):
shard_id = shard_of(first_token, int(shards))
else:
print("Generation type not recognized for {}".format(filename))
sys.exit(1)
per_shard_info_set[shard_id].add_sstable(level, sst)
for shard_id in range(0, int(shards)):
per_shard_info_set[shard_id].dump()
print("NOTE: please take 'estimated droppable pctg' with a grain of salt. It's an estimation which results from dividing # of expired data by # of rows.")
for shard_id in range(0, int(shards)):
per_shard_info_set[shard_id].summary()
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment