/* Compile with -znow */
#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <fcntl.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/wait.h>
#include <elf.h>
#include <libgen.h>
#include <signal.h>
#include <syscall.h>
#include <termios.h>

Elf64_Addr ldbase;

Elf64_Addr search_section(int fd, char* section);

void load(char* path, Elf64_Addr rebase)
{
    int fd;
    Elf64_Ehdr ehdr;
    Elf64_Phdr* phdr;
    uint16_t phnum;
    Elf64_Addr bss;
    uint64_t flen;
    Elf64_Addr highest = 0;

    if((fd = open(path, O_RDONLY)) < 0)
    {
        perror("Error in open()");
        exit(0);
    }
    read(fd, &ehdr, sizeof(ehdr));
    phnum = ehdr.e_phnum;
    phdr = malloc(sizeof(*phdr) * phnum);
    pread(fd, phdr, sizeof(*phdr) * phnum, ehdr.e_phoff);
    flen = lseek(fd, 0, SEEK_END);
    bss = search_section(fd, ".bss");

    for(int i = 0; i < phnum; ++i)
    {
        if(phdr[i].p_type != PT_LOAD) continue;

        uint32_t   flags   = phdr[i].p_flags;
        Elf64_Off  offset  = phdr[i].p_offset;
        Elf64_Addr vaddr   = phdr[i].p_vaddr;
        size_t     filesz  = phdr[i].p_filesz;
        size_t     memsz   = phdr[i].p_memsz;
        void*      aligned = (void*) (vaddr & (~0xfff));

        uint32_t prot = ((flags & PF_R) ? PROT_READ  : 0) |
                        ((flags & PF_W) ? PROT_WRITE : 0) |
                        ((flags & PF_X) ? PROT_EXEC  : 0);

        filesz += vaddr - (Elf64_Addr) aligned;
        memsz  += vaddr - (Elf64_Addr) aligned;
        offset -= vaddr - (Elf64_Addr) aligned;
        size_t _filesz = (filesz + 0xfff) & ~0xfff;

        mmap(rebase + aligned, filesz, prot, MAP_PRIVATE | MAP_FIXED, fd, offset);
        if(memsz > _filesz)
        {
            void* extra = rebase + aligned + _filesz;
            mmap(extra, memsz - _filesz, prot, MAP_PRIVATE | MAP_FIXED | MAP_ANON, -1, 0);
        }

        if(bss != 0 && (bss >= vaddr && bss < (vaddr + filesz)))
        {
            size_t bss_size = _filesz - (bss - (Elf64_Addr) aligned);
            memset((void*) rebase + bss, '\0', bss_size);
        }
    }
    close(fd);
}

Elf64_Addr search_section(int fd, char* section)
{
    Elf64_Ehdr ehdr;
    Elf64_Shdr* shdr;
    uint16_t shnum;
    uint16_t shstrndx;
    char* shstrtab;

    pread(fd, &ehdr, sizeof(ehdr), 0);
    shnum = ehdr.e_shnum;
    shdr = malloc(sizeof(*shdr) * shnum);
    shstrndx = ehdr.e_shstrndx;
    pread(fd, shdr, sizeof(*shdr) * shnum, ehdr.e_shoff);

    shstrtab = malloc(shdr[shstrndx].sh_size);
    pread(fd, shstrtab, shdr[shstrndx].sh_size, shdr[shstrndx].sh_offset);

    for(int i = 0; i < shnum; ++i)
        if(!strcmp(&shstrtab[shdr[i].sh_name], section))
        {
            free(shstrtab);
            free(shdr);
            return shdr[i].sh_addr;
        }

    free(shstrtab);
    free(shdr);
    return 0;
}

void* ld_addr()
{
    FILE* f = fopen("/proc/self/maps", "rb");
    char buf[1024];
    void* p;
    while(fgets(buf, sizeof buf, f))
    {
        if(strncmp(basename(strchr(buf, '/')), "ld", 2)) continue;
        sscanf(buf, "%lx", &p);
        fclose(f);
        return p;
    }
    fclose(f);
    return NULL;
}

char* search_path(char* cmd)
{
    if(*cmd == '/') return cmd;
    char* dup, * path, * p;
    char* filepath;
    dup = path = p = strdup(getenv("PATH"));
    struct stat buf;
    do
    {
        p = strchr(p, ':');
        if(p != NULL)
            *p++ = '\0';

        filepath = malloc(strlen(path) + strlen(cmd) + 1 + 1);
        strcpy(filepath, path);
        strcat(filepath, "/");
        strcat(filepath, cmd);
        if(fstatat(AT_FDCWD, filepath, &buf, 0) == 0)
        {
            free(dup);
            return filepath;
        }

        free(filepath);
        path = p;
    }
    while(p != NULL);

    free(dup);
    return NULL;
}

__attribute__((noreturn))
void run(int argc, char** argv, int readfd, int writefd)
{
    Elf64_Addr base = 0x400000;
    uint64_t ldentry, entry, phnum, phentsize, phaddr;
    uint64_t auxv[8 * 2];
    char* stack;
    void** sp;

    load(argv[0], base);

    ldentry   = ((Elf64_Ehdr*) ldbase)->e_entry + ldbase;
    entry     = ((Elf64_Ehdr*)   base)->e_entry + base;
    phnum     = ((Elf64_Ehdr*)   base)->e_phnum;
    phentsize = ((Elf64_Ehdr*)   base)->e_phentsize;
    phaddr    = ((Elf64_Ehdr*)   base)->e_phoff + base;

    stack = mmap(NULL, 0x21000, PROT_READ | PROT_WRITE,
                 MAP_ANONYMOUS | MAP_PRIVATE | MAP_STACK, -1, 0);
    sp = (void**) &stack[0x21000];
    *--sp = NULL; // End of stack

    if(argc & 1)
        *--sp = NULL; // Keep stack aligned
    auxv[ 0] = 0x06; auxv[ 1] = 0x1000;    // AT_PAGESZ
    auxv[ 2] = 0x19; auxv[ 3] = ldentry;   // AT_RANDOM (whatever)
    auxv[ 4] = 0x09; auxv[ 5] = entry;     // AT_ENTRY
    auxv[ 6] = 0x07; auxv[ 7] = ldbase;    // AT_BASE
    auxv[ 8] = 0x05; auxv[ 9] = phnum;     // AT_PHNUM
    auxv[10] = 0x04; auxv[11] = phentsize; // AT_PHENT
    auxv[12] = 0x03; auxv[13] = phaddr;    // AT_PHDR
    auxv[14] =    0; auxv[15] = 0;         // End of auxv
    sp -= sizeof(auxv) / sizeof(*auxv); memcpy(sp, auxv, sizeof(auxv));
    *--sp = NULL; // End of envp
    *--sp = NULL; // End of argv
    sp -= argc; memcpy(sp, argv, argc * 8);
    *(size_t*) --sp = argc;

    if(readfd >= 0)
    {
        dup2(readfd, fileno(stdin));
        close(readfd);
    }
    if(writefd >= 0)
    {
        dup2(writefd, fileno(stdout));
        close(writefd);
    }

    #if defined(__x86_64__)
    asm volatile("mov %0, %%rsp;"
                 "jmp *%1;"
                 : : "r"(sp), "r"(ldentry));
    #elif defined(__aarch64__)
    asm volatile("mov sp, %0;"
                 "br  %1;"
                 : : "r"(sp), "r"(ldentry) : "x0");
    #endif
    __builtin_unreachable();
}

void runline(char* cmd)
{
    int* argc;
    char*** argv, * saveptr = NULL;
    int count;
    int pipefds[2], writefd, readfd = -1;

    argc = NULL;
    argv = NULL;
    strtok_r(cmd, "|", &saveptr);
    count = 0;
    do
    {
        argv = realloc(argv, (count + 1) * sizeof(*argv));
        argc = realloc(argc, (count + 1) * sizeof(*argc));
        cmd += strspn(cmd, " ");
        argv[count] = NULL;
        argc[count] = 0;

        strtok(cmd, " ");
        do
        {
            argv[count] = realloc(argv[count], ++argc[count] * sizeof(**argv));
            argv[count][argc[count] - 1] = cmd;
        }
        while(cmd = strtok(NULL, " "));

        char* filepath;
        if((filepath = search_path(argv[count][0])) == NULL)
        {
            perror("Error in fstatat");
            free(argv[count]);
            count--;
            continue;
        }
        argv[count][0] = filepath;
        count++;
    }
    while(cmd = strtok_r(NULL, "|", &saveptr));

    for(int i = 0; i < count; ++i)
    {
        if(i > 0)
            readfd = pipefds[0];
        if(i < (count - 1))
        {
            pipe(pipefds);
            writefd = pipefds[1];
        }

        if(syscall(SYS_clone, SIGCHLD, NULL, NULL, NULL, NULL) == 0)
            run(argc[i], argv[i], readfd, writefd);

        if(i < (count - 1))
            close(writefd);
        writefd = -1;
        if(i > 0)
            close(readfd);

        free(argv[i]);
    }

    for(int i = 0; i < count; ++i)
        wait(NULL);
}

int main()
{
    char interp[128];
    int self;

    char buf[1024];

    ldbase = (Elf64_Addr) ld_addr();
    self = open("/proc/self/exe", O_RDONLY);
    interp[pread(self, interp, sizeof(interp) - 1,
                 search_section(self, ".interp"))] = '\0';
    close(self);
    load(interp, ldbase);

    int terminal = isatty(fileno(stdin));
    while(1)
    {
        if(terminal)
            printf("> ");

        fgets(buf, sizeof buf, stdin);
        if(feof(stdin)) break;
        buf[strcspn(buf, "\n")] = '\0';
        if(strlen(buf) == 0) continue;
        runline(buf);
        wait(NULL);
    }
    puts("\nexit");
    _exit(0);
}