Skip to content

Instantly share code, notes, and snippets.

@djkazic
Created August 7, 2025 16:08
Show Gist options
  • Save djkazic/9d52c21bcfc56e6c5f9c59dc66d96838 to your computer and use it in GitHub Desktop.
Save djkazic/9d52c21bcfc56e6c5f9c59dc66d96838 to your computer and use it in GitHub Desktop.
package main
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"math/big"
"sort"
)
// =============================================================================
// GLOBAL PARAMETERS
// =============================================================================
const (
// Using a small prime for educational purposes.
// This prime is chosen to allow for efficient FFTs in a full STARK,
// though FFTs are not explicitly used in this simplified polynomial arithmetic.
fieldPrime int64 = 3*1<<30 + 1 // A prime of the form k * 2^n + 1
// Number of steps in our computation trace.
traceLength = 8
// Blowup factor for extending the evaluation domain.
// This determines the size of the FRI domain relative to the trace domain.
blowupFactor = 8
// Number of queries for the FRI protocol to achieve soundness.
// More queries lead to stronger soundness guarantees.
numQueries = 3
)
// =============================================================================
// FINITE FIELD ARITHMETIC (GF_p)
// =============================================================================
// FieldElement represents an element in a prime finite field.
type FieldElement struct {
value *big.Int
prime *big.Int
}
// normalizeBigInt ensures a big.Int value is within [0, prime-1] after modulo.
func normalizeBigInt(val *big.Int, prime *big.Int) *big.Int {
res := new(big.Int).Mod(val, prime)
if res.Sign() < 0 {
res.Add(res, prime)
}
return res
}
// NewFieldElement creates a new field element.
func NewFieldElement(value int64, prime int64) *FieldElement {
val := big.NewInt(value)
p := big.NewInt(prime)
return &FieldElement{
value: normalizeBigInt(val, p), // Ensure normalization at creation
prime: p,
}
}
// Zero returns the additive identity.
func Zero(prime int64) *FieldElement { return NewFieldElement(0, prime) }
// One returns the multiplicative identity.
func One(prime int64) *FieldElement { return NewFieldElement(1, prime) }
// Add performs addition in the finite field.
func (fe *FieldElement) Add(other *FieldElement) *FieldElement {
res := new(big.Int).Add(fe.value, other.value)
return &FieldElement{value: normalizeBigInt(res, fe.prime), prime: fe.prime}
}
// Sub performs subtraction in the finite field.
func (fe *FieldElement) Sub(other *FieldElement) *FieldElement {
res := new(big.Int).Sub(fe.value, other.value)
return &FieldElement{value: normalizeBigInt(res, fe.prime), prime: fe.prime}
}
// Mul performs multiplication in the finite field.
func (fe *FieldElement) Mul(other *FieldElement) *FieldElement {
res := new(big.Int).Mul(fe.value, other.value)
return &FieldElement{value: normalizeBigInt(res, fe.prime), prime: fe.prime}
}
// Div performs division in the finite field (multiplication by modular inverse).
func (fe *FieldElement) Div(other *FieldElement) *FieldElement {
inv := new(big.Int).ModInverse(other.value, fe.prime)
if inv == nil {
panic(fmt.Sprintf("division by zero or no inverse for %s in field %s", other.value, fe.prime))
}
res := new(big.Int).Mul(fe.value, inv)
return &FieldElement{value: normalizeBigInt(res, fe.prime), prime: fe.prime}
}
// Exp performs exponentiation in the finite field.
func (fe *FieldElement) Exp(power *big.Int) *FieldElement {
res := new(big.Int).Exp(fe.value, power, fe.prime)
return &FieldElement{value: normalizeBigInt(res, fe.prime), prime: fe.prime}
}
// Neg performs negation in the finite field.
func (fe *FieldElement) Neg() *FieldElement {
res := new(big.Int).Neg(fe.value)
return &FieldElement{value: normalizeBigInt(res, fe.prime), prime: fe.prime}
}
// Equal checks if two field elements are equal.
func (fe *FieldElement) Equal(other *FieldElement) bool {
return fe.value.Cmp(other.value) == 0 && fe.prime.Cmp(other.prime) == 0
}
// IsZero checks if the field element is zero.
func (fe *FieldElement) IsZero() bool {
return fe.value.Cmp(big.NewInt(0)) == 0
}
// String returns the string representation of the field element.
func (fe *FieldElement) String() string {
return fe.value.String()
}
// Bytes returns the raw byte representation of the field element (variable length).
// Use Bytes32() for fixed-length representation for hashing.
func (fe *FieldElement) Bytes() []byte {
return fe.value.Bytes()
}
// Bytes32 returns the 32-byte representation of the field element.
// It pads with leading zeros if necessary.
func (fe *FieldElement) Bytes32() []byte {
b := fe.value.Bytes()
if len(b) > 32 {
// This should ideally not happen if fieldPrime fits within 32 bytes
// and normalization is correct. Handle as an error or truncate.
panic(fmt.Sprintf("Field element byte representation too large: %d bytes, expected <= 32", len(b)))
}
pad := make([]byte, 32-len(b)) // Create padding of zeros
return append(pad, b...) // Prepend padding
}
// =============================================================================
// POLYNOMIAL ARITHMETIC
// =============================================================================
// Polynomial represents a polynomial with coefficients in a finite field.
type Polynomial []*FieldElement
// NewPolynomial creates a new polynomial from coefficients.
func NewPolynomial(coeffs []*FieldElement) Polynomial {
// Trim leading zero coefficients
for len(coeffs) > 1 && coeffs[len(coeffs)-1].IsZero() {
coeffs = coeffs[:len(coeffs)-1]
}
// After trimming, if there's one coefficient left and it's zero,
// it represents the zero polynomial (degree -1).
if len(coeffs) == 1 && coeffs[0].IsZero() {
return Polynomial{} // Return empty slice for zero polynomial
}
return Polynomial(coeffs)
}
// Degree returns the degree of the polynomial.
func (p Polynomial) Degree() int {
if len(p) == 0 {
return -1 // Degree of zero polynomial is -1
}
// Degree is length - 1, after trimming leading zeros
return len(p) - 1
}
// Evaluate evaluates the polynomial at a given point x.
func (p Polynomial) Evaluate(x *FieldElement) *FieldElement {
if len(p) == 0 {
return Zero(x.prime.Int64())
}
res := Zero(x.prime.Int64())
for i := len(p) - 1; i >= 0; i-- {
res = res.Mul(x).Add(p[i])
}
return res
}
// Add adds two polynomials.
func (p Polynomial) Add(q Polynomial) Polynomial {
maxLen := len(p)
if len(q) > maxLen {
maxLen = len(q)
}
resCoeffs := make([]*FieldElement, maxLen)
for i := 0; i < maxLen; i++ {
var pCoeff, qCoeff *FieldElement
if i < len(p) {
pCoeff = p[i]
} else {
pCoeff = Zero(fieldPrime)
}
if i < len(q) {
qCoeff = q[i]
} else {
qCoeff = Zero(fieldPrime)
}
resCoeffs[i] = pCoeff.Add(qCoeff)
}
return NewPolynomial(resCoeffs)
}
// Sub subtracts one polynomial from another.
func (p Polynomial) Sub(q Polynomial) Polynomial {
return p.Add(q.Neg())
}
// Mul multiplies two polynomials.
func (p Polynomial) Mul(q Polynomial) Polynomial {
if len(p) == 0 || len(q) == 0 {
return NewPolynomial([]*FieldElement{})
}
resCoeffs := make([]*FieldElement, len(p)+len(q)-1)
for i := range resCoeffs {
resCoeffs[i] = Zero(fieldPrime)
}
for i := 0; i < len(p); i++ {
for j := 0; j < len(q); j++ {
term := p[i].Mul(q[j])
resCoeffs[i+j] = resCoeffs[i+j].Add(term)
}
}
return NewPolynomial(resCoeffs)
}
// Div divides two polynomials, returns quotient and remainder.
// This is standard polynomial long division.
func (p Polynomial) Div(q Polynomial) (Polynomial, Polynomial) {
if q.Degree() < 0 || (q.Degree() == 0 && q[0].IsZero()) {
panic("division by zero polynomial")
}
if p.Degree() < q.Degree() {
return NewPolynomial([]*FieldElement{Zero(fieldPrime)}), p
}
rem := make(Polynomial, len(p))
copy(rem, p)
quotCoeffs := make([]*FieldElement, p.Degree()-q.Degree()+1)
for i := range quotCoeffs {
quotCoeffs[i] = Zero(fieldPrime)
}
for rem.Degree() >= q.Degree() {
leadCoeffRem := rem[rem.Degree()]
leadCoeffQ := q[q.Degree()]
termCoeff := leadCoeffRem.Div(leadCoeffQ)
degDiff := rem.Degree() - q.Degree()
quotCoeffs[degDiff] = termCoeff
// Create a temporary polynomial for q * (termCoeff * x^degDiff)
tempPolyCoeffs := make([]*FieldElement, degDiff+1)
// FIX: Initialize all coefficients to Zero(fieldPrime) to prevent nil pointers
for k := range tempPolyCoeffs {
tempPolyCoeffs[k] = Zero(fieldPrime)
}
tempPolyCoeffs[degDiff] = termCoeff
tempPoly := NewPolynomial(tempPolyCoeffs)
subtracted := q.Mul(tempPoly)
// Subtract from remainder
rem = rem.Add(subtracted.Neg())
// Trim leading zeros from remainder
for { // Loop indefinitely until break condition met
currentDegree := rem.Degree()
if currentDegree == -1 { // Remainder is zero polynomial
break
}
if rem[currentDegree].IsZero() {
rem = rem[:currentDegree] // Trim the leading zero
} else {
break // Non-zero leading coefficient, stop trimming
}
}
}
return NewPolynomial(quotCoeffs), NewPolynomial(rem)
}
// Interpolate interpolates a polynomial that passes through the given points (x, y).
// Uses Lagrange interpolation.
func Interpolate(domain, values []*FieldElement) Polynomial {
if len(domain) != len(values) {
panic("domain and values length mismatch")
}
n := len(domain)
p := NewPolynomial([]*FieldElement{Zero(fieldPrime)}) // Initialize as zero polynomial
for i := 0; i < n; i++ {
// Numerator polynomial: product of (x - domain[j]) for j != i
numCoeffs := []*FieldElement{One(fieldPrime)} // Start with 1
numPoly := NewPolynomial(numCoeffs)
for j := 0; j < n; j++ {
if i == j {
continue
}
// (x - domain[j]) is represented as NewPolynomial({-domain[j], 1})
numPoly = numPoly.Mul(NewPolynomial([]*FieldElement{domain[j].Neg(), One(fieldPrime)}))
}
// Denominator value: product of (domain[i] - domain[j]) for j != i
denVal := One(fieldPrime)
for j := 0; j < n; j++ {
if i == j {
continue
}
denVal = denVal.Mul(domain[i].Sub(domain[j]))
}
// Calculate the Lagrange basis polynomial term: values[i] * numPoly / denVal
termScale := values[i].Div(denVal)
p = p.Add(numPoly.Scale(termScale))
}
return p
}
// Scale multiplies each coefficient by a scalar.
func (p Polynomial) Scale(s *FieldElement) Polynomial {
resCoeffs := make(Polynomial, len(p))
for i, c := range p {
resCoeffs[i] = c.Mul(s)
}
return NewPolynomial(resCoeffs)
}
// Neg negates all coefficients.
func (p Polynomial) Neg() Polynomial {
resCoeffs := make([]*FieldElement, len(p))
for i, c := range p {
resCoeffs[i] = c.Neg()
}
return NewPolynomial(resCoeffs)
}
// ScaleByPower scales each coefficient c_i by scalar^i.
func (p Polynomial) ScaleByPower(scalar *FieldElement) Polynomial {
resCoeffs := make(Polynomial, len(p))
for i, c := range p {
power := big.NewInt(int64(i)) // Note: int64(i) is fine for current traceLength, but be mindful of overflow for very large degrees.
resCoeffs[i] = c.Mul(scalar.Exp(power))
}
return NewPolynomial(resCoeffs)
}
// =============================================================================
// MERKLE TREE
// =============================================================================
// MerkleNode represents a node in the Merkle tree.
type MerkleNode struct {
Left *MerkleNode
Right *MerkleNode
Hash []byte
}
// NewMerkleNode creates a new Merkle node.
func NewMerkleNode(left, right *MerkleNode, data []byte) *MerkleNode {
node := &MerkleNode{Left: left, Right: right}
if left == nil && right == nil {
// Leaf node: hash the data directly
hash := sha256.Sum256(data)
node.Hash = hash[:]
} else {
// Internal node: hash the concatenation of child hashes
var combined []byte
if right == nil { // Handle odd number of leaves by duplicating the last node
combined = append(left.Hash, left.Hash...) // Hash of (L, L)
} else {
combined = append(left.Hash, right.Hash...)
}
hash := sha256.Sum256(combined)
node.Hash = hash[:]
}
return node
}
// NewMerkleTree builds a Merkle tree from a list of byte slices (leaves).
// Returns the root node and all levels of the tree for path generation.
func NewMerkleTree(data [][]byte) (*MerkleNode, [][]*MerkleNode) {
var nodes []*MerkleNode
for _, d := range data {
nodes = append(nodes, NewMerkleNode(nil, nil, d))
}
levels := [][]*MerkleNode{nodes}
for len(nodes) > 1 {
var newLevel []*MerkleNode
for i := 0; i < len(nodes); i += 2 {
if i+1 < len(nodes) {
newLevel = append(newLevel, NewMerkleNode(nodes[i], nodes[i+1], nil))
} else {
// Duplicate the last node if there's an odd number of nodes at this level
newLevel = append(newLevel, NewMerkleNode(nodes[i], nodes[i], nil))
}
}
nodes = newLevel
levels = append(levels, nodes)
}
return nodes[0], levels
}
// GetAuthenticationPath returns the Merkle path for a leaf at a given index.
// The path consists of sibling hashes from the leaf up to the root.
func GetAuthenticationPath(index int, levels [][]*MerkleNode) [][]byte {
var path [][]byte
// Iterate through levels from leaves up to the root's children
for i := 0; i < len(levels)-1; i++ {
// Sibling index: if current index is even, sibling is index+1; if odd, sibling is index-1.
// XORing with 1 effectively flips the last bit, finding the sibling.
siblingIndex := index ^ 1
if siblingIndex < len(levels[i]) { // Ensure sibling exists in the current level
path = append(path, levels[i][siblingIndex].Hash)
} else {
// This case should ideally not happen if tree construction handles odd levels correctly
// (by duplicating the last node, ensuring a sibling always exists).
// For robustness, we can duplicate the current node's hash if no sibling.
path = append(path, levels[i][index].Hash) // Use self as sibling if no other
}
index /= 2 // Move to the parent's index in the next level
}
return path
}
// VerifyMerklePath verifies a Merkle path for a given leaf and root.
// It reconstructs the root hash from the leaf data and the path, then compares.
func VerifyMerklePath(root, leafData []byte, path [][]byte, index int) bool {
currentHash := sha256.Sum256(leafData) // Hash the leaf data first
for _, siblingHash := range path {
var combined []byte
if index%2 == 0 { // If current node was a left child
combined = append(currentHash[:], siblingHash...)
} else { // If current node was a right child
combined = append(siblingHash, currentHash[:]...)
}
newHash := sha256.Sum256(combined)
currentHash = newHash
index /= 2 // Move up to the parent level
}
return hex.EncodeToString(root) == hex.EncodeToString(currentHash[:])
}
// =============================================================================
// FIAT-SHAMIR TRANSCRIPT
// =============================================================================
// Transcript handles the Fiat-Shamir transform to make the protocol non-interactive.
// It deterministically generates challenges based on the proof elements.
type Transcript struct {
state []byte // The current state of the transcript, updated with hashes
}
func NewTranscript() *Transcript {
return &Transcript{state: []byte{}}
}
// Add adds data to the transcript state using hash-chaining.
func (t *Transcript) Add(data []byte) {
h := sha256.New()
h.Write(t.state)
h.Write(data)
t.state = h.Sum(nil)
}
// GetChallenge generates a pseudo-random field element challenge from the current state.
func (t *Transcript) GetChallenge() *FieldElement {
hash := sha256.Sum256(t.state)
// Update state for next challenge to ensure uniqueness
t.state = hash[:]
// Use the hash bytes to generate a large integer, then reduce it modulo fieldPrime
val := new(big.Int).SetBytes(hash[:])
return &FieldElement{value: val.Mod(val, big.NewInt(fieldPrime)), prime: big.NewInt(fieldPrime)}
}
// GetQueryIndices generates pseudo-random indices for FRI queries.
// It ensures the indices are unique and within the specified maximum index.
func (t *Transcript) GetQueryIndices(maxIndex int, count int) []int {
var indices []int
seen := make(map[int]bool) // To track unique indices
for len(indices) < count {
hash := sha256.Sum256(t.state)
t.state = hash[:] // Update state for next challenge
// Use the first 8 bytes of the hash to generate a uint64
val := binary.BigEndian.Uint64(hash[:8])
index := int(val % uint64(maxIndex))
if !seen[index] {
indices = append(indices, index)
seen[index] = true
}
}
sort.Ints(indices) // Sort for deterministic behavior (optional but good practice)
return indices
}
// =============================================================================
// STARK PROOF STRUCTURES
// =============================================================================
// FriProof contains the proof for a single FRI layer.
type FriProof struct {
Root []byte // Merkle root of the evaluations for this layer
Paths [][][]byte // Merkle paths for the queried indices
Values [][]*FieldElement // Values at the queried indices for this layer
// For the final layer, this contains the coefficients of the final polynomial
FinalCoeffs []*FieldElement
}
// StarkProof contains the entire proof.
type StarkProof struct {
TraceRoot []byte
CompositionRoot []byte
FriProofs []FriProof
// Evaluations of trace and composition polynomials at the DEEP point z
TraceEvalAtZ *FieldElement
TraceEvalAtGZ *FieldElement
TraceEvalAtG2Z *FieldElement // NEW: Added P(g^2*z) evaluation
CompositionEvalAtZ *FieldElement
// Query data for the initial trace and composition polynomials
TraceQueryData [][]byte // Leaf data for trace polynomial evaluations at query indices
TraceQueryPaths [][][]byte
CompositionQueryData [][]byte // Leaf data for composition polynomial evaluations at query indices
CompositionQueryPaths [][][]byte
}
// =============================================================================
// PROVER
// =============================================================================
type Prover struct {
prime int64
g *FieldElement // Generator of the entire multiplicative group Z_p^*
traceDomainGen *FieldElement // Generator of the trace evaluation domain
evalDomainGen *FieldElement // Generator of the larger evaluation domain (for FRI)
traceDomain []*FieldElement
evalDomain []*FieldElement
trace []*FieldElement
transcript *Transcript
}
func NewProver() *Prover {
p := &Prover{
prime: fieldPrime,
transcript: NewTranscript(),
}
// Find a generator for the multiplicative group Z_p^*.
// This `g` is used to construct the evaluation domains.
p.g = findGenerator(p.prime)
// Create domains
// Trace domain: powers of traceDomainGen for traceLength points.
// traceDomainGen is g^((p-1)/traceLength).
traceDomainSize := int64(traceLength)
powerTrace := new(big.Int).Div(big.NewInt(p.prime-1), big.NewInt(traceDomainSize))
p.traceDomainGen = p.g.Exp(powerTrace)
p.traceDomain = make([]*FieldElement, traceDomainSize)
for i := int64(0); i < traceDomainSize; i++ {
p.traceDomain[i] = p.traceDomainGen.Exp(big.NewInt(i))
}
// Evaluation domain (for FRI): powers of evalDomainGen for traceLength * blowupFactor points.
// evalDomainGen is g^((p-1)/(traceLength * blowupFactor)).
evalDomainSize := int64(traceLength * blowupFactor)
powerEval := new(big.Int).Div(big.NewInt(p.prime-1), big.NewInt(evalDomainSize))
p.evalDomainGen = p.g.Exp(powerEval)
p.evalDomain = make([]*FieldElement, evalDomainSize)
for i := int64(0); i < evalDomainSize; i++ {
p.evalDomain[i] = p.evalDomainGen.Exp(big.NewInt(i))
}
return p
}
// GenerateTrace creates the execution trace for our computation.
// For this example, it's a Fibonacci-like sequence: a_n = a_{n-1}^2 + a_{n-2}^2
func (p *Prover) GenerateTrace() {
p.trace = make([]*FieldElement, traceLength)
p.trace[0] = NewFieldElement(2, p.prime)
p.trace[1] = NewFieldElement(3, p.prime)
for i := 2; i < traceLength; i++ {
t1 := p.trace[i-1].Mul(p.trace[i-1])
t2 := p.trace[i-2].Mul(p.trace[i-2])
p.trace[i] = t1.Add(t2)
}
}
func (p *Prover) Prove() *StarkProof {
// 1. Generate and commit to the execution trace
p.GenerateTrace()
fmt.Printf("Trace: %v\n", p.trace) // DEBUG PRINT
// Interpolate the trace points to get the trace polynomial P(x)
tracePoly := Interpolate(p.traceDomain, p.trace)
// Evaluate P(x) on the larger evaluation domain for FRI
traceEvals := make([]*FieldElement, len(p.evalDomain))
traceEvalsBytes := make([][]byte, len(p.evalDomain))
for i, x := range p.evalDomain {
traceEvals[i] = tracePoly.Evaluate(x)
traceEvalsBytes[i] = traceEvals[i].Bytes32() // Use Bytes32 for fixed-length hashing
}
// Commit to the trace evaluations with a Merkle tree
traceMerkleRoot, traceMerkleLevels := NewMerkleTree(traceEvalsBytes)
p.transcript.Add(traceMerkleRoot.Hash) // Add root to transcript for Fiat-Shamir
// 2. Define AIR constraints and create the low-degree composition polynomial C(x)
// Constraint: P(g^2*x) - P(g*x)^2 - P(x)^2 = 0
// This constraint holds for points in the trace domain, specifically traceDomain[0] to traceDomain[traceLength-3].
// Let constraint_poly_numerator(x) = P(g^2*x) - P(g*x)^2 - P(x)^2.
// The composition polynomial C(x) = constraint_poly_numerator(x) / Z(x), where Z(x) is the vanishing polynomial for the points
// where the constraint holds.
// Step 2a: Construct the constraint polynomial (numerator) in coefficient form.
// P(x) is tracePoly
p_x_poly := tracePoly
gTraceSquared := p.traceDomainGen.Mul(p.traceDomainGen)
// P(g*x) in coefficient form
p_gx_poly := tracePoly.ScaleByPower(p.traceDomainGen)
// P(g^2*x) in coefficient form
p_g2x_poly := tracePoly.ScaleByPower(gTraceSquared)
// Calculate P(g*x)^2 and P(x)^2
p_gx_squared_poly := p_gx_poly.Mul(p_gx_poly)
p_x_squared_poly := p_x_poly.Mul(p_x_poly)
// Construct the actual constraint polynomial (numerator)
// constraint_poly_numerator(x) = P(g^2*x) - P(g*x)^2 - P(x)^2
actualConstraintPolyNumerator := p_g2x_poly.Sub(p_gx_squared_poly).Sub(p_x_squared_poly)
// DEBUG: Evaluate this polynomial on the constraint domain to ensure it's zero there
constraintDomainForInterpolation := p.traceDomain[:traceLength-2] // Points x_0, ..., x_{n-3}
debugConstraintEvals := make([]*FieldElement, traceLength-2)
for i, x := range constraintDomainForInterpolation {
debugConstraintEvals[i] = actualConstraintPolyNumerator.Evaluate(x)
}
fmt.Printf("DEBUG: Constraint Evaluations on Constraint Domain (from numerator poly): %v\n", debugConstraintEvals) // DEBUG PRINT
// Step 2c: Define the vanishing polynomial Z(x) in coefficient form.
// Z(x) = product_{i=0}^{traceLength-3} (x - traceDomain[i])
vanishingPoly := NewPolynomial([]*FieldElement{One(p.prime)}) // Start with 1
for i := 0; i < traceLength-2; i++ { // Iterate over traceDomain[0] to traceDomain[traceLength-3]
term := NewPolynomial([]*FieldElement{p.traceDomain[i].Neg(), One(p.prime)}) // (x - traceDomain[i])
vanishingPoly = vanishingPoly.Mul(term)
}
// Step 2d: Perform polynomial division to get the low-degree composition polynomial C(x).
// C(x) = actualConstraintPolyNumerator(x) / vanishingPoly(x)
actualCompositionPoly, rem := actualConstraintPolyNumerator.Div(vanishingPoly)
if rem.Degree() != -1 { // Remainder must be zero
panic(fmt.Sprintf("Actual composition polynomial division remainder not zero. Remainder degree: %d", rem.Degree()))
}
fmt.Printf("Actual Composition Poly Degree: %d\n", actualCompositionPoly.Degree()) // Debugging
// Step 2e: Evaluate this actual low-degree composition polynomial on the large evaluation domain.
// These are the evaluations that will be committed to.
compositionEvals := make([]*FieldElement, len(p.evalDomain))
compositionEvalsBytes := make([][]byte, len(p.evalDomain))
for i, x := range p.evalDomain {
compositionEvals[i] = actualCompositionPoly.Evaluate(x) // Evaluate the low-degree poly
compositionEvalsBytes[i] = compositionEvals[i].Bytes32() // Use Bytes32 for fixed-length hashing
}
// Commit to the composition polynomial evaluations
compositionMerkleRoot, compositionMerkleLevels := NewMerkleTree(compositionEvalsBytes)
p.transcript.Add(compositionMerkleRoot.Hash)
// 3. DEEP step: Evaluate polynomials at a random point z
z := p.transcript.GetChallenge() // Random challenge point from the verifier
// Evaluate P(x) at z and g*z and g^2*z
traceEvalAtZ := tracePoly.Evaluate(z)
traceEvalAtGZ := tracePoly.Evaluate(z.Mul(p.traceDomainGen))
traceEvalAtG2Z := tracePoly.Evaluate(z.Mul(gTraceSquared))
// Evaluate the *actual low-degree* composition polynomial at z
compositionEvalAtZ := actualCompositionPoly.Evaluate(z) // Use the correct low-degree poly
// Add these evaluations to the transcript
p.transcript.Add(traceEvalAtZ.Bytes32()) // Use Bytes32
p.transcript.Add(traceEvalAtGZ.Bytes32()) // Use Bytes32
p.transcript.Add(traceEvalAtG2Z.Bytes32()) // Use Bytes32
p.transcript.Add(compositionEvalAtZ.Bytes32()) // Use Bytes32
// Construct DEEP polynomials:
// P_z(x) = (P(x) - P(z)) / (x - z)
// P_gz(x) = (P(x) - P(g*z)) / (x - g*z)
// C_z(x) = (C(x) - C(z)) / (x - z)
// Denominators for DEEP polynomials
xMinusZ := NewPolynomial([]*FieldElement{z.Neg(), One(p.prime)})
xMinusGZ := NewPolynomial([]*FieldElement{z.Mul(p.traceDomainGen).Neg(), One(p.prime)})
// Numerators for DEEP polynomials
pMinusPZ := tracePoly.Add(NewPolynomial([]*FieldElement{traceEvalAtZ.Neg()}))
pMinusPGZ := tracePoly.Add(NewPolynomial([]*FieldElement{traceEvalAtGZ.Neg()}))
cMinusCZ := actualCompositionPoly.Add(NewPolynomial([]*FieldElement{compositionEvalAtZ.Neg()})) // Use actualCompositionPoly
// Perform polynomial division to get DEEP polynomials
p_z, rem1 := pMinusPZ.Div(xMinusZ)
if rem1.Degree() != -1 { // Check if remainder is zero polynomial (degree -1)
panic("P_z division remainder not zero")
}
p_gz, rem2 := pMinusPGZ.Div(xMinusGZ)
if rem2.Degree() != -1 {
panic("P_gz division remainder not zero")
}
c_z, rem3 := cMinusCZ.Div(xMinusZ)
if rem3.Degree() != -1 {
panic("C_z division remainder not zero")
}
// Get random challenges for linear combination of DEEP polynomials
alpha1 := p.transcript.GetChallenge()
alpha2 := p.transcript.GetChallenge()
alpha3 := p.transcript.GetChallenge()
// Form the combined DEEP polynomial: D(x) = alpha1*P_z(x) + alpha2*P_gz(x) + alpha3*C_z(x)
dPoly := p_z.Scale(alpha1).Add(p_gz.Scale(alpha2)).Add(c_z.Scale(alpha3))
// 4. FRI Protocol
// The first polynomial for FRI is the combined DEEP polynomial D(x).
currentFriPoly := dPoly
currentFriDomain := p.evalDomain // Start with the full evaluation domain
var friProofs []FriProof
var allFriEvals [][]*FieldElement // Store all evaluations for query phase
for currentFriPoly.Degree() > 0 {
// Evaluate the current FRI polynomial on its domain
currentFriEvals := make([]*FieldElement, len(currentFriDomain))
for i, x := range currentFriDomain {
currentFriEvals[i] = currentFriPoly.Evaluate(x)
}
allFriEvals = append(allFriEvals, currentFriEvals) // Store for later queries
// Commit to the current FRI layer's evaluations
friEvalsBytes := make([][]byte, len(currentFriEvals))
for i, e := range currentFriEvals {
friEvalsBytes[i] = e.Bytes32() // Use Bytes32 for fixed-length hashing
}
friRoot, _ := NewMerkleTree(friEvalsBytes) // Discard friLevels here, as they are regenerated later
friProofs = append(friProofs, FriProof{Root: friRoot.Hash})
p.transcript.Add(friRoot.Hash) // Add root to transcript
// Get folding challenge for the next layer
alpha := p.transcript.GetChallenge()
// Fold the polynomial: P_next(x^2) = (P(x) + P(-x))/2 + alpha * (P(x) - P(-x))/(2x)
// This is equivalent to: P_next(y) = P_even(y) + alpha * P_odd(y) where y = x^2
// P_even(x) = sum(c_2i * x^(2i)) and P_odd(x) = sum(c_2i+1 * x^(2i+1))
// P(x) = P_even(x) + P_odd(x)
// P_even(x) = (P(x) + P(-x)) / 2
// P_odd(x) = (P(x) - P(-x)) / 2
// P_next(y) = (P(sqrt(y)) + P(-sqrt(y)))/2 + alpha * (P(sqrt(y)) - P(-sqrt(y)))/(2*sqrt(y))
// The polynomial `currentFriPoly` has coefficients c_0, c_1, c_2, ...
// P_even(x) has coefficients c_0, c_2, c_4, ...
// P_odd(x) has coefficients c_1, c_3, c_5, ...
// P_next(y) = (c_0 + c_2*y + c_4*y^2 + ...) + alpha * (c_1 + c_3*y + c_5*y^2 + ...)
nextPolyCoeffs := make([]*FieldElement, (currentFriPoly.Degree()/2)+1)
for i := 0; i < len(nextPolyCoeffs); i++ {
evenCoeff := currentFriPoly[2*i]
oddCoeff := Zero(p.prime)
if 2*i+1 < len(currentFriPoly) { // Check if odd coefficient exists
oddCoeff = currentFriPoly[2*i+1]
}
nextPolyCoeffs[i] = evenCoeff.Add(alpha.Mul(oddCoeff))
}
currentFriPoly = NewPolynomial(nextPolyCoeffs)
// Update the domain for the next layer: x -> x^2
nextDomainSize := len(currentFriDomain) / 2
nextDomain := make([]*FieldElement, nextDomainSize)
for i := 0; i < nextDomainSize; i++ {
nextDomain[i] = currentFriDomain[i].Mul(currentFriDomain[i])
}
currentFriDomain = nextDomain
}
// The last layer is a constant polynomial (degree 0)
finalCoeffs := []*FieldElement{currentFriPoly[0]}
p.transcript.Add(finalCoeffs[0].Bytes32()) // Use Bytes32
friProofs = append(friProofs, FriProof{FinalCoeffs: finalCoeffs})
// 5. FRI Query Phase
// Get random query indices from the transcript. These indices are for the *initial* FRI domain.
queryIndices := p.transcript.GetQueryIndices(len(p.evalDomain), numQueries)
// Collect query data for the initial trace and composition polynomials
traceQueryData := make([][]byte, numQueries)
traceQueryPaths := make([][][]byte, numQueries)
compositionQueryData := make([][]byte, numQueries)
compositionQueryPaths := make([][][]byte, numQueries)
for i, index := range queryIndices {
traceQueryData[i] = traceEvals[index].Bytes32() // Use Bytes32
traceQueryPaths[i] = GetAuthenticationPath(index, traceMerkleLevels)
compositionQueryData[i] = compositionEvals[index].Bytes32() // Use Bytes32
compositionQueryPaths[i] = GetAuthenticationPath(index, compositionMerkleLevels)
}
// Collect query data for all FRI layers
// This requires re-generating the Merkle trees for each layer to get the levels.
// In a real implementation, these levels would be stored during the commit phase.
// For simplicity and to avoid storing all levels, we regenerate them here.
// This is inefficient but demonstrates the concept.
for layerIdx := range friProofs {
if friProofs[layerIdx].FinalCoeffs != nil {
break // Skip the final layer as it has no Merkle root/paths
}
// Re-evaluate the polynomial for this layer and build its Merkle tree
// This is highly inefficient; in practice, you'd store `friLevels` for each layer.
// To get the correct evaluations for a specific layer, we need to re-run the folding
// up to that layer. This is complex to do on the fly.
// A simpler approach for this toy example is to use the stored `allFriEvals`.
currentLayerEvals := allFriEvals[layerIdx]
_, currentLayerMerkleLevels := NewMerkleTree(fieldElementsToBytes32(currentLayerEvals)) // Use Bytes32
layerPaths := make([][][]byte, numQueries)
layerValues := make([][]*FieldElement, numQueries)
for i, initialQueryIndex := range queryIndices {
// The query index for a specific FRI layer `k` is `initialQueryIndex / (2^k)`.
// We need to ensure the index is still within the domain of that layer.
layerQueryIndex := initialQueryIndex >> layerIdx // Divide by 2^layerIdx
if layerQueryIndex >= len(currentLayerEvals) {
// This can happen if the initial query index is too large for a later, smaller domain.
// In a real FRI, the query indices are generated such that they map correctly to all layers.
// For this toy, we'll just skip or handle as an error.
fmt.Printf("Warning: Query index %d out of bounds for FRI layer %d (domain size %d)\n", layerQueryIndex, layerIdx, len(currentLayerEvals))
// Provide dummy data or handle gracefully. For now, we'll use 0.
layerPaths[i] = [][]byte{}
layerValues[i] = []*FieldElement{Zero(p.prime)}
continue
}
layerPaths[i] = GetAuthenticationPath(layerQueryIndex, currentLayerMerkleLevels)
layerValues[i] = []*FieldElement{currentLayerEvals[layerQueryIndex]}
}
friProofs[layerIdx].Paths = layerPaths
friProofs[layerIdx].Values = layerValues
}
return &StarkProof{
TraceRoot: traceMerkleRoot.Hash,
CompositionRoot: compositionMerkleRoot.Hash,
FriProofs: friProofs,
TraceEvalAtZ: traceEvalAtZ,
TraceEvalAtGZ: traceEvalAtGZ,
TraceEvalAtG2Z: traceEvalAtG2Z, // NEW: Include in proof
CompositionEvalAtZ: compositionEvalAtZ,
TraceQueryData: traceQueryData,
TraceQueryPaths: traceQueryPaths,
CompositionQueryData: compositionQueryData,
CompositionQueryPaths: compositionQueryPaths,
}
}
// =============================================================================
// VERIFIER
// =============================================================================
type Verifier struct {
prime int64
g *FieldElement
traceDomainGen *FieldElement
evalDomainGen *FieldElement
traceDomain []*FieldElement // Added for verifier
evalDomain []*FieldElement // Added for verifier
transcript *Transcript
}
func NewVerifier() *Verifier {
v := &Verifier{
prime: fieldPrime,
transcript: NewTranscript(),
}
v.g = findGenerator(v.prime)
// Initialize traceDomain and evalDomain for the verifier
traceDomainSize := int64(traceLength)
powerTrace := new(big.Int).Div(big.NewInt(v.prime-1), big.NewInt(traceDomainSize))
v.traceDomainGen = v.g.Exp(powerTrace)
v.traceDomain = make([]*FieldElement, traceDomainSize)
for i := int64(0); i < traceDomainSize; i++ {
v.traceDomain[i] = v.traceDomainGen.Exp(big.NewInt(i))
}
evalDomainSize := int64(traceLength * blowupFactor)
powerEval := new(big.Int).Div(big.NewInt(v.prime-1), big.NewInt(evalDomainSize))
v.evalDomainGen = v.g.Exp(powerEval)
v.evalDomain = make([]*FieldElement, evalDomainSize)
for i := int64(0); i < evalDomainSize; i++ {
v.evalDomain[i] = v.evalDomainGen.Exp(big.NewInt(i))
}
return v
}
func (v *Verifier) Verify(proof *StarkProof) bool {
// Rebuild transcript state to get the same random challenges as the prover
v.transcript.Add(proof.TraceRoot)
v.transcript.Add(proof.CompositionRoot)
z := v.transcript.GetChallenge() // Re-derive challenge z
// Re-add evaluations at z to transcript to derive subsequent challenges
v.transcript.Add(proof.TraceEvalAtZ.Bytes32()) // Use Bytes32
v.transcript.Add(proof.TraceEvalAtGZ.Bytes32()) // Use Bytes32
v.transcript.Add(proof.TraceEvalAtG2Z.Bytes32()) // Use Bytes32
v.transcript.Add(proof.CompositionEvalAtZ.Bytes32()) // Use Bytes32
// Re-derive DEEP challenges
alpha1 := v.transcript.GetChallenge()
alpha2 := v.transcript.GetChallenge()
alpha3 := v.transcript.GetChallenge()
// Use these challenges to avoid 'declared and not used' error, as the full DEEP check is simplified.
_ = alpha1
_ = alpha2
_ = alpha3
// 1. DEEP Consistency Check
// The verifier reconstructs the expected value of the combined DEEP polynomial at x=z
// based on the constraint equation and the provided evaluations P(z), P(g*z), C(z).
// This is a simplified check. A full DEEP check involves ensuring the DEEP polynomials
// are indeed low-degree, which is handled by the FRI protocol. Here we check the
// algebraic consistency at z.
// Check the AIR constraint at z: P(g^2*z) - P(g*z)^2 - P(z)^2 = 0
// This should hold for the trace polynomial.
gTraceSquared := v.traceDomainGen.Mul(v.traceDomainGen)
// Use gTraceSquared to avoid 'declared and not used' error.
_ = gTraceSquared
// NEW: Use proof.TraceEvalAtG2Z for P(g^2*z)
expectedConstraintEvalAtZ := proof.TraceEvalAtG2Z.Sub(proof.TraceEvalAtGZ.Mul(proof.TraceEvalAtGZ)).Sub(proof.TraceEvalAtZ.Mul(proof.TraceEvalAtZ))
// Vanishing polynomial at z
// Reconstruct vanishing polynomial in coefficient form
verifierVanishingPoly := NewPolynomial([]*FieldElement{One(v.prime)})
for i := 0; i < traceLength-2; i++ {
term := NewPolynomial([]*FieldElement{v.traceDomain[i].Neg(), One(v.prime)})
verifierVanishingPoly = verifierVanishingPoly.Mul(term)
}
vanishingPolyAtZ := verifierVanishingPoly.Evaluate(z)
// C(z) * Z(z) should equal the constraint evaluation at z
expectedConstraintFromComposition := proof.CompositionEvalAtZ.Mul(vanishingPolyAtZ)
// If vanishingPolyAtZ is zero, it means z is a root of the vanishing polynomial.
// In this case, expectedConstraintEvalAtZ must also be zero.
if vanishingPolyAtZ.IsZero() {
if !expectedConstraintEvalAtZ.IsZero() {
fmt.Println("DEEP consistency check failed: Constraint should be zero at vanishing point, but isn't.")
return false
}
} else {
if !expectedConstraintFromComposition.Equal(expectedConstraintEvalAtZ) {
fmt.Println("DEEP consistency check failed: C(z) * Z(z) != Constraint(z)")
fmt.Printf("Expected Constraint(z): %s, C(z)*Z(z): %s\n", expectedConstraintEvalAtZ, expectedConstraintFromComposition)
return false
}
}
fmt.Println("DEEP consistency check passed.")
// 2. FRI Verification
// Rebuild FRI challenges based on roots provided in the proof
var friChallenges []*FieldElement
for _, friProof := range proof.FriProofs {
if friProof.FinalCoeffs != nil {
v.transcript.Add(friProof.FinalCoeffs[0].Bytes32()) // Use Bytes32
} else {
v.transcript.Add(friProof.Root)
friChallenges = append(friChallenges, v.transcript.GetChallenge())
}
}
// Get query indices (same as prover)
queryIndices := v.transcript.GetQueryIndices(traceLength*blowupFactor, numQueries)
// 3. Verify query data for initial trace and composition polynomials
for i, index := range queryIndices {
// Verify trace polynomial evaluation
ok := VerifyMerklePath(proof.TraceRoot, proof.TraceQueryData[i], proof.TraceQueryPaths[i], index)
if !ok {
fmt.Printf("Trace Merkle path verification failed for query %d (index %d)\n", i, index)
return false
}
// Verify composition polynomial evaluation
ok = VerifyMerklePath(proof.CompositionRoot, proof.CompositionQueryData[i], proof.CompositionQueryPaths[i], index)
if !ok {
fmt.Printf("Composition Merkle path verification failed for query %d (index %d)\n", i, index)
return false
}
}
// 4. Verify FRI folding consistency
// This is the core of FRI. For each query, the verifier checks that the folding
// from one FRI layer to the next was done correctly.
// The values at the queried points for each layer are provided in the proof.
currentFriDomain := make([]*FieldElement, len(v.evalDomain))
copy(currentFriDomain, v.evalDomain) // Copy initial eval domain
// Iterate through each FRI layer (except the last one which is the final polynomial)
for layerIdx := 0; layerIdx < len(proof.FriProofs)-1; layerIdx++ {
friProofLayer := proof.FriProofs[layerIdx]
//nextFriProofLayer := proof.FriProofs[layerIdx+1] // Not directly used in this simplified check
alpha := friChallenges[layerIdx] // Get the challenge for this folding step
_ = alpha // Use alpha to avoid 'declared and not used' error.
for i, initialQueryIndex := range queryIndices {
// Calculate the actual query index for the current layer's domain
// This index is `initialQueryIndex / (2^k)`
currentLayerQueryIndex := initialQueryIndex >> layerIdx
if currentLayerQueryIndex >= len(currentFriDomain) {
// This query index is out of bounds for this layer's domain.
// This indicates an issue with query generation or domain size.
fmt.Printf("Error: Query index %d out of bounds for FRI layer %d (domain size %d)\n", currentLayerQueryIndex, layerIdx, len(currentFriDomain))
return false
}
// Get the queried values and paths for the current layer
currentLayerValue := friProofLayer.Values[i][0]
currentLayerPath := friProofLayer.Paths[i]
// Verify the Merkle path for the current layer's queried value
ok := VerifyMerklePath(friProofLayer.Root, currentLayerValue.Bytes32(), currentLayerPath, currentLayerQueryIndex) // Use Bytes32
if !ok {
fmt.Printf("FRI Merkle path verification failed for layer %d, query %d\n", layerIdx, i)
return false
}
// The full folding check would involve verifying that:
// P_next(x^2) = P_even(x^2) + alpha * P_odd(x^2)
// This would require the prover to provide P(x) and P(-x) (or equivalent) for each query point.
// Given the current `FriProof.Values` only stores one value per query,
// a full algebraic folding check cannot be performed directly without modifying the proof structure.
// For this example, we proceed with Merkle path verification for all layers.
}
// Update the domain for the next iteration
nextDomainSize := len(currentFriDomain) / 2
nextDomain := make([]*FieldElement, nextDomainSize)
for i := 0; i < nextDomainSize; i++ {
nextDomain[i] = currentFriDomain[i].Mul(currentFriDomain[i])
}
currentFriDomain = nextDomain
}
// 5. Final FRI Layer Check
// The last FRI layer must be a constant polynomial (degree 0).
// Its single coefficient is provided in proof.FriProofs[lastLayer].FinalCoeffs[0].
lastFriLayer := proof.FriProofs[len(proof.FriProofs)-1]
if len(lastFriLayer.FinalCoeffs) != 1 {
fmt.Println("FRI final layer check failed: Expected a constant polynomial.")
return false
}
// finalConstant := lastFriLayer.FinalCoeffs[0] // Declared but not used, can be removed or used.
// The verifier should also check that the queried values from the second to last layer
// would fold into this final constant. This is the last folding check.
// This requires the same logic as the folding check above, applied to the last two layers.
// For simplicity, we are skipping the explicit folding check in the loop for now.
fmt.Println("FRI Merkle path and final constant checks passed.")
return true
}
// =============================================================================
// HELPER FUNCTIONS & MAIN
// =============================================================================
// findGenerator finds a generator of the multiplicative group of the finite field Z_p^*.
// It iterates through numbers and checks if they generate the group by verifying
// that g^((p-1)/q) != 1 for all prime factors q of (p-1).
func findGenerator(prime int64) *FieldElement {
p := big.NewInt(prime)
pMinus1 := new(big.Int).Sub(p, big.NewInt(1))
factors := primeFactors(pMinus1)
for i := int64(2); i < prime; i++ {
g := NewFieldElement(i, prime)
isGen := true
for _, factor := range factors {
power := new(big.Int).Div(pMinus1, factor)
if g.Exp(power).Equal(One(prime)) {
isGen = false
break
}
}
if isGen {
return g
}
}
panic("no generator found")
}
// primeFactors finds the distinct prime factors of a number.
func primeFactors(n *big.Int) []*big.Int {
factors := make(map[string]*big.Int) // Use a map to store unique factors
d := big.NewInt(2)
num := new(big.Int).Set(n)
zero := big.NewInt(0)
for new(big.Int).Mul(d, d).Cmp(num) <= 0 { // Loop until d*d > num
for new(big.Int).Mod(num, d).Cmp(zero) == 0 { // While d divides num
factors[d.String()] = new(big.Int).Set(d) // Add d to factors
num.Div(num, d) // Divide num by d
}
d.Add(d, big.NewInt(1)) // Increment d
}
if num.Cmp(big.NewInt(1)) > 0 { // Corrected: big.New(1) -> big.NewInt(1)
factors[num.String()] = num
}
var result []*big.Int
for _, f := range factors {
result = append(result, f)
}
return result
}
// fieldElementsToBytes32 converts a slice of FieldElement pointers to a slice of 32-byte slices.
func fieldElementsToBytes32(elements []*FieldElement) [][]byte {
bytesSlice := make([][]byte, len(elements))
for i, e := range elements {
bytesSlice[i] = e.Bytes32()
}
return bytesSlice
}
func main() {
fmt.Println("Starting STARK Prover...")
prover := NewProver()
proof := prover.Prove()
fmt.Println("Proof generated.")
fmt.Printf(" Trace Merkle Root: %x\n", proof.TraceRoot)
fmt.Printf(" Composition Merkle Root: %x\n", proof.CompositionRoot)
fmt.Printf(" Number of FRI Layers: %d\n", len(proof.FriProofs))
fmt.Println("\nStarting STARK Verifier...")
verifier := NewVerifier()
isValid := verifier.Verify(proof)
fmt.Printf("\nVerification result: %t\n", isValid)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment