'Maximum number of subarrays having equal sum

I was asked this question in one of my interview. Given an array of integers (with both positive and negative values) we need to find the maximum number of disjoint subarrays having equal sum.

Example :

Input : [1, 2, 3] Output : 2 {since we have at most 2 subarrays with sum = 3 i.e. [1, 2],[3]}

Input: [2 2 2 -2] Output : 2 {two subarrays each with sum = 2 i.e. [2],[2, 2, -2]}

My Approach

First approach that came to my mind was to find the prefix sum array and then taking each element (prefix[i]) as target finding the number of subarrays with sum = prefix[i] .

But this failed in the case of negative numbers. How to handle the negative cases?

EDIT Complete array must be covered with these subarrays. That is why in example 2 we get 2 as output and not 3 ([2],[2],[2]).



Solution 1:[1]

First, assuming that they can be non-consecutive elements

Probably there is not efficient algorithm (if so, P=NP), then, simply check all possible combinations (partitions).

If partitions is a function returning all partitions given a set, then your problem is solved with:

static Optional<List<List<Integer>>> maximumSubarraysEqSum(List<Integer> xs) {
  return
    // all indexes partitions
    partitions(IntStream.range(0, xs.size()).mapToObj(x -> x).collect(toList()))
    // with all groups with same sum
    .stream().filter(zs -> zs.stream().mapToInt(ys -> ys.stream().mapToInt(xs::get).sum()).distinct().count() == 1L)
    // sorting by size desc
    .sorted((zs, ys) -> Integer.compare(ys.size(), zs.size()))
    // first (or all with same size if you want, ...)
    .findFirst()
    // map indexes to values
    .map(zs -> zs.stream().map(ys -> ys.stream().map(xs::get).collect(toList())).collect(toList()));
}

running

public static void main(String... args) {
    System.out.println("    " + maximumSubarraysEqSum(List.of(1, 2, 3)));
    System.out.println("    " + maximumSubarraysEqSum(List.of(2, 2, 2, -2)));
    System.out.println("    " + maximumSubarraysEqSum(List.of(3, 4)));
}

the output is

Optional[[[1, 2], [3]]]
Optional[[[2, 2, -2], [2]]]
Optional[[[3, 4]]]

Second, assuming they must be consecutive elements

The first subarray set what sum is possible, then for every possible "first sum" we check the number of groups

static int maxSubarrayEqSum(List<Integer> xs) {
    // possible sizes for the first subarray
    return IntStream.range(1, xs.size() + 1)
            // with that sum count (if exists) how many subarrays match
            .map(sz -> countMaxSubArrayEqSum(0, xs, xs.stream().limit(sz).mapToInt(x -> x).sum()))
            // get the maximum number
            .max()
            .getAsInt();
}

counting the maximum subarrays is get the first possible sum

static int countMaxSubArrayEqSum(int from, List<Integer> xs, int sz) {
    int acc = 0;
    for(int i = from; i < xs.size(); i++) {
        acc += xs.get(i);
        if(acc == sz) {
            if(i == xs.size() - 1)
                return 1;
            int count = countMaxSubArrayEqSum(i + 1, xs, sz);
            if(count > 0)
                return count + 1;
        }
    }
    return 0;
}

Optimal solution for consecutive sequences O(n^2)

Instead backtracking, we can get the maxSplit value (maximum of groups) checking if xs can be divided by |xs|, |xs|-1, |xs|-2, ..., 1 groups.

static int maxSplit(int [] xs) {

    // sum all
    int S = 0;
    for(int i = 0; i < xs.length; i++)
        S += xs[i];

    // try every possible number of groups
    for(int i = xs.length; i > 1; i--)
        if(S % i == 0 && maxSplitFit(xs, i, S / i))
            return i;

    return 1;
}

xs can be divided into k groups if we can use a voracious strategy

static boolean maxSplitFit(int [] xs, int k, int s) {

    int groupsCount = 1;
    int acc = 0;

    // look for (not final) groups summing S/k
    int i = 0;
    while(i < xs.length - 1 && groupsCount < k) {
        acc += xs[i];
        if ((s == 0 && acc == 0) // divide by 0
         || (s != 0 && s * groupsCount == acc)) // S/k
            groupsCount++;
        i++;
    }

    // if there are not enough groups then it is not possible
    if(groupsCount != k)
        return false;

    // the last group must contains all remaining elements
    while(i < xs.length)
        acc += xs[i++];

    // S/k for every group
    return s * k == acc;
}

Aside

A possible function to get all partitions could be

static <T> List<List<List<T>>> partitions(List<T> set) {
    // last element
    T element = set.get(set.size() - 1);

    if (set.size() == 1)
        return singletonList(singletonList(singletonList(element)));

    List<List<List<T>>> current = new ArrayList<>();

    // it could be in any of previous groups or in a new one

    // add on every previous
    for (List<List<T>> p : partitions(set.subList(0, set.size() - 1))) {
        // every candidate group to place
        for (int i = 0; i < p.size(); i++) {
            List<List<T>> b = new ArrayList<>(p);
            b.add(i, new ArrayList<>(b.remove(i)));
            b.get(i).add(element);
            current.add(b);
        }
        // add singleton
        List<List<T>> b = new ArrayList<>(p);
        b.add(singletonList(element));
        current.add(b);
    }

    return current;
}

Solution 2:[2]

Update: OP has clarified that the entire array must be partitioned into complete subarrays. As such, the first half of this post constitutes a full answer, but the second half may still be useful as a soltuion to a generalization of this question.


If we're using the whole array, the problem can be solved in O(n) time, which is optimal. We get to use the fact that splitting A into k subarrays means all subarrays will have a sum of sum(A)/k. After computing the array sum, we can solve the problem in a single pass through the array, using just prefix sums and a hashmap.

If we're not using the whole array, the problem is harder. There's an O(n^2) solution where you compute all subarray sums, sort equal-sum subarrays by right endpoint, and use greedy interval scheduling. I have no idea if this is optimal, but I'm interested to see if there's a better solution.


Full array partition case

The precise problem definition here is to find the length of the longest sequence of points [x_0, x_1, ... x_m] such that 0 <= x_0 < x_1 < ... < x_m < length(A) such that sum(A[0:x_0]) == sum(A[x_0:x_1]) == ... == sum(A[x_m:length(A)]).

Suppose we've computed the sum of the array to be S. Then, we can split A into k parts of equal sum if and only if (k divides S and if we see S/k, 2S/k, ... , kS/k as a subsequence of the prefix sums of A). One easy way to do that is to keep a running sum: if our running sum r divides S, then save 2S/r in a hashmap as a 'sum we're searching for'. If our current sum is one we've been searching for, save the next number from that subsequence as a 'sum we're searching for', unless we've reached the end of that sequence.

For example, say that S is 32. Then, A can be partitioned into 8 equal sum subarrays if and only if 4, 8, 12, 16, 20, 24, 28 appears as a subsequence of the prefix sums of A (32 will always be present at the end). So as soon as we see 4 as a prefix sum, we check that it divides S, and then save 8 to our search set for prefix sums. We also keep a helper dictionary, mapping 8 to 4, so that after finding 8 as a prefix sum, we know that 8+4 is the next prefix sum to look for.

Python code:

def best_full_partition(nums: List[int]) -> int:
    """Given a list of integers (positive or negative) 'nums',
    return the maximum number of disjoint equal sum subarrays we can
    partition nums into (using all elements)"""

    total = sum(nums)

    # Special case where total sum is 0
    if total == 0:
        # Count the number of times 0 is a partial sum
        answer = 0
        current_sum = 0
        for x in nums:
            current_sum += x
            if current_sum == 0:
                answer += 1
        return answer

    best_found = 1

    # Prefix sums we're trying to find; all share common factor with total
    looking_for = set()

    """ Map from prefix sums to the common factor/original prefix sum.
        There may be several: e.g. total/6 and total/3 may be targets
        of 2*total/3. """
    looking_to_original_sums = collections.defaultdict(set)

    current_sum = 0

    for x in nums:
        current_sum += x
        if current_sum == 0:
            continue

        if current_sum in looking_for:
            for original_sum in looking_to_original_sums[current_sum]:
                new_target = current_sum + original_sum

                # If we've found all matches in this chain
                if new_target == total:
                    best_found = max(best_found, total // original_sum)
                    continue

                looking_for.add(new_target)
                looking_to_original_sums[new_target].add(original_sum)

            looking_to_original_sums.pop(current_sum)
            looking_for.discard(current_sum)

        # Check if current sum is a divisor of full array sum
        if total % current_sum == 0:
            # If this splits array in half by sum, we've reached its end
            if 2 * current_sum == total:
                best_found = max(best_found, 2)
            else:
                # Add the next multiple of this sum to our search set
                looking_for.add(2 * current_sum)
                looking_to_original_sums[2 * current_sum].add(current_sum)

    return best_found

This takes O(n) time, which is optimal, and O(n) space.


Partial array partition

This case is harder, because there's fewer conditions on what the valid subarray sums can be. The trick is to just compute the sum of all subarrays. We make a hashmap, mapping each sum to the index bounds of its subarray, so sum(A[L, L+1, ... R]) maps to [L, R]. Since there are duplicates, we keep a list of all intervals which produced that sum, and we generate that list to be sorted by the right endpoint.

Now, we can use earliest deadline first scheduling, aka greedy scheduling, to find the maximum number of intervals we can take from that list without overlap. Both steps take quadratic time. It may be possible to improve this, but I have no ideas for how to do so.

Python code:

def best_partial_partition(nums: List[int]) -> int:
    """Given a list of integers (positive or negative) 'nums',
    return the maximum number of disjoint equal sum subarrays we can
    create from nums (using all elements is not required)"""

    n = len(nums)

    best_found = 1

    # For each subarray sum, stores a list of all subarrays with that sum
    # Sorted by right endpoint, both ends inclusive
    sum_to_intervals = collections.defaultdict(list)

    for right_end in range(n):
        curr_sum = 0
        for left_end in reversed(range(right_end+1)):
            curr_sum += nums[left_end]
            sum_to_intervals[curr_sum].append([left_end, right_end])

    # Use greedy interval scheduling to get most intervals

    for interval_list in sum_to_intervals.values():
        # Can skip if we know no improvement is possible here
        if len(interval_list) <= best_found:
            continue

        curr_len = 0
        curr_right_end = -1
        for left, right in interval_list:
            if left > curr_right_end:
                curr_len += 1
                curr_right_end = right

        best_found = max(best_found, curr_len)

    return best_found

This runs in O(n^2) time.

Edit: Fixed a bug in full partition solver when the array sum was 0; thanks to @josejuan for pointing this out. This case needs to be treated separately to avoid dividing by zero.

Solution 3:[3]

I assume the whole array needs to be covered with sub-arrays.

After thinking about it for some time I am pretty confident, that the following O(n^2) algorithm should be valid:

public static int maxSplitCount(int[] a) {

    int res = 0; // highest sub-array count found so far

    int sum = 0; // sum within each sub-array
    for (int i = 0; i != a.length; ++i) { // iterate over any possible front sub-array...
        sum += a[i]; // ...and determine its element-sum

        int arrayCount = 1; // count the sub-arrays with this element-sum
        int currentSum = 0; // sum in the currently traversed sub-array
        for (int j = i + 1; j != a.length; ++j) {
            currentSum += a[j];

            // if sum matches, we found the end of the currently traversed sub-array
            if (currentSum == sum) {
                arrayCount += 1;
                currentSum = 0;
            }
        }

        // tricky part (see below)
        if (currentSum > 0 || currentSum % sum != 0) {
            continue;
        }

        arrayCount += currentSum / sum;
        if (arrayCount > res) {
            res = arrayCount;
        }
    }

    return res;
}

Up until the "tricky part", this should be the naive implementation you'd probably come up with on your first attempt. There is no issue with having negative numbers per se. They will simply reduce currentSum, requiring later elements to again increase it appropriately until eventually - by collecting enough elements - currentSum == sum. At some point (beyond the initialization) currentSum may also be 0 or even something negative. This is no problem either. Particularly, currentSum being 0 only indicates that we're currently traversing through a sub-array that itself has a leading 0-sum sub-array.

The only situation a problem occurs is when we have remaining elements at the end of the array, that together did not form a whole sub-array with element-sum sum - i. e. currentSum is not 0 after the last inner loop iteration. There starts the "tricky part" where we have to distinguish three cases. The remaining elements' sum (i. e. currentSum) is

  1. positive. However, apparently the remaining elements themselves could not be split in a way that yielded sub-arrays with element-sum sum (otherwise that would have happened in the just finished inner for-loop already).
  2. negative with currentSum % sum != 0. Since all elements up to the remaining elements in total have a multiple of sum as the element-sum, the entire array does not have such a multiple as total element-sum.
  3. negative with currentSum % sum == 0.

Each for their individual reason, 1) and 2) render cases in which we cannot have an appropriate partition with the current value of sum at all. In 3) however, we have some negative "overhead" that we can perfectly even-out with some (one or more) of the previously found sub-arrays with element-sum sum each, in total forming a 0-sum sub-array. That sub-array can then be simply seen as appended to the sub-array in front of that (again with element-sum sum), building a joint sub-array still of size sum.

Let's take your second example and look at when we're in outer iteration i = 0 and have just left the inner for-loop. We have arrayCount = 3 (the three individual [2]-arrays), and currentSum = -2, now we can join the "remaining elements" (only the -2) with the last [2] array to form a 0-sum-array and append that to the second-last [2] array. We have in total lost one array to the merge that we previously counted in arrayCount. This is taken into account by arrayCount += currentSum / sum; (note that at this point currentSum is non-positive so the division is non-positive and += is actually correctly decrementing arrayCount). Note that after that line arrayCount could even be negative (or 0) if the remaining elements had a higher negative sum - like -10 (or -6) - but that only tells us, that the whole array's total element-sum had the opposite sign of the current value of sum (or is 0), which again tells us that there is no appropriate partition. In that case however, arrayCount > res just fails and we disregard that iteration.


If you're particularly attentive, you might have wondered about what happens when in some outer iteration sum = 0. This would currently cause an exception but you can easily add a workaround just before the "tricky part":

if (sum == 0) {
    if (currentSum == 0 && arrayCount > res) {
        res = arrayCount;
    }
    continue;
}

or something equivalent. Indeed, in that case, we are in fact not able to even-out the negative "overhead" with any number of previous sum-sum arrays (since sum is 0), since - just like in case 2) above - currentSum is not a multiple of sum.

And since we're already thinking of sum = 0, what's about sum < 0? To also handle that case properly, we only have to change currentSum > 0 to currentSum != 0 && ((currentSum > 0) == (sum > 0))1 since the rest of the code already is sign-unaware regarding currentSum and sum. Go figure!


1: Certainly, at this point we can simplify some logic expressions and if-statements when also regarding the sum = 0 and sum > 0 case.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2
Solution 3 Reizo