Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save SuryaPratapK/be626e34da8f973b4e97403fd488b226 to your computer and use it in GitHub Desktop.
Save SuryaPratapK/be626e34da8f973b4e97403fd488b226 to your computer and use it in GitHub Desktop.
class Solution {
using ll = long long;
int mod=1e9+7;
int tot_ways_to_permute;
vector<int> fact;
vector<int> inverse_fact;
vector<int> freq;
int mem[10][40+2][42*9];//Digits:[0,9]...Length:[2,80]...MaxSumOfDigits:[80*9]
int modMul(int& a,int& b){
return ((ll)(a%mod)*(b%mod))%mod;
}
void computeFactorial(int n){
fact[0]=1;
for(int i=1;i<=n;++i)
fact[i] = modMul(fact[i-1],i);
}
int binaryExponentiation(int a,int b){
int res=1;
while(b>0){
if(b&1)
res = modMul(res,a);
a = modMul(a,a);
b /=2;
}
return res;
}
void computeInverseFactorial(int n){
for(int i=0;i<=n;++i)
inverse_fact[i] = binaryExponentiation(fact[i],mod-2);
}
//Solution
int countPermutation(int digit,int leftover,int target){
if(digit==10){
if(leftover==0) return target==0?tot_ways_to_permute:0;
else return 0;
}
if(mem[digit][leftover][target]!=-1)
return mem[digit][leftover][target];
int include_count = min({leftover,freq[digit],digit>0?target/digit:INT_MAX});
ll ans=0;
for(int i=0;i<=include_count;++i){
ll ways_for_current_digit = modMul(inverse_fact[i],inverse_fact[freq[digit]-i]);
ans += (ways_for_current_digit * countPermutation(digit+1,leftover-i,target-digit*i))%mod;
ans = ans%mod;
}
return mem[digit][leftover][target] = ans;
}
public:
int countBalancedPermutations(string num) {
int n=num.size();
int sum=0;
freq = vector<int>(10);
for(int i=0;num[i]!='\0';++i){
sum+=num[i]-'0';
freq[num[i]-'0']++;
}
if(sum&1) return 0;
int target = sum/2;
fact = vector<int>(n+1);
computeFactorial(n);
inverse_fact = vector<int>(n+1);
computeInverseFactorial(n);
tot_ways_to_permute = modMul(fact[floor((double)n/2)],fact[ceil((double)n/2)]);//Overcounts for duplicates
memset(mem,-1,sizeof(mem));
return countPermutation(0,n/2,target);
}
};
/*
//JAVA
import java.util.Arrays;
class Solution {
private static final int MOD = (int)1e9 + 7;
private int totWaysToPermute;
private int[] fact;
private int[] inverseFact;
private int[] freq;
private int[][][] mem;
private int modMul(int a, int b) {
return (int)(((long)(a % MOD) * (b % MOD)) % MOD);
}
private void computeFactorial(int n) {
fact[0] = 1;
for (int i = 1; i <= n; ++i) {
fact[i] = modMul(fact[i - 1], i);
}
}
private int binaryExponentiation(int a, int b) {
int res = 1;
while (b > 0) {
if ((b & 1) == 1) {
res = modMul(res, a);
}
a = modMul(a, a);
b >>= 1;
}
return res;
}
private void computeInverseFactorial(int n) {
for (int i = 0; i <= n; ++i) {
inverseFact[i] = binaryExponentiation(fact[i], MOD - 2);
}
}
private int countPermutation(int digit, int leftover, int target) {
if (digit == 10) {
return (leftover == 0 && target == 0) ? totWaysToPermute : 0;
}
if (mem[digit][leftover][target] != -1) {
return mem[digit][leftover][target];
}
int includeCount = Math.min(leftover, freq[digit]);
if (digit > 0) {
includeCount = Math.min(includeCount, target / digit);
}
long ans = 0;
for (int i = 0; i <= includeCount; ++i) {
long waysForCurrentDigit = modMul(inverseFact[i], inverseFact[freq[digit] - i]);
ans += (waysForCurrentDigit * countPermutation(digit + 1, leftover - i, target - digit * i)) % MOD;
ans %= MOD;
}
return mem[digit][leftover][target] = (int)ans;
}
public int countBalancedPermutations(String num) {
int n = num.length();
int sum = 0;
freq = new int[10];
for (int i = 0; i < n; ++i) {
int digit = num.charAt(i) - '0';
sum += digit;
freq[digit]++;
}
if ((sum & 1) == 1) {
return 0;
}
int target = sum / 2;
fact = new int[n + 1];
computeFactorial(n);
inverseFact = new int[n + 1];
computeInverseFactorial(n);
int halfLen = n / 2;
totWaysToPermute = modMul(fact[halfLen], fact[n - halfLen]);
mem = new int[10][halfLen + 1][42 * 9 + 1];
for (int i = 0; i < 10; ++i) {
for (int j = 0; j <= halfLen; ++j) {
Arrays.fill(mem[i][j], -1);
}
}
return countPermutation(0, halfLen, target);
}
}
#Python
import math
class Solution:
MOD = 10**9 + 7
def modMul(self, a, b):
return (a % self.MOD) * (b % self.MOD) % self.MOD
def computeFactorial(self, n):
self.fact = [1] * (n + 1)
for i in range(1, n + 1):
self.fact[i] = self.modMul(self.fact[i - 1], i)
def binaryExponentiation(self, a, b):
res = 1
while b > 0:
if b & 1:
res = self.modMul(res, a)
a = self.modMul(a, a)
b >>= 1
return res
def computeInverseFactorial(self, n):
self.inverse_fact = [1] * (n + 1)
for i in range(n + 1):
self.inverse_fact[i] = self.binaryExponentiation(self.fact[i], self.MOD - 2)
def countPermutation(self, digit, leftover, target):
if digit == 10:
return self.tot_ways_to_permute if (leftover == 0 and target == 0) else 0
if self.mem[digit][leftover][target] != -1:
return self.mem[digit][leftover][target]
include_count = min(leftover, self.freq[digit])
if digit > 0:
include_count = min(include_count, target // digit)
ans = 0
for i in range(include_count + 1):
ways_for_current_digit = self.modMul(self.inverse_fact[i], self.inverse_fact[self.freq[digit] - i])
ans += ways_for_current_digit * self.countPermutation(digit + 1, leftover - i, target - digit * i)
ans %= self.MOD
self.mem[digit][leftover][target] = ans
return ans
def countBalancedPermutations(self, num: str) -> int:
n = len(num)
sum_digits = 0
self.freq = [0] * 10
for ch in num:
digit = int(ch)
sum_digits += digit
self.freq[digit] += 1
if sum_digits % 2 == 1:
return 0
target = sum_digits // 2
self.computeFactorial(n)
self.computeInverseFactorial(n)
half_len = n // 2
self.tot_ways_to_permute = self.modMul(self.fact[half_len], self.fact[n - half_len])
# Initialize memoization table
max_sum = 42 * 9 # As per the C++ code's comment
self.mem = [[[-1] * (max_sum + 1) for _ in range(half_len + 1)] for _ in range(10)]
return self.countPermutation(0, half_len, target)
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment