LeetCode 18. 4Sum
Given an array S of n integers, are there elements a, b, c, and d in S such that a + b + c + d = target? Find all unique quadruplets in the array which gives the sum of target.
Note: The solution set must not contain duplicate quadruplets.
Example
For example, given array S = [1, 0, -1, 0, -2, 2], and target = 0.
A solution set is:
[
[-1, 0, 0, 1],
[-2, -1, 1, 2],
[-2, 0, 0, 2]
]
Solution
Hash Table
// By Hash table:
// 1. Store all two-sum pair in a hash table, where key is sum,
// value is a vector of all index pairs.
// 2. Iterate all two-sum pair and try finding their
// complement(0 - sum of two elements) in the hash table.
// Time Complexity:
// - Store all two-sum pair in a hash table: O(n^2)
// - Iterate all two-elements pair: O(n^2)
// - Find their complement in the hash table: O(1)
// - Form quartets by complement pairs:
// - worst case: nearly O(n^2) pairs for one two-sum value
// - Check if this quartet is counted (if complement is found)
// - by red-black BST(std::set): O(log n^4)
// - by hash table(std::unordered_map): O(1)
// => worst case: O(n^2) + O(n^2 * (n^2 + log n^4)) = O(n^4)
// or O(n^2) + O(n^2 * (n^2 + 1)) = O(n^4)
vector<vector<int>> fourSum(vector<int>& nums, int target) {
vector<vector<int>> tuples;
if (nums.size() < 4) {
return tuples;
}
vector<int> values = removeRedundancy(nums, target);
unordered_map<int, vector<pair<int, int>>> twoSumPairs =
getTwoSumPairs(values);
tuples = getQuartet(values, twoSumPairs, target);
return tuples;
}
vector<int> removeRedundancy(vector<int>& nums, int target) {
vector<int> trimmed;
unordered_map<int, int> valueCount;
for (int n : nums) {
int count = valueCount[n]; // 0 if not found
if ((count == 3 && 4*n != target) || count > 3) {
continue;
}
valueCount[n] = count + 1;
trimmed.push_back(n);
}
return trimmed;
}
unordered_map<int, vector<pair<int, int>>> getTwoSumPairs(vector<int>& values) {
unordered_map<int, vector<pair<int, int>>> twoSumPairs;
for (int i = 0 ; i < values.size() ; ++i) {
for (int j = i + 1 ; j < values.size() ; ++j) {
int sum = values[i] + values[j];
twoSumPairs[sum].push_back(make_pair(i, j));
}
}
return twoSumPairs;
}
vector<vector<int>> getQuartet(vector<int>& values,
unordered_map<int, vector<pair<int, int>>>& twoSumPairs,
int target) {
vector<vector<int>> quartets;
// Use set to avoid duplicate quartet.
set<vector<int>> quarSet;
for (const auto& it : twoSumPairs) {
int sum = it.first;
auto complementIt = twoSumPairs.find(target - sum);
if (complementIt == twoSumPairs.end()) {
continue;
}
auto& pairs = it.second;
vector<pair<int, int>>& complementaryPairs = complementIt->second;
for (pair<int, int> p : pairs) { // index pair: a, b and a < b
for (pair<int, int> cp : complementaryPairs) { // index pair: c, d and c < d
// If quartet si formed by Indices { p, q, r, s } and p < q < r < s,
// we may get {{p, q}, {r, s}}, {{r, s}, {p, q}}
// or {{p, r}, {q, s}}, {{q, s}, {p, r}}
// or {{p, s}, {q, r}}, {{q, r}, {p, s}}, but they are all same.
// We only handle {{p, q}, {r, s}} or {{r, s}, {p, q}} here.
// Therefore, we will have two duplicate quartet.
// To avoid duplicate, we sort them into { p, q, r, s } and store them
// into a set. The set can be transformed to a vector later.
if (p.second >= cp.first) { // overlap.
continue;
}
vector<int> quartet({values[p.first], values[p.second],
values[cp.first], values[cp.second]});
sort(quartet.begin(), quartet.end());
quarSet.insert(quartet);
}
}
}
for (vector<int> quartet : quarSet) {
quartets.push_back(quartet);
}
return quartets;
}
Two Pointers
// By leveraging the approach of LeetCode 15. 3Sum and
// LeetCode 167. Two Sum II - Input array is sorted:
//
// {a, b, c, d} such that a + b + c + d = target
// = a + { b, c, d }, where b + c + d = target - a
//
// { b, c, d } such that b + c + d = target - a
// = b + { c, d }, where c + d = target - a - b
//
// 1. Sort the array
// 2. Iterate all possible a and find (b, c, d) tuples
// such that b + c + d = target - a
// - Iterate all possible b and find (c, d) tuples such that
// c + d = target - a - b
// - Search c, d from both sides
//
// Time Complexity:
// - Sort the array: O(n log n)
// - Iterate all possible a: O(n)
// - Iterate all possible b: O(n)
// - Search matched (c, d) from both sides: O(n)
// => Total: O(n log n) + O(n * n * n) = O(n^3)
vector<vector<int>> fourSum(vector<int>& nums, int target) {
vector<vector<int>> tuples;
if (nums.size() < 4) {
return tuples;
}
sort(nums.begin(), nums.end());
for (int smallest = 0 ; smallest < nums.size() - 3 ; ++smallest) {
// Skip the duplicate smallest values.
if (smallest > 0 && nums[smallest] == nums[smallest - 1]) {
continue;
}
int targetForThreeSum = target - nums[smallest];
threeSum(nums, targetForThreeSum, smallest + 1, tuples, nums[smallest]);
}
return tuples;
}
void threeSum(vector<int>& nums, int target, int start,
vector<vector<int>>& tuples, int prefix) {
assert(start > 0);
assert(prefix <= nums[start]);
if (nums.size() - start < 3) {
return;
}
for (int s = start ; s < nums.size() - 2 ; ++s) {
// Skip the duplicate values.
if (s > start && nums[s] == nums[s - 1]) {
continue;
}
int targetForTwoSum = target - nums[s];
twoSum(nums, targetForTwoSum, s + 1, tuples, {prefix, nums[s]});
}
}
void twoSum(vector<int>& nums, int target, int start,
vector<vector<int>>& tuples, vector<int> prefix) {
assert(start > 0);
assert(prefix.back() <= nums[start]);
if (nums.size() - start < 2) {
return;
}
int left = start;
int right = nums.size() - 1;
while (left < right) {
int sum = nums[left] + nums[right];
if (sum == target) {
vector<int> tuple = prefix;
tuple.push_back(nums[left]);
tuple.push_back(nums[right]);
tuples.push_back(tuple);
++left;
--right;
while (left < right && nums[left] == nums[left - 1]) {
++left;
}
while (left < right && nums[right] == nums[right + 1]) {
--right;
}
continue;
}
sum < target ? ++left : --right;
}
}
Improvement
vector<vector<int>> fourSum(vector<int>& nums, int target) {
vector<vector<int>> tuples;
if (nums.size() < 4) {
return tuples;
}
sort(nums.begin(), nums.end());
int min = nums.front();
int max = nums.back();
// Do nothing if target is too small or too large.
if (target < 4 * min || target > 4 * max) {
return tuples;
}
for (int smallest = 0 ; smallest < nums.size() - 3 ; ++smallest) {
int k = nums[smallest];
// Skip the duplicate k.
if (smallest > 0 && k == nums[smallest - 1]) {
continue;
}
// Skip k if it's too small to form a tuple.
if (k + 3 * max < target) {
continue;
}
// Stop looping after k becomes too large.
if (4 * k >= target) {
if (4 * k == target && k == nums[smallest + 3]) {
tuples.push_back({k, k, k, k});
}
break;
}
int targetForThreeSum = target - nums[smallest];
threeSum(nums, targetForThreeSum, smallest + 1, tuples, nums[smallest]);
}
return tuples;
}
void threeSum(vector<int>& nums, int target, int start,
vector<vector<int>>& tuples, int prefix) {
assert(start > 0);
assert(prefix <= nums[start]);
if (nums.size() - start < 3) {
return;
}
int min = nums[start];
int max = nums.back();
// Do nothing if target is too small or too large.
if (target < 3 * min || target > 3 * max) {
return;
}
for (int s = start ; s < nums.size() - 2 ; ++s) {
int k = nums[s];
// Skip the duplicate k.
if (s > start && k == nums[s - 1]) {
continue;
}
// Skip k if it's too small to form a tuple.
if (k + 2 * max < target) {
continue;
}
// Stop looping after k becomes too large.
if (3 * k >= target) {
if (3 * k == target && k == nums[s + 2]) {
tuples.push_back({prefix, k, k, k});
}
break;
}
int targetForTwoSum = target - nums[s];
twoSum(nums, targetForTwoSum, s + 1, tuples, {prefix, nums[s]});
}
}
void twoSum(vector<int>& nums, int target, int start,
vector<vector<int>>& tuples, vector<int> prefix) {
assert(start > 0);
assert(prefix.back() <= nums[start]);
if (nums.size() - start < 2) {
return;
}
int min = nums[start];
int max = nums.back();
// Do nothing if target is too small or too large.
if (target < 2 * min || target > 2 * max) {
return;
}
int left = start;
int right = nums.size() - 1;
while (left < right) {
int sum = nums[left] + nums[right];
if (sum == target) {
vector<int> tuple = prefix;
tuple.push_back(nums[left]);
tuple.push_back(nums[right]);
tuples.push_back(tuple);
++left;
--right;
while (left < right && nums[left] == nums[left - 1]) {
++left;
}
while (left < right && nums[right] == nums[right + 1]) {
--right;
}
continue;
}
sum < target ? ++left : --right;
}
}