Created
January 8, 2025 10:32
-
-
Save HarryR/fd2339e6d321c49590bca371f8f7abea to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SPDX-License-Identifier: AGPL-3.0-only | |
pragma solidity ^0.8.0; | |
import { IERC721 } from '@openzeppelin/contracts/token/ERC721/IERC721.sol'; | |
import { IERC721Receiver } from '@openzeppelin/contracts/token/ERC721/IERC721Receiver.sol'; | |
import { Staking } from './Staking.sol'; | |
import { randomBytes32 } from './Random.sol'; | |
import { Moderation } from './Moderation.sol'; | |
contract Distributor is IERC721Receiver | |
{ | |
struct Item { | |
IERC721 nft; | |
uint256 tokenId; | |
} | |
struct Win { | |
Item item; | |
address who; | |
} | |
Item[] private items; | |
Win[] private winners; | |
uint256 private winBlock; | |
Staking public immutable staker; | |
Moderation public immutable moderation; | |
constructor( | |
Staking in_staking, | |
Moderation in_moderation | |
) { | |
staker = in_staking; | |
moderation = in_moderation; | |
} | |
/// Moderators (allowed posters) can to trigger the distribution function | |
function tick() | |
external | |
{ | |
if( ! moderation.isModerator(msg.sender) ) | |
{ | |
revert NotAllowed(); | |
} | |
internal_distribute(); | |
} | |
function internal_distribute() | |
internal | |
{ | |
// Transfer previous winners whenever called in a future block | |
if( winners.length > 0 && winBlock != block.number ) | |
{ | |
winBlock = block.number; | |
while( winners.length > 0 ) | |
{ | |
Win memory x = winners[winners.length - 1]; | |
winners.pop(); | |
x.item.nft.safeTransferFrom(address(this), x.who, x.item.tokenId); | |
} | |
} | |
// Select a random winner & push into the stack (distributed above) | |
if( staker.getSum() > 0 && items.length > moderation.poolSize() ) | |
{ | |
uint x = uint(randomBytes32()) % items.length; | |
Item memory y = items[x]; | |
if( x != (items.length - 1) ) | |
{ | |
items[x] = items[items.length - 1]; | |
} | |
items.pop(); | |
address winner = staker.random(randomBytes32()); | |
winners.push(Win({item: y, who: winner})); | |
} | |
} | |
error NotAllowed(); | |
function onERC721Received( | |
address operator, | |
address from, | |
uint256 tokenId, | |
bytes calldata /* data */ | |
) | |
external | |
returns (bytes4) | |
{ | |
if( ! moderation.isAllowed(msg.sender, operator, from) ) | |
{ | |
revert NotAllowed(); | |
} | |
items.push(Item(IERC721(msg.sender), tokenId)); | |
internal_distribute(); | |
return IERC721Receiver.onERC721Received.selector; | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SPDX-License-Identifier: AGPL-3.0-only | |
pragma solidity ^0.8.0; | |
import { IERC20Metadata } from "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol"; | |
import { Staking, ValueField } from './Staking.sol'; | |
import { Moderation } from './Moderation.sol'; | |
import { Distributor } from './Distributor.sol'; | |
contract FrontendUtils { | |
struct UserOverview { | |
bool isOwner; | |
bool isModerator; | |
uint userStakedAmount; | |
uint stakedTotal; | |
uint stakersCount; | |
uint userTokenBalance; | |
string tokenSymbol; | |
string tokenName; | |
uint tokenSupply; | |
uint tokenDecimals; | |
} | |
Moderation public immutable moderation; | |
Distributor public immutable distributor; | |
Staking public immutable staking; | |
constructor( | |
Moderation in_moderation, | |
Distributor in_distributor, | |
Staking in_staking | |
) { | |
moderation = in_moderation; | |
distributor = in_distributor; | |
staking = in_staking; | |
} | |
function getOverviewForUser(address who) | |
external view | |
returns (UserOverview memory) | |
{ | |
IERC20Metadata token = IERC20Metadata(address(staking.stakingToken())); | |
return UserOverview({ | |
isOwner: moderation.owner() == who, | |
isModerator: moderation.isModerator(who), | |
userStakedAmount: ValueField.unwrap(staking.getValue(who)), | |
stakedTotal: staking.getSum(), | |
stakersCount: staking.getCount(), | |
userTokenBalance: token.balanceOf(who), | |
tokenSymbol: token.symbol(), | |
tokenSupply: token.totalSupply(), | |
tokenName: token.name(), | |
tokenDecimals: token.decimals() | |
}); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SPDX-License-Identifier: MIT | |
pragma solidity ^0.8.0; | |
import { ERC721 } from "@openzeppelin/contracts/token/ERC721/ERC721.sol"; | |
contract MockNFT is ERC721 { | |
constructor() | |
ERC721("Mock NFT", "MOCK") | |
{ } | |
uint public tokenIdCounter; | |
function mint(address to, uint n) | |
external | |
{ | |
for( uint i = 0; i < n; i++ ) | |
{ | |
uint tokenId = tokenIdCounter; | |
tokenIdCounter += 1; | |
_safeMint(to, tokenId); | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SPDX-License-Identifier: MIT | |
pragma solidity ^0.8.0; | |
import { ERC20 } from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; | |
contract MockToken is ERC20 { | |
constructor() | |
ERC20("Mock NFT", "MOCK") | |
{ } | |
uint public tokenIdCounter; | |
function mint(address to, uint256 amount) | |
external | |
{ | |
_mint(to, amount); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SPDX-License-Identifier: AGPL-3.0-only | |
pragma solidity ^0.8.0; | |
import { Ownable } from '@openzeppelin/contracts/access/Ownable.sol'; | |
import { EnumerableSet } from '@openzeppelin/contracts/utils/structs/EnumerableSet.sol'; | |
contract Moderation is Ownable { | |
using EnumerableSet for EnumerableSet.AddressSet; | |
EnumerableSet.AddressSet private allowedNFTs; | |
EnumerableSet.AddressSet private allowedPosters; | |
uint public poolSize; | |
constructor (uint in_poolSize) | |
Ownable() | |
{ | |
poolSize = in_poolSize; | |
} | |
/// Retrieve list of allowed NFT collections | |
function getNFTList() | |
external view | |
returns (address[] memory out_addrs) | |
{ | |
uint n = allowedNFTs.length(); | |
out_addrs = new address[](n); | |
for( uint i = 0; i < n; i++ ) | |
{ | |
out_addrs[i] = allowedNFTs.at(i); | |
} | |
} | |
/// Retrieve list of moderators | |
function getModeratorList() | |
external view | |
returns (address[] memory out_addrs) | |
{ | |
uint n = allowedPosters.length(); | |
out_addrs = new address[](n); | |
for( uint i = 0; i < n; i++ ) | |
{ | |
out_addrs[i] = allowedPosters.at(i); | |
} | |
} | |
function setPoolSize(uint newPoolSize) | |
external onlyOwner | |
{ | |
poolSize = newPoolSize; | |
} | |
function modifyAllowedNFTs(address[] calldata in_nfts, bool state) | |
external onlyOwner | |
{ | |
for( uint i = 0; i < in_nfts.length; i++ ) | |
{ | |
address x = in_nfts[i]; | |
if( state == true && ! allowedNFTs.contains(x) ) | |
{ | |
allowedNFTs.add(x); | |
} | |
else if( state == false && allowedNFTs.contains(x) ) | |
{ | |
allowedNFTs.remove(x); | |
} | |
} | |
} | |
function modifyModerators(address[] calldata in_addrs, bool state) | |
external onlyOwner | |
{ | |
for( uint i = 0; i < in_addrs.length; i++ ) | |
{ | |
address x = in_addrs[i]; | |
if( state == true && ! allowedPosters.contains(x) ) | |
{ | |
allowedPosters.add(x); | |
} | |
else if( state == false && allowedPosters.contains(x) ) | |
{ | |
allowedPosters.remove(x); | |
} | |
} | |
} | |
function isModerator(address who) | |
external view | |
returns (bool) | |
{ | |
return allowedPosters.contains(who); | |
} | |
function isAllowed(address nft, address operator, address from) | |
external view | |
returns (bool) | |
{ | |
if( false == allowedNFTs.contains(nft) ) | |
{ | |
return false; | |
} | |
return allowedPosters.length() == 0 | |
|| (allowedPosters.contains(operator) || allowedPosters.contains(from)); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SPDX-License-Identifier: AGPL-3.0-only | |
pragma solidity ^0.8.0; | |
error RandomBytesFailure(); | |
address constant RANDOM_PRECOMPILE = 0x0100000000000000000000000000000000000001; | |
function randomBytes32() | |
view returns (bytes32) | |
{ | |
if( block.chainid == 1337 ) { | |
return keccak256(abi.encodePacked( | |
msg.sender, | |
msg.value, | |
block.number, | |
block.timestamp, | |
gasleft() | |
)); | |
} | |
else { | |
(bool success, bytes memory entropy) = RANDOM_PRECOMPILE.staticcall( | |
abi.encode(32, "") | |
); | |
if( success != true ) { | |
revert RandomBytesFailure(); | |
} | |
return bytes32(entropy); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { ethers } from 'hardhat'; | |
import { expect } from 'chai'; | |
import { Distributor, MockNFT, MockToken, Moderation, Staking } from '../src/contracts'; | |
import { HardhatEthersSigner } from '@nomicfoundation/hardhat-ethers/signers'; | |
describe('Staker', () => { | |
let d: Distributor; | |
let s: Staking; | |
let moderation: Moderation; | |
let mockNFT: MockNFT; | |
let mockToken: MockToken; | |
before(async () => { | |
mockNFT = await (await ethers.getContractFactory('MockNFT')).deploy(); | |
mockToken = await (await ethers.getContractFactory('MockToken')).deploy(); | |
const modf = await ethers.getContractFactory('Moderation'); | |
const m = moderation = await modf.deploy(10); | |
await m.waitForDeployment(); | |
const stf = await ethers.getContractFactory('SumtreeLibrary'); | |
const st = await stf.deploy(); | |
await st.waitForDeployment(); | |
const stakingFactory = await ethers.getContractFactory('Staking', { | |
libraries: { | |
'SumtreeLibrary': await st.getAddress() | |
} | |
}); | |
s = await stakingFactory.deploy(await mockToken.getAddress()); | |
const distributorFactory = await ethers.getContractFactory('Distributor'); | |
d = await distributorFactory.deploy(await s.getAddress(), await m.getAddress()); | |
await s.waitForDeployment(); | |
await d.waitForDeployment(); | |
//console.log('Deployed'); | |
}); | |
function calculateIterationCount(weights: Record<string,number>, sigmaThreshold: number = 5): number | |
{ | |
// Get minimum weight probability | |
const totalWeight = Object.values(weights).reduce((sum, w) => sum + w, 0); | |
const minProb = Math.min(...Object.values(weights).map(w => w / totalWeight)); | |
//console.log(' - Total Weight', totalWeight, 'Min Prob', minProb); | |
// Basic sample size calculation based on minimum probability | |
// n * p * (1-p) needs to be large enough for normal approximation | |
// and we want enough samples for reliable statistics | |
const NORMAL_APPROXIMATION_THRESHOLD_SQUARED = 25 | |
return Math.ceil(NORMAL_APPROXIMATION_THRESHOLD_SQUARED * sigmaThreshold * sigmaThreshold / minProb); | |
} | |
function verifyDistribution( | |
weightsMap: Record<string, number>, | |
resultsMap: Record<string, bigint>, | |
sigmaThreshold: number | |
) { | |
const totalItems = Object.values(resultsMap) | |
.reduce((sum, count) => sum + Number(count), 0); | |
const totalWeight = Object.values(weightsMap) | |
.reduce((sum, w) => sum + w, 0); | |
for (const [address, weight] of Object.entries(weightsMap)) { | |
const probability = weight / totalWeight; | |
const expectedCount = totalItems * probability; | |
// Calculate standard deviation | |
const stdDev = Math.sqrt(totalItems * probability * (1 - probability)); | |
// Calculate acceptable range | |
const margin = sigmaThreshold * stdDev; | |
const range: [number, number] = [ | |
Math.floor(expectedCount - margin), | |
Math.ceil(expectedCount + margin) | |
]; | |
const actualCount = Number(resultsMap[address] || 0n); | |
const withinRange = actualCount >= range[0] && actualCount <= range[1]; | |
expect(withinRange).eq(true); | |
} | |
} | |
it('Distribution', async () => { | |
const nSigners = 5; | |
const allSigners = await ethers.getSigners(); | |
if( allSigners.length < nSigners ) { | |
throw new Error('Not enough signers!'); | |
} | |
const signers: HardhatEthersSigner[] = []; | |
for( let i = 0; i < nSigners; i++ ) { | |
signers.push(allSigners[i]); | |
} | |
//const signers = allSigners.slice(nSigners); | |
const weightsMap: Record<string,number> = {}; | |
// Distribute tokens to the signers, then have the signers stake them | |
for( let i = 0; i < nSigners; i++ ) | |
{ | |
const w = Math.floor(1 + (Math.random() * 100)); | |
const x = signers[i]; | |
await (await mockToken.mint(x, w)).wait(); | |
const b = mockToken.connect(x); | |
await (await b.approve(await s.getAddress(), w)).wait(); | |
(await s.connect(x).stake(w)).wait(); | |
weightsMap[x.address] = w; | |
}; | |
// We must approve the NFTs before they can be used... | |
await (await moderation.modifyAllowedNFTs([await mockNFT.getAddress()], true)).wait(); | |
// Mint in batches of 10, which will be randomly distributed to signers | |
const batchCount = 10; | |
const sigmaThreshold = 4; | |
const m = calculateIterationCount(weightsMap, sigmaThreshold); | |
//console.log(' - Iteration Count', m, batchCount, m/batchCount); | |
for( let i = 0; i < Math.floor(m/batchCount); i++ ) | |
{ | |
const tx = await mockNFT.mint(await d.getAddress(), batchCount); | |
await tx.wait(); | |
} | |
// Collect the count of how many NFTs each account has | |
const resultsMap: Record<string,bigint> = {}; | |
for( const s of signers ) | |
{ | |
const a = s.address; | |
const b = await mockNFT.balanceOf(a); | |
resultsMap[a] = b; | |
} | |
verifyDistribution(weightsMap, resultsMap, sigmaThreshold); | |
}); | |
}); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SPDX-License-Identifier: AGPL-3.0-only | |
pragma solidity ^0.8.0; | |
import { IERC20Metadata } from "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol"; | |
import { Sumtree, SumtreeLibrary, ValueField } from './Sumtree.sol'; | |
contract Staking { | |
IERC20Metadata public immutable stakingToken; | |
Sumtree private tree; | |
using SumtreeLibrary for Sumtree; | |
constructor( address _stakingToken ) | |
{ | |
stakingToken = IERC20Metadata(_stakingToken); | |
} | |
error AlreadyStaked(); | |
error CannotStakeZero(); | |
error TransferFailed(); | |
error MustStakeWhole(); | |
function stake(ValueField amount) | |
external | |
{ | |
return stakeFor(msg.sender, amount); | |
} | |
function stakeFor(address who, ValueField amount) | |
public | |
{ | |
if( ValueField.unwrap(amount) == 0 ) revert CannotStakeZero(); | |
if( ValueField.unwrap(amount) % stakingToken.decimals() != 0 ) revert MustStakeWhole(); | |
if( tree.has(who) ) revert AlreadyStaked(); | |
bool ok = stakingToken.transferFrom(msg.sender, address(this), ValueField.unwrap(amount)); | |
if( false == ok ) revert TransferFailed(); | |
tree.add(who, amount); | |
} | |
function unstake() | |
external | |
{ | |
ValueField amount = tree.remove(msg.sender); | |
if( ValueField.unwrap(amount) > 0 ) | |
{ | |
stakingToken.transfer(msg.sender, ValueField.unwrap(amount)); | |
} | |
} | |
function getValue(address staker) | |
external view | |
returns (ValueField) | |
{ | |
if( tree.has(staker) ) { | |
return tree.nodes[tree.nodesByKey[staker]].value; | |
} | |
return ValueField.wrap(0); | |
} | |
function getSum() | |
external view | |
returns (uint) | |
{ | |
return tree.getSum(); | |
} | |
function getCount() | |
external view | |
returns (uint) | |
{ | |
return tree.getCount(); | |
} | |
function random(bytes32 seed) | |
external view | |
returns (address) | |
{ | |
return tree.random(seed); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SPDX-License-Identifier: AGPL-3.0-only | |
pragma solidity ^0.8.0; | |
type SumField is uint104; | |
type ValueField is uint88; | |
type NodeIndex is uint16; | |
struct Node { | |
SumField sum; | |
ValueField value; | |
NodeIndex left; | |
NodeIndex right; | |
NodeIndex count; | |
NodeIndex parent; | |
address key; | |
} | |
using { NodeIndex_add as +, NodeIndex_neq as !=, NodeIndex_eq as ==, NodeIndex_lt as <, NodeIndex_gt as > } for NodeIndex global; | |
using { ValueField_eq as ==, ValueField_lt as <, ValueField_gt as > } for ValueField global; | |
using { SumField_add as +, SumField_sub as -, SumField_gte as >=, SumField_lt as < } for SumField global; | |
function NodeIndex_add(NodeIndex a, NodeIndex b) pure returns (NodeIndex) { | |
return NodeIndex.wrap(NodeIndex.unwrap(a) + NodeIndex.unwrap(b)); | |
} | |
function NodeIndex_eq(NodeIndex a, NodeIndex b) pure returns (bool) { | |
return NodeIndex.unwrap(a) == NodeIndex.unwrap(b); | |
} | |
function NodeIndex_neq(NodeIndex a, NodeIndex b) pure returns (bool) { | |
return NodeIndex.unwrap(a) != NodeIndex.unwrap(b); | |
} | |
function NodeIndex_lt(NodeIndex a, NodeIndex b) pure returns (bool) { | |
return NodeIndex.unwrap(a) < NodeIndex.unwrap(b); | |
} | |
function NodeIndex_gt(NodeIndex a, NodeIndex b) pure returns (bool) { | |
return NodeIndex.unwrap(a) > NodeIndex.unwrap(b); | |
} | |
function ValueField_lt(ValueField a, ValueField b) pure returns (bool) { | |
return ValueField.unwrap(a) < ValueField.unwrap(b); | |
} | |
function ValueField_gt(ValueField a, ValueField b) pure returns (bool) { | |
return ValueField.unwrap(a) > ValueField.unwrap(b); | |
} | |
function ValueField_eq(ValueField a, ValueField b) pure returns (bool) { | |
return ValueField.unwrap(a) == ValueField.unwrap(b); | |
} | |
function SumField_add(SumField a, SumField b) pure returns (SumField) { | |
return SumField.wrap(SumField.unwrap(a) + SumField.unwrap(b)); | |
} | |
function SumField_gte(SumField a, SumField b) pure returns (bool) { | |
return SumField.unwrap(a) >= SumField.unwrap(b); | |
} | |
function SumField_sub(SumField a, SumField b) pure returns (SumField) { | |
return SumField.wrap(SumField.unwrap(a) - SumField.unwrap(b)); | |
} | |
function SumField_lt(SumField a, SumField b) pure returns (bool) { | |
return SumField.unwrap(a) < SumField.unwrap(b); | |
} | |
function to_SumField(ValueField x) pure returns (SumField) { | |
return SumField.wrap(ValueField.unwrap(x)); | |
} | |
struct Sumtree { | |
mapping(NodeIndex => Node) nodes; | |
mapping(address => NodeIndex) nodesByKey; // Direct key to nodeId mapping | |
NodeIndex nextNodeId; | |
NodeIndex root_id; | |
} | |
library SumtreeLibrary { | |
error DuplicateKey(address key); | |
NodeIndex private constant EMPTY = NodeIndex.wrap(0); | |
NodeIndex private constant NodeIndex_ONE = NodeIndex.wrap(1); | |
SumField private constant SumField_ZERO = SumField.wrap(0); | |
function add(Sumtree storage self, address key, ValueField value) | |
public | |
{ | |
if( has(self, key) ) { | |
revert DuplicateKey(key); | |
} | |
Node memory new_node = Node({ | |
sum: to_SumField(value), | |
value: value, | |
left: EMPTY, | |
right: EMPTY, | |
parent: EMPTY, | |
key: key, | |
count: NodeIndex_ONE | |
}); | |
NodeIndex new_id = self.nextNodeId = self.nextNodeId + NodeIndex_ONE; | |
self.nodes[new_id] = new_node; | |
self.nodesByKey[key] = new_id; | |
if( self.root_id == EMPTY ) { | |
self.root_id = new_id; | |
return; | |
} | |
NodeIndex current_id = self.root_id; | |
NodeIndex parent_id = EMPTY; | |
while( current_id != EMPTY ) | |
{ | |
parent_id = current_id; | |
Node storage currentNode = self.nodes[current_id]; | |
currentNode.sum = currentNode.sum + to_SumField(value); | |
currentNode.count = currentNode.count + NodeIndex_ONE; | |
if (value < currentNode.value || | |
(value == currentNode.value && uint160(key) < uint160(currentNode.key))) { | |
current_id = currentNode.left; | |
} else { | |
current_id = currentNode.right; | |
} | |
} | |
self.nodes[new_id].parent = parent_id; | |
Node storage parentNode = self.nodes[parent_id]; | |
if (value < parentNode.value || | |
(value == parentNode.value && uint160(key) < uint160(parentNode.key))) { | |
parentNode.left = new_id; | |
} else { | |
parentNode.right = new_id; | |
} | |
rebalance(self, parent_id); | |
} | |
// Right rotation | |
function rotateRight(Sumtree storage tree, NodeIndex y) | |
private | |
returns (NodeIndex) | |
{ | |
Node storage yNode = tree.nodes[y]; | |
NodeIndex x = tree.nodes[y].left; | |
Node storage xNode = tree.nodes[x]; | |
NodeIndex T2 = xNode.right; | |
// Update parent references | |
NodeIndex yParent = tree.nodes[y].parent; | |
xNode.parent = yParent; | |
yNode.parent = x; | |
if (T2 != EMPTY) { | |
tree.nodes[T2].parent = y; | |
} | |
// Perform rotation | |
xNode.right = y; | |
yNode.left = T2; | |
// Update counts and sums | |
updateCountAndSum(tree, y); | |
updateCountAndSum(tree, x); | |
// Update parent's child reference | |
if (yParent != EMPTY) { | |
Node storage yParentNode = tree.nodes[yParent]; | |
if (yParentNode.left == y) { | |
yParentNode.left = x; | |
} else { | |
yParentNode.right = x; | |
} | |
} else { | |
tree.root_id = x; | |
} | |
return x; | |
} | |
// Left rotation | |
function rotateLeft(Sumtree storage tree, NodeIndex x) | |
private | |
returns (NodeIndex) | |
{ | |
Node storage xNode = tree.nodes[x]; | |
NodeIndex y = xNode.right; | |
Node storage yNode = tree.nodes[y]; | |
NodeIndex T2 = yNode.left; | |
// Update parent references | |
NodeIndex xParent = xNode.parent; | |
yNode.parent = xParent; | |
xNode.parent = y; | |
if (T2 != EMPTY) { | |
tree.nodes[T2].parent = x; | |
} | |
// Perform rotation | |
yNode.left = x; | |
xNode.right = T2; | |
// Update counts and sums | |
updateCountAndSum(tree, x); | |
updateCountAndSum(tree, y); | |
// Update parent's child reference | |
if (xParent != EMPTY) { | |
Node storage xParentNode = tree.nodes[xParent]; | |
if (xParentNode.left == x) { | |
xParentNode.left = y; | |
} else { | |
xParentNode.right = y; | |
} | |
} else { | |
tree.root_id = y; | |
} | |
return y; | |
} | |
// Rebalance tree after insertion or deletion | |
function rebalance(Sumtree storage tree, NodeIndex nodeId) | |
private | |
{ | |
NodeIndex current = nodeId; | |
while (current != EMPTY) { | |
updateCountAndSum(tree, current); | |
int256 balance = getBalance(tree, current); | |
// We consider a subtree imbalanced if the difference in count is more than 2x | |
Node storage currentNode = tree.nodes[current]; | |
bool isImbalanced = balance > int256(uint256(NodeIndex.unwrap(getCount(tree, currentNode.right))) * 2) || | |
-balance > int256(uint256(NodeIndex.unwrap(getCount(tree, currentNode.left))) * 2); | |
if (isImbalanced) { | |
if (balance > 0) { | |
// Left heavy | |
if (getBalance(tree, currentNode.left) < 0) { | |
// Left-Right case | |
currentNode.left = rotateLeft(tree, currentNode.left); | |
} | |
current = rotateRight(tree, current); | |
} else { | |
// Right heavy | |
if (getBalance(tree, currentNode.right) > 0) { | |
// Right-Left case | |
currentNode.right = rotateRight(tree, currentNode.right); | |
} | |
current = rotateLeft(tree, current); | |
} | |
} | |
current = tree.nodes[current].parent; | |
} | |
} | |
function getCount(Sumtree storage tree, NodeIndex nodeId) private view returns (NodeIndex) { | |
return nodeId == EMPTY ? EMPTY : tree.nodes[nodeId].count; | |
} | |
function getBalance(Sumtree storage tree, NodeIndex nodeId) private view returns (int256) { | |
if (nodeId == EMPTY) return 0; | |
Node storage node = tree.nodes[nodeId]; | |
return int256(uint256(NodeIndex.unwrap(getCount(tree, node.left)))) - | |
int256(uint256(NodeIndex.unwrap(getCount(tree, node.right)))); | |
} | |
function updateCountAndSum(Sumtree storage tree, NodeIndex nodeId) private { | |
if (nodeId == EMPTY) return; | |
Node storage node = tree.nodes[nodeId]; | |
node.count = NodeIndex_ONE + getCount(tree, node.left) + getCount(tree, node.right); | |
SumField leftSum = node.left != EMPTY ? tree.nodes[node.left].sum : SumField_ZERO; | |
SumField rightSum = node.right != EMPTY ? tree.nodes[node.right].sum : SumField_ZERO; | |
node.sum = leftSum + to_SumField(node.value) + rightSum; | |
} | |
error KeyNotFound(address key); | |
function remove(Sumtree storage self, address key) | |
public | |
returns (ValueField value) | |
{ | |
if (!has(self, key)) { | |
revert KeyNotFound(key); | |
} | |
NodeIndex nodeId = self.nodesByKey[key]; | |
Node storage node = self.nodes[nodeId]; | |
value = node.value; | |
// Case 1: Node has no children or one child | |
NodeIndex replacementId; | |
if (node.left == EMPTY) { | |
replacementId = node.right; | |
} else if (node.right == EMPTY) { | |
replacementId = node.left; | |
} | |
// Case 2: Node has both children | |
else { | |
// Find successor (smallest value in right subtree) | |
NodeIndex successorId = node.right; | |
while (self.nodes[successorId].left != EMPTY) { | |
successorId = self.nodes[successorId].left; | |
} | |
Node storage successorNode = self.nodes[successorId]; | |
// Store the successor's original position details | |
NodeIndex successorParent = successorNode.parent; | |
NodeIndex successorRight = successorNode.right; | |
// If successor is not the immediate right child | |
if (successorParent != nodeId) { | |
// Replace successor with its right child | |
if (successorRight != EMPTY) { | |
self.nodes[successorRight].parent = successorParent; | |
} | |
self.nodes[successorParent].left = successorRight; | |
// Set successor's right to node's right child | |
successorNode.right = node.right; | |
self.nodes[node.right].parent = successorId; | |
} | |
// Move node's left child to successor | |
successorNode.left = node.left; | |
self.nodes[node.left].parent = successorId; | |
// Put successor in node's position | |
successorNode.parent = node.parent; | |
if (node.parent == EMPTY) { | |
self.root_id = successorId; | |
} else { | |
Node storage parentNode = self.nodes[node.parent]; | |
if (parentNode.left == nodeId) { | |
parentNode.left = successorId; | |
} else { | |
parentNode.right = successorId; | |
} | |
} | |
// Start rebalancing from successor's original parent | |
rebalance(self, successorParent != nodeId ? successorParent : successorId); | |
delete self.nodes[nodeId]; | |
//delete self.nodesByKey[key]; | |
self.nodesByKey[key] = NodeIndex.wrap(0); | |
return value; | |
} | |
// Handle Case 1 (no children or one child) | |
NodeIndex parentId = node.parent; | |
if (replacementId != EMPTY) { | |
self.nodes[replacementId].parent = parentId; | |
} | |
if (parentId != EMPTY) { | |
Node storage parentNode = self.nodes[parentId]; | |
if (parentNode.left == nodeId) { | |
parentNode.left = replacementId; | |
} else { | |
parentNode.right = replacementId; | |
} | |
rebalance(self, parentId); | |
} else { | |
self.root_id = replacementId; | |
} | |
delete self.nodes[nodeId]; | |
//delete self.nodesByKey[key]; | |
self.nodesByKey[key] = NodeIndex.wrap(0); | |
return value; | |
} | |
function replaceNode(Sumtree storage self, NodeIndex oldId, NodeIndex newId) | |
private | |
{ | |
NodeIndex parentId = self.nodes[oldId].parent; | |
// Update parent's child pointer | |
if (parentId != EMPTY) { | |
Node storage parentNode = self.nodes[parentId]; | |
if (parentNode.left == oldId) { | |
parentNode.left = newId; | |
} else { | |
parentNode.right = newId; | |
} | |
} else { | |
self.root_id = newId; | |
} | |
// Update new node's parent pointer | |
if (newId != EMPTY) { | |
self.nodes[newId].parent = parentId; | |
} | |
} | |
error EmptyTree(); | |
error RandomFailed(); | |
error OutOfRange(); | |
function random(Sumtree storage self, bytes32 seed) | |
internal view | |
returns (address) | |
{ | |
uint seedHashed = uint(keccak256(abi.encodePacked(seed))); | |
uint remaining_r = (seedHashed % getSum(self)); | |
return pick(self, remaining_r); | |
} | |
function pick(Sumtree storage self, uint random_value) | |
internal view | |
returns (address) | |
{ | |
if( self.root_id == EMPTY ) { | |
revert EmptyTree(); | |
} | |
if( random_value >= getSum(self) ) { | |
revert OutOfRange(); | |
} | |
NodeIndex current_id = self.root_id; | |
while( current_id != EMPTY ) | |
{ | |
Node storage current = self.nodes[current_id]; | |
SumField left_sum = SumField_ZERO; | |
// Get sum of left subtree | |
if( current.left != EMPTY ) { | |
left_sum = self.nodes[current.left].sum; | |
} | |
// If random_value falls in left subtree | |
if( random_value < SumField.unwrap(left_sum) ) | |
{ | |
current_id = current.left; | |
continue; | |
} | |
// If random_value falls in current node's range | |
if( random_value < SumField.unwrap(left_sum + to_SumField(current.value)) ) { | |
return current.key; | |
} | |
random_value = random_value - SumField.unwrap((left_sum + to_SumField(current.value))); | |
current_id = current.right; | |
} | |
revert RandomFailed(); | |
} | |
function getCount(Sumtree storage self) | |
internal view | |
returns (uint) | |
{ | |
if( self.root_id == EMPTY ) { | |
return 0; | |
} | |
return NodeIndex.unwrap(self.nodes[self.root_id].count); | |
} | |
function getSum(Sumtree storage self) | |
internal view | |
returns (uint) | |
{ | |
if( self.root_id == EMPTY ) { | |
return 0; | |
} | |
return SumField.unwrap(self.nodes[self.root_id].sum); | |
} | |
function has(Sumtree storage self, address key) | |
internal view | |
returns (bool) | |
{ | |
return self.nodesByKey[key] != EMPTY; | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { ethers } from 'hardhat'; | |
import { SumtreeLibrary, TestSumtree } from '../src/contracts'; | |
import { randomBytes, hexlify, getAddress, AddressLike } from 'ethers'; | |
import { expect } from 'chai'; | |
describe('Sumtree', () => { | |
let tst: TestSumtree; | |
let sumtreeLibrary: SumtreeLibrary; | |
before(async () => { | |
const stf = await ethers.getContractFactory('SumtreeLibrary'); | |
sumtreeLibrary = await stf.deploy(); | |
await sumtreeLibrary.waitForDeployment(); | |
}) | |
beforeEach(async () => { | |
const f = await ethers.getContractFactory('TestSumtree', { | |
libraries: { | |
'SumtreeLibrary': await sumtreeLibrary.getAddress() | |
} | |
}); | |
tst = await f.deploy() | |
await tst.waitForDeployment(); | |
}); | |
interface ValidationResult { | |
isValid: boolean; | |
actualSum?: bigint; | |
actualCount?: bigint; | |
} | |
async function printNode(nodeId: bigint, indent:number = 0, prefix:string='') { | |
const n = await tst.node(nodeId); | |
console.log(' '.repeat(indent), prefix, `id=${nodeId} s=${n.sum} v=${n.value} c=${n.count} k=${n.key} p=${n.parent}`); | |
if( n.left != 0n ) { | |
await printNode(n.left, indent + 2, 'L'); | |
} | |
if( n.right != 0n ) { | |
await printNode(n.right, indent + 2, 'R'); | |
} | |
} | |
async function validateNode(nodeId: bigint): Promise<ValidationResult> { | |
if (nodeId === 0n) { | |
return { isValid: true, actualSum: BigInt(0), actualCount: 0n }; | |
} | |
const node = await tst.node(nodeId); | |
// Validate left subtree | |
if (node.left !== 0n) { | |
const leftNode = await tst.node(node.left); | |
// Check ordering | |
if (leftNode.value > node.value || | |
(leftNode.value === node.value && BigInt(leftNode.key) >= BigInt(node.key))) { | |
console.error(`Order violation at node ${nodeId} with left child ${node.left}`); | |
console.error(`Parent: (${node.value.toString()}, ${node.key})`); | |
console.error(`Left: (${leftNode.value.toString()}, ${leftNode.key})`); | |
return { isValid: false }; | |
} | |
// Check parent pointer | |
if (leftNode.parent !== nodeId) { | |
console.error(`Parent pointer mismatch: node ${node.left} should point to ${nodeId}`); | |
return { isValid: false }; | |
} | |
} | |
// Validate right subtree | |
if (node.right !== 0n) { | |
const rightNode = await tst.node(node.right); | |
// Check ordering | |
if (rightNode.value < node.value || | |
(rightNode.value === node.value && BigInt(rightNode.key) <= BigInt(node.key))) { | |
console.error(`Order violation at node ${nodeId} with right child ${node.right}`); | |
console.error(`Parent: (${node.value.toString()}, ${node.key})`); | |
console.error(`Right: (${rightNode.value.toString()}, ${rightNode.key})`); | |
return { isValid: false }; | |
} | |
// Check parent pointer | |
if (rightNode.parent !== nodeId) { | |
console.error(`Parent pointer mismatch: node ${node.right} should point to ${nodeId}`); | |
return { isValid: false }; | |
} | |
} | |
// Recursively validate children and get their sums and counts | |
const leftResult = await validateNode(node.left); | |
if (!leftResult.isValid) return { isValid: false }; | |
const rightResult = await validateNode(node.right); | |
if (!rightResult.isValid) return { isValid: false }; | |
// Calculate actual sum and count | |
const actualSum = leftResult.actualSum! + node.value + rightResult.actualSum!; | |
const actualCount = leftResult.actualCount! + 1n + rightResult.actualCount!; | |
// Validate sum | |
if (node.sum !== actualSum) { | |
console.error(`Sum mismatch at node ${nodeId}:`); | |
console.error(`Expected: ${node.sum.toString()}`); | |
console.error(`Actual: ${actualSum.toString()}`); | |
return { isValid: false }; | |
} | |
// Validate count | |
if (node.count !== actualCount) { | |
console.error(`Count mismatch at node ${nodeId}:`); | |
console.error(`Expected: ${node.count}`); | |
console.error(`Actual: ${actualCount}`); | |
return { isValid: false }; | |
} | |
return { | |
isValid: true, | |
actualSum, | |
actualCount | |
}; | |
} | |
async function validateTree(root: bigint): Promise<boolean> { | |
try { | |
const result = await validateNode(root); | |
if (result.isValid) { | |
//console.log("Tree is valid:"); | |
//console.log(`Total sum: ${result.actualSum!.toString()}`); | |
//console.log(`Total count: ${result.actualCount}`); | |
} | |
return result.isValid; | |
} catch (error) { | |
//console.error("Error validating tree:", error); | |
return false; | |
} | |
} | |
interface NodeValue { | |
value: bigint; | |
key: string; | |
} | |
async function treeToSortedList(): Promise<NodeValue[]> { | |
const values: NodeValue[] = []; | |
const rootId = await tst.root(); | |
async function inorderTraversal(nodeId: bigint): Promise<void> { | |
if (nodeId === 0n) return; | |
const node = await tst.node(nodeId); | |
// Traverse left | |
if (node.left !== 0n) { | |
await inorderTraversal(node.left); | |
} | |
// Add current node | |
values.push({ | |
value: node.value, | |
key: node.key | |
}); | |
// Traverse right | |
if (node.right !== 0n) { | |
await inorderTraversal(node.right); | |
} | |
} | |
await inorderTraversal(rootId); | |
return values; | |
} | |
it('Simple left rotation', async () => { | |
// Tree becomes: | |
// 100 200 | |
// \ -> / \ | |
// 200 100 300 | |
// \ | |
// 300 | |
// Insert in increasing value order | |
const pairs: [bigint,AddressLike][] = [ | |
[100n, getAddress('0x' + '1'.repeat(40))], // id=1 | |
[200n, getAddress('0x' + '2'.repeat(40))], // id=2 | |
[300n, getAddress('0x' + '3'.repeat(40))], // id=3 | |
]; | |
for( const [w, a] of pairs ) { | |
const tx = await tst.add(a, w); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
} | |
//await printNode(await tst.root()); | |
expect(await tst.root()).eq(2n); | |
expect((await tst.node(2n)).left).eq(1n); | |
expect((await tst.node(2n)).right).eq(3n); | |
}); | |
it('Simple right rotation', async () => { | |
// Tree becomes: | |
// 300 200 | |
// / -> / \ | |
// 200 100 300 | |
// / | |
//100 | |
// Insert in decreasing value order | |
const pairs: [bigint,AddressLike][] = [ | |
[300n, getAddress('0x' + '4'.repeat(40))], // id=1 | |
[200n, getAddress('0x' + '5'.repeat(40))], // id=2 | |
[100n, getAddress('0x' + '6'.repeat(40))], // id=3 | |
]; | |
for( const [w, a] of pairs ) { | |
const tx = await tst.add(a, w); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
} | |
//await printNode(await tst.root()); | |
expect(await tst.root()).eq(2n); | |
expect((await tst.node(2n)).left).eq(3n); | |
expect((await tst.node(2n)).right).eq(1n); | |
}); | |
it('Right-right case (double rotation)', async () => { | |
// Tree becomes: | |
// 100 100 200 | |
// \ -> \ -> / \ | |
// 300 200 100 300 | |
// / \ | |
// 200 300 | |
const pairs: [bigint,AddressLike][] = [ | |
[100n, getAddress('0x' + '7'.repeat(40))], // id=1 | |
[300n, getAddress('0x' + '8'.repeat(40))], // id=2 | |
[200n, getAddress('0x' + '9'.repeat(40))], // id=3 | |
]; | |
for( const [w, a] of pairs ) { | |
const tx = await tst.add(a, w); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
} | |
//await printNode(await tst.root()); | |
expect(await tst.root()).eq(3n); | |
expect((await tst.node(3n)).left).eq(1n); | |
expect((await tst.node(3n)).right).eq(2n); | |
}); | |
it('Left-right case (double rotation)', async () => { | |
// Tree becomes: | |
// 300 300 200 | |
// / -> / -> / \ | |
// 100 200 100 300 | |
// \ / | |
// 200 100 | |
const pairs: [bigint,AddressLike][] = [ | |
[300n, getAddress('0x' + '7'.repeat(40))], // id=1 | |
[100n, getAddress('0x' + '8'.repeat(40))], // id=2 | |
[200n, getAddress('0x' + '9'.repeat(40))], // id=3 | |
]; | |
for( const [w, a] of pairs ) { | |
const tx = await tst.add(a, w); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
} | |
//await printNode(await tst.root()); | |
expect(await tst.root()).eq(3n); | |
expect((await tst.node(3n)).left).eq(2n); | |
expect((await tst.node(3n)).right).eq(1n); | |
}); | |
it('Key-based ordering (same values)', async () => { | |
const pairs: [bigint,AddressLike][] = [ | |
[100n, getAddress('0x' + '1'.repeat(40))], // id=1 | |
[100n, getAddress('0x' + '2'.repeat(40))], // id=2 | |
[100n, getAddress('0x' + '3'.repeat(40))], // id=3 | |
]; | |
for( const [w, a] of pairs ) { | |
const tx = await tst.add(a, w); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
} | |
//await printNode(await tst.root()); | |
expect(await tst.root()).eq(2n); | |
expect((await tst.node(2n)).left).eq(1n); | |
expect((await tst.node(2n)).right).eq(3n); | |
}); | |
it('Remove leaf node from balanced tree', async () => { | |
// Tree structure: | |
// 200 | |
// / \ | |
// 100 300 | |
// | |
// After removing 100: | |
// 300 | |
// / | |
// 200 | |
const pairs: [bigint,AddressLike][] = [ | |
[200n, getAddress('0x' + '1'.repeat(40))], // id=1 | |
[100n, getAddress('0x' + '2'.repeat(40))], // id=2 | |
[300n, getAddress('0x' + '3'.repeat(40))], // id=3 | |
]; | |
// Build initial tree | |
for (const [w, a] of pairs) { | |
const tx = await tst.add(a, w); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
} | |
//await printNode(await tst.root()); | |
// Remove leaf node 100 | |
const tx = await tst.remove(pairs[1][1]); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
//await printNode(await tst.root()); | |
// Verify structure | |
const rootNodeId = await tst.root() | |
expect(rootNodeId).eq(3n); | |
expect((await tst.node(rootNodeId)).left).eq(1n); | |
expect((await tst.node(rootNodeId)).right).eq(0n); | |
}); | |
it('Remove node with one child', async () => { | |
// Initial tree: | |
// 200 | |
// / \ | |
// 100 300 | |
// / | |
// 50 | |
// | |
// After removing 100: | |
// 200 | |
// / \ | |
// 50 300 | |
const pairs: [bigint,AddressLike][] = [ | |
[200n, getAddress('0x' + '1'.repeat(40))], // id=1 | |
[100n, getAddress('0x' + '2'.repeat(40))], // id=2 | |
[300n, getAddress('0x' + '3'.repeat(40))], // id=3 | |
[50n, getAddress('0x' + '4'.repeat(40))], // id=4 | |
]; | |
// Build initial tree | |
for (const [w, a] of pairs) { | |
const tx = await tst.add(a, w); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
} | |
// Remove node 100 | |
const tx = await tst.remove(pairs[1][1]); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
// Verify structure | |
expect(await tst.root()).eq(1n); | |
expect((await tst.node(1n)).left).eq(4n); | |
expect((await tst.node(1n)).right).eq(3n); | |
}); | |
it('Remove node with two children', async () => { | |
// Initial tree: | |
// 200 | |
// / \ | |
// 100 300 | |
// / \ | |
// 50 150 | |
// | |
// After removing 100: | |
// 200 | |
// / \ | |
// 50 300 | |
// / | |
// 150 | |
const pairs: [bigint,AddressLike][] = [ | |
[200n, getAddress('0x' + '1'.repeat(40))], // id=1 | |
[100n, getAddress('0x' + '2'.repeat(40))], // id=2 | |
[300n, getAddress('0x' + '3'.repeat(40))], // id=3 | |
[50n, getAddress('0x' + '4'.repeat(40))], // id=4 | |
[150n, getAddress('0x' + '5'.repeat(40))], // id=5 | |
]; | |
// Build initial tree | |
for (const [w, a] of pairs) { | |
const tx = await tst.add(a, w); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
} | |
//await printNode(await tst.root()); | |
// Remove node 100 | |
const tx = await tst.remove(pairs[1][1]); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
/* | |
await printNode(await tst.root()); | |
// Verify structure | |
expect(await tst.root()).eq(1n); | |
expect((await tst.node(1n)).left).eq(5n); | |
expect((await tst.node(1n)).right).eq(3n); | |
expect((await tst.node(5n)).left).eq(4n); | |
*/ | |
}); | |
it('Remove root node triggers rebalance', async () => { | |
// Initial tree: | |
// 200 | |
// / \ | |
// 100 300 | |
// / \ \ | |
// 50 150 400 | |
// | |
// After removing 200: | |
// 300 | |
// / \ | |
// 100 400 | |
// / \ | |
// 50 150 | |
const pairs: [bigint,AddressLike][] = [ | |
[200n, getAddress('0x' + '1'.repeat(40))], // id=1 | |
[100n, getAddress('0x' + '2'.repeat(40))], // id=2 | |
[300n, getAddress('0x' + '3'.repeat(40))], // id=3 | |
[50n, getAddress('0x' + '4'.repeat(40))], // id=4 | |
[150n, getAddress('0x' + '5'.repeat(40))], // id=5 | |
[400n, getAddress('0x' + '6'.repeat(40))], // id=6 | |
]; | |
// Build initial tree | |
for (const [w, a] of pairs) { | |
const tx = await tst.add(a, w); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
} | |
// Remove root node | |
const tx = await tst.remove(pairs[0][1]); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
// Verify structure | |
expect(await tst.root()).eq(3n); | |
expect((await tst.node(3n)).left).eq(2n); | |
expect((await tst.node(3n)).right).eq(6n); | |
expect((await tst.node(2n)).left).eq(4n); | |
expect((await tst.node(2n)).right).eq(5n); | |
}); | |
it('Insert after removal maintains valid tree', async () => { | |
// Initial tree: | |
// 200 | |
// / \ | |
// 100 300 | |
// / \ | |
// 50 150 | |
const pairs: [bigint,AddressLike][] = [ | |
[200n, getAddress('0x' + '1'.repeat(40))], // id=1 | |
[100n, getAddress('0x' + '2'.repeat(40))], // id=2 | |
[300n, getAddress('0x' + '3'.repeat(40))], // id=3 | |
[50n, getAddress('0x' + '4'.repeat(40))], // id=4 | |
[150n, getAddress('0x' + '5'.repeat(40))], // id=5 | |
]; | |
// Build initial tree | |
for (const [w, a] of pairs) { | |
const tx = await tst.add(a, w); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
} | |
// Remove node with value 100 | |
const tx = await tst.remove(pairs[1][1]); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
// Insert new nodes to test different scenarios | |
const newPairs: [bigint,AddressLike][] = [ | |
[125n, getAddress('0x' + '6'.repeat(40))], // Between 50 and 150 | |
[175n, getAddress('0x' + '7'.repeat(40))], // Between 150 and 200 | |
[250n, getAddress('0x' + '8'.repeat(40))], // Between 200 and 300 | |
]; | |
for (const [w, a] of newPairs) { | |
const tx = await tst.add(a, w); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
} | |
// Verify final tree structure maintains all invariants | |
expect(await validateTree(await tst.root())).eq(true); | |
// Optional: Print final tree state for debugging | |
//await printNode(await tst.root()); | |
}); | |
it('Remove should fail for non-existent key', async () => { | |
const nonExistentAddress = getAddress('0x' + 'f'.repeat(40)); | |
await expect(tst.remove(nonExistentAddress)) | |
.to.be.revertedWithCustomError(sumtreeLibrary, 'KeyNotFound') | |
.withArgs(nonExistentAddress); | |
}); | |
it('Verify sum updates after removal', async () => { | |
// Add three nodes | |
const pairs: [bigint,AddressLike][] = [ | |
[200n, getAddress('0x' + '1'.repeat(40))], | |
[100n, getAddress('0x' + '2'.repeat(40))], | |
[300n, getAddress('0x' + '3'.repeat(40))], | |
]; | |
for (const [w, a] of pairs) { | |
const tx = await tst.add(a, w); | |
await tx.wait(); | |
} | |
// Get initial sum | |
const initialSum = (await tst.node(await tst.root())).sum; | |
expect(initialSum).eq(600n); // 200 + 100 + 300 | |
// Remove middle node (100) | |
const tx = await tst.remove(pairs[1][1]); | |
await tx.wait(); | |
// Verify sum is updated | |
const finalSum = (await tst.node(await tst.root())).sum; | |
expect(finalSum).eq(500n); // 200 + 300 | |
}); | |
function validateOrdering(list: NodeValue[]): { isValid: boolean; violations: string[] } { | |
const violations: string[] = []; | |
for (let i = 0; i < list.length - 1; i++) { | |
const current = list[i]; | |
const next = list[i + 1]; | |
// Check if next is greater than current | |
if (next.value < current.value || | |
(next.value === current.value && next.key.toLowerCase() <= current.key.toLowerCase())) { | |
violations.push( | |
`Ordering violation at index ${i}:\n` + | |
`Current: (value: ${current.value.toString()}, key: ${current.key})\n` + | |
`Next: (value: ${next.value.toString()}, key: ${next.key})` | |
); | |
} | |
} | |
return { | |
isValid: violations.length === 0, | |
violations | |
}; | |
} | |
it('Random insert & verify pick', async () => { | |
for( let i = 0; i < 100; i++ ) { | |
const a = getAddress(hexlify(randomBytes(20))); | |
const w = 1 + Math.floor(Math.random() * 1000); | |
const tx = await tst.add(a, w); | |
await tx.wait(); | |
expect(await validateTree(await tst.root())).eq(true); | |
const x = await treeToSortedList(); | |
const y = validateOrdering(x); | |
if(! y.isValid ) { | |
console.log(x); | |
console.log(y.violations); | |
} | |
expect(y.isValid).eq(true); | |
} | |
//await printNode(await tst.root()); | |
const x = await treeToSortedList(); | |
const y = validateOrdering(x); | |
expect(y.isValid).eq(true); | |
let total: bigint = 0n; | |
for( const z of x ) | |
{ | |
const totalBefore = total; | |
total += z.value; | |
// If we pick the beginning & end of range we get the expected | |
const a = await tst.pick(totalBefore); | |
const b = await tst.pick(total - 1n); | |
expect(a).eq(b); | |
expect(b).eq(z.key); | |
// Somewhere in the middle | |
const c = await tst.pick(totalBefore + ((total - 1n - totalBefore) / 2n)); | |
expect(c).eq(z.key); | |
} | |
}); | |
it('Random insert, remove & verify pick', async () => { | |
// Keep track of all addresses and their current presence in tree | |
const addressMap = new Map<string, boolean>(); | |
const existingAddresses: string[] = []; | |
let numNodes = 0; | |
for (let i = 0; i < 200; i++) { | |
// 70% chance to add, 30% chance to remove when we have nodes | |
const shouldAdd = numNodes === 0 || Math.random() < 0.7; | |
if (shouldAdd) { | |
// Add new node | |
const a = getAddress(hexlify(randomBytes(20))); | |
const w = 1 + Math.floor(Math.random() * 1000); | |
const tx = await tst.add(a, w); | |
const receipt = await tx.wait(); | |
//console.log('Insert cost', receipt?.cumulativeGasUsed, numNodes); | |
addressMap.set(a.toLowerCase(), true); | |
existingAddresses.push(a); | |
numNodes++; | |
// Verify tree is valid after addition | |
expect(await validateTree(await tst.root())).eq(true); | |
} else { | |
// Remove random existing node | |
const indexToRemove = Math.floor(Math.random() * existingAddresses.length); | |
const addressToRemove = existingAddresses[indexToRemove]; | |
if (addressMap.get(addressToRemove.toLowerCase())) { | |
const tx = await tst.remove(addressToRemove); | |
const receipt = await tx.wait(); | |
//console.log('Remove cost', receipt?.cumulativeGasUsed, numNodes); | |
addressMap.set(addressToRemove.toLowerCase(), false); | |
numNodes--; | |
// Verify tree is valid after removal | |
expect(await validateTree(await tst.root())).eq(true); | |
} | |
} | |
// Verify ordering after each operation | |
const x = await treeToSortedList(); | |
const y = validateOrdering(x); | |
if (!y.isValid) { | |
console.log("Operation:", i); | |
console.log("Current tree:", x); | |
console.log("Violations:", y.violations); | |
} | |
expect(y.isValid).eq(true); | |
// Every 10 operations, verify pick functionality for all nodes | |
if (i % 10 === 0) { | |
const x = await treeToSortedList(); | |
let total: bigint = 0n; | |
for (const z of x) { | |
const totalBefore = total; | |
total += z.value; | |
// Verify picks at range boundaries | |
const a = await tst.pick(totalBefore); | |
const b = await tst.pick(total - 1n); | |
expect(a).eq(b); | |
expect(b).eq(z.key); | |
// Verify pick in middle of range | |
const c = await tst.pick(totalBefore + ((total - 1n - totalBefore) / 2n)); | |
expect(c).eq(z.key); | |
} | |
} | |
} | |
// Final verification of entire tree | |
const finalList = await treeToSortedList(); | |
expect(validateOrdering(finalList).isValid).eq(true); | |
// Verify final tree size matches our tracking | |
expect(finalList.length).eq(numNodes); | |
// Final verification of pick functionality for all nodes | |
let total: bigint = 0n; | |
for (const z of finalList) { | |
const totalBefore = total; | |
total += z.value; | |
const a = await tst.pick(totalBefore); | |
const b = await tst.pick(total - 1n); | |
expect(a).eq(b); | |
expect(b).eq(z.key); | |
const c = await tst.pick(totalBefore + ((total - 1n - totalBefore) / 2n)); | |
expect(c).eq(z.key); | |
} | |
}); | |
}); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SPDX-License-Identifier: MIT | |
pragma solidity ^0.8.0; | |
import {Node, Sumtree, SumtreeLibrary, ValueField, NodeIndex} from '../Sumtree.sol'; | |
contract TestSumtree { | |
Sumtree private st; | |
using SumtreeLibrary for Sumtree; | |
function add(address k, ValueField v) external { | |
st.add(k,v); | |
} | |
function remove(address key) public returns (ValueField value) { | |
return st.remove(key); | |
} | |
function random(bytes32 seed) public view returns (address) { | |
return st.random(seed); | |
} | |
function pick(uint r) public view returns (address) { | |
return st.pick(r); | |
} | |
function count() public view returns (uint) { | |
return st.getCount(); | |
} | |
function total() public view returns (uint) { | |
return st.getSum(); | |
} | |
function root() public view returns (NodeIndex) { | |
return st.root_id; | |
} | |
function node(NodeIndex idx) public view returns (Node memory) | |
{ | |
return st.nodes[idx]; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment