
[LeetCode 1442] Count Triplets That Can Form Two Arrays of Equal XOR


The Question

Given an array of integers arr. We want to count number of indices i, j and k where (0 <= i < j <= k < arr.length) such as arr[i] ^ arr[i + 1] ^ ... ^ arr[j - 1] is equal to arr[j] ^ arr[j + 1] ^ ... ^ arr[k].

O(n^4)-time solution [TLE]

Just doing brute-force to enumerate i, j and k, and another loop to check the xor values in both subarrays.

class Solution:
    def countTriplets(self, arr: List[int]) -> int:
        n = len(arr)
        ans = 0
        for i in range(n):
            for j in range(i+1, n):
                for k in range(j, n):
                    a = b = 0
                    for l in range(i, j):
                        a ^= arr[l]
                    for l in range(j, k+1):
                        b ^= arr[l]
                    ans += 1 if a == b else 0
        return ans

O(n^3)-time solution [AC]

We can optimize the above solution by preprocessing all the prefix-xor values, then the xor value for any subarray can be computed in O(1) time.

class Solution:
    def countTriplets(self, arr: List[int]) -> int:
        n = len(arr)
        ans = 0
        prefix = list(arr)
        for i in range(1, n):
            prefix[i] ^= prefix[i - 1]
        for i in range(n):
            for j in range(i+1, n):
                for k in range(j, n):
                    a = prefix[j - 1] ^ prefix[i] ^ arr[i]
                    b = prefix[k] ^ prefix[j - 1]
                    ans += 1 if a == b else 0
        return ans

O(n^2)-time solution [AC]

Instead of doing nested for-loops for all i, j and k, we can enumerate j first to determine the boundary of the two subarrays. Then we can enumerate (say) all the left boundary i of the left subarray and count different xor values using a dictionary. Finally we can enumerate the right boundary k of the right subarray and count all the xor values that appeared before.

from collections import Counter

class Solution:
    def countTriplets(self, arr: List[int]) -> int:
        n = len(arr)
        ans = 0
        for j in range(1, n):
            xor = 0
            counter = Counter()
            for i in range(j - 1, -1, -1):
                xor ^= arr[i]
                counter[xor] += 1
            xor = 0
            for k in range(j, n):
                xor ^= arr[k]
                ans += counter[xor]
        return ans

Yet another O(n^2)-time solution [AC]

The fact that arr[i] ^ arr[i + 1] ^ ... ^ arr[j - 1] is equal to arr[j] ^ arr[j + 1] ^ ... ^ arr[k] implies that arr[i] ^ ... ^ a[k] = 0. On the other hand, if arr[i] ^ ... ^ a[k] = 0, we can produce exactly k - i unique triplets. Therefore, we can enumerate all the subarrays (of length >= 2) that have xor value equal to 0 and do the counting accordingly.

class Solution:
    def countTriplets(self, arr: List[int]) -> int:
        n = len(arr)
        ans = 0
        xor = 0
        for i in range(n):
            xor = arr[i]
            for j in range(i + 1, n):
                xor ^= arr[j]
                if xor == 0:
                    ans += j - i
        return ans

O(n)-time solution [AC, Optimal]

We can further optimize the runtime of the above solution to O(n). Note that arr[i] ^ ... ^ a[k] = 0 is equivalent to arr[0] ^ ... ^ arr[i-1] = arr[0] ^ ... ^ arr[k]. Therefore, # of unique triplets when k is fixed is sum_s(k - s - 1) for all 0 <= s < k such that arr[0] ^ ... ^ arr[s] = arr[0] ^ ... ^ arr[k]. We can rearrange sum_s(k - s - 1) as sum_s(k-1) - sum_s(s). It follows that we just need to maintain (using two dictionaries) two values (count and index sum) for each distinct prefix xor while doing the prefix scan.

from collections import Counter

class Solution:
    def countTriplets(self, arr: List[int]) -> int:
        ans = 0
        # Add a dummy entry for xor = 0 with index = -1 and cnt = 0.
        xor_index_sum = Counter({0:-1})
        xor_cnt = Counter({0:1})
        xor = 0
        for i in range(len(arr)):
            xor = xor ^ arr[i]
            if xor_cnt[xor] > 0:
                ans += xor_cnt[xor] * (i - 1) - xor_index_sum[xor]
            xor_index_sum[xor] += i
            xor_cnt[xor] += 1
        return ans
