#include <stdio.h>
#include <string.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/mman.h>
#include <unistd.h>
#include <fcntl.h>

#include "cryptodf.h"
#include "keyutils.h"
#include "xyespwr.h"
#include "_version.h"

enum {
    table_recsize = public_key_size + yespower_hash_size,
};

static void show_version()
{
    fputs(
        "feda-ct program: challenge table manipulation\n"
        "vers. " FEDA_VERSION " (compiled " __DATE__ ")\n"
        "Copyright (c) Andrey Vikt. Stolyarov, 2024\n",
        stdout);
}

static void do_help()
{
    show_version();
    fputs("\n"
        "Usage: feda-ct <command> <table_file> [...]\n"
        "The following commands are recognized:\n"
        "\n"
        "    scan      scan the given table for its fill rate\n"
        "    poscheck  check the given table whether records are positioned\n"
        "              properly (i.e., the first 8 bytes MOD count of recs\n"
        "              always matches the record's index)\n"
        "    cannibal  cannibalize another table to fill more records in\n"
        "              the given one; takes exactly two args, the first\n"
        "              for the table to be filled, and the second for the\n"
        "              table to be cannibalized; the cannibalized records\n"
        "              get zeroed in the second table\n"
        "    fill      fill the empty records with random challenges and\n"
        "              their respective hashes; this is WAY TOO SLOW so it\n"
        "              is only recommended for 99%-filled tables\n"
        "    help      show this help and exit\n",
        stdout);
}

unsigned char *mmap_table(const char *fname, int need_write, int *recs)
{
    int fd, prot;
    long long size;
    void *p;

    fd = open(fname, need_write ? O_RDWR : O_RDONLY);
    if(fd == -1) {
        perror(fname);
        return NULL;
    }

    size = lseek(fd, 0, SEEK_END);
    if(size == -1) {
        perror("lseek");
        close(fd);
        return NULL;
    }
    if(size % table_recsize != 0) {
        fprintf(stderr,
            "WARNING: the file size is not a multiple of the record size\n");
    }
    if(size < table_recsize) {
        fprintf(stderr, "FATAL: the table is empty, nothing to work with\n");
        close(fd);
        return NULL;
    }

    prot = need_write ? PROT_READ|PROT_WRITE : PROT_READ;
    p = mmap(NULL, size,  prot, MAP_SHARED, fd, 0);
    if(p == MAP_FAILED) {
        perror("mmap");
        close(fd);
        return NULL;
    }

    close(fd);
    *recs = size / table_recsize;
    return p;
}

static int do_scan(const char *fname)
{
    unsigned char *p;
    int records, i, cnt;

    if(!fname || !*fname) {
        fprintf(stderr, "Please tell me what to scan\n");
        return 1;
    }

    p = mmap_table(fname, 0, &records);
    if(!p)
        return 1;

    cnt = 0;
    for(i = 0; i < records; i++) {
        const unsigned char *start = p + i * table_recsize;
        if(!all_zeroes(start, public_key_size))
            cnt++;
    }

    printf("total: %d     filled: %d     empty: %d    rate: %d%%\n",
           records, cnt, records - cnt, cnt * 100 / records);

    return 0;
}

static unsigned long long pack_index(const unsigned char *data)
{
    int i;
    unsigned long long res = 0;
    for(i = 0; i < 8; i++) {
        res <<= 8;
        res |= data[7 - i];
    }
    return res;
}

static void unpack_index(unsigned char *data, unsigned long long index)
{
    int i;
    for(i = 0; i < 8; i++) {
        data[i] = index & 0xFF;
        index >>= 8;
    }
}

static int do_poscheck(const char *fname)
{
    unsigned char *p;
    int records, i, ok;

    if(!fname || !*fname) {
        fprintf(stderr, "Please tell me what to check\n");
        return 1;
    }

    p = mmap_table(fname, 0, &records);
    if(!p)
        return 1;

    ok = 1;
    for(i = 0; i < records; i++) {
        unsigned long long index;
        const unsigned char *start = p + i * table_recsize;
        if(all_zeroes(start, table_recsize))
            continue;
        index = pack_index(start);
        index %= records;
        if(index != i) {
            fprintf(stderr, "ERROR: recnum %d index value %llu\n",
                            i, index);
            ok = 0;
            break;
        }
    }
    return ok ? 0 : 1;
        /* we don't use NOT here because the value is not boolean despite
           it looks much like a boolean: it's actually the exit code.
         */
}


/* the code of the do_cannibal function was gracefully donated by Y. Kotov */

static int do_cannibal(const char *cannibal, const char *dinner)
{
    unsigned char *ptr_c, *ptr_d;
    int records, records_tmp, i, ok;

    if(!cannibal || !*cannibal) {
        fprintf(stderr, "Please tell me what to fill\n");
        return 1;
    }
    if(!dinner || !*dinner) {
        fprintf(stderr, "Please tell me what to cannibalize\n");
        return 1;
    }

    ptr_c = mmap_table(cannibal, 1, &records);
    if(!ptr_c)
        return 1;
    ptr_d = mmap_table(dinner, 1, &records_tmp);
    if(!ptr_d)
        return 1;

    if(records != records_tmp) {
        fprintf(stderr, "ERROR: records count don't match: %d != %d\n",
                        records, records_tmp);
        return 1;
    }

    ok = 1;
    for(i = 0; i < records; i++) {
        unsigned long long index;
        unsigned char *start_c = ptr_c + i * table_recsize;
        unsigned char *start_d = ptr_d + i * table_recsize;
        if(!all_zeroes(start_c, table_recsize)) /* already exists */
            continue;
        if(all_zeroes(start_d, table_recsize))  /* nothing to eat */
            continue;
        index = pack_index(start_d);
        index %= records;
        if(index != i) {
            fprintf(stderr, "ERROR: recnum %d index value %llu\n",
                            i, index);
            ok = 0;
            break;
        }
        memcpy(start_c, start_d, table_recsize);
        memset(start_d, 0, table_recsize);
    }
    return ok ? 0 : 1;
        /* we don't use NOT here because the value is not boolean despite
           it looks much like a boolean: it's actually the exit code.
         */
}

static int do_fill(const char *fname)
{
    unsigned char *p;
    int records, i, ok;

    if(!fname || !*fname) {
        fprintf(stderr, "Please tell me what to fill\n");
        return 1;
    }

    p = mmap_table(fname, 1, &records);
    if(!p)
        return 1;

    ok = 1;
    for(i = 0; i < records; i++) {
        unsigned long long index;
        unsigned char *start = p + i * table_recsize;
        unsigned char buf[public_key_size];
        unsigned char yespower_hash[yespower_hash_size];

        if(!all_zeroes(start, table_recsize))
            continue;

        if(!get_random(buf, sizeof(buf))) {
            fprintf(stderr, "ERROR: unable to get random\n");
            return 1;
        }

        /* Fix first bytes of random data to match record position. */
        index = pack_index(buf);
        index += i - index % records;
        unpack_index(buf, index);

        if(!take_yespower_hash(yespower_hash, buf)) {
            fprintf(stderr, "ERROR: couldn't take yespower hash\n");
            return 1;
        }

        memcpy(start + public_key_size, yespower_hash, yespower_hash_size);
        memcpy(start + 8, buf + 8, public_key_size - 8);
        memcpy(start, buf, 8);
    }
    return ok ? 0 : 1;
        /* we don't use NOT here because the value is not boolean despite
           it looks much like a boolean: it's actually the exit code.
         */
}

int main(int argc, const char * const *argv)
{
    if(argc < 2) {
        show_version();
        fputs("\nTry ``feda-ct help'' for help\n", stdout);
        return 1;
    }

    if(0 == strcmp("help", argv[1])) {
        do_help();
        return 0;
    }
    if(0 == strcmp("scan", argv[1])) {
        return do_scan(argv[2]);
    }
    if(0 == strcmp("poscheck", argv[1])) {
        return do_poscheck(argv[2]);
    }
    if(0 == strcmp("cannibal", argv[1])) {
        if(argc < 4) {
            fputs("The ``cannibal'' command needs exactly two arguments\n",
                  stderr);
            return 1;
        }
        return do_cannibal(argv[2], argv[3]);
    }
    if(0 == strcmp("fill", argv[1])) {
        return do_fill(argv[2]);
    }

    fprintf(stderr, "command ``%s'' not recognized, try ``help''\n", argv[1]);
    return 1;
}
