/* This file is Copyright (C) 2019 Jason Pepas. */
/* This file is released under the terms of the MIT License. */
/* See https://opensource.org/licenses/MIT */

#include "printer.h"
#include <assert.h>
#include <sys/errno.h>
#include <stdlib.h>
#include <string.h>


/* Prints the Symbol in symp into fp.
Returns 0 or errno. */
static int print_symbol(Symbol* symp, FILE* fp) {
    int err = fprintf(fp, "%s", symp->valuep);
    if (err < 0) {
        return err;
    } else {
        return 0;
    }
}


/* Prints the CLong in lp into fp.
Returns 0 or errno. */
static int print_clong(CLong* lp, FILE* fp) {
    int err = fprintf(fp, "%li", lp->value);
    if (err < 0) {
        return err;
    } else {
        return 0;
    }
}


/* Prints the CDouble in dp into fp.
Returns 0 or errno. */
static int print_cdouble(CDouble* dp, FILE* fp) {
    int err = fprintf(fp, "%f", dp->value);
    if (err < 0) {
        return err;
    } else {
        return 0;
    }
}


/* Is ch an unescaped char? */
bool is_unescaped(char ch) {
    char* found = strchr("\a\b\e\f\n\r\t\v\\\"", (int)ch);
    return found != NULL;
}


/* Returns the escaped version of unesc.
For example, if unesc is a newline, 'n' is returned.
Asserts false if unesc is not a valid escape char. */
static char escape_char(char unesc) {
    if (unesc == '\a') {
        return 'a';
    } else if (unesc == '\b') {
        return 'b';
    } else if (unesc == '\e') {
        return 'e';
    } else if (unesc == '\f') {
        return 'f';
    } else if (unesc == '\n') {
        return 'n';
    } else if (unesc == '\r') {
        return 'r';
    } else if (unesc == '\t') {
        return 't';
    } else if (unesc == '\v') {
        return 'v';
    } else if (unesc == '\\') {
        return '\\';
    } else if (unesc == '"') {
        return '"';
    } else {
        assert(false);
    }
}


/* Escapes srcp into a malloc'ed dstpp.
Returns 0 or errno. */
static int escape_str(char* srcp, char** dstpp) {
    size_t src_len = strlen(srcp);
    size_t src_size = src_len + 1;
    /* dst will be worst-case twice as large (every byte becomes two bytes),
    so start there, then shrink to fit at the end. */
    size_t dst_size = src_size * 2;

    char* dstp = malloc(dst_size);
    if (dstp == NULL) {
        int err = errno;
        errno = 0;
        return err;
    }

    char* src_cursor = srcp;
    char* src_last = srcp + src_len - 1;
    char* dst_cursor = dstp;
    size_t dst_len = 0;
    while(src_cursor <= src_last) {
        char ch = *src_cursor;
        if (is_unescaped(ch)) {
            *dst_cursor = '\\';
            dst_cursor++;
            *dst_cursor = escape_char(ch);
        } else {
            *dst_cursor = ch;
        }
        src_cursor++;
        dst_cursor++;
        dst_len++;
    }

    /* shrink-to-fit. */
    size_t newdst_size = dst_len + 1;
    char* newdstp = realloc(dstp, newdst_size);
    if (newdstp == NULL) {
        free(dstp);
        int err = errno;
        errno = 0;
        return err;
    } else {
        dstp = newdstp;
    }

    *dstpp = dstp;
    return 0;
}


/* Prints the CString in csp into fp.
Returns 0 or errno. */
static int print_cstring(CString* csp, FILE* fp) {
    char* esc;
    int err = escape_str(csp->valuep, &esc);
    if (err) {
        return err;
    }
    err = fprintf(fp, "\"%s\"", esc);
    if (err < 0) {
        return err;
    } else {
        return 0;
    }
}


/* Prints the List in lp into fp.
Returns 0 or errno. */
static int print_list(List* lp, FILE* fp) {
    int err = fputs("(", fp);
    if (err == EOF) {
        err = errno;
        errno = 0;
        return err;
    }
    List* i = lp;
    while (!is_list_empty(i)) {
        if (i != lp) {
            err = fputs(" ", fp);
            if (err == EOF) {
                err = errno;
                errno = 0;
                return err;
            }
        }
        err = print_form(i->datap, fp);
        if (err) {
            return err;
        }
        i = i->nextp;
    }
    err = fputs(")", fp);
    if (err == EOF) {
        err = errno;
        errno = 0;
        return err;
    }
    return 0;
}


/* Prints the CBool in cbp into fp.
Returns 0 or errno. */
static int print_cbool(CBool* cbp, FILE* fp) {
    int err;
    if (cbp == g_true) {
        err = fputs("true", fp);
    } else if (cbp == g_false) {
        err = fputs("false", fp);
    } else {
        assert(false);
    }
    if (err == EOF) {
        err = errno;
        errno = 0;
        return err;
    }
    return 0;
}


/* Prints nil into fp.
Returns 0 or errno. */
static int print_nil(FILE* fp) {
    int err = fputs("nil", fp);
    if (err == EOF) {
        err = errno;
        errno = 0;
        return err;
    }
    return 0;
}


/* Prints the CFunc in cfp into fp.
Returns 0 or errno. */
static int print_cfunc(CFunc* cfp, FILE* fp) {
    int err = fprintf(fp, "<C function @%p>", cfp->f);
    if (err < 0) {
        return err;
    } else {
        return 0;
    }
}


/* Prints the Form in formp into fp.
Returns 0 or errno. */
int print_form(Form* formp, FILE* fp) {
    if (is_symbol(formp)) {
        Symbol* symp = (Symbol*)formp;
        return print_symbol(symp, fp);
    } else if (is_clong(formp)) {
        CLong* lp = (CLong*)formp;
        return print_clong(lp, fp);
    } else if (is_cdouble(formp)) {
        CDouble* dp = (CDouble*)formp;
        return print_cdouble(dp, fp);
    } else if (is_cstring(formp)) {
        CString* csp = (CString*)formp;
        return print_cstring(csp, fp);
    } else if (is_list(formp)) {
        List* lp = (List*)formp;
        return print_list(lp, fp);
    } else if (is_cbool(formp)) {
        CBool* cbp = (CBool*)formp;
        return print_cbool(cbp, fp);
    } else if (is_nil(formp)) {
        return print_nil(fp);
    } else if (is_cfunc(formp)) {
        CFunc* cfp = (CFunc*)formp;
        return print_cfunc(cfp, fp);
    } else {
        assert(false);
    }
}