Special Permutations - DP with bitmask
Problem Statement:
You are given a 0-indexed integer array nums
containing n
distinct positive integers. A permutation of nums
is called special if:
- For all indexes
0 <= i < n - 1
, eithernums[i] % nums[i+1] == 0
ornums[i+1] % nums[i] == 0
.
Return the total number of special permutations. As the answer could be large, return it modulo 109 + 7
.
Example 1:
Input: nums = [2,3,6] Output: 2 Explanation: [3,6,2] and [2,6,3] are the two special permutations of nums.
Example 2:
Input: nums = [1,4,3] Output: 2 Explanation: [3,1,4] and [4,1,3] are the two special permutations of nums.
Constraints:
2 <= nums.length <= 14
1 <= nums[i] <= 109
Solution:
Here is the baseline solution:
class Solution {
public:
int helper(vector<int>&nums, int cur, int mask, int n)
{
if (mask+1==(1<<n)) return 1;
int res=0;
for(int i=0; i<n; i++)
if (((mask&(1<<i))==0) && ((nums[i]%nums[cur]==0)||(nums[cur]%nums[i]==0)))
res += helper(nums, i, (mask|(1<<i)), n);
return res;
}
int specialPerm(vector<int>& nums)
{
int n = nums.size(), res=0;
for (int i=0; i<n; i++)res += helper(nums, i, (1<<i), n);
return res;
}
};
We start counting at each index i
and start with the bitmask (1<<i)
. Bitmask denotes the indexes which have been already included. So, this is like counting permutations starting from each index one by one.
When we have included every index, then the mask
reaches a state 11...1
and the answer for this state is 1. This signifies that we have successfully created a permutation.
Given any current index cur
and a mask
denoting incuded indexes, we have two conditions to include a new index i
in permutation:
- It must not have been already included. For this the condition is
(mask & (1<<i))==0
- The condition of mod with current element must be zero. For this we have
(nums[i]%nums[cur]==0) || (nums[cur]%nums[i]==0)
Finally let us add memoization and overflow steps to get AC.
class Solution {
vector<vector<int>> dp;
int mod=1e9+7;
public:
int helper(vector<int>&nums, int cur, int mask, int n)
{
if (dp[cur][mask]!=-1) return dp[cur][mask];
if (mask+1==(1<<n)) return dp[cur][mask]=1;
int res=0;
for(int i=0; i<n; i++)
{
// not included till now and condition is met
if (((mask&(1<<i))==0) && ((nums[i]%nums[cur]==0)||(nums[cur]%nums[i]==0)))
{
res += helper(nums, i, (mask|(1<<i)), n) % mod;
res %= mod;
}
}
return dp[cur][mask] = res;
}
int specialPerm(vector<int>& nums)
{
int n = nums.size(), res=0;
dp = vector<vector<int>>(n+1, vector<int>((1<<n)+1, -1));
for (int i=0; i<n; i++)
{
res += helper(nums, i, (1<<i), n) % mod;
res %= mod;
}
return res;
}
};
TC: $O(n * 2^n)$, SC: $O(n * 2^n)$