3 minute read

Problem Statement

2044

Intuition

The problem requires counting subsets of an array such that the bitwise OR of the subset equals the maximum bitwise OR possible for the entire array. The first intuition is to calculate the maximum bitwise OR for the full array and then search for all subsets that match this value.

Approach

  1. Calculate Maximum Bitwise OR:
    First, we calculate the maximum bitwise OR (max_bitwise_or) of the entire array. This will serve as the benchmark for the subsets. The calc helper function takes an array and computes the cumulative bitwise OR for its elements.

  2. Subset Exploration:
    We use a backtracking approach to explore all possible subsets of the given array. For each subset, we compute its bitwise OR and compare it with the max_bitwise_or. If they match, we increment a result counter (res).

  3. Backtracking Algorithm:
    The recursive function find_subsets starts from an index i and generates all possible subsets by including or excluding the current element. At each recursive step, if the subset’s bitwise OR equals max_bitwise_or, we increment the counter.

  4. Edge Case Handling:
    The algorithm handles edge cases such as an empty array or subsets that do not match the bitwise OR by simply skipping them.

Complexity

  • Time complexity:
    The algorithm generates all subsets of the array, and for each subset, it calculates the bitwise OR. The total number of subsets is ( 2^n ), and for each subset, calculating the bitwise OR takes ( O(n) ) time in the worst case.
    Therefore, the overall time complexity is:
    \(O(n \cdot 2^n)\)

  • Space complexity:
    The space complexity is determined by the recursive call stack depth and the space needed for the current subset. In the worst case, the depth of recursion and the size of the current subset can be ( O(n) ), so the space complexity is:
    \(O(n)\)

Code

class Solution:
    def countMaxOrSubsets(self, nums: List[int]) -> int:
        def calc(arr):
            ret = 0
            for num in arr:
                ret |= num
            return ret

        N = len(nums)
        res = 0
        max_bitwise_or = calc(nums)

        def find_subsets(i, curr):
            nonlocal res
            if i > N:
                return
            if calc(curr) == max_bitwise_or:
                res += 1

            for j in range(i, N):
                curr.append(nums[j])
                find_subsets(j + 1, curr)
                curr.pop()

        find_subsets(0, [])
        return res

Editorial

Approach 1: Recursion

class Solution:
    def countMaxOrSubsets(self, nums: List[int]) -> int:
        max_or_value = 0
        for num in nums:
            max_or_value |= num
        return self._count_subsets(nums, 0, 0, max_or_value)

    def _count_subsets(
        self, nums: List[int], index: int, current_or: int, target_or: int
    ) -> int:
        # Base case: reached the end of the array
        if index == len(nums):
            return 1 if current_or == target_or else 0

        # Don't include the current number
        count_without = self._count_subsets(
            nums, index + 1, current_or, target_or
        )

        # Include the current number
        count_with = self._count_subsets(
            nums, index + 1, current_or | nums[index], target_or
        )

        # Return the sum of both cases
        return count_without + count_with

Approach 2: Memoization

class Solution:
    def countMaxOrSubsets(self, nums: List[int]) -> int:
        max_or_value = 0
        n = len(nums)

        # Calculate the maximum OR value
        for num in nums:
            max_or_value |= num

        # Initialize memo with -1
        memo = [[-1] * (max_or_value + 1) for _ in range(n)]

        return self._count_subsets_recursive(nums, 0, 0, max_or_value, memo)

    def _count_subsets_recursive(
        self,
        nums: List[int],
        index: int,
        current_or: int,
        target_or: int,
        memo: List[List[int]],
    ) -> int:
        # Base case: reached the end of the array
        if index == len(nums):
            return 1 if current_or == target_or else 0

        # Check if the result for this state is already memoized
        if memo[index][current_or] != -1:
            return memo[index][current_or]

        # Don't include the current number
        count_without = self._count_subsets_recursive(
            nums, index + 1, current_or, target_or, memo
        )

        # Include the current number
        count_with = self._count_subsets_recursive(
            nums, index + 1, current_or | nums[index], target_or, memo
        )

        # Memoize and return the result
        memo[index][current_or] = count_without + count_with
        return memo[index][current_or]