Skip to content

Instantly share code, notes, and snippets.

@nalinbhardwaj
Created December 27, 2022 18:57
Show Gist options
  • Select an option

  • Save nalinbhardwaj/aa82c6d9324baa1aaaf41ebd4966d912 to your computer and use it in GitHub Desktop.

Select an option

Save nalinbhardwaj/aa82c6d9324baa1aaaf41ebd4966d912 to your computer and use it in GitHub Desktop.

Open this in zkREPL →

This file can be included into other zkREPLs with include "gist:aa82c6d9324baa1aaaf41ebd4966d912";

pragma circom 2.0.8;
include "circomlib/gates.circom";
include "circomlib/bitify.circom";
include "circomlib/poseidon.circom";
// include "https://github.com/0xPARC/circom-secp256k1/blob/master/circuits/bigint.circom";
function min(a, b) {
if (a < b) {
return a;
}
return b;
}
function log_ceil(n) {
var n_temp = n;
for (var i = 0; i < 254; i++) {
if (n_temp == 0) {
return i;
}
n_temp = n_temp \ 2;
}
return 254;
}
template Modulo32() {
signal input dividend;
signal input divisor;
signal output quotient;
signal output remainder;
remainder <-- dividend % divisor;
quotient <-- (dividend - remainder) / divisor;
component lt = LessThan(32);
lt.in[0] <== remainder;
lt.in[1] <== divisor;
lt.out === 1;
log("remainder: ", remainder);
log("quotient: ", quotient);
log("lt.out: ", lt.out);
log("divisor: ", divisor);
dividend === quotient * divisor + remainder;
}
template bitwiseAND(bits) {
signal input in[2];
signal output out;
component bitified[2];
for (var i = 0;i < 2;i++) {
bitified[i] = Num2Bits(256);
bitified[i].in <== in[i];
}
component fullAND = Bits2Num(bits);
component bitAND[bits];
for (var i = 0;i < bits;i++) {
bitAND[i] = AND();
bitAND[i].a <== bitified[0].out[i];
bitAND[i].b <== bitified[1].out[i];
fullAND.in[i] <== bitAND[i].out;
}
out <== fullAND.out;
}
template ShiftLeft() {
signal input in;
signal input shift;
signal output out;
signal hard[6];
hard[0] <== 2;
hard[1] <== 4;
hard[2] <== 16;
hard[3] <== 256;
hard[4] <== 65536;
hard[5] <== 4294967296;
signal acc[6];
signal tmp[6];
component n2b = Num2Bits(6);
n2b.in <== shift;
for (var i = 0;i < 6;i++){
if (i == 0) {
acc[i] <== n2b.out[i] * hard[i] + (1 - n2b.out[i]);
tmp[i] <== 0;
} else {
tmp[i] <== hard[i] * n2b.out[i] + (1 - n2b.out[i]);
acc[i] <== tmp[i] * acc[i-1];
}
}
component and = bitwiseAND(32);
and.in[0] <== acc[5] * in;
and.in[1] <== 0xFFFFFFFF;
out <== and.out;
}
template ShiftRight() {
signal input in;
signal input shift;
signal output out;
log("shiftRight");
log("in: ", in);
log("shift: ", shift);
signal hard[6];
hard[0] <== 2;
hard[1] <== 4;
hard[2] <== 16;
hard[3] <== 256;
hard[4] <== 65536;
hard[5] <== 4294967296;
signal acc[6];
signal tmp[6];
component n2b = Num2Bits(256);
n2b.in <== shift;
log("in", in);
log("shift", shift);
log("out", out);
for (var i = 0;i < 6;i++){
if (i == 0) {
acc[i] <== n2b.out[i] * hard[i] + (1 - n2b.out[i]);
tmp[i] <== 0;
} else {
tmp[i] <== hard[i] * n2b.out[i] + (1 - n2b.out[i]);
acc[i] <== tmp[i] * acc[i-1];
}
}
log("divisor", acc[5]);
component div = Modulo32();
div.dividend <== in;
div.divisor <== acc[5];
out <== div.quotient;
}
// shift 11001
// in 2*
//
// compute value from [startBit, endBit) exclusive in full value
template getBits(fullBitCount, startBit, endBit) {
assert(startBit < endBit);
assert(endBit <= fullBitCount);
signal input full;
signal output specificBitsValue;
component fullBits = Num2Bits(fullBitCount);
component specificBitsOutput = Bits2Num(endBit - startBit);
fullBits.in <== full;
for (var i = startBit;i < endBit;i++) {
specificBitsOutput.in[i - startBit] <== fullBits.out[i];
}
specificBitsValue <== specificBitsOutput.out;
}
template bitwiseOR(bits) {
signal input in[2];
signal output out;
component bitified[2];
for (var i = 0;i < 2;i++) {
bitified[i] = Num2Bits(bits);
bitified[i].in <== in[i];
}
component fullOR = Bits2Num(bits);
component bitOR[bits];
for (var i = 0;i < bits;i++) {
bitOR[i] = OR();
bitOR[i].a <== bitified[0].out[i];
bitOR[i].b <== bitified[1].out[i];
fullOR.in[i] <== bitOR[i].out;
}
out <== fullOR.out;
}
template bitwiseXOR(bits) {
signal input in[2];
signal output out;
component bitified[2];
for (var i = 0;i < 2;i++) {
bitified[i] = Num2Bits(bits);
bitified[i].in <== in[i];
}
component fullXOR = Bits2Num(bits);
component bitXOR[bits];
for (var i = 0;i < bits;i++) {
bitXOR[i] = XOR();
bitXOR[i].a <== bitified[0].out[i];
bitXOR[i].b <== bitified[1].out[i];
fullXOR.in[i] <== bitXOR[i].out;
}
out <== fullXOR.out;
}
template bitwiseNOT(bits) {
signal input in;
signal output out;
component bitified = Num2Bits(bits);
bitified.in <== in;
component flip[bits];
component b2n = Bits2Num(bits);
for (var i = 0;i < bits;i++) {
flip[i] = NOT();
flip[i].in <== bitified.out[i];
b2n.in[i] <== flip[i].out;
}
out <== b2n.out;
}
template SE() {
signal input dat;
signal input idx;
signal output out;
log("dat", dat);
log("idx", idx);
component shiftedDat = ShiftRight();
shiftedDat.in <== dat;
shiftedDat.shift <== idx - 1;
component isShiftedDatZero = IsZero();
isShiftedDatZero.in <== shiftedDat.out;
log("isShiftedDatZero", isShiftedDatZero.out, shiftedDat.out);
component computeSigned = ShiftLeft();
computeSigned.in <== 1;
computeSigned.shift <== 32 - idx;
log("computeSigned.out[0]", computeSigned.out);
component computeSignedFull = ShiftLeft();
computeSignedFull.in <== computeSigned.out - 1;
computeSignedFull.shift <== idx;
component computeMask = ShiftLeft();
computeMask.in <== 1;
computeMask.shift <== idx;
signal mask <== computeMask.out - 1;
component computebitAND = bitwiseAND(32);
computebitAND.in[0] <== dat;
computebitAND.in[1] <== mask;
component computebitOR = bitwiseOR(32);
computebitOR.in[0] <== computebitAND.out;
computebitOR.in[1] <== (1 - isShiftedDatZero.out) * computeSignedFull.out;
log("out", computebitOR.out);
out <== computebitOR.out;
}
template AddUint32() {
signal input in[2];
signal output out;
signal normalAdd <== in[0] + in[1];
component ba = bitwiseAND(32);
ba.in[0] <== normalAdd;
ba.in[1] <== 4294967295;
out <== ba.out;
}
template SubUint32() {
signal input in[2];
signal output out;
signal normalAdd <== in[0] - in[1] + 4294967296;
component ba = bitwiseAND(32);
ba.in[0] <== normalAdd;
ba.in[1] <== 4294967295;
out <== ba.out;
}
template SignedMultiply() {
signal input in[2];
signal output out;
component n2b[2];
n2b[0] = Num2Bits(32);
n2b[1] = Num2Bits(32);
n2b[0].in <== in[0];
n2b[1].in <== in[1];
signal x <== (1 - n2b[0].out[31]) * in[0];
signal x2 <== x + n2b[0].out[31] * (in[0] - 4294967296);
signal y <== (1 - n2b[1].out[31]) * in[1];
signal y2 <== y + n2b[1].out[31] * (in[1] - 4294967296);
signal res <== x2 * y2 + 4294967296;
component ba = bitwiseAND(32);
ba.in[0] <== res;
ba.in[1] <== 4294967295;
out <== ba.out;
}
template CountLeadingZeroes() {
signal input in;
signal output out;
component n2b = Num2Bits(32);
n2b.in <== in;
signal pref_sum[32];
signal ans[32];
component is_zero[32];
for (var i = 31; i >= 0; i--) {
is_zero[i] = IsZero();
if (i == 31) {
pref_sum[i] <== n2b.out[i];
} else {
pref_sum[i] <== pref_sum[i+1] + n2b.out[i];
}
is_zero[i].in <== pref_sum[i];
if (i == 31) {
ans[i] <== is_zero[i].out;
} else {
ans[i] <== ans[i+1] + is_zero[i].out;
}
}
out <== ans[0];
}
template remapOpcode() {
signal input in;
signal output out;
component isLessThan = LessThan(7);
isLessThan.in[0] <== in;
isLessThan.in[1] <== 15;
component isGreaterThan = GreaterEqThan(7);
isGreaterThan.in[0] <== in;
isGreaterThan.in[1] <== 8;
signal isBetween <== isLessThan.out * isGreaterThan.out;
out <== isBetween * 0 + (1 - isBetween) * in;
}
template remapFunc() {
signal input inOpc;
signal output out;
component isEq[7];
signal sumFunc[7];
// addi
isEq[0] = IsEqual();
isEq[0].in[0] <== inOpc;
isEq[0].in[1] <== 8;
sumFunc[0] <== isEq[0].out * 32;
// addiu
isEq[1] = IsEqual();
isEq[1].in[0] <== inOpc;
isEq[1].in[1] <== 9;
sumFunc[1] <== isEq[1].out * 33 + sumFunc[0];
// slti
isEq[2] = IsEqual();
isEq[2].in[0] <== inOpc;
isEq[2].in[1] <== 10;
sumFunc[2] <== isEq[2].out * 42 + sumFunc[1];
// sltiu
isEq[3] = IsEqual();
isEq[3].in[0] <== inOpc;
isEq[3].in[1] <== 11;
sumFunc[3] <== isEq[3].out * 43 + sumFunc[2];
// andi
isEq[4] = IsEqual();
isEq[4].in[0] <== inOpc;
isEq[4].in[1] <== 12;
sumFunc[4] <== isEq[4].out * 36 + sumFunc[3];
// ori
isEq[5] = IsEqual();
isEq[5].in[0] <== inOpc;
isEq[5].in[1] <== 13;
sumFunc[5] <== isEq[5].out * 37 + sumFunc[4];
// xori
isEq[6] = IsEqual();
isEq[6].in[0] <== inOpc;
isEq[6].in[1] <== 14;
sumFunc[6] <== isEq[6].out * 38 + sumFunc[5];
out <== sumFunc[6];
}
template Execute() {
var FULL_LEN = 32;
signal input insn;
signal input rs;
signal input rt;
signal input mem;
signal preOpcode;
signal preFunc;
signal opcode;
signal func;
signal output memOut;
component computeOpcode = getBits(FULL_LEN, 26, 32);
computeOpcode.full <== insn;
preOpcode <== computeOpcode.specificBitsValue;
component computeFunc = getBits(FULL_LEN, 0, 6);
computeFunc.full <== insn;
preFunc <== computeFunc.specificBitsValue;
log("preOpcode", preOpcode);
log("preFunc", preFunc);
// remap func from opcode
component remapperOp = remapOpcode();
remapperOp.in <== preOpcode;
opcode <== remapperOp.out;
component remapperFunc = remapFunc();
remapperFunc.inOpc <== preOpcode;
func <== remapperFunc.out;
log("opcode", opcode);
log("func", func);
component computeShamt = ShiftRight();
computeShamt.in <== insn;
computeShamt.shift <== 6;
component computeShamt2 = bitwiseAND(32);
computeShamt2.in[0] <== computeShamt.out;
computeShamt2.in[1] <== 31;
signal shamt <== computeShamt2.out;
signal opcodeFuncOutput[64][64];
component sll = ShiftLeft();
sll.in <== rt;
sll.shift <== shamt;
opcodeFuncOutput[0][0] <== sll.out;
component srl = ShiftRight();
srl.in <== rt;
srl.shift <== shamt;
opcodeFuncOutput[0][2] <== srl.out;
component sra = SE();
sra.dat <== srl.out;
sra.idx <== 32 - shamt;
opcodeFuncOutput[0][3] <== sra.out;
component rsAND0x1F = bitwiseAND(FULL_LEN);
rsAND0x1F.in[0] <== rs;
rsAND0x1F.in[1] <== 31;
component sllv = ShiftLeft();
sllv.in <== rt;
sllv.shift <== rsAND0x1F.out;
opcodeFuncOutput[0][4] <== sllv.out;
component srlv = ShiftRight();
srlv.in <== rt;
srlv.shift <== rsAND0x1F.out;
opcodeFuncOutput[0][6] <== srlv.out;
component rtShiftrs = ShiftRight();
rtShiftrs.in <== rt;
rtShiftrs.shift <== rs;
// component srav = SE();
// srav.dat <== rtShiftrs.out;
// srav.idx <== 32 - rs;
// opcodeFuncOutput[0][7] <== srav.out;
opcodeFuncOutput[0][7] <== 0;
// add or addu
component add = AddUint32();
add.in[0] <== rs;
add.in[1] <== rt;
opcodeFuncOutput[0][32] <== add.out;
opcodeFuncOutput[0][33] <== add.out;
// sub or subu
component sub = SubUint32();
sub.in[0] <== rs;
sub.in[1] <== rt;
opcodeFuncOutput[0][34] <== sub.out;
opcodeFuncOutput[0][35] <== sub.out;
// and
component and = bitwiseAND(FULL_LEN);
and.in[0] <== rs;
and.in[1] <== rt;
opcodeFuncOutput[0][36] <== and.out;
// or
component or = bitwiseOR(FULL_LEN);
or.in[0] <== rs;
or.in[1] <== rt;
opcodeFuncOutput[0][37] <== or.out;
log("opcodeFuncOutput[0][37]", opcodeFuncOutput[0][37]);
// xor
component xor = bitwiseXOR(FULL_LEN);
xor.in[0] <== rs;
xor.in[1] <== rt;
opcodeFuncOutput[0][38] <== xor.out;
// nor
component nor = bitwiseNOT(FULL_LEN);
nor.in <== or.out;
opcodeFuncOutput[0][39] <== nor.out;
// slt
component slt = LessThan(32);
slt.in[0] <== rs;
slt.in[1] <== rt;
opcodeFuncOutput[0][42] <== slt.out;
// sltu
component sltu = LessThan(33);
sltu.in[0] <== rs;
sltu.in[1] <== rt;
opcodeFuncOutput[0][43] <== sltu.out;
// lui
component lui = ShiftLeft();
lui.in <== rt;
lui.shift <== 16;
opcodeFuncOutput[15][0] <== lui.out;
// sb
component rsAND0x3 = bitwiseAND(FULL_LEN);
rsAND0x3.in[0] <== rs;
rsAND0x3.in[1] <== 3;
component rtAND0xFF = bitwiseAND(32);
rtAND0xFF.in[0] <== rt;
rtAND0xFF.in[1] <== 0xFF;
signal rsAND0x3Mask <== 24 - rsAND0x3.out * 8;
component sbVal = ShiftLeft();
sbVal.in <== rtAND0xFF.out;
sbVal.shift <== rsAND0x3Mask;
component sbMask = ShiftLeft();
sbMask.in <== 0xFF;
sbMask.shift <== rsAND0x3Mask;
component sbMask2 = bitwiseXOR(32);
sbMask2.in[0] <== 0xFFFFFFFF;
sbMask2.in[1] <== sbMask.out;
component sbAnd = bitwiseAND(32);
sbAnd.in[0] <== mem;
sbAnd.in[1] <== sbMask2.out;
component sb = bitwiseOR(32);
sb.in[0] <== sbAnd.out;
sb.in[1] <== sbVal.out;
opcodeFuncOutput[40][0] <== sb.out;
// sh
component rsAND0x2 = bitwiseAND(FULL_LEN);
rsAND0x2.in[0] <== rs;
rsAND0x2.in[1] <== 2;
component rtAND0xFFFF = bitwiseAND(32);
rtAND0xFFFF.in[0] <== rt;
rtAND0xFFFF.in[1] <== 0xFFFF;
signal rsAND0x2Mask <== 16 - rsAND0x2.out * 8;
component shVal = ShiftLeft();
shVal.in <== rtAND0xFFFF.out;
shVal.shift <== rsAND0x2Mask;
component shMask = ShiftLeft();
shMask.in <== 0xFFFF;
shMask.shift <== rsAND0x2Mask;
component shMask2 = bitwiseXOR(32);
shMask2.in[0] <== 0xFFFFFFFF;
shMask2.in[1] <== shMask.out;
component shAnd = bitwiseAND(32);
shAnd.in[0] <== mem;
shAnd.in[1] <== shMask2.out;
component sh = bitwiseOR(32);
sh.in[0] <== shAnd.out;
sh.in[1] <== shVal.out;
opcodeFuncOutput[41][0] <== sh.out;
// swl
component swlVal = ShiftRight();
swlVal.in <== rt;
swlVal.shift <== rsAND0x3.out * 8;
component swlMask = ShiftRight();
swlMask.in <== 0xFFFFFFFF;
swlMask.shift <== rsAND0x3.out * 8;
signal swlMaskNeg <== (1<<32) - 1 - swlMask.out;
component swlAnd = bitwiseAND(32);
swlAnd.in[0] <== mem;
swlAnd.in[1] <== swlMaskNeg;
component swl = bitwiseOR(32);
swl.in[0] <== swlAnd.out;
swl.in[1] <== swlVal.out;
opcodeFuncOutput[42][0] <== swl.out;
// sw
opcodeFuncOutput[43][0] <== rt;
// swr
component swrVal = ShiftLeft();
swrVal.in <== rt;
swrVal.shift <== 24 - rsAND0x3.out * 8;
component swrMask = ShiftLeft();
swrMask.in <== 0xFFFFFFFF;
swrMask.shift <== 24 - rsAND0x3.out * 8;
signal swrMaskNeg <== (1<<32) - 1 - swrMask.out;
component swrAnd = bitwiseAND(32);
swrAnd.in[0] <== mem;
swrAnd.in[1] <== swrMaskNeg;
component swr = bitwiseOR(32);
swr.in[0] <== swrAnd.out;
swr.in[1] <== swrVal.out;
opcodeFuncOutput[46][0] <== swr.out;
// ll
opcodeFuncOutput[48][0] <== mem;
// sc
opcodeFuncOutput[56][0] <== rt;
// mul
component mul = SignedMultiply();
mul.in[0] <== rs;
mul.in[1] <== rt;
opcodeFuncOutput[28][2] <== mul.out;
// clz
component clz = CountLeadingZeroes();
clz.in <== rs;
opcodeFuncOutput[28][32] <== clz.out;
// clo
component not = bitwiseNOT(32);
not.in <== rs;
component clo = CountLeadingZeroes();
clo.in <== not.out;
opcodeFuncOutput[28][33] <== clo.out;
var CNT = 27;
var actualFuncOutputsO[CNT] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 40, 41, 42, 43, 46, 48, 56, 28, 28, 28];
var actualFuncOutputsF[CNT] = [0, 2, 3, 4, 6, 7, 32, 33, 34, 35, 36, 37, 38, 39, 42, 43, 0, 0, 0, 0, 0, 0, 0, 0, 2, 32, 33];
for (var i = 0;i < 64;i++) {
for (var j = 0;j < 64;j++) {
var isAssigned = 0;
for (var k = 0;k < CNT;k++) {
if (i == actualFuncOutputsO[k] && j == actualFuncOutputsF[k]) {
isAssigned = 1;
}
}
if (isAssigned == 0) {
opcodeFuncOutput[i][j] <== 0;
}
}
}
signal OpcodeFuncOutputPrefixSum[64][64];
component OpcodeFuncOutputIsEqual[64][64];
component OpcodeFuncIsEqual[64][64];
component FuncIsEqual[64][64];
component OpcodeIsEqual[64][64];
for (var i = 0;i < 64;i++) {
for (var j = 0;j < 64;j++) {
FuncIsEqual[i][j] = IsEqual();
FuncIsEqual[i][j].in[0] <== j;
FuncIsEqual[i][j].in[1] <== func;
OpcodeIsEqual[i][j] = IsEqual();
OpcodeIsEqual[i][j].in[0] <== i;
OpcodeIsEqual[i][j].in[1] <== opcode;
OpcodeFuncOutputIsEqual[i][j] = AND();
OpcodeFuncOutputIsEqual[i][j].a <== OpcodeIsEqual[i][j].out;
OpcodeFuncOutputIsEqual[i][j].b <== FuncIsEqual[i][j].out;
OpcodeFuncOutputPrefixSum[i][j] <== OpcodeFuncOutputIsEqual[i][j].out * opcodeFuncOutput[i][j] + (j == 0 ? ((i == 0 ? 0 : OpcodeFuncOutputPrefixSum[i-1][63])) : OpcodeFuncOutputPrefixSum[i][j-1]);
}
}
log("memOut", OpcodeFuncOutputPrefixSum[63][63]);
memOut <== OpcodeFuncOutputPrefixSum[63][63];
}
component main = Execute();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment