Skip to content

Instantly share code, notes, and snippets.

@Ifihan
Created April 15, 2025 22:38
Show Gist options
  • Save Ifihan/41d1b123ae8dc324b979b2a40c54be39 to your computer and use it in GitHub Desktop.
Save Ifihan/41d1b123ae8dc324b979b2a40c54be39 to your computer and use it in GitHub Desktop.
Count Good Triplets in an Array

Question

Approach

I first mapped every number in nums2 to its index, so I could easily find the position of any number. Then, I created a new array by replacing each value in nums1 with its index from nums2. This way, I transformed the problem into counting the number of increasing triplets in this new array.

Next, I needed an efficient way to count how many smaller elements appeared before each index, and how many larger elements appeared after. To do that, I used a Fenwick Tree (or Binary Indexed Tree).

I scanned the array from left to right to count the number of smaller elements before each position, and then from right to left to count the number of greater elements after each position. Finally, for each index acting as the "middle" of the triplet, I multiplied the number of smaller elements before it with the number of greater elements after it, and summed all these products to get the total number of good triplets.

Implementation

class FenwickTree:
    def __init__(self, size):
        self.tree = [0] * (size + 2)

    def update(self, index, value):
        index += 1
        while index < len(self.tree):
            self.tree[index] += value
            index += index & -index

    def query(self, index):
        index += 1
        result = 0
        while index:
            result += self.tree[index]
            index -= index & -index
        return result

    def query_range(self, left, right):
        return self.query(right) - self.query(left - 1)


class Solution:
    def goodTriplets(self, nums1: List[int], nums2: List[int]) -> int:
        n = len(nums1)
        pos_in_nums2 = [0] * n
        for i, val in enumerate(nums2):
            pos_in_nums2[val] = i

        pos = [pos_in_nums2[val] for val in nums1]

        left_tree = FenwickTree(n)
        left_smaller = [0] * n
        for i in range(n):
            left_smaller[i] = left_tree.query(pos[i] - 1)
            left_tree.update(pos[i], 1)

        right_tree = FenwickTree(n)
        right_greater = [0] * n
        for i in reversed(range(n)):
            right_greater[i] = right_tree.query(n - 1) - right_tree.query(pos[i])
            right_tree.update(pos[i], 1)

        result = sum(left_smaller[i] * right_greater[i] for i in range(n))
        return result

Complexities

  • Time: O(n logn)
  • Space: O(n)
image
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment