Skip to content

Instantly share code, notes, and snippets.

@m00nlight
Last active April 26, 2025 15:50
Show Gist options
  • Save m00nlight/daa6786cc503fde12a77 to your computer and use it in GitHub Desktop.
Save m00nlight/daa6786cc503fde12a77 to your computer and use it in GitHub Desktop.
Python KMP algorithm
class KMP:
def partial(self, pattern):
""" Calculate partial match table: String -> [Int]"""
ret = [0]
for i in range(1, len(pattern)):
j = ret[i - 1]
while j > 0 and pattern[j] != pattern[i]:
j = ret[j - 1]
ret.append(j + 1 if pattern[j] == pattern[i] else j)
return ret
def search(self, T, P):
"""
KMP search main algorithm: String -> String -> [Int]
Return all the matching position of pattern string P in T
"""
partial, ret, j = self.partial(P), [], 0
for i in range(len(T)):
while j > 0 and T[i] != P[j]:
j = partial[j - 1]
if T[i] == P[j]: j += 1
if j == len(P):
ret.append(i - (j - 1))
j = partial[j - 1]
return ret
def test():
p1 = "aa"
t1 = "aaaaaaaa"
kmp = KMP()
assert(kmp.search(t1, p1) == [0, 1, 2, 3, 4, 5, 6])
p2 = "abc"
t2 = "abdabeabfabc"
assert(kmp.search(t2, p2) == [9])
p3 = "aab"
t3 = "aaabaacbaab"
assert(kmp.search(t3, p3) == [1, 8])
p4 = "11"
t4 = "111"
assert(kmp.search(t4, p4) == [0, 1])
print("all test pass")
@steve-thousand
Copy link

Thanks for this!

@fengs
Copy link

fengs commented Nov 4, 2018

Here

            if j == len(P): 
                ret.append(i - (j - 1))
                j = 0

Should be

            if j == len(P): 
                ret.append(i - (j - 1))
                j = partial[j - 1]

That's the whole point of KMP not to start over again.

@srikommareddi
Copy link

can you share the replace via KPM too ?

@srikommareddi
Copy link

replace str where ever the match is found

@aladine
Copy link

aladine commented May 7, 2019

That is correct, line 26 should be j = partial[j - 1]
https://gist.github.com/m00nlight/daa6786cc503fde12a77#gistcomment-2750908

@ElonMoon
Copy link

Here

for i in range(1, len(pattern)):
    j = ret[j - 1]

Should be

for i in range(1, len(pattern)):
    j = ret[-1]

@m00nlight
Copy link
Author

FYI, this implementation doesn't account for overlapping solutions (e.g., for the input (111, 11), the output is [0] instead of [0, 1]). If you want to handle overlapping substrings, change line 26 to j = partial[j - 1].

Thank for pointing this out. It is indeed a bug of the original implementation. Fixed that.

Here

            if j == len(P): 
                ret.append(i - (j - 1))
                j = 0

Should be

            if j == len(P): 
                ret.append(i - (j - 1))
                j = partial[j - 1]

That's the whole point of KMP not to start over again.

Thanks for point this out also. fixed.

@perkfly
Copy link

perkfly commented Mar 26, 2020

in one function:

def kmp(t, p):
    """return all matching positions of p in t"""
    next = [0]
    j = 0
    for i in range(1, len(p)):
        while j > 0 and p[j] != p[i]:
            j = next[j - 1]
        if p[j] == p[i]:
            j += 1
        next.append(j)
    # the search part and build part is almost identical.
    ans = []
    j = 0
    for i in range(len(t)):
        while j > 0 and t[i] != p[j]:
            j = next[j - 1]
        if t[i] == p[j]:
            j += 1
        if j == len(p):
            ans.append(i - (j - 1))
            j = next[j - 1]
    return ans

def test():
    p1 = "aa"
    t1 = "aaaaaaaa"

    assert(kmp(t1, p1) == [0, 1, 2, 3, 4, 5, 6])

    p2 = "abc"
    t2 = "abdabeabfabc"

    assert(kmp(t2, p2) == [9])

    p3 = "aab"
    t3 = "aaabaacbaab"

    assert(kmp(t3, p3) == [1, 8])

    print("all test pass")


if __name__ == "__main__":
    test()

@HenryPaik1
Copy link

HenryPaik1 commented May 31, 2020

Hi, would you mind if I asking what j-1 means in line9: j = ret[j - 1](or why j-1)? I am having difficulties in getting this part. Thanks in advance.

@m00nlight
Copy link
Author

m00nlight commented Jun 20, 2020

Hi, would you mind if I asking what j-1 means in line9: j = ret[j - 1](or why j-1)? I am having difficulties in getting this part. Thanks in advance.

@HenryPaik1

The partial function calculates for the pattern string the longest prefix which is also the suffix up to the current point. Take some concrete example, for the pattern string p = "ababaa", it has the following substring start from the beginning

a
ab
aba
abab
ababa
ababaa

For each 1 <= i <= len(p), partial[i] is the longest prefix which is also a suffix of string p[0:i](notice the prefix and suffix should be strict substring). So for the above example, for the string "a" the partial number should be 0, for string "ab" partial number should be 0, for string "aba" partial number is 1 since the prefix "a" match the suffix "a", the string "abab" the partial number is 2, since prefix "ab" match suffix "ab",
the partial number for "ababa" is 3 since prefix "aba" is the longest prefix that matches the suffix, for string "ababaa" the number is 1, since the only prefix "a" match suffix "a".

The idea of calculating the value is for the current match, what is the longest prefix that is also the suffix up to this point. So when we calculate partial[i] we first set j to last calculate the point of the previous substring(partial[i - 1]) we know that for partial[i - 1] characters, it is both the prefix and suffix of the substring pattern[0..(i -1)](inclusive substring, not string slice in python). Now if we compare pattern[j] with pattern[i], if they are the same, we know the longest prefix which is also a suffix that can be extended. That's the line 10 of the code, otherwise, we need to jump even back until a match point. The j - 1 here is just indicates we need to jump back to the partial match of substring pattern[0..(j - 1)] since partial[i] is the match for string pattern[0..i](inclusive for the bound).

Hope this would be helpful for you to understand the algorithm. And I think you can refer to @igorp1 reply, since he uses the more meaningful variable name for the function, which should also be helpful for understanding the algorithm.

@jianhui-ben
Copy link

Thank you @m00nlight, super helpful! I wrote up my own version based on what I learn, if anyone sees anything wrong or anything that could be improved, please let me know!

class KMP:
    
    def partial(self, s):
        """ 
        Calculate partial match table: String -> [Int]
        
        # arrarra -> [0, 0, 0, 1, 2, 3, 4]
        # amar ->    [0, 0, 1, 0]
        # aaoiaa ->  [0, 1, 0, 0, 1, 2]
        """
        res = [0]
        for x in s[1:]:
            check_indx = res[-1]
            if s[check_indx] == x:
                res += [check_indx + 1]
            else:
                res += [0]
        
        return res
        
        
    def search(self, T, P):
        """ 
        KMP search main algorithm: String, String -> [Int] 
        Return all the matching position of pattern string P in S
        
        sample run :

        abcxabxdabxabcdabcdabcyabcdabcy     <~~ text T
        abcdabcy                            <~~ pattern P
        00001230                            <~~ mapping array
        ^^^^                                <~~ current comparison
        abcxabxdabxabcdabcdabcyabcdabcy
           abcdabcy
           00001230
           ^
        abcxabcdabxabcdabcdabcyabcdabcy
            abcdabcy
            00001230
            ^^^^^^^
        abcxabcdabxabcdabcdabcyabcdabcy
                abcdabcy
                00001230
                  ^

        abcxabcdabxabcdabcdabcyabcdabcy
                   abcdabcy
                   00001230
                   ^^^^^^^^

        abcxabcdabxabcdabcdabcyabcdabcy
                       abcdabcy
                       00001230
                          ^^^^^
                          
        abcxabcdabxabcdabcdabcyabcdabcy
                       abcdabcy
                       00001230
                              ^

        """
        mapping = self.partial(P)
        result = []
        
        p_pointer = 0
        t_pointer = 0
        
        while t_pointer < len(T):
            
            if P[p_pointer] == T[t_pointer]:
                p_pointer += 1
                t_pointer += 1 
                
                if p_pointer >= len(P):
                    result += [t_pointer-len(P)]
                    p_pointer = 0 if p_pointer == 0 else mapping[p_pointer-1]
                
            else:
                t_pointer += 1 if p_pointer == 0 else 0
                p_pointer = 0 if p_pointer == 0 else mapping[p_pointer-1]
             
        return result

I think the for loop in the partial() function is incorrect. There should be a while loop inside the for loop. A counter example for your partial function would be input pattern 'a a b a a a c'

@PMLP-novo
Copy link

Fails on
p5 = "ABC ABCDAB ABCDABCDABDE"
t5 = "ABCDABD"
assert (kmp.search(t5, p5) == [15])

@honglu2875
Copy link

Fails on p5 = "ABC ABCDAB ABCDABCDABDE" t5 = "ABCDABD" assert (kmp.search(t5, p5) == [15])

you need to switch t5 and p5 (t is the target string)...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment