'Maximum number of subarrays having equal sum
I was asked this question in one of my interview. Given an array of integers (with both positive and negative values) we need to find the maximum number of disjoint subarrays having equal sum.
Example :
Input : [1, 2, 3] Output : 2 {since we have at most 2 subarrays with sum = 3 i.e. [1, 2],[3]}
Input: [2 2 2 -2] Output : 2 {two subarrays each with sum = 2 i.e. [2],[2, 2, -2]}
My Approach
First approach that came to my mind was to find the prefix sum array and then taking each element (prefix[i]) as target finding the number of subarrays with sum = prefix[i] .
But this failed in the case of negative numbers. How to handle the negative cases?
EDIT Complete array must be covered with these subarrays. That is why in example 2 we get 2 as output and not 3 ([2],[2],[2]).
Solution 1:[1]
First, assuming that they can be non-consecutive elements
Probably there is not efficient algorithm (if so, P=NP), then, simply check all possible combinations (partitions).
If partitions is a function returning all partitions given a set, then your problem is solved with:
static Optional<List<List<Integer>>> maximumSubarraysEqSum(List<Integer> xs) {
return
// all indexes partitions
partitions(IntStream.range(0, xs.size()).mapToObj(x -> x).collect(toList()))
// with all groups with same sum
.stream().filter(zs -> zs.stream().mapToInt(ys -> ys.stream().mapToInt(xs::get).sum()).distinct().count() == 1L)
// sorting by size desc
.sorted((zs, ys) -> Integer.compare(ys.size(), zs.size()))
// first (or all with same size if you want, ...)
.findFirst()
// map indexes to values
.map(zs -> zs.stream().map(ys -> ys.stream().map(xs::get).collect(toList())).collect(toList()));
}
running
public static void main(String... args) {
System.out.println(" " + maximumSubarraysEqSum(List.of(1, 2, 3)));
System.out.println(" " + maximumSubarraysEqSum(List.of(2, 2, 2, -2)));
System.out.println(" " + maximumSubarraysEqSum(List.of(3, 4)));
}
the output is
Optional[[[1, 2], [3]]]
Optional[[[2, 2, -2], [2]]]
Optional[[[3, 4]]]
Second, assuming they must be consecutive elements
The first subarray set what sum is possible, then for every possible "first sum" we check the number of groups
static int maxSubarrayEqSum(List<Integer> xs) {
// possible sizes for the first subarray
return IntStream.range(1, xs.size() + 1)
// with that sum count (if exists) how many subarrays match
.map(sz -> countMaxSubArrayEqSum(0, xs, xs.stream().limit(sz).mapToInt(x -> x).sum()))
// get the maximum number
.max()
.getAsInt();
}
counting the maximum subarrays is get the first possible sum
static int countMaxSubArrayEqSum(int from, List<Integer> xs, int sz) {
int acc = 0;
for(int i = from; i < xs.size(); i++) {
acc += xs.get(i);
if(acc == sz) {
if(i == xs.size() - 1)
return 1;
int count = countMaxSubArrayEqSum(i + 1, xs, sz);
if(count > 0)
return count + 1;
}
}
return 0;
}
Optimal solution for consecutive sequences O(n^2)
Instead backtracking, we can get the maxSplit value (maximum of groups) checking if xs can be divided by |xs|, |xs|-1, |xs|-2, ..., 1 groups.
static int maxSplit(int [] xs) {
// sum all
int S = 0;
for(int i = 0; i < xs.length; i++)
S += xs[i];
// try every possible number of groups
for(int i = xs.length; i > 1; i--)
if(S % i == 0 && maxSplitFit(xs, i, S / i))
return i;
return 1;
}
xs can be divided into k groups if we can use a voracious strategy
static boolean maxSplitFit(int [] xs, int k, int s) {
int groupsCount = 1;
int acc = 0;
// look for (not final) groups summing S/k
int i = 0;
while(i < xs.length - 1 && groupsCount < k) {
acc += xs[i];
if ((s == 0 && acc == 0) // divide by 0
|| (s != 0 && s * groupsCount == acc)) // S/k
groupsCount++;
i++;
}
// if there are not enough groups then it is not possible
if(groupsCount != k)
return false;
// the last group must contains all remaining elements
while(i < xs.length)
acc += xs[i++];
// S/k for every group
return s * k == acc;
}
Aside
A possible function to get all partitions could be
static <T> List<List<List<T>>> partitions(List<T> set) {
// last element
T element = set.get(set.size() - 1);
if (set.size() == 1)
return singletonList(singletonList(singletonList(element)));
List<List<List<T>>> current = new ArrayList<>();
// it could be in any of previous groups or in a new one
// add on every previous
for (List<List<T>> p : partitions(set.subList(0, set.size() - 1))) {
// every candidate group to place
for (int i = 0; i < p.size(); i++) {
List<List<T>> b = new ArrayList<>(p);
b.add(i, new ArrayList<>(b.remove(i)));
b.get(i).add(element);
current.add(b);
}
// add singleton
List<List<T>> b = new ArrayList<>(p);
b.add(singletonList(element));
current.add(b);
}
return current;
}
Solution 2:[2]
Update: OP has clarified that the entire array must be partitioned into complete subarrays. As such, the first half of this post constitutes a full answer, but the second half may still be useful as a soltuion to a generalization of this question.
If we're using the whole array, the problem can be solved in O(n) time, which is optimal. We get to use the fact that splitting A into k subarrays means all subarrays will have a sum of sum(A)/k. After computing the array sum, we can solve the problem in a single pass through the array, using just prefix sums and a hashmap.
If we're not using the whole array, the problem is harder. There's an O(n^2) solution where you compute all subarray sums, sort equal-sum subarrays by right endpoint, and use greedy interval scheduling. I have no idea if this is optimal, but I'm interested to see if there's a better solution.
Full array partition case
The precise problem definition here is to find the length of the longest sequence of points [x_0, x_1, ... x_m] such that 0 <= x_0 < x_1 < ... < x_m < length(A) such that sum(A[0:x_0]) == sum(A[x_0:x_1]) == ... == sum(A[x_m:length(A)]).
Suppose we've computed the sum of the array to be S. Then, we can split A into k parts of equal sum if and only if (k divides S and if we see S/k, 2S/k, ... , kS/k as a subsequence of the prefix sums of A). One easy way to do that is to keep a running sum: if our running sum r divides S, then save 2S/r in a hashmap as a 'sum we're searching for'. If our current sum is one we've been searching for, save the next number from that subsequence as a 'sum we're searching for', unless we've reached the end of that sequence.
For example, say that S is 32. Then, A can be partitioned into 8 equal sum subarrays if and only if 4, 8, 12, 16, 20, 24, 28 appears as a subsequence of the prefix sums of A (32 will always be present at the end). So as soon as we see 4 as a prefix sum, we check that it divides S, and then save 8 to our search set for prefix sums. We also keep a helper dictionary, mapping 8 to 4, so that after finding 8 as a prefix sum, we know that 8+4 is the next prefix sum to look for.
Python code:
def best_full_partition(nums: List[int]) -> int:
"""Given a list of integers (positive or negative) 'nums',
return the maximum number of disjoint equal sum subarrays we can
partition nums into (using all elements)"""
total = sum(nums)
# Special case where total sum is 0
if total == 0:
# Count the number of times 0 is a partial sum
answer = 0
current_sum = 0
for x in nums:
current_sum += x
if current_sum == 0:
answer += 1
return answer
best_found = 1
# Prefix sums we're trying to find; all share common factor with total
looking_for = set()
""" Map from prefix sums to the common factor/original prefix sum.
There may be several: e.g. total/6 and total/3 may be targets
of 2*total/3. """
looking_to_original_sums = collections.defaultdict(set)
current_sum = 0
for x in nums:
current_sum += x
if current_sum == 0:
continue
if current_sum in looking_for:
for original_sum in looking_to_original_sums[current_sum]:
new_target = current_sum + original_sum
# If we've found all matches in this chain
if new_target == total:
best_found = max(best_found, total // original_sum)
continue
looking_for.add(new_target)
looking_to_original_sums[new_target].add(original_sum)
looking_to_original_sums.pop(current_sum)
looking_for.discard(current_sum)
# Check if current sum is a divisor of full array sum
if total % current_sum == 0:
# If this splits array in half by sum, we've reached its end
if 2 * current_sum == total:
best_found = max(best_found, 2)
else:
# Add the next multiple of this sum to our search set
looking_for.add(2 * current_sum)
looking_to_original_sums[2 * current_sum].add(current_sum)
return best_found
This takes O(n) time, which is optimal, and O(n) space.
Partial array partition
This case is harder, because there's fewer conditions on what the valid subarray sums can be. The trick is to just compute the sum of all subarrays. We make a hashmap, mapping each sum to the index bounds of its subarray, so sum(A[L, L+1, ... R]) maps to [L, R]. Since there are duplicates, we keep a list of all intervals which produced that sum, and we generate that list to be sorted by the right endpoint.
Now, we can use earliest deadline first scheduling, aka greedy scheduling, to find the maximum number of intervals we can take from that list without overlap. Both steps take quadratic time. It may be possible to improve this, but I have no ideas for how to do so.
Python code:
def best_partial_partition(nums: List[int]) -> int:
"""Given a list of integers (positive or negative) 'nums',
return the maximum number of disjoint equal sum subarrays we can
create from nums (using all elements is not required)"""
n = len(nums)
best_found = 1
# For each subarray sum, stores a list of all subarrays with that sum
# Sorted by right endpoint, both ends inclusive
sum_to_intervals = collections.defaultdict(list)
for right_end in range(n):
curr_sum = 0
for left_end in reversed(range(right_end+1)):
curr_sum += nums[left_end]
sum_to_intervals[curr_sum].append([left_end, right_end])
# Use greedy interval scheduling to get most intervals
for interval_list in sum_to_intervals.values():
# Can skip if we know no improvement is possible here
if len(interval_list) <= best_found:
continue
curr_len = 0
curr_right_end = -1
for left, right in interval_list:
if left > curr_right_end:
curr_len += 1
curr_right_end = right
best_found = max(best_found, curr_len)
return best_found
This runs in O(n^2) time.
Edit: Fixed a bug in full partition solver when the array sum was 0; thanks to @josejuan for pointing this out. This case needs to be treated separately to avoid dividing by zero.
Solution 3:[3]
I assume the whole array needs to be covered with sub-arrays.
After thinking about it for some time I am pretty confident, that the following O(n^2) algorithm should be valid:
public static int maxSplitCount(int[] a) {
int res = 0; // highest sub-array count found so far
int sum = 0; // sum within each sub-array
for (int i = 0; i != a.length; ++i) { // iterate over any possible front sub-array...
sum += a[i]; // ...and determine its element-sum
int arrayCount = 1; // count the sub-arrays with this element-sum
int currentSum = 0; // sum in the currently traversed sub-array
for (int j = i + 1; j != a.length; ++j) {
currentSum += a[j];
// if sum matches, we found the end of the currently traversed sub-array
if (currentSum == sum) {
arrayCount += 1;
currentSum = 0;
}
}
// tricky part (see below)
if (currentSum > 0 || currentSum % sum != 0) {
continue;
}
arrayCount += currentSum / sum;
if (arrayCount > res) {
res = arrayCount;
}
}
return res;
}
Up until the "tricky part", this should be the naive implementation you'd probably come up with on your first attempt. There is no issue with having negative numbers per se. They will simply reduce currentSum, requiring later elements to again increase it appropriately until eventually - by collecting enough elements - currentSum == sum. At some point (beyond the initialization) currentSum may also be 0 or even something negative. This is no problem either. Particularly, currentSum being 0 only indicates that we're currently traversing through a sub-array that itself has a leading 0-sum sub-array.
The only situation a problem occurs is when we have remaining elements at the end of the array, that together did not form a whole sub-array with element-sum sum - i. e. currentSum is not 0 after the last inner loop iteration. There starts the "tricky part" where we have to distinguish three cases. The remaining elements' sum (i. e. currentSum) is
- positive. However, apparently the remaining elements themselves could not be split in a way that yielded sub-arrays with element-sum
sum(otherwise that would have happened in the just finished innerfor-loop already). - negative with
currentSum % sum != 0. Since all elements up to the remaining elements in total have a multiple ofsumas the element-sum, the entire array does not have such a multiple as total element-sum. - negative with
currentSum % sum == 0.
Each for their individual reason, 1) and 2) render cases in which we cannot have an appropriate partition with the current value of sum at all. In 3) however, we have some negative "overhead" that we can perfectly even-out with some (one or more) of the previously found sub-arrays with element-sum sum each, in total forming a 0-sum sub-array. That sub-array can then be simply seen as appended to the sub-array in front of that (again with element-sum sum), building a joint sub-array still of size sum.
Let's take your second example and look at when we're in outer iteration i = 0 and have just left the inner for-loop. We have arrayCount = 3 (the three individual [2]-arrays), and currentSum = -2, now we can join the "remaining elements" (only the -2) with the last [2] array to form a 0-sum-array and append that to the second-last [2] array. We have in total lost one array to the merge that we previously counted in arrayCount. This is taken into account by arrayCount += currentSum / sum; (note that at this point currentSum is non-positive so the division is non-positive and += is actually correctly decrementing arrayCount). Note that after that line arrayCount could even be negative (or 0) if the remaining elements had a higher negative sum - like -10 (or -6) - but that only tells us, that the whole array's total element-sum had the opposite sign of the current value of sum (or is 0), which again tells us that there is no appropriate partition. In that case however, arrayCount > res just fails and we disregard that iteration.
If you're particularly attentive, you might have wondered about what happens when in some outer iteration sum = 0. This would currently cause an exception but you can easily add a workaround just before the "tricky part":
if (sum == 0) {
if (currentSum == 0 && arrayCount > res) {
res = arrayCount;
}
continue;
}
or something equivalent. Indeed, in that case, we are in fact not able to even-out the negative "overhead" with any number of previous sum-sum arrays (since sum is 0), since - just like in case 2) above - currentSum is not a multiple of sum.
And since we're already thinking of sum = 0, what's about sum < 0? To also handle that case properly, we only have to change currentSum > 0 to currentSum != 0 && ((currentSum > 0) == (sum > 0))1 since the rest of the code already is sign-unaware regarding currentSum and sum. Go figure!
1: Certainly, at this point we can simplify some logic expressions and if-statements when also regarding the sum = 0 and sum > 0 case.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | |
| Solution 2 | |
| Solution 3 | Reizo |
