Skip to content

Instantly share code, notes, and snippets.

#include <iostream>
#include <vector>
#include <fstream>
#include <array>
#include <algorithm>
#include <cstdlib>
#include <chrono>
using namespace std;
1r2r1k1/2qn1pp1/1npp1bbp/p3p3/P2PP1PP/B1P2N2/2BN1P2/1R1QR1K1 w - - 1 22
1r4k1/5p2/1P1p4/2nP2bp/1R6/3B4/5BP1/6K1 w - - 3 41
8/8/2p1N1kp/6p1/1r1BK3/8/7P/8 b - - 5 63
4r1n1/p2q1ppk/1p3n2/3p1PBp/2rPp2P/P4PPB/1P1RQ3/3R2K1 b - - 4 24
8/4k3/1N2p3/p4p2/3Pp3/1P2n1P1/P2bN3/6K1 b - - 3 42
r3q1k1/5p2/6p1/pR1p1pPp/P2Pr2P/6Q1/1PP5/1K1R4 b - - 4 33
2rr2k1/1b1q1pb1/1Q2p1p1/p3B2p/P1R2P2/1P1P2P1/5RBP/6K1 b - - 0 23
3r3k/4N1pp/8/p1n2p1P/8/2R3P1/6K1/8 b - - 1 47
1q1r2k1/5pp1/1p2p2p/2bnPn2/4NP2/1P4P1/1B2Q1BP/R6K b - - 6 34
b1rr4/3p1pk1/p2Pp1p1/P1p1N1pn/1pP5/1P3P1P/2BR2P1/4RK2 b - - 2 38
import torch
def permute_ft_output(nnue, permutation):
l1_size = nnue.layer_stacks.l1.in_features
assert l1_size == len(permutation)*2
permutation.extend([x + l1_size // 2 for x in permutation])
ft_permutation = permutation + list(range(l1_size, nnue.input.num_outputs))
nnue.input.weight.data = nnue.input.weight.data[:, ft_permutation]