Skip to content

Instantly share code, notes, and snippets.

@klauspost
Created January 13, 2025 12:32
Show Gist options
  • Save klauspost/4b0949c1c32d84d5edb7c21956b4fdc1 to your computer and use it in GitHub Desktop.
Save klauspost/4b0949c1c32d84d5edb7c21956b4fdc1 to your computer and use it in GitHub Desktop.
package main
import (
"bytes"
"fmt"
"os"
"sync"
"time"
"unsafe"
)
var thePool = sync.Pool{
New: func() interface{} { return make([]int32, 1024*1024*4) },
}
// levenshtein calculates the Levenshtein distance between two strings using Wagner-Fischer algorithm
// Space Complexity: O(min(m,n)) - only uses two rows instead of full matrix
// Time Complexity: O(m*n) where m and n are the lengths of the input strings
func levenshtein(s1, s2 []byte) int {
// Early termination checks
if bytes.Equal(s1, s2) {
return 0
}
if len(s1) == 0 {
return len(s2)
}
if len(s2) == 0 {
return len(s1)
}
// Make s1 the shorter string for space optimization
if len(s1) > len(s2) {
s1, s2 = s2, s1
}
// Use two rows instead of full matrix for space optimization
alloc := thePool.Get().([]int32)
defer thePool.Put(alloc)
prevRow := alloc[:len(s1)+1]
currRow := alloc[len(s1)+1:]
// Initialize first row
for i := range prevRow {
prevRow[i] = int32(i)
}
// Main computation loop
for j := range s2 {
currRow[0] = int32(j + 1)
// Move bound check out of loop...
currRowPlus1 := currRow[1:]
currRowPlus1 = currRowPlus1[:len(s1)]
prevRowPlus1 := prevRow[1:]
prevRowPlus1 = prevRowPlus1[:len(s1)]
currRowL := currRow[:len(s1)]
prevRowL := prevRow[:len(s1)]
s2v := int32(s2[j])
cost := int32(1)
for i, v := range s1 {
cost = min(1, int32(v)^s2v)
if true {
// Order seems to matter a lot...
currRowPlus1[i] = min(prevRowPlus1[i]+1, prevRowL[i]+cost, currRowL[i]+1)
} else {
// Find minimum of three operations
v := prevRowPlus1[i] + 1
if currRowL[i]+1 < v {
v = currRowL[i] + 1
}
if prevRowL[i]+cost < v {
v = prevRowL[i] + cost
}
currRowPlus1[i] = v
}
}
// Swap rows
prevRow, currRow = currRow, prevRow
}
return int(prevRow[len(s1)])
}
func main() {
args := os.Args[1:]
minD := -1
times := 0
if len(args) < 2 {
fmt.Println("Please provide at least two strings as arguments.")
return
}
start := time.Now()
for i := 0; i < len(args); i++ {
for j := 0; j < len(args); j++ {
if i != j {
d := levenshtein(
*(*[]byte)(unsafe.Pointer(&args[i])),
*(*[]byte)(unsafe.Pointer(&args[j])))
if minD == -1 || d < minD {
minD = d
}
times++
}
}
}
fmt.Printf("Time elapsed: %s\n", time.Since(start))
fmt.Printf("times: %d\n", times)
fmt.Printf("min_distance: %d\n", minD)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment