Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jessiepathfinder/71babe4dd7ada987aaed9d01078ed040 to your computer and use it in GitHub Desktop.
Save jessiepathfinder/71babe4dd7ada987aaed9d01078ed040 to your computer and use it in GitHub Desktop.
ZKBridge.sol
pragma solidity ^0.7.0;
pragma experimental "ABIEncoderV2";
library Bits {
uint constant internal ONE = uint(1);
uint constant internal ONES = uint(~0);
// Sets the bit at the given 'index' in 'self' to '1'.
// Returns the modified value.
function setBit(uint self, uint8 index) internal pure returns (uint) {
return self | ONE << index;
}
// Sets the bit at the given 'index' in 'self' to '0'.
// Returns the modified value.
function clearBit(uint self, uint8 index) internal pure returns (uint) {
return self & ~(ONE << index);
}
// Sets the bit at the given 'index' in 'self' to:
// '1' - if the bit is '0'
// '0' - if the bit is '1'
// Returns the modified value.
function toggleBit(uint self, uint8 index) internal pure returns (uint) {
return self ^ ONE << index;
}
// Get the value of the bit at the given 'index' in 'self'.
function bit(uint self, uint8 index) internal pure returns (uint8) {
return uint8(self >> index & 1);
}
// Check if the bit at the given 'index' in 'self' is set.
// Returns:
// 'true' - if the value of the bit is '1'
// 'false' - if the value of the bit is '0'
function bitSet(uint self, uint8 index) internal pure returns (bool) {
return self >> index & 1 == 1;
}
// Checks if the bit at the given 'index' in 'self' is equal to the corresponding
// bit in 'other'.
// Returns:
// 'true' - if both bits are '0' or both bits are '1'
// 'false' - otherwise
function bitEqual(uint self, uint other, uint8 index) internal pure returns (bool) {
return (self ^ other) >> index & 1 == 0;
}
// Get the bitwise NOT of the bit at the given 'index' in 'self'.
function bitNot(uint self, uint8 index) internal pure returns (uint8) {
return uint8(1 - (self >> index & 1));
}
// Computes the bitwise AND of the bit at the given 'index' in 'self', and the
// corresponding bit in 'other', and returns the value.
function bitAnd(uint self, uint other, uint8 index) internal pure returns (uint8) {
return uint8((self & other) >> index & 1);
}
// Computes the bitwise OR of the bit at the given 'index' in 'self', and the
// corresponding bit in 'other', and returns the value.
function bitOr(uint self, uint other, uint8 index) internal pure returns (uint8) {
return uint8((self | other) >> index & 1);
}
// Computes the bitwise XOR of the bit at the given 'index' in 'self', and the
// corresponding bit in 'other', and returns the value.
function bitXor(uint self, uint other, uint8 index) internal pure returns (uint8) {
return uint8((self ^ other) >> index & 1);
}
// Gets 'numBits' consecutive bits from 'self', starting from the bit at 'startIndex'.
// Returns the bits as a 'uint'.
// Requires that:
// - '0 < numBits <= 256'
// - 'startIndex < 256'
// - 'numBits + startIndex <= 256'
function bits(uint self, uint8 startIndex, uint16 numBits) internal pure returns (uint) {
require(0 < numBits && startIndex < 256 && startIndex + numBits <= 256);
return self >> startIndex & ONES >> 256 - numBits;
}
// Computes the index of the highest bit set in 'self'.
// Returns the highest bit set as an 'uint8'.
// Requires that 'self != 0'.
function highestBitSet(uint self) internal pure returns (uint8 highest) {
require(self != 0);
uint val = self;
for (uint8 i = 128; i >= 1; i >>= 1) {
if (val & (ONE << i) - 1 << i != 0) {
highest += i;
val >>= i;
}
}
}
// Computes the index of the lowest bit set in 'self'.
// Returns the lowest bit set as an 'uint8'.
// Requires that 'self != 0'.
function lowestBitSet(uint self) internal pure returns (uint8 lowest) {
require(self != 0);
uint val = self;
for (uint8 i = 128; i >= 1; i >>= 1) {
if (val & (ONE << i) - 1 == 0) {
lowest += i;
val >>= i;
}
}
}
}
/*
* Data structures and utilities used in the Patricia Tree.
*
* More info at: https://github.com/chriseth/patricia-trie
*/
library Data {
struct Label {
bytes32 data;
uint length;
}
struct Edge {
bytes32 node;
Label label;
}
struct Node {
Edge[2] children;
}
struct Tree {
bytes32 root;
Data.Edge rootEdge;
mapping(bytes32 => Data.Node) nodes;
}
// Returns a label containing the longest common prefix of `self` and `label`,
// and a label consisting of the remaining part of `label`.
function splitCommonPrefix(Label memory self, Label memory other) internal pure returns (
Label memory prefix,
Label memory labelSuffix
) {
return splitAt(self, commonPrefix(self, other));
}
// Splits the label at the given position and returns prefix and suffix,
// i.e. 'prefix.length == pos' and 'prefix.data . suffix.data == l.data'.
function splitAt(Label memory self, uint pos) internal pure returns (Label memory prefix, Label memory suffix) {
assert(pos <= self.length && pos <= 256);
prefix.length = pos;
if (pos == 0) {
prefix.data = bytes32(0);
} else {
prefix.data = bytes32(uint(self.data) & ~uint(1) << 255 - pos);
}
suffix.length = self.length - pos;
suffix.data = self.data << pos;
}
// Returns the length of the longest common prefix of the two labels.
/*
function commonPrefix(Label memory self, Label memory other) internal pure returns (uint prefix) {
uint length = self.length < other.length ? self.length : other.length;
// TODO: This could actually use a "highestBitSet" helper
uint diff = uint(self.data ^ other.data);
uint mask = uint(1) << 255;
for (; prefix < length; prefix++) {
if ((mask & diff) != 0) {
break;
}
diff += diff;
}
}
*/
function commonPrefix(Label memory self, Label memory other) internal pure returns (uint prefix) {
uint length = self.length < other.length ? self.length : other.length;
if (length == 0) {
return 0;
}
uint diff = uint(self.data ^ other.data) & ~uint(0) << 256 - length; // TODO Mask should not be needed.
if (diff == 0) {
return length;
}
return 255 - Bits.highestBitSet(diff);
}
// Returns the result of removing a prefix of length `prefix` bits from the
// given label (shifting its data to the left).
function removePrefix(Label memory self, uint prefix) internal pure returns (Label memory r) {
require(prefix <= self.length);
r.length = self.length - prefix;
r.data = self.data << prefix;
}
// Removes the first bit from a label and returns the bit and a
// label containing the rest of the label (shifted to the left).
function chopFirstBit(Label memory self) internal pure returns (uint firstBit, Label memory tail) {
require(self.length > 0);
return (uint(self.data >> 255), Label(self.data << 1, self.length - 1));
}
function edgeHash(Data.Edge memory self) internal pure returns (bytes32) {
return keccak256(abi.encode(self.node, self.label.length, self.label.data));
}
// Returns the hash of the encoding of a node.
function hash(Data.Node memory self) internal pure returns (bytes32) {
return keccak256(abi.encode(edgeHash(self.children[0]), edgeHash(self.children[1])));
}
function insertNode(Data.Tree storage tree, Data.Node memory n) internal returns (bytes32 newHash) {
bytes32 h = hash(n);
tree.nodes[h].children[0] = n.children[0];
tree.nodes[h].children[1] = n.children[1];
return h;
}
function replaceNode(Data.Tree storage self, bytes32 oldHash, Data.Node memory n) internal returns (bytes32 newHash) {
delete self.nodes[oldHash];
return insertNode(self, n);
}
function insertAtEdge(Tree storage self, Edge memory e, Label memory key, bytes32 value) internal returns (Edge memory) {
assert(key.length >= e.label.length);
(Label memory prefix, Label memory suffix) = splitCommonPrefix(key, e.label);
bytes32 newNodeHash;
if (suffix.length == 0) {
// Full match with the key, update operation
newNodeHash = value;
} else if (prefix.length >= e.label.length) {
// Partial match, just follow the path
assert(suffix.length > 1);
Node memory n = self.nodes[e.node];
(uint head, Label memory tail) = chopFirstBit(suffix);
n.children[head] = insertAtEdge(self, n.children[head], tail, value);
delete self.nodes[e.node];
newNodeHash = insertNode(self, n);
} else {
// Mismatch, so let us create a new branch node.
(uint head, Label memory tail) = chopFirstBit(suffix);
Node memory branchNode;
branchNode.children[head] = Edge(value, tail);
branchNode.children[1 - head] = Edge(e.node, removePrefix(e.label, prefix.length + 1));
newNodeHash = insertNode(self, branchNode);
}
return Edge(newNodeHash, prefix);
}
function insert(Tree storage self, bytes memory key, bytes memory value) internal {
Label memory k = Label(keccak256(key), 256);
bytes32 valueHash = keccak256(value);
Edge memory e;
if (self.root == 0) {
// Empty Trie
e.label = k;
e.node = valueHash;
} else {
e = insertAtEdge(self, self.rootEdge, k, valueHash);
}
self.root = edgeHash(e);
self.rootEdge = e;
}
}
/*
* Patricia tree implementation.
*
* More info at: https://github.com/chriseth/patricia-trie
*/
library PatriciaTree{
using Data for Data.Tree;
using Data for Data.Node;
using Data for Data.Edge;
using Data for Data.Label;
using Bits for uint;
// Returns the Merkle-proof for the given key
// Proof format should be:
// - uint branchMask - bitmask with high bits at the positions in the key
// where we have branch nodes (bit in key denotes direction)
// - bytes32[] _siblings - hashes of sibling edges
function getProof(Data.Tree storage tree, bytes memory key) internal view returns (uint branchMask, bytes32[] memory _siblings) {
require(tree.root != 0);
Data.Label memory k = Data.Label(keccak256(key), 256);
Data.Edge memory e = tree.rootEdge;
bytes32[256] memory siblings;
uint length;
uint numSiblings;
while (true) {
(Data.Label memory prefix, Data.Label memory suffix) = Data.splitCommonPrefix(k, e.label);
assert(prefix.length == e.label.length);
if (suffix.length == 0) {
// Found it
break;
}
length += prefix.length;
branchMask |= uint(1) << 255 - length;
length += 1;
(uint head, Data.Label memory tail) = suffix.chopFirstBit();
siblings[numSiblings++] = tree.nodes[e.node].children[1 - head].edgeHash();
e = tree.nodes[e.node].children[head];
k = tail;
}
if (numSiblings > 0) {
_siblings = new bytes32[](numSiblings);
for (uint i = 0; i < numSiblings; i++) {
_siblings[i] = siblings[i];
}
}
}
function verifyProof(bytes32 rootHash, bytes memory key, bytes memory value, uint branchMask, bytes32[] memory siblings) internal pure returns (bool) {
Data.Label memory k = Data.Label(keccak256(key), 256);
Data.Edge memory e;
e.node = keccak256(value);
for (uint i = 0; branchMask != 0; i++) {
uint bitSet = branchMask.lowestBitSet();
branchMask &= ~(uint(1) << bitSet);
(k, e.label) = k.splitAt(255 - bitSet);
uint bit;
(bit, e.label) = e.label.chopFirstBit();
bytes32[2] memory edgeHashes;
edgeHashes[bit] = e.edgeHash();
edgeHashes[1 - bit] = siblings[siblings.length - i - 1];
e.node = keccak256(abi.encode(edgeHashes));
}
e.label = k;
require(rootHash == e.edgeHash());
return true;
}
function insert(Data.Tree storage tree, bytes memory key, bytes memory value) internal {
tree.insert(key, value);
}
}
interface IERC20 {
event Transfer(address indexed from, address indexed to, uint256 value);
event Approval(address indexed owner, address indexed spender, uint256 value);
function totalSupply() external view returns (uint256);
function balanceOf(address account) external view returns (uint256);
function transfer(address to, uint256 amount) external returns (bool);
function allowance(address owner, address spender) external view returns (uint256);
function approve(address spender, uint256 amount) external returns (bool);
function transferFrom(
address from,
address to,
uint256 amount
) external returns (bool);
}
contract ERC20 is IERC20 {
mapping(address => uint256) private _balances;
mapping(address => mapping(address => uint256)) private _allowances;
uint256 private _totalSupply;
string private _name;
string private _symbol;
constructor(string memory name_, string memory symbol_) {
_name = name_;
_symbol = symbol_;
}
function name() public view virtual returns (string memory) {
return _name;
}
function symbol() public view virtual returns (string memory) {
return _symbol;
}
function decimals() external view virtual returns (uint8) {
return 18;
}
function totalSupply() public override view virtual returns (uint256) {
return _totalSupply;
}
function balanceOf(address account) external view virtual override returns (uint256) {
return _balances[account];
}
function transfer(address to, uint256 amount) external virtual override returns (bool) {
_transfer(msg.sender, to, amount);
return true;
}
function allowance(address owner, address spender) external view virtual override returns (uint256) {
return _allowances[owner][spender];
}
function approve(address spender, uint256 amount) external virtual override returns (bool) {
_approve(msg.sender, spender, amount);
return true;
}
function transferFrom(
address from,
address to,
uint256 amount
) external virtual override returns (bool) {
_spendAllowance(from, msg.sender, amount);
_transfer(from, to, amount);
return true;
}
function increaseAllowance(address spender, uint256 addedValue) external virtual returns (bool) {
_approve(msg.sender, spender, _allowances[msg.sender][spender] + addedValue);
return true;
}
function decreaseAllowance(address spender, uint256 subtractedValue) external virtual returns (bool) {
uint256 currentAllowance = _allowances[msg.sender][spender];
require(currentAllowance >= subtractedValue, "ERC20: decreased allowance below zero");
_approve(msg.sender, spender, currentAllowance - subtractedValue);
return true;
}
function _transfer(
address from,
address to,
uint256 amount
) internal virtual {
require(from != address(0), "ERC20: transfer from the zero address");
require(to != address(0), "ERC20: transfer to the zero address");
_beforeTokenTransfer(from, to, amount);
uint256 fromBalance = _balances[from];
require(fromBalance >= amount, "ERC20: transfer amount exceeds balance");
_balances[from] = fromBalance - amount;
// Overflow not possible: the sum of all balances is capped by totalSupply, and the sum is preserved by
// decrementing then incrementing.
_balances[to] += amount;
emit Transfer(from, to, amount);
_afterTokenTransfer(from, to, amount);
}
function _mint(address account, uint256 amount) internal virtual {
require(account != address(0), "ERC20: mint to the zero address");
_beforeTokenTransfer(address(0), account, amount);
_totalSupply += amount;
// Overflow not possible: balance + amount is at most totalSupply + amount, which is checked above.
_balances[account] += amount;
emit Transfer(address(0), account, amount);
_afterTokenTransfer(address(0), account, amount);
}
function _burn(address account, uint256 amount) internal virtual {
require(account != address(0), "ERC20: burn from the zero address");
_beforeTokenTransfer(account, address(0), amount);
uint256 accountBalance = _balances[account];
require(accountBalance >= amount, "ERC20: burn amount exceeds balance");
_balances[account] = accountBalance - amount;
// Overflow not possible: amount <= accountBalance <= totalSupply.
_totalSupply -= amount;
emit Transfer(account, address(0), amount);
_afterTokenTransfer(account, address(0), amount);
}
function _approve(
address owner,
address spender,
uint256 amount
) internal virtual {
require(owner != address(0), "ERC20: approve from the zero address");
require(spender != address(0), "ERC20: approve to the zero address");
_allowances[owner][spender] = amount;
emit Approval(owner, spender, amount);
}
function _spendAllowance(
address owner,
address spender,
uint256 amount
) internal virtual {
uint256 currentAllowance = _allowances[owner][spender];
if (currentAllowance != type(uint256).max) {
require(currentAllowance >= amount, "ERC20: insufficient allowance");
_approve(owner, spender, currentAllowance - amount);
}
}
function _beforeTokenTransfer(
address from,
address to,
uint256 amount
) internal virtual {}
function _afterTokenTransfer(
address from,
address to,
uint256 amount
) internal virtual {}
}
interface IERC3156FlashBorrower {
function onFlashLoan(
address initiator,
address token,
uint256 amount,
uint256 fee,
bytes calldata data
) external returns (bytes32);
}
interface IERC3156FlashLender {
function maxFlashLoan(address token) external view returns (uint256);
function flashFee(address token, uint256 amount) external view returns (uint256);
function flashLoan(
IERC3156FlashBorrower receiver,
address token,
uint256 amount,
bytes calldata data
) external returns (bool);
}
abstract contract ERC20FlashMint is ERC20, IERC3156FlashLender {
bytes32 private constant _RETURN_VALUE = keccak256("ERC3156FlashBorrower.onFlashLoan");
function maxFlashLoan(address token) public view virtual override returns (uint256) {
return token == address(this) ? type(uint256).max - ERC20.totalSupply() : 0;
}
function flashFee(address token, uint256 amount) public view virtual override returns (uint256) {
require(token == address(this), "ERC20FlashMint: wrong token");
return _flashFee(token, amount);
}
function _flashFee(address token, uint256 amount) internal view virtual returns (uint256) {
// silence warning about unused variable without the addition of bytecode.
token;
amount;
return 0;
}
function _flashFeeReceiver() internal view virtual returns (address) {
return address(0);
}
function flashLoan(
IERC3156FlashBorrower receiver,
address token,
uint256 amount,
bytes calldata data
) public virtual override returns (bool) {
require(amount <= maxFlashLoan(token), "ERC20FlashMint: amount exceeds maxFlashLoan");
uint256 fee = flashFee(token, amount);
_mint(address(receiver), amount);
require(
receiver.onFlashLoan(msg.sender, token, amount, fee, data) == _RETURN_VALUE,
"ERC20FlashMint: invalid return value"
);
address flashFeeReceiver = _flashFeeReceiver();
_spendAllowance(address(receiver), address(this), amount + fee);
if (fee == 0 || flashFeeReceiver == address(0)) {
_burn(address(receiver), amount + fee);
} else {
_burn(address(receiver), amount);
_transfer(address(receiver), flashFeeReceiver, fee);
}
return true;
}
}
//NOW FOR THE ACTUAL FUN
abstract contract ZKBridge{
constructor(){
owner = msg.sender;
}
address public owner;
Data.Tree public depositTree;
mapping(bytes32 => bool) public isValidWithdrawalRootHash;
function addWithdrawalRootHash(bytes32 rh) external{
require(msg.sender == owner, "Unauthorized!");
isValidWithdrawalRootHash[rh] = true;
}
mapping(address => uint256) public depositCount;
mapping(uint256 => uint256) public depositAmounts;
mapping(address => mapping(uint256 => bool)) public withdrawn;
function _deposit(uint256 amount) internal{
uint256 id = depositCount[msg.sender]++;
PatriciaTree.insert(depositTree, abi.encode(msg.sender, id), abi.encode(msg.sender, amount));
}
function getDepositProof(address sender, uint256 id) external view returns (bytes32 withdrawalRootHash, uint branchMask, bytes32[] memory siblings){
withdrawalRootHash = depositTree.root;
(branchMask, siblings) = PatriciaTree.getProof(depositTree, abi.encode(sender, id));
}
function verifyDepositProof(address sender, uint256 amount, uint256 id, bytes32 rootHash, uint branchMask, bytes32[] calldata siblings) external pure returns (bool){
return PatriciaTree.verifyProof(rootHash, abi.encode(sender, id), abi.encode(sender, amount), branchMask, siblings);
}
function _sendWithdrawal(uint256 amount) internal virtual;
function withdraw(uint256 amount, uint256 id, bytes32 withdrawalRootHash, uint branchMask, bytes32[] calldata siblings) external {
require(isValidWithdrawalRootHash[withdrawalRootHash], "Invalid withdrawal root hash!");
require(PatriciaTree.verifyProof(withdrawalRootHash, abi.encode(msg.sender, id), abi.encode(msg.sender, amount), branchMask, siblings), "Invalid deposit proof!");
require(!withdrawn[msg.sender][id], "Already withdrawn!");
withdrawn[msg.sender][id] = true;
_sendWithdrawal(amount);
}
}
contract ZKBridgeERC20Input is ZKBridge{
IERC20 immutable public token;
constructor(IERC20 t){
token = t;
}
function _sendWithdrawal(uint256 amount) internal override{
safeTransfer(msg.sender, amount);
}
function safeTransfer(
address to,
uint256 value
) private {
_callOptionalReturn(abi.encodeWithSelector(token.transfer.selector, to, value));
}
function safeTransferFrom(
address from,
address to,
uint256 value
) private {
_callOptionalReturn(abi.encodeWithSelector(token.transferFrom.selector, from, to, value));
}
function deposit(uint256 amount) external{
safeTransferFrom(msg.sender, address(this), amount);
_deposit(amount);
}
function functionCallWithValue(
bytes memory data,
string memory errorMessage
) private returns (bytes memory) {
(bool success, bytes memory returndata) = address(token).call(data);
return verifyCallResultFromTarget(success, returndata, errorMessage);
}
function _callOptionalReturn(bytes memory data) private {
// We need to perform a low level call here, to bypass Solidity's return data size checking mechanism, since
// we're implementing it ourselves. We use {Address-functionCall} to perform this call, which verifies that
// the target address contains contract code and also asserts for success in the low-level call.
bytes memory returndata = functionCallWithValue(data, "SafeERC20: low-level call failed");
if (returndata.length > 0) {
// Return data is optional
require(abi.decode(returndata, (bool)), "SafeERC20: ERC20 operation did not succeed");
}
}
function verifyCallResultFromTarget(
bool success,
bytes memory returndata,
string memory errorMessage
) private pure returns (bytes memory) {
if (success) {
return returndata;
} else {
if (returndata.length > 0) {
assembly {
let returndata_size := mload(returndata)
revert(add(32, returndata), returndata_size)
}
}
revert(errorMessage);
}
}
}
contract ZKBridgeOutputERC20Token is ZKBridge, ERC20FlashMint{
function _sendWithdrawal(uint256 amount) internal override{
_mint(msg.sender, amount);
}
function exit(uint256 amount) external{
_burn(msg.sender, amount);
_deposit(amount);
}
constructor(string memory name, string memory symbol) ERC20(name, symbol){
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment