Skip to content

Instantly share code, notes, and snippets.

@jin-x
Last active June 6, 2021 09:50
Show Gist options
  • Save jin-x/e5337b40154d13cb0b37681e679518f1 to your computer and use it in GitHub Desktop.
Save jin-x/e5337b40154d13cb0b37681e679518f1 to your computer and use it in GitHub Desktop.
Magic Square Finder
#include <iostream>
#include <iomanip>
#include <array>
#include <algorithm>
#include <numeric>
#define SQUARE_SIZE 4
//#define PRINT_SQUARES
#define DEBUG
using std::cout;
using std::endl;
////////////////////////////////////////////////////////////////////////////////
template<size_t S>
class MagicSquareFinder
{
public:
using MagicSquare = std::array<int,S*S>;
void find(bool semimagic = false);
size_t count_magic = 0;
size_t count_semimagic = 0;
uint64_t iterations = 0;
protected:
int check(bool semimagic);
bool next_permutation_from(size_t n);
virtual void found(bool semimagic) {}
MagicSquare magic;
};
template<size_t S>
class PrintMagicSquares : public MagicSquareFinder<S>
{
public:
using MagicSquareFinder<S>::count_magic;
using MagicSquareFinder<S>::count_semimagic;
protected:
void found(bool semimagic) override;
using MagicSquareFinder<S>::magic;
};
// MagicSquareFinder class /////////////////////////////////////////////////////
template<size_t S>
void MagicSquareFinder<S>::find(bool semimagic)
{
count_magic = 0;
count_semimagic = 0;
iterations = 0;
std::iota(magic.begin(), magic.end(), 1);
int n;
do {
++iterations;
#ifdef DEBUG
if ((iterations & (1024*1024-1)) == 0) { cout << (iterations / (1024*1024)) << "M iters\r"; }
#endif
n = check(semimagic);
if (n <= 0) {
if (n != 0) { ++count_semimagic; } else { ++count_magic; }
found(n != 0);
n = S*S-1;
}
} while (next_permutation_from(n));
}
// Returns 0 for magic, -1 for semimagic, > 0 for improper
template<size_t S>
int MagicSquareFinder<S>::check(bool semimagic)
{
int correct_sum = S*(S*S+1)/2;
// Horizontal sums
for (size_t start = 0; start < S*S; start += S) {
int local_sum = 0;
for (size_t el = start; el < start+S; ++el) {
local_sum += magic[el];
if (local_sum > correct_sum) { return el; }
}
if (local_sum != correct_sum) { return start+(S-1); }
}
// Vertical sums
for (size_t start = 0; start < S; ++start) {
int local_sum = 0;
for (size_t el = start; el < S*S; el += S) {
local_sum += magic[el];
if (local_sum > correct_sum) { return el; }
}
if (local_sum != correct_sum) { return start+S*(S-1); }
}
// Diagonal sums
{
int local_sum = 0;
for (size_t el = S-1; el < S*S-1; el += S-1) {
local_sum += magic[el];
if (local_sum > correct_sum) { return semimagic ? -1 : el; }
}
if (local_sum != correct_sum) { return semimagic ? -1 : S*(S-1); }
}
{
int local_sum = 0;
for (size_t el = 0; el < S*S; el += S+1) {
local_sum += magic[el];
if (local_sum > correct_sum) { return semimagic ? -1 : el; }
}
if (local_sum != correct_sum) { return semimagic ? -1 : S*S-1; }
}
return 0;
}
template<size_t S>
bool MagicSquareFinder<S>::next_permutation_from(size_t n)
{
if (n < S*S-1) {
std::sort(std::next(magic.begin(), n+1), magic.end(), std::greater<int>());
}
return std::next_permutation(magic.begin(), magic.end());
}
// PrintMagicSquares class /////////////////////////////////////////////////////
template<size_t S>
void PrintMagicSquares<S>::found(bool semimagic)
{
int in_row = 0;
if (semimagic) {
cout << "Semimagic square #" << count_semimagic << '\n';
} else {
cout << "Magic square #" << count_magic << '\n';
}
for (const auto n : magic) {
cout << std::setw(3) << n;
if (++in_row == S) {
cout << "\n";
in_row = 0;
}
}
cout << endl;
}
// Main function ///////////////////////////////////////////////////////////////
int main()
{
#ifdef PRINT_SQUARES
PrintMagicSquares<SQUARE_SIZE> ms;
#else
MagicSquareFinder<SQUARE_SIZE> ms;
#endif
ms.find(true);
cout << "Total magic squares = " << ms.count_magic << ", semimagic = " << ms.count_semimagic << endl;
cout << "Number of iterations = " << ms.iterations << endl;
return 0;
}
#include <iostream>
#include <iomanip>
#include <array>
#include <algorithm>
#include <numeric>
#include <cstring>
#define SQUARE_SIZE 4
//#define PRINT_SQUARES
#define PRINT_LIMIT 4 // only if PRINT_SQUARESis defined
#define PRINT_FLAG_MASK (1<<1) // only if PRINT_SQUARESis defined
#define PRINT_FLAG_VALUE (1<<1) // only if PRINT_SQUARESis defined
#define NO_PRINT_SEMIMAGIC // only if PRINT_SQUARESis defined
#define PRINT_ITERS
//#define NO_PRINT_STAT
using std::cout;
using std::endl;
////////////////////////////////////////////////////////////////////////////////
template<size_t S>
class MagicSquareFinder
{
public:
using MagicSquare = std::array<int,S*S>;
void find(bool semimagic = false);
size_t count_magic = 0;
size_t count_semimagic = 0;
#if SQUARE_SIZE == 4
size_t xchk_flags = 0;
size_t count_xchk[(S-1)*(S-1)+5] {};
size_t count_xchkTB = 0, count_xchkLR = 0, count_xchkTBLR = 0, count_xchk1379 = 0, count_xchk2468 = 0, count_xchk_allminis = 0,
count_xchk10_13 = 0, count_xchk_allminis10_13 = 0, count_xchk_allWS = 0, count_xchk_all = 0;
#endif
uint64_t iterations = 0;
protected:
int check(bool semimagic); // return index of value to change or 0 for magic square, -1 for semimagic square
bool next_permutation_from(size_t n);
virtual bool found(bool semimagic) { return true; } // return true to continue
MagicSquare magic;
};
template<size_t S>
class PrintMagicSquares : public MagicSquareFinder<S>
{
public:
using MagicSquareFinder<S>::count_magic;
using MagicSquareFinder<S>::count_semimagic;
using MagicSquareFinder<S>::xchk_flags;
protected:
bool found(bool semimagic) override;
using MagicSquareFinder<S>::magic;
#ifdef PRINT_LIMIT
size_t print_count;
#endif
};
// MagicSquareFinder class /////////////////////////////////////////////////////
template<size_t S>
void MagicSquareFinder<S>::find(bool semimagic)
{
count_magic = 0;
count_semimagic = 0;
#if SQUARE_SIZE == 4
memset(count_xchk, 0, sizeof(count_xchk));
count_xchkTB = count_xchkLR = count_xchkTBLR = count_xchk1379 = count_xchk2468 = count_xchk_allminis = count_xchk10_13 = count_xchk_allminis10_13 = count_xchk_allWS = count_xchk_all = 0;
#endif
iterations = 0;
std::iota(magic.begin(), magic.end(), 1);
int n;
do {
++iterations;
#ifdef PRINT_ITERS
if ((iterations & (1024*1024-1)) == 0) { cout << (iterations / (1024*1024)) << "M iters\r"; }
#endif
n = check(semimagic);
if (n <= 0) {
if (n != 0) { ++count_semimagic; } else { ++count_magic; }
if (!found(n != 0)) { break; }
n = S*S-1;
}
} while (next_permutation_from(n));
}
// Returns 0 for magic, -1 for semimagic, > 0 for improper
template<size_t S>
int MagicSquareFinder<S>::check(bool semimagic)
{
static const int correct_sum = S*(S*S+1)/2;
#if SQUARE_SIZE == 4
xchk_flags = -1;
#endif
// Horizontal sums
for (size_t start = 0; start < S*S; start += S) {
int local_sum = 0;
for (size_t el = start; el < start+S; ++el) {
local_sum += magic[el];
if (local_sum > correct_sum) { return el; }
}
if (local_sum != correct_sum) { return start+(S-1); }
}
// Vertical sums
for (size_t start = 0; start < S; ++start) {
int local_sum = 0;
for (size_t el = start; el < S*S; el += S) {
local_sum += magic[el];
if (local_sum > correct_sum) { return el; }
}
if (local_sum != correct_sum) { return start+S*(S-1); }
}
// Diagonal sums
{
int local_sum = 0;
for (size_t el = S-1; el < S*S-1; el += S-1) {
local_sum += magic[el];
if (local_sum > correct_sum) { return semimagic ? -1 : el; }
}
if (local_sum != correct_sum) { return semimagic ? -1 : S*(S-1); }
}
{
int local_sum = 0;
for (size_t el = 0; el < S*S; el += S+1) {
local_sum += magic[el];
if (local_sum > correct_sum) { return semimagic ? -1 : el; }
}
if (local_sum != correct_sum) { return semimagic ? -1 : S*S-1; }
}
// Extra checks
#if SQUARE_SIZE == 4
xchk_flags = 0;
int bit = 1;
if (magic[0]+magic[S-1]+magic[S*(S-1)]+magic[S*S-1] != correct_sum) { xchk_flags |= bit; }
for (size_t i = 0; i < S-1; ++i) {
for (size_t j = 0; j < S-1; ++j) {
bit <<= 1;
if (magic[i+j*S]+magic[i+1+j*S]+magic[i+(j+1)*S]+magic[i+1+(j+1)*S] != correct_sum) { xchk_flags |= bit; }
}
}
for (size_t i = 0; i < 2; ++i) {
for (size_t j = 0; j < 2; ++j) {
bit <<= 1;
if (magic[i+j*S]+magic[i+2+j*S]+magic[i+(j+2)*S]+magic[i+2+(j+2)*S] != correct_sum) { xchk_flags |= bit; }
}
}
bit <<= 1;
if (magic[1]+magic[2]+magic[S*(S-1)+1]+magic[S*(S-1)+2] != correct_sum) { xchk_flags |= bit; }
bit <<= 1;
if (magic[S]+magic[S*2]+magic[S*2-1]+magic[S*3-1] != correct_sum) { xchk_flags |= bit; }
for (size_t x = 0; x < (S-1)*(S-1)+5; ++x) {
if ((xchk_flags & 1<<x) == 0) { ++count_xchk[x]; }
}
if ((xchk_flags & ((1<<1)+(1<<3)+(1<<7)+(1<<9))) == 0) { ++count_xchk1379; }
if ((xchk_flags & ((1<<2)+(1<<4)+(1<<6)+(1<<8))) == 0) { ++count_xchk2468; }
if ((xchk_flags & 0x3FE) == 0) { ++count_xchk_allminis; }
if ((xchk_flags & ((1<<10)+(1<<11)+(1<<12)+(1<<13))) == 0) { ++count_xchk10_13; }
if ((xchk_flags & 0x3FFE) == 0) { ++count_xchk_allminis10_13; }
if ((xchk_flags & (1<<14)) == 0) { ++count_xchkTB; }
if ((xchk_flags & (1<<15)) == 0) { ++count_xchkLR; }
if ((xchk_flags & ((1<<14)+(1<<15))) == 0) { ++count_xchkTBLR; }
if ((xchk_flags & ~((1<<10)+(1<<11)+(1<<12)+(1<<13))) == 0) { ++count_xchk_allWS; }
if (xchk_flags == 0) { ++count_xchk_all; }
#endif
return 0;
}
template<size_t S>
bool MagicSquareFinder<S>::next_permutation_from(size_t n)
{
if (n < S*S-1) {
std::sort(std::next(magic.begin(), n+1), magic.end(), std::greater<int>());
}
return std::next_permutation(magic.begin(), magic.end());
}
// PrintMagicSquares class /////////////////////////////////////////////////////
template<size_t S>
bool PrintMagicSquares<S>::found(bool semimagic)
{
#ifdef PRINT_LIMIT
#if PRINT_LIMIT <= 0
return false;
#endif
if (count_magic + count_semimagic == 1) { print_count = 0; }
#endif
#ifdef NO_PRINT_SEMIMAGIC
if (semimagic) { return true; }
#endif
#ifdef PRINT_FLAG_MASK
if ((xchk_flags & (PRINT_FLAG_MASK)) != (PRINT_FLAG_VALUE)) { return true; }
#endif
int in_row = 0;
if (semimagic) {
cout << "Semimagic square #" << count_semimagic << '\n';
} else {
cout << "Magic square #" << count_magic << '\n';
}
for (const auto n : magic) {
cout << std::setw(3) << n;
if (++in_row == S) {
cout << "\n";
in_row = 0;
}
}
cout << endl;
#ifdef PRINT_LIMIT
return (++print_count) < PRINT_LIMIT;
#else
return true;
#endif
}
// Main function ///////////////////////////////////////////////////////////////
int main()
{
#ifdef PRINT_SQUARES
PrintMagicSquares<SQUARE_SIZE> ms;
#else
MagicSquareFinder<SQUARE_SIZE> ms;
#endif
ms.find(true);
cout << "Total magic squares = " << ms.count_magic << ", semimagic = " << ms.count_semimagic << endl;
cout << "Number of iterations = " << ms.iterations << endl;
#if !defined(NO_PRINT_STAT) && SQUARE_SIZE == 4
cout << "Extra checks 2x2:" << endl;
cout << "- corners: " << ms.count_xchk[0] << endl;
for (size_t x = 1; x < 10; ++x) {
cout << "- mini-square #" << x << ": " << ms.count_xchk[x] << endl;
}
cout << "- mini-squares #1+3+7+9: " << ms.count_xchk1379 << endl;
cout << "- mini-squares #2+4+6+8: " << ms.count_xchk2468 << endl;
cout << "- all mini-squares: " << ms.count_xchk_allminis << endl;
for (size_t x = 10; x < 14; ++x) {
cout << "- sparsed mini-square #" << x-9 << ": " << ms.count_xchk[x] << endl;
}
cout << "- all sparsed mini-squares: " << ms.count_xchk10_13 << endl;
cout << "- all mini-squares including sparsed: " << ms.count_xchk_allminis10_13 << endl;
cout << "- top2 + bottom2: " << ms.count_xchkTB << endl;
cout << "- left2 + right2: " << ms.count_xchkLR << endl;
cout << "- top2+bottom2+left2+right2: " << ms.count_xchkTBLR << endl;
cout << "- all extra checks together without sparsed mini-squares: " << ms.count_xchk_allWS << endl;
cout << "- all extra checks together: " << ms.count_xchk_all << endl;
#endif
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment