Skip to content

Instantly share code, notes, and snippets.

@HarryR
Created January 8, 2025 10:32
Show Gist options
  • Save HarryR/fd2339e6d321c49590bca371f8f7abea to your computer and use it in GitHub Desktop.
Save HarryR/fd2339e6d321c49590bca371f8f7abea to your computer and use it in GitHub Desktop.
// SPDX-License-Identifier: AGPL-3.0-only
pragma solidity ^0.8.0;
import { IERC721 } from '@openzeppelin/contracts/token/ERC721/IERC721.sol';
import { IERC721Receiver } from '@openzeppelin/contracts/token/ERC721/IERC721Receiver.sol';
import { Staking } from './Staking.sol';
import { randomBytes32 } from './Random.sol';
import { Moderation } from './Moderation.sol';
contract Distributor is IERC721Receiver
{
struct Item {
IERC721 nft;
uint256 tokenId;
}
struct Win {
Item item;
address who;
}
Item[] private items;
Win[] private winners;
uint256 private winBlock;
Staking public immutable staker;
Moderation public immutable moderation;
constructor(
Staking in_staking,
Moderation in_moderation
) {
staker = in_staking;
moderation = in_moderation;
}
/// Moderators (allowed posters) can to trigger the distribution function
function tick()
external
{
if( ! moderation.isModerator(msg.sender) )
{
revert NotAllowed();
}
internal_distribute();
}
function internal_distribute()
internal
{
// Transfer previous winners whenever called in a future block
if( winners.length > 0 && winBlock != block.number )
{
winBlock = block.number;
while( winners.length > 0 )
{
Win memory x = winners[winners.length - 1];
winners.pop();
x.item.nft.safeTransferFrom(address(this), x.who, x.item.tokenId);
}
}
// Select a random winner & push into the stack (distributed above)
if( staker.getSum() > 0 && items.length > moderation.poolSize() )
{
uint x = uint(randomBytes32()) % items.length;
Item memory y = items[x];
if( x != (items.length - 1) )
{
items[x] = items[items.length - 1];
}
items.pop();
address winner = staker.random(randomBytes32());
winners.push(Win({item: y, who: winner}));
}
}
error NotAllowed();
function onERC721Received(
address operator,
address from,
uint256 tokenId,
bytes calldata /* data */
)
external
returns (bytes4)
{
if( ! moderation.isAllowed(msg.sender, operator, from) )
{
revert NotAllowed();
}
items.push(Item(IERC721(msg.sender), tokenId));
internal_distribute();
return IERC721Receiver.onERC721Received.selector;
}
}
// SPDX-License-Identifier: AGPL-3.0-only
pragma solidity ^0.8.0;
import { IERC20Metadata } from "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol";
import { Staking, ValueField } from './Staking.sol';
import { Moderation } from './Moderation.sol';
import { Distributor } from './Distributor.sol';
contract FrontendUtils {
struct UserOverview {
bool isOwner;
bool isModerator;
uint userStakedAmount;
uint stakedTotal;
uint stakersCount;
uint userTokenBalance;
string tokenSymbol;
string tokenName;
uint tokenSupply;
uint tokenDecimals;
}
Moderation public immutable moderation;
Distributor public immutable distributor;
Staking public immutable staking;
constructor(
Moderation in_moderation,
Distributor in_distributor,
Staking in_staking
) {
moderation = in_moderation;
distributor = in_distributor;
staking = in_staking;
}
function getOverviewForUser(address who)
external view
returns (UserOverview memory)
{
IERC20Metadata token = IERC20Metadata(address(staking.stakingToken()));
return UserOverview({
isOwner: moderation.owner() == who,
isModerator: moderation.isModerator(who),
userStakedAmount: ValueField.unwrap(staking.getValue(who)),
stakedTotal: staking.getSum(),
stakersCount: staking.getCount(),
userTokenBalance: token.balanceOf(who),
tokenSymbol: token.symbol(),
tokenSupply: token.totalSupply(),
tokenName: token.name(),
tokenDecimals: token.decimals()
});
}
}
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
import { ERC721 } from "@openzeppelin/contracts/token/ERC721/ERC721.sol";
contract MockNFT is ERC721 {
constructor()
ERC721("Mock NFT", "MOCK")
{ }
uint public tokenIdCounter;
function mint(address to, uint n)
external
{
for( uint i = 0; i < n; i++ )
{
uint tokenId = tokenIdCounter;
tokenIdCounter += 1;
_safeMint(to, tokenId);
}
}
}
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
import { ERC20 } from "@openzeppelin/contracts/token/ERC20/ERC20.sol";
contract MockToken is ERC20 {
constructor()
ERC20("Mock NFT", "MOCK")
{ }
uint public tokenIdCounter;
function mint(address to, uint256 amount)
external
{
_mint(to, amount);
}
}
// SPDX-License-Identifier: AGPL-3.0-only
pragma solidity ^0.8.0;
import { Ownable } from '@openzeppelin/contracts/access/Ownable.sol';
import { EnumerableSet } from '@openzeppelin/contracts/utils/structs/EnumerableSet.sol';
contract Moderation is Ownable {
using EnumerableSet for EnumerableSet.AddressSet;
EnumerableSet.AddressSet private allowedNFTs;
EnumerableSet.AddressSet private allowedPosters;
uint public poolSize;
constructor (uint in_poolSize)
Ownable()
{
poolSize = in_poolSize;
}
/// Retrieve list of allowed NFT collections
function getNFTList()
external view
returns (address[] memory out_addrs)
{
uint n = allowedNFTs.length();
out_addrs = new address[](n);
for( uint i = 0; i < n; i++ )
{
out_addrs[i] = allowedNFTs.at(i);
}
}
/// Retrieve list of moderators
function getModeratorList()
external view
returns (address[] memory out_addrs)
{
uint n = allowedPosters.length();
out_addrs = new address[](n);
for( uint i = 0; i < n; i++ )
{
out_addrs[i] = allowedPosters.at(i);
}
}
function setPoolSize(uint newPoolSize)
external onlyOwner
{
poolSize = newPoolSize;
}
function modifyAllowedNFTs(address[] calldata in_nfts, bool state)
external onlyOwner
{
for( uint i = 0; i < in_nfts.length; i++ )
{
address x = in_nfts[i];
if( state == true && ! allowedNFTs.contains(x) )
{
allowedNFTs.add(x);
}
else if( state == false && allowedNFTs.contains(x) )
{
allowedNFTs.remove(x);
}
}
}
function modifyModerators(address[] calldata in_addrs, bool state)
external onlyOwner
{
for( uint i = 0; i < in_addrs.length; i++ )
{
address x = in_addrs[i];
if( state == true && ! allowedPosters.contains(x) )
{
allowedPosters.add(x);
}
else if( state == false && allowedPosters.contains(x) )
{
allowedPosters.remove(x);
}
}
}
function isModerator(address who)
external view
returns (bool)
{
return allowedPosters.contains(who);
}
function isAllowed(address nft, address operator, address from)
external view
returns (bool)
{
if( false == allowedNFTs.contains(nft) )
{
return false;
}
return allowedPosters.length() == 0
|| (allowedPosters.contains(operator) || allowedPosters.contains(from));
}
}
// SPDX-License-Identifier: AGPL-3.0-only
pragma solidity ^0.8.0;
error RandomBytesFailure();
address constant RANDOM_PRECOMPILE = 0x0100000000000000000000000000000000000001;
function randomBytes32()
view returns (bytes32)
{
if( block.chainid == 1337 ) {
return keccak256(abi.encodePacked(
msg.sender,
msg.value,
block.number,
block.timestamp,
gasleft()
));
}
else {
(bool success, bytes memory entropy) = RANDOM_PRECOMPILE.staticcall(
abi.encode(32, "")
);
if( success != true ) {
revert RandomBytesFailure();
}
return bytes32(entropy);
}
}
import { ethers } from 'hardhat';
import { expect } from 'chai';
import { Distributor, MockNFT, MockToken, Moderation, Staking } from '../src/contracts';
import { HardhatEthersSigner } from '@nomicfoundation/hardhat-ethers/signers';
describe('Staker', () => {
let d: Distributor;
let s: Staking;
let moderation: Moderation;
let mockNFT: MockNFT;
let mockToken: MockToken;
before(async () => {
mockNFT = await (await ethers.getContractFactory('MockNFT')).deploy();
mockToken = await (await ethers.getContractFactory('MockToken')).deploy();
const modf = await ethers.getContractFactory('Moderation');
const m = moderation = await modf.deploy(10);
await m.waitForDeployment();
const stf = await ethers.getContractFactory('SumtreeLibrary');
const st = await stf.deploy();
await st.waitForDeployment();
const stakingFactory = await ethers.getContractFactory('Staking', {
libraries: {
'SumtreeLibrary': await st.getAddress()
}
});
s = await stakingFactory.deploy(await mockToken.getAddress());
const distributorFactory = await ethers.getContractFactory('Distributor');
d = await distributorFactory.deploy(await s.getAddress(), await m.getAddress());
await s.waitForDeployment();
await d.waitForDeployment();
//console.log('Deployed');
});
function calculateIterationCount(weights: Record<string,number>, sigmaThreshold: number = 5): number
{
// Get minimum weight probability
const totalWeight = Object.values(weights).reduce((sum, w) => sum + w, 0);
const minProb = Math.min(...Object.values(weights).map(w => w / totalWeight));
//console.log(' - Total Weight', totalWeight, 'Min Prob', minProb);
// Basic sample size calculation based on minimum probability
// n * p * (1-p) needs to be large enough for normal approximation
// and we want enough samples for reliable statistics
const NORMAL_APPROXIMATION_THRESHOLD_SQUARED = 25
return Math.ceil(NORMAL_APPROXIMATION_THRESHOLD_SQUARED * sigmaThreshold * sigmaThreshold / minProb);
}
function verifyDistribution(
weightsMap: Record<string, number>,
resultsMap: Record<string, bigint>,
sigmaThreshold: number
) {
const totalItems = Object.values(resultsMap)
.reduce((sum, count) => sum + Number(count), 0);
const totalWeight = Object.values(weightsMap)
.reduce((sum, w) => sum + w, 0);
for (const [address, weight] of Object.entries(weightsMap)) {
const probability = weight / totalWeight;
const expectedCount = totalItems * probability;
// Calculate standard deviation
const stdDev = Math.sqrt(totalItems * probability * (1 - probability));
// Calculate acceptable range
const margin = sigmaThreshold * stdDev;
const range: [number, number] = [
Math.floor(expectedCount - margin),
Math.ceil(expectedCount + margin)
];
const actualCount = Number(resultsMap[address] || 0n);
const withinRange = actualCount >= range[0] && actualCount <= range[1];
expect(withinRange).eq(true);
}
}
it('Distribution', async () => {
const nSigners = 5;
const allSigners = await ethers.getSigners();
if( allSigners.length < nSigners ) {
throw new Error('Not enough signers!');
}
const signers: HardhatEthersSigner[] = [];
for( let i = 0; i < nSigners; i++ ) {
signers.push(allSigners[i]);
}
//const signers = allSigners.slice(nSigners);
const weightsMap: Record<string,number> = {};
// Distribute tokens to the signers, then have the signers stake them
for( let i = 0; i < nSigners; i++ )
{
const w = Math.floor(1 + (Math.random() * 100));
const x = signers[i];
await (await mockToken.mint(x, w)).wait();
const b = mockToken.connect(x);
await (await b.approve(await s.getAddress(), w)).wait();
(await s.connect(x).stake(w)).wait();
weightsMap[x.address] = w;
};
// We must approve the NFTs before they can be used...
await (await moderation.modifyAllowedNFTs([await mockNFT.getAddress()], true)).wait();
// Mint in batches of 10, which will be randomly distributed to signers
const batchCount = 10;
const sigmaThreshold = 4;
const m = calculateIterationCount(weightsMap, sigmaThreshold);
//console.log(' - Iteration Count', m, batchCount, m/batchCount);
for( let i = 0; i < Math.floor(m/batchCount); i++ )
{
const tx = await mockNFT.mint(await d.getAddress(), batchCount);
await tx.wait();
}
// Collect the count of how many NFTs each account has
const resultsMap: Record<string,bigint> = {};
for( const s of signers )
{
const a = s.address;
const b = await mockNFT.balanceOf(a);
resultsMap[a] = b;
}
verifyDistribution(weightsMap, resultsMap, sigmaThreshold);
});
});
// SPDX-License-Identifier: AGPL-3.0-only
pragma solidity ^0.8.0;
import { IERC20Metadata } from "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol";
import { Sumtree, SumtreeLibrary, ValueField } from './Sumtree.sol';
contract Staking {
IERC20Metadata public immutable stakingToken;
Sumtree private tree;
using SumtreeLibrary for Sumtree;
constructor( address _stakingToken )
{
stakingToken = IERC20Metadata(_stakingToken);
}
error AlreadyStaked();
error CannotStakeZero();
error TransferFailed();
error MustStakeWhole();
function stake(ValueField amount)
external
{
return stakeFor(msg.sender, amount);
}
function stakeFor(address who, ValueField amount)
public
{
if( ValueField.unwrap(amount) == 0 ) revert CannotStakeZero();
if( ValueField.unwrap(amount) % stakingToken.decimals() != 0 ) revert MustStakeWhole();
if( tree.has(who) ) revert AlreadyStaked();
bool ok = stakingToken.transferFrom(msg.sender, address(this), ValueField.unwrap(amount));
if( false == ok ) revert TransferFailed();
tree.add(who, amount);
}
function unstake()
external
{
ValueField amount = tree.remove(msg.sender);
if( ValueField.unwrap(amount) > 0 )
{
stakingToken.transfer(msg.sender, ValueField.unwrap(amount));
}
}
function getValue(address staker)
external view
returns (ValueField)
{
if( tree.has(staker) ) {
return tree.nodes[tree.nodesByKey[staker]].value;
}
return ValueField.wrap(0);
}
function getSum()
external view
returns (uint)
{
return tree.getSum();
}
function getCount()
external view
returns (uint)
{
return tree.getCount();
}
function random(bytes32 seed)
external view
returns (address)
{
return tree.random(seed);
}
}
// SPDX-License-Identifier: AGPL-3.0-only
pragma solidity ^0.8.0;
type SumField is uint104;
type ValueField is uint88;
type NodeIndex is uint16;
struct Node {
SumField sum;
ValueField value;
NodeIndex left;
NodeIndex right;
NodeIndex count;
NodeIndex parent;
address key;
}
using { NodeIndex_add as +, NodeIndex_neq as !=, NodeIndex_eq as ==, NodeIndex_lt as <, NodeIndex_gt as > } for NodeIndex global;
using { ValueField_eq as ==, ValueField_lt as <, ValueField_gt as > } for ValueField global;
using { SumField_add as +, SumField_sub as -, SumField_gte as >=, SumField_lt as < } for SumField global;
function NodeIndex_add(NodeIndex a, NodeIndex b) pure returns (NodeIndex) {
return NodeIndex.wrap(NodeIndex.unwrap(a) + NodeIndex.unwrap(b));
}
function NodeIndex_eq(NodeIndex a, NodeIndex b) pure returns (bool) {
return NodeIndex.unwrap(a) == NodeIndex.unwrap(b);
}
function NodeIndex_neq(NodeIndex a, NodeIndex b) pure returns (bool) {
return NodeIndex.unwrap(a) != NodeIndex.unwrap(b);
}
function NodeIndex_lt(NodeIndex a, NodeIndex b) pure returns (bool) {
return NodeIndex.unwrap(a) < NodeIndex.unwrap(b);
}
function NodeIndex_gt(NodeIndex a, NodeIndex b) pure returns (bool) {
return NodeIndex.unwrap(a) > NodeIndex.unwrap(b);
}
function ValueField_lt(ValueField a, ValueField b) pure returns (bool) {
return ValueField.unwrap(a) < ValueField.unwrap(b);
}
function ValueField_gt(ValueField a, ValueField b) pure returns (bool) {
return ValueField.unwrap(a) > ValueField.unwrap(b);
}
function ValueField_eq(ValueField a, ValueField b) pure returns (bool) {
return ValueField.unwrap(a) == ValueField.unwrap(b);
}
function SumField_add(SumField a, SumField b) pure returns (SumField) {
return SumField.wrap(SumField.unwrap(a) + SumField.unwrap(b));
}
function SumField_gte(SumField a, SumField b) pure returns (bool) {
return SumField.unwrap(a) >= SumField.unwrap(b);
}
function SumField_sub(SumField a, SumField b) pure returns (SumField) {
return SumField.wrap(SumField.unwrap(a) - SumField.unwrap(b));
}
function SumField_lt(SumField a, SumField b) pure returns (bool) {
return SumField.unwrap(a) < SumField.unwrap(b);
}
function to_SumField(ValueField x) pure returns (SumField) {
return SumField.wrap(ValueField.unwrap(x));
}
struct Sumtree {
mapping(NodeIndex => Node) nodes;
mapping(address => NodeIndex) nodesByKey; // Direct key to nodeId mapping
NodeIndex nextNodeId;
NodeIndex root_id;
}
library SumtreeLibrary {
error DuplicateKey(address key);
NodeIndex private constant EMPTY = NodeIndex.wrap(0);
NodeIndex private constant NodeIndex_ONE = NodeIndex.wrap(1);
SumField private constant SumField_ZERO = SumField.wrap(0);
function add(Sumtree storage self, address key, ValueField value)
public
{
if( has(self, key) ) {
revert DuplicateKey(key);
}
Node memory new_node = Node({
sum: to_SumField(value),
value: value,
left: EMPTY,
right: EMPTY,
parent: EMPTY,
key: key,
count: NodeIndex_ONE
});
NodeIndex new_id = self.nextNodeId = self.nextNodeId + NodeIndex_ONE;
self.nodes[new_id] = new_node;
self.nodesByKey[key] = new_id;
if( self.root_id == EMPTY ) {
self.root_id = new_id;
return;
}
NodeIndex current_id = self.root_id;
NodeIndex parent_id = EMPTY;
while( current_id != EMPTY )
{
parent_id = current_id;
Node storage currentNode = self.nodes[current_id];
currentNode.sum = currentNode.sum + to_SumField(value);
currentNode.count = currentNode.count + NodeIndex_ONE;
if (value < currentNode.value ||
(value == currentNode.value && uint160(key) < uint160(currentNode.key))) {
current_id = currentNode.left;
} else {
current_id = currentNode.right;
}
}
self.nodes[new_id].parent = parent_id;
Node storage parentNode = self.nodes[parent_id];
if (value < parentNode.value ||
(value == parentNode.value && uint160(key) < uint160(parentNode.key))) {
parentNode.left = new_id;
} else {
parentNode.right = new_id;
}
rebalance(self, parent_id);
}
// Right rotation
function rotateRight(Sumtree storage tree, NodeIndex y)
private
returns (NodeIndex)
{
Node storage yNode = tree.nodes[y];
NodeIndex x = tree.nodes[y].left;
Node storage xNode = tree.nodes[x];
NodeIndex T2 = xNode.right;
// Update parent references
NodeIndex yParent = tree.nodes[y].parent;
xNode.parent = yParent;
yNode.parent = x;
if (T2 != EMPTY) {
tree.nodes[T2].parent = y;
}
// Perform rotation
xNode.right = y;
yNode.left = T2;
// Update counts and sums
updateCountAndSum(tree, y);
updateCountAndSum(tree, x);
// Update parent's child reference
if (yParent != EMPTY) {
Node storage yParentNode = tree.nodes[yParent];
if (yParentNode.left == y) {
yParentNode.left = x;
} else {
yParentNode.right = x;
}
} else {
tree.root_id = x;
}
return x;
}
// Left rotation
function rotateLeft(Sumtree storage tree, NodeIndex x)
private
returns (NodeIndex)
{
Node storage xNode = tree.nodes[x];
NodeIndex y = xNode.right;
Node storage yNode = tree.nodes[y];
NodeIndex T2 = yNode.left;
// Update parent references
NodeIndex xParent = xNode.parent;
yNode.parent = xParent;
xNode.parent = y;
if (T2 != EMPTY) {
tree.nodes[T2].parent = x;
}
// Perform rotation
yNode.left = x;
xNode.right = T2;
// Update counts and sums
updateCountAndSum(tree, x);
updateCountAndSum(tree, y);
// Update parent's child reference
if (xParent != EMPTY) {
Node storage xParentNode = tree.nodes[xParent];
if (xParentNode.left == x) {
xParentNode.left = y;
} else {
xParentNode.right = y;
}
} else {
tree.root_id = y;
}
return y;
}
// Rebalance tree after insertion or deletion
function rebalance(Sumtree storage tree, NodeIndex nodeId)
private
{
NodeIndex current = nodeId;
while (current != EMPTY) {
updateCountAndSum(tree, current);
int256 balance = getBalance(tree, current);
// We consider a subtree imbalanced if the difference in count is more than 2x
Node storage currentNode = tree.nodes[current];
bool isImbalanced = balance > int256(uint256(NodeIndex.unwrap(getCount(tree, currentNode.right))) * 2) ||
-balance > int256(uint256(NodeIndex.unwrap(getCount(tree, currentNode.left))) * 2);
if (isImbalanced) {
if (balance > 0) {
// Left heavy
if (getBalance(tree, currentNode.left) < 0) {
// Left-Right case
currentNode.left = rotateLeft(tree, currentNode.left);
}
current = rotateRight(tree, current);
} else {
// Right heavy
if (getBalance(tree, currentNode.right) > 0) {
// Right-Left case
currentNode.right = rotateRight(tree, currentNode.right);
}
current = rotateLeft(tree, current);
}
}
current = tree.nodes[current].parent;
}
}
function getCount(Sumtree storage tree, NodeIndex nodeId) private view returns (NodeIndex) {
return nodeId == EMPTY ? EMPTY : tree.nodes[nodeId].count;
}
function getBalance(Sumtree storage tree, NodeIndex nodeId) private view returns (int256) {
if (nodeId == EMPTY) return 0;
Node storage node = tree.nodes[nodeId];
return int256(uint256(NodeIndex.unwrap(getCount(tree, node.left)))) -
int256(uint256(NodeIndex.unwrap(getCount(tree, node.right))));
}
function updateCountAndSum(Sumtree storage tree, NodeIndex nodeId) private {
if (nodeId == EMPTY) return;
Node storage node = tree.nodes[nodeId];
node.count = NodeIndex_ONE + getCount(tree, node.left) + getCount(tree, node.right);
SumField leftSum = node.left != EMPTY ? tree.nodes[node.left].sum : SumField_ZERO;
SumField rightSum = node.right != EMPTY ? tree.nodes[node.right].sum : SumField_ZERO;
node.sum = leftSum + to_SumField(node.value) + rightSum;
}
error KeyNotFound(address key);
function remove(Sumtree storage self, address key)
public
returns (ValueField value)
{
if (!has(self, key)) {
revert KeyNotFound(key);
}
NodeIndex nodeId = self.nodesByKey[key];
Node storage node = self.nodes[nodeId];
value = node.value;
// Case 1: Node has no children or one child
NodeIndex replacementId;
if (node.left == EMPTY) {
replacementId = node.right;
} else if (node.right == EMPTY) {
replacementId = node.left;
}
// Case 2: Node has both children
else {
// Find successor (smallest value in right subtree)
NodeIndex successorId = node.right;
while (self.nodes[successorId].left != EMPTY) {
successorId = self.nodes[successorId].left;
}
Node storage successorNode = self.nodes[successorId];
// Store the successor's original position details
NodeIndex successorParent = successorNode.parent;
NodeIndex successorRight = successorNode.right;
// If successor is not the immediate right child
if (successorParent != nodeId) {
// Replace successor with its right child
if (successorRight != EMPTY) {
self.nodes[successorRight].parent = successorParent;
}
self.nodes[successorParent].left = successorRight;
// Set successor's right to node's right child
successorNode.right = node.right;
self.nodes[node.right].parent = successorId;
}
// Move node's left child to successor
successorNode.left = node.left;
self.nodes[node.left].parent = successorId;
// Put successor in node's position
successorNode.parent = node.parent;
if (node.parent == EMPTY) {
self.root_id = successorId;
} else {
Node storage parentNode = self.nodes[node.parent];
if (parentNode.left == nodeId) {
parentNode.left = successorId;
} else {
parentNode.right = successorId;
}
}
// Start rebalancing from successor's original parent
rebalance(self, successorParent != nodeId ? successorParent : successorId);
delete self.nodes[nodeId];
//delete self.nodesByKey[key];
self.nodesByKey[key] = NodeIndex.wrap(0);
return value;
}
// Handle Case 1 (no children or one child)
NodeIndex parentId = node.parent;
if (replacementId != EMPTY) {
self.nodes[replacementId].parent = parentId;
}
if (parentId != EMPTY) {
Node storage parentNode = self.nodes[parentId];
if (parentNode.left == nodeId) {
parentNode.left = replacementId;
} else {
parentNode.right = replacementId;
}
rebalance(self, parentId);
} else {
self.root_id = replacementId;
}
delete self.nodes[nodeId];
//delete self.nodesByKey[key];
self.nodesByKey[key] = NodeIndex.wrap(0);
return value;
}
function replaceNode(Sumtree storage self, NodeIndex oldId, NodeIndex newId)
private
{
NodeIndex parentId = self.nodes[oldId].parent;
// Update parent's child pointer
if (parentId != EMPTY) {
Node storage parentNode = self.nodes[parentId];
if (parentNode.left == oldId) {
parentNode.left = newId;
} else {
parentNode.right = newId;
}
} else {
self.root_id = newId;
}
// Update new node's parent pointer
if (newId != EMPTY) {
self.nodes[newId].parent = parentId;
}
}
error EmptyTree();
error RandomFailed();
error OutOfRange();
function random(Sumtree storage self, bytes32 seed)
internal view
returns (address)
{
uint seedHashed = uint(keccak256(abi.encodePacked(seed)));
uint remaining_r = (seedHashed % getSum(self));
return pick(self, remaining_r);
}
function pick(Sumtree storage self, uint random_value)
internal view
returns (address)
{
if( self.root_id == EMPTY ) {
revert EmptyTree();
}
if( random_value >= getSum(self) ) {
revert OutOfRange();
}
NodeIndex current_id = self.root_id;
while( current_id != EMPTY )
{
Node storage current = self.nodes[current_id];
SumField left_sum = SumField_ZERO;
// Get sum of left subtree
if( current.left != EMPTY ) {
left_sum = self.nodes[current.left].sum;
}
// If random_value falls in left subtree
if( random_value < SumField.unwrap(left_sum) )
{
current_id = current.left;
continue;
}
// If random_value falls in current node's range
if( random_value < SumField.unwrap(left_sum + to_SumField(current.value)) ) {
return current.key;
}
random_value = random_value - SumField.unwrap((left_sum + to_SumField(current.value)));
current_id = current.right;
}
revert RandomFailed();
}
function getCount(Sumtree storage self)
internal view
returns (uint)
{
if( self.root_id == EMPTY ) {
return 0;
}
return NodeIndex.unwrap(self.nodes[self.root_id].count);
}
function getSum(Sumtree storage self)
internal view
returns (uint)
{
if( self.root_id == EMPTY ) {
return 0;
}
return SumField.unwrap(self.nodes[self.root_id].sum);
}
function has(Sumtree storage self, address key)
internal view
returns (bool)
{
return self.nodesByKey[key] != EMPTY;
}
}
import { ethers } from 'hardhat';
import { SumtreeLibrary, TestSumtree } from '../src/contracts';
import { randomBytes, hexlify, getAddress, AddressLike } from 'ethers';
import { expect } from 'chai';
describe('Sumtree', () => {
let tst: TestSumtree;
let sumtreeLibrary: SumtreeLibrary;
before(async () => {
const stf = await ethers.getContractFactory('SumtreeLibrary');
sumtreeLibrary = await stf.deploy();
await sumtreeLibrary.waitForDeployment();
})
beforeEach(async () => {
const f = await ethers.getContractFactory('TestSumtree', {
libraries: {
'SumtreeLibrary': await sumtreeLibrary.getAddress()
}
});
tst = await f.deploy()
await tst.waitForDeployment();
});
interface ValidationResult {
isValid: boolean;
actualSum?: bigint;
actualCount?: bigint;
}
async function printNode(nodeId: bigint, indent:number = 0, prefix:string='') {
const n = await tst.node(nodeId);
console.log(' '.repeat(indent), prefix, `id=${nodeId} s=${n.sum} v=${n.value} c=${n.count} k=${n.key} p=${n.parent}`);
if( n.left != 0n ) {
await printNode(n.left, indent + 2, 'L');
}
if( n.right != 0n ) {
await printNode(n.right, indent + 2, 'R');
}
}
async function validateNode(nodeId: bigint): Promise<ValidationResult> {
if (nodeId === 0n) {
return { isValid: true, actualSum: BigInt(0), actualCount: 0n };
}
const node = await tst.node(nodeId);
// Validate left subtree
if (node.left !== 0n) {
const leftNode = await tst.node(node.left);
// Check ordering
if (leftNode.value > node.value ||
(leftNode.value === node.value && BigInt(leftNode.key) >= BigInt(node.key))) {
console.error(`Order violation at node ${nodeId} with left child ${node.left}`);
console.error(`Parent: (${node.value.toString()}, ${node.key})`);
console.error(`Left: (${leftNode.value.toString()}, ${leftNode.key})`);
return { isValid: false };
}
// Check parent pointer
if (leftNode.parent !== nodeId) {
console.error(`Parent pointer mismatch: node ${node.left} should point to ${nodeId}`);
return { isValid: false };
}
}
// Validate right subtree
if (node.right !== 0n) {
const rightNode = await tst.node(node.right);
// Check ordering
if (rightNode.value < node.value ||
(rightNode.value === node.value && BigInt(rightNode.key) <= BigInt(node.key))) {
console.error(`Order violation at node ${nodeId} with right child ${node.right}`);
console.error(`Parent: (${node.value.toString()}, ${node.key})`);
console.error(`Right: (${rightNode.value.toString()}, ${rightNode.key})`);
return { isValid: false };
}
// Check parent pointer
if (rightNode.parent !== nodeId) {
console.error(`Parent pointer mismatch: node ${node.right} should point to ${nodeId}`);
return { isValid: false };
}
}
// Recursively validate children and get their sums and counts
const leftResult = await validateNode(node.left);
if (!leftResult.isValid) return { isValid: false };
const rightResult = await validateNode(node.right);
if (!rightResult.isValid) return { isValid: false };
// Calculate actual sum and count
const actualSum = leftResult.actualSum! + node.value + rightResult.actualSum!;
const actualCount = leftResult.actualCount! + 1n + rightResult.actualCount!;
// Validate sum
if (node.sum !== actualSum) {
console.error(`Sum mismatch at node ${nodeId}:`);
console.error(`Expected: ${node.sum.toString()}`);
console.error(`Actual: ${actualSum.toString()}`);
return { isValid: false };
}
// Validate count
if (node.count !== actualCount) {
console.error(`Count mismatch at node ${nodeId}:`);
console.error(`Expected: ${node.count}`);
console.error(`Actual: ${actualCount}`);
return { isValid: false };
}
return {
isValid: true,
actualSum,
actualCount
};
}
async function validateTree(root: bigint): Promise<boolean> {
try {
const result = await validateNode(root);
if (result.isValid) {
//console.log("Tree is valid:");
//console.log(`Total sum: ${result.actualSum!.toString()}`);
//console.log(`Total count: ${result.actualCount}`);
}
return result.isValid;
} catch (error) {
//console.error("Error validating tree:", error);
return false;
}
}
interface NodeValue {
value: bigint;
key: string;
}
async function treeToSortedList(): Promise<NodeValue[]> {
const values: NodeValue[] = [];
const rootId = await tst.root();
async function inorderTraversal(nodeId: bigint): Promise<void> {
if (nodeId === 0n) return;
const node = await tst.node(nodeId);
// Traverse left
if (node.left !== 0n) {
await inorderTraversal(node.left);
}
// Add current node
values.push({
value: node.value,
key: node.key
});
// Traverse right
if (node.right !== 0n) {
await inorderTraversal(node.right);
}
}
await inorderTraversal(rootId);
return values;
}
it('Simple left rotation', async () => {
// Tree becomes:
// 100 200
// \ -> / \
// 200 100 300
// \
// 300
// Insert in increasing value order
const pairs: [bigint,AddressLike][] = [
[100n, getAddress('0x' + '1'.repeat(40))], // id=1
[200n, getAddress('0x' + '2'.repeat(40))], // id=2
[300n, getAddress('0x' + '3'.repeat(40))], // id=3
];
for( const [w, a] of pairs ) {
const tx = await tst.add(a, w);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
}
//await printNode(await tst.root());
expect(await tst.root()).eq(2n);
expect((await tst.node(2n)).left).eq(1n);
expect((await tst.node(2n)).right).eq(3n);
});
it('Simple right rotation', async () => {
// Tree becomes:
// 300 200
// / -> / \
// 200 100 300
// /
//100
// Insert in decreasing value order
const pairs: [bigint,AddressLike][] = [
[300n, getAddress('0x' + '4'.repeat(40))], // id=1
[200n, getAddress('0x' + '5'.repeat(40))], // id=2
[100n, getAddress('0x' + '6'.repeat(40))], // id=3
];
for( const [w, a] of pairs ) {
const tx = await tst.add(a, w);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
}
//await printNode(await tst.root());
expect(await tst.root()).eq(2n);
expect((await tst.node(2n)).left).eq(3n);
expect((await tst.node(2n)).right).eq(1n);
});
it('Right-right case (double rotation)', async () => {
// Tree becomes:
// 100 100 200
// \ -> \ -> / \
// 300 200 100 300
// / \
// 200 300
const pairs: [bigint,AddressLike][] = [
[100n, getAddress('0x' + '7'.repeat(40))], // id=1
[300n, getAddress('0x' + '8'.repeat(40))], // id=2
[200n, getAddress('0x' + '9'.repeat(40))], // id=3
];
for( const [w, a] of pairs ) {
const tx = await tst.add(a, w);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
}
//await printNode(await tst.root());
expect(await tst.root()).eq(3n);
expect((await tst.node(3n)).left).eq(1n);
expect((await tst.node(3n)).right).eq(2n);
});
it('Left-right case (double rotation)', async () => {
// Tree becomes:
// 300 300 200
// / -> / -> / \
// 100 200 100 300
// \ /
// 200 100
const pairs: [bigint,AddressLike][] = [
[300n, getAddress('0x' + '7'.repeat(40))], // id=1
[100n, getAddress('0x' + '8'.repeat(40))], // id=2
[200n, getAddress('0x' + '9'.repeat(40))], // id=3
];
for( const [w, a] of pairs ) {
const tx = await tst.add(a, w);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
}
//await printNode(await tst.root());
expect(await tst.root()).eq(3n);
expect((await tst.node(3n)).left).eq(2n);
expect((await tst.node(3n)).right).eq(1n);
});
it('Key-based ordering (same values)', async () => {
const pairs: [bigint,AddressLike][] = [
[100n, getAddress('0x' + '1'.repeat(40))], // id=1
[100n, getAddress('0x' + '2'.repeat(40))], // id=2
[100n, getAddress('0x' + '3'.repeat(40))], // id=3
];
for( const [w, a] of pairs ) {
const tx = await tst.add(a, w);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
}
//await printNode(await tst.root());
expect(await tst.root()).eq(2n);
expect((await tst.node(2n)).left).eq(1n);
expect((await tst.node(2n)).right).eq(3n);
});
it('Remove leaf node from balanced tree', async () => {
// Tree structure:
// 200
// / \
// 100 300
//
// After removing 100:
// 300
// /
// 200
const pairs: [bigint,AddressLike][] = [
[200n, getAddress('0x' + '1'.repeat(40))], // id=1
[100n, getAddress('0x' + '2'.repeat(40))], // id=2
[300n, getAddress('0x' + '3'.repeat(40))], // id=3
];
// Build initial tree
for (const [w, a] of pairs) {
const tx = await tst.add(a, w);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
}
//await printNode(await tst.root());
// Remove leaf node 100
const tx = await tst.remove(pairs[1][1]);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
//await printNode(await tst.root());
// Verify structure
const rootNodeId = await tst.root()
expect(rootNodeId).eq(3n);
expect((await tst.node(rootNodeId)).left).eq(1n);
expect((await tst.node(rootNodeId)).right).eq(0n);
});
it('Remove node with one child', async () => {
// Initial tree:
// 200
// / \
// 100 300
// /
// 50
//
// After removing 100:
// 200
// / \
// 50 300
const pairs: [bigint,AddressLike][] = [
[200n, getAddress('0x' + '1'.repeat(40))], // id=1
[100n, getAddress('0x' + '2'.repeat(40))], // id=2
[300n, getAddress('0x' + '3'.repeat(40))], // id=3
[50n, getAddress('0x' + '4'.repeat(40))], // id=4
];
// Build initial tree
for (const [w, a] of pairs) {
const tx = await tst.add(a, w);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
}
// Remove node 100
const tx = await tst.remove(pairs[1][1]);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
// Verify structure
expect(await tst.root()).eq(1n);
expect((await tst.node(1n)).left).eq(4n);
expect((await tst.node(1n)).right).eq(3n);
});
it('Remove node with two children', async () => {
// Initial tree:
// 200
// / \
// 100 300
// / \
// 50 150
//
// After removing 100:
// 200
// / \
// 50 300
// /
// 150
const pairs: [bigint,AddressLike][] = [
[200n, getAddress('0x' + '1'.repeat(40))], // id=1
[100n, getAddress('0x' + '2'.repeat(40))], // id=2
[300n, getAddress('0x' + '3'.repeat(40))], // id=3
[50n, getAddress('0x' + '4'.repeat(40))], // id=4
[150n, getAddress('0x' + '5'.repeat(40))], // id=5
];
// Build initial tree
for (const [w, a] of pairs) {
const tx = await tst.add(a, w);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
}
//await printNode(await tst.root());
// Remove node 100
const tx = await tst.remove(pairs[1][1]);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
/*
await printNode(await tst.root());
// Verify structure
expect(await tst.root()).eq(1n);
expect((await tst.node(1n)).left).eq(5n);
expect((await tst.node(1n)).right).eq(3n);
expect((await tst.node(5n)).left).eq(4n);
*/
});
it('Remove root node triggers rebalance', async () => {
// Initial tree:
// 200
// / \
// 100 300
// / \ \
// 50 150 400
//
// After removing 200:
// 300
// / \
// 100 400
// / \
// 50 150
const pairs: [bigint,AddressLike][] = [
[200n, getAddress('0x' + '1'.repeat(40))], // id=1
[100n, getAddress('0x' + '2'.repeat(40))], // id=2
[300n, getAddress('0x' + '3'.repeat(40))], // id=3
[50n, getAddress('0x' + '4'.repeat(40))], // id=4
[150n, getAddress('0x' + '5'.repeat(40))], // id=5
[400n, getAddress('0x' + '6'.repeat(40))], // id=6
];
// Build initial tree
for (const [w, a] of pairs) {
const tx = await tst.add(a, w);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
}
// Remove root node
const tx = await tst.remove(pairs[0][1]);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
// Verify structure
expect(await tst.root()).eq(3n);
expect((await tst.node(3n)).left).eq(2n);
expect((await tst.node(3n)).right).eq(6n);
expect((await tst.node(2n)).left).eq(4n);
expect((await tst.node(2n)).right).eq(5n);
});
it('Insert after removal maintains valid tree', async () => {
// Initial tree:
// 200
// / \
// 100 300
// / \
// 50 150
const pairs: [bigint,AddressLike][] = [
[200n, getAddress('0x' + '1'.repeat(40))], // id=1
[100n, getAddress('0x' + '2'.repeat(40))], // id=2
[300n, getAddress('0x' + '3'.repeat(40))], // id=3
[50n, getAddress('0x' + '4'.repeat(40))], // id=4
[150n, getAddress('0x' + '5'.repeat(40))], // id=5
];
// Build initial tree
for (const [w, a] of pairs) {
const tx = await tst.add(a, w);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
}
// Remove node with value 100
const tx = await tst.remove(pairs[1][1]);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
// Insert new nodes to test different scenarios
const newPairs: [bigint,AddressLike][] = [
[125n, getAddress('0x' + '6'.repeat(40))], // Between 50 and 150
[175n, getAddress('0x' + '7'.repeat(40))], // Between 150 and 200
[250n, getAddress('0x' + '8'.repeat(40))], // Between 200 and 300
];
for (const [w, a] of newPairs) {
const tx = await tst.add(a, w);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
}
// Verify final tree structure maintains all invariants
expect(await validateTree(await tst.root())).eq(true);
// Optional: Print final tree state for debugging
//await printNode(await tst.root());
});
it('Remove should fail for non-existent key', async () => {
const nonExistentAddress = getAddress('0x' + 'f'.repeat(40));
await expect(tst.remove(nonExistentAddress))
.to.be.revertedWithCustomError(sumtreeLibrary, 'KeyNotFound')
.withArgs(nonExistentAddress);
});
it('Verify sum updates after removal', async () => {
// Add three nodes
const pairs: [bigint,AddressLike][] = [
[200n, getAddress('0x' + '1'.repeat(40))],
[100n, getAddress('0x' + '2'.repeat(40))],
[300n, getAddress('0x' + '3'.repeat(40))],
];
for (const [w, a] of pairs) {
const tx = await tst.add(a, w);
await tx.wait();
}
// Get initial sum
const initialSum = (await tst.node(await tst.root())).sum;
expect(initialSum).eq(600n); // 200 + 100 + 300
// Remove middle node (100)
const tx = await tst.remove(pairs[1][1]);
await tx.wait();
// Verify sum is updated
const finalSum = (await tst.node(await tst.root())).sum;
expect(finalSum).eq(500n); // 200 + 300
});
function validateOrdering(list: NodeValue[]): { isValid: boolean; violations: string[] } {
const violations: string[] = [];
for (let i = 0; i < list.length - 1; i++) {
const current = list[i];
const next = list[i + 1];
// Check if next is greater than current
if (next.value < current.value ||
(next.value === current.value && next.key.toLowerCase() <= current.key.toLowerCase())) {
violations.push(
`Ordering violation at index ${i}:\n` +
`Current: (value: ${current.value.toString()}, key: ${current.key})\n` +
`Next: (value: ${next.value.toString()}, key: ${next.key})`
);
}
}
return {
isValid: violations.length === 0,
violations
};
}
it('Random insert & verify pick', async () => {
for( let i = 0; i < 100; i++ ) {
const a = getAddress(hexlify(randomBytes(20)));
const w = 1 + Math.floor(Math.random() * 1000);
const tx = await tst.add(a, w);
await tx.wait();
expect(await validateTree(await tst.root())).eq(true);
const x = await treeToSortedList();
const y = validateOrdering(x);
if(! y.isValid ) {
console.log(x);
console.log(y.violations);
}
expect(y.isValid).eq(true);
}
//await printNode(await tst.root());
const x = await treeToSortedList();
const y = validateOrdering(x);
expect(y.isValid).eq(true);
let total: bigint = 0n;
for( const z of x )
{
const totalBefore = total;
total += z.value;
// If we pick the beginning & end of range we get the expected
const a = await tst.pick(totalBefore);
const b = await tst.pick(total - 1n);
expect(a).eq(b);
expect(b).eq(z.key);
// Somewhere in the middle
const c = await tst.pick(totalBefore + ((total - 1n - totalBefore) / 2n));
expect(c).eq(z.key);
}
});
it('Random insert, remove & verify pick', async () => {
// Keep track of all addresses and their current presence in tree
const addressMap = new Map<string, boolean>();
const existingAddresses: string[] = [];
let numNodes = 0;
for (let i = 0; i < 200; i++) {
// 70% chance to add, 30% chance to remove when we have nodes
const shouldAdd = numNodes === 0 || Math.random() < 0.7;
if (shouldAdd) {
// Add new node
const a = getAddress(hexlify(randomBytes(20)));
const w = 1 + Math.floor(Math.random() * 1000);
const tx = await tst.add(a, w);
const receipt = await tx.wait();
//console.log('Insert cost', receipt?.cumulativeGasUsed, numNodes);
addressMap.set(a.toLowerCase(), true);
existingAddresses.push(a);
numNodes++;
// Verify tree is valid after addition
expect(await validateTree(await tst.root())).eq(true);
} else {
// Remove random existing node
const indexToRemove = Math.floor(Math.random() * existingAddresses.length);
const addressToRemove = existingAddresses[indexToRemove];
if (addressMap.get(addressToRemove.toLowerCase())) {
const tx = await tst.remove(addressToRemove);
const receipt = await tx.wait();
//console.log('Remove cost', receipt?.cumulativeGasUsed, numNodes);
addressMap.set(addressToRemove.toLowerCase(), false);
numNodes--;
// Verify tree is valid after removal
expect(await validateTree(await tst.root())).eq(true);
}
}
// Verify ordering after each operation
const x = await treeToSortedList();
const y = validateOrdering(x);
if (!y.isValid) {
console.log("Operation:", i);
console.log("Current tree:", x);
console.log("Violations:", y.violations);
}
expect(y.isValid).eq(true);
// Every 10 operations, verify pick functionality for all nodes
if (i % 10 === 0) {
const x = await treeToSortedList();
let total: bigint = 0n;
for (const z of x) {
const totalBefore = total;
total += z.value;
// Verify picks at range boundaries
const a = await tst.pick(totalBefore);
const b = await tst.pick(total - 1n);
expect(a).eq(b);
expect(b).eq(z.key);
// Verify pick in middle of range
const c = await tst.pick(totalBefore + ((total - 1n - totalBefore) / 2n));
expect(c).eq(z.key);
}
}
}
// Final verification of entire tree
const finalList = await treeToSortedList();
expect(validateOrdering(finalList).isValid).eq(true);
// Verify final tree size matches our tracking
expect(finalList.length).eq(numNodes);
// Final verification of pick functionality for all nodes
let total: bigint = 0n;
for (const z of finalList) {
const totalBefore = total;
total += z.value;
const a = await tst.pick(totalBefore);
const b = await tst.pick(total - 1n);
expect(a).eq(b);
expect(b).eq(z.key);
const c = await tst.pick(totalBefore + ((total - 1n - totalBefore) / 2n));
expect(c).eq(z.key);
}
});
});
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
import {Node, Sumtree, SumtreeLibrary, ValueField, NodeIndex} from '../Sumtree.sol';
contract TestSumtree {
Sumtree private st;
using SumtreeLibrary for Sumtree;
function add(address k, ValueField v) external {
st.add(k,v);
}
function remove(address key) public returns (ValueField value) {
return st.remove(key);
}
function random(bytes32 seed) public view returns (address) {
return st.random(seed);
}
function pick(uint r) public view returns (address) {
return st.pick(r);
}
function count() public view returns (uint) {
return st.getCount();
}
function total() public view returns (uint) {
return st.getSum();
}
function root() public view returns (NodeIndex) {
return st.root_id;
}
function node(NodeIndex idx) public view returns (Node memory)
{
return st.nodes[idx];
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment