#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <time.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/select.h>
#include <netdb.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <errno.h>

#include <monocypher/monocypher.h>

#include "cryptodf.h"
#include "comcrypt.h"
#include "keyutils.h"
#include "nmobfusc.h"
#include "message.h"
#include "addrport.h"
#include "_version.h"


#define DEFAULT_HOST "nc1.feda.croco.net"
#define DEFAULT_PORT 0xFEDA
#define DEFAULT_TIMEOUT 5000

static void show_version()
{
    fputs(
        "natcheck program: checks your IP connectivity conditions\n"
        "FEDAnet utils vers. " FEDA_VERSION " (compiled " __DATE__ ")\n"
        "Copyright (c) Andrey Vikt. Stolyarov, 2025\n",
        stdout);
}

static const char helptext[] =
    "\n"
    "Usage: natcheck <options> [hostname [port]]\n"
    "options are:\n"
    "   -t <timeout>  timeout in milliseconds (default 5000)\n"
    "   -b <port>     bind to this LOCAL (source) port; give \"-\" for\n"
    "                 the port to be choosen by the program but still\n"
    "                 explicitly bound with the bind(2) syscall\n"
    "   -v            be verbose; more flags for more messages\n"
    "   -q            be quiet\n"
    "   -h            display this text and exit\n"
    "\n"
    "Hostname specifies the FEDA server to use; the default is\n"
    DEFAULT_HOST ".  The default port is 65242 (0xFEDA)\n";

static void do_help()
{
    show_version();
    fputs(helptext, stdout);
}

enum {
    cmdl_lp_no_bind = -2,
    cmdl_lp_choose_and_bind = -1
};

struct cmdline_opts {
    int verbosity;
    int help;
    int localport;
    const char *hostname;
    int port;
    int timeout_msec;
};

static void set_defaults(struct cmdline_opts *opts)
{
    opts->verbosity = 0;
    opts->help = 0;
    opts->hostname = DEFAULT_HOST;
    opts->port = DEFAULT_PORT;
    opts->timeout_msec = DEFAULT_TIMEOUT;
    opts->localport = cmdl_lp_no_bind;
}

static void print_need_param(char c)
{
    fprintf(stderr, "-%c needs a parameter; try -h for help\n", c);
}

static int
parse_cmdline(int argc, const char *const *argv, struct cmdline_opts *opts)
{
    int idx = 1;
    char *err;
    while(idx < argc) {
        if(argv[idx][0] == '-') {
            switch(argv[idx][1]) {
            case 't':
                if(idx+1 >= argc || argv[idx+1][0] == '-') {
                    print_need_param('t');
                    return 0;
                }
                opts->timeout_msec = strtol(argv[idx+1], &err, 0);
                if(*err) {
                    fprintf(stderr, "invalid timeout %s\n", argv[idx+1]);
                    return 0;
                }
                idx += 2;
                break;
            case 'b':
                if(idx+1 >= argc) {
                    print_need_param('t');
                    return 0;
                }
                if(argv[idx+1][0] == '-') {
                    opts->localport = cmdl_lp_choose_and_bind;
                    idx += 2;
                    break;
                }
                opts->localport = strtol(argv[idx+1], &err, 0);
                if(*err || opts->localport < 1 || opts->localport > 65534) {
                    fprintf(stderr, "invalid local port %s\n", argv[idx+1]);
                    return 0;
                }
                idx += 2;
                break;
            case 'v':
                opts->verbosity++;
                idx++;
                break;
            case 'q':
                opts->verbosity = -1;
                idx++;
                break;
            case 'h':
                opts->help = 1;
                idx++;
                break;
            default:
                fprintf(stderr, "unknown option ``-%c''\n", argv[idx][1]);
                return 0;
            }
        } else {  /* no ``-''; must be the host/port */
            opts->hostname = argv[idx];
            if(idx + 1 < argc) {
                opts->port = strtol(argv[idx+1], &err, 0);
                if(*err || opts->port < 1 || opts->port > 65534) {
                    fprintf(stderr, "invalid port %s\n", argv[idx+1]);
                    return 0;
                }
            }
            return 1;
        }
    }
    return 1;
}


#if 0  /* both versions work, and both (with glibc) ruin static linking */

static int get_the_ip(const char *host, unsigned int *saddr)
{
    struct addrinfo hints;
    struct addrinfo *ainfo, *p;
    int res;

    hints.ai_family = AF_INET;
    hints.ai_socktype = 0;
    hints.ai_protocol = 0;
    hints.ai_flags = 0;
    hints.ai_addrlen = 0;
    hints.ai_addr = NULL;
    hints.ai_canonname = NULL;
    hints.ai_next = NULL;

    res = getaddrinfo(host, NULL, &hints, &ainfo);
    if(res != 0) {
        fprintf(stderr, "%s: %s\n", host, gai_strerror(res));
        return 0;
    }

    for(p = ainfo; p; p = p->ai_next) {
        if(p->ai_family == AF_INET) {
            *saddr = ((struct sockaddr_in*)(p->ai_addr))->sin_addr.s_addr;
            res = 1;
            goto quit;
        }
    }

    res = 0;
quit:
    freeaddrinfo(ainfo);
    return res;
}

#else

static int get_the_ip(const char *host, unsigned int *saddr)
{
    struct hostent *he;

    he = gethostbyname(host);

    if(!he) {
        message(mlv_alert, "%s: %s\n", host, hstrerror(h_errno));
        return 0;
    }
    if(he->h_addrtype != AF_INET || he->h_length != 4) {
        message(mlv_alert, "can't work with non-IPv4 addresses\n");
        return 0;
    }

    memcpy(saddr, he->h_addr_list[0], 4);
    return 1;
}

#endif

    /* this is a part of the protocol, thus can't change */
enum { max_sibling_servers = 10 };

struct session {
    int sockfd;
    struct addrport my_local, my_visible, serv;
    struct crypto_comm_context comctx;
    struct addrport expected[max_sibling_servers];
    int expected_count;
};

static int sess_send_dgram(struct session *sess, const struct addrport *peer,
                           const unsigned char *dgram, int len)
{
    int res;
    struct sockaddr_in saddr;

    saddr.sin_family = AF_INET;
    saddr.sin_addr.s_addr = htonl(peer->addr);
    saddr.sin_port = htons(peer->port);

    res = sendto(sess->sockfd, dgram, len, 0,
                 (struct sockaddr*)&saddr, sizeof(saddr));
    if(res < sizeof(dgram)) {
        message_perror(mlv_alert, NULL, "sendto");
        message(mlv_debug, "sendto returned %d, expected %d\n",
                           res, sizeof(dgram));
        return 0;
    }

    return 1;
}

static int sess_send_dgram_stub(struct session *sess)
{
    unsigned char dgram[128];

    set_plain_dgram_head(dgram, fedaprot_stub);
    fill_noise(dgram + 2, sizeof(dgram) - 2);

    return sess_send_dgram(sess, &sess->serv, dgram, sizeof(dgram));
}

static int sess_send_dgram_test(struct session *sess)
{
    int i;
    unsigned char dgram[128];

    set_plain_dgram_head(dgram, fedaprot_test);
    memcpy(dgram + 2, sess->comctx.kex_public, kex_public_size);
    i = kex_public_size + 2;
    fill_noise(dgram + i, sizeof(dgram) - i);

    return sess_send_dgram(sess, &sess->serv, dgram, sizeof(dgram));
}

    /*
       *len is inout: must be sizeof(buf) on call, will contain the actual
       length of the datagram on exit;
       returns -1 for error, 0 for timeout, > 0 for received dgram;
       even if the dgram is of zero length, still returns 1 (but *len
       will contain 0)
     */
static int sess_receive_dgram(int sockfd, int timeout_msec,
                              void *buf, int *len, struct addrport *ap)
{
    int res;
    struct sockaddr_in saddr;
    socklen_t addrlen;
    fd_set readset;
    struct timeval tmo;

    FD_ZERO(&readset);
    FD_SET(sockfd, &readset);
    tmo.tv_sec = timeout_msec / 1000;
    tmo.tv_usec = (timeout_msec % 1000) * 1000;
    res = select(sockfd+1, &readset, NULL, NULL, &tmo);
    if(res <= 0)
        return 0;

    addrlen = sizeof(saddr);
    res = recvfrom(sockfd, buf, *len, 0,
                   (struct sockaddr*)&saddr, &addrlen);
    if(res < 0) {
        message_perror(mlv_alert, "TROUBLES", "recvfrom");
        return -1;
    }

    ap->addr = ntohl(saddr.sin_addr.s_addr);
    ap->port = ntohs(saddr.sin_port);
    *len = res;

    message(mlv_debug, "received %d bytes from %s\n", res, addrport2a(ap));

    return res > 0 ? res : 1;
}

    /* returns -1 or 0 just like bind(2) */
static int do_bind(int sockfd, int port)
{
    struct sockaddr_in saddr;
    saddr.sin_family = AF_INET;
    saddr.sin_addr.s_addr = INADDR_ANY;
    saddr.sin_port = htons(port);
    return bind(sockfd, (struct sockaddr*)&saddr, sizeof(saddr));
}

static int trybind(int sockfd)
{
    int port, res;

    port = rand_from_range(50000, 59999);
    for(; port < 0xFEDA; port++) {
        res = do_bind(sockfd, port);
        if(res != -1) {     /* bound, we're done */
            message(mlv_info, "bound to local port %d\n", port);
            return 1;
        }
        if(errno != EADDRINUSE) {
            message_perror(mlv_alert, NULL, "bind");
            return 0;
        }
        /* otherwise, just continue */
    }
    message(mlv_alert, "couldn't pick a suitable port\n");
    return 0;
}

static int prepare_local_port(int *sockfd, int localport)
{
    int res;

    *sockfd = socket(AF_INET, SOCK_DGRAM, 0);
    if(*sockfd == -1) {
        message_perror(mlv_alert, NULL, "socket");
        return 0;
    }
    switch(localport) {
    case cmdl_lp_no_bind:
        break;
    case cmdl_lp_choose_and_bind:
        res = trybind(*sockfd);
        if(!res)
            return 0;
        break;
    default:
        message(mlv_debug, "localport value is set to %d\n", localport);
        res = do_bind(*sockfd, localport);
        if(res == -1) {
            message_perror(mlv_alert, NULL, "bind");
            return 0;
        }
    }
    return 1;
}


static int session_start(struct session *sess,
                         const struct addrport *serv, int localport)
{
    int res;

    memset(&sess->my_local, 0, sizeof(sess->my_local));
    memset(&sess->my_visible, 0, sizeof(sess->my_visible));
    memcpy(&sess->serv, serv, sizeof(sess->serv));
    res = comctx_init(&sess->comctx);
    if(!res) {
        message(mlv_alert, "can't init keys (random numbers problem?)\n");
        return 0;
    }

    return prepare_local_port(&sess->sockfd, localport);
}



static int decrypt_semiencrypted(struct session *sess,
                                 unsigned char *buf, int dl,
                                 unsigned char **ctext)
{
    int r;
    unsigned char *dgram = buf + cipher_nonce_offset;
    unsigned char nonce[cipher_nonce_total];
    unsigned char decrypt_key[cipher_key_size];
    unsigned char *pubkey, *noncepart, *mac;

    pubkey = dgram + 2;
    noncepart = pubkey + public_key_size;
    mac = noncepart + cipher_nonce_used;
    *ctext = mac + cipher_mac_size;

    memset(nonce, 0, cipher_nonce_offset);
    memcpy(nonce + cipher_nonce_offset, noncepart, cipher_nonce_used);

    derive_cipher_keys(sess->comctx.kex_secret, sess->comctx.kex_public, 
                       pubkey, NULL, decrypt_key);

    deobfuscate_buffer(nonce + cipher_nonce_offset, cipher_nonce_used);
    r = crypto_aead_unlock(*ctext, mac, decrypt_key, nonce, NULL, 0,
                           *ctext, dl - (*ctext - dgram));
    if(r == -1) {
        message(mlv_normal, "decrypt failed, dgram ignored\n");
        return 0;
    }
    return 1;
}


static int handle_test_reply(struct session *sess, unsigned char *buf, int dl)
{
    int r, i;
    unsigned char *ctext;

    r = decrypt_semiencrypted(sess, buf, dl, &ctext);
    if(!r)
        return 0;

    mem2addrport(ctext, &sess->my_visible);
    message(mlv_normal, "they see us as %s\n", addrport2a(&sess->my_visible));
    for(i = 0; i < max_sibling_servers; i++) {
        const unsigned char *siblinfo = ctext + 8 + 6 * i;
        if(all_zeroes(siblinfo, 6)) {
            sess->expected_count = i;
            message(mlv_info, "total: %d sibling servers\n", i);
            break;
        }
        mem2addrport(siblinfo, &sess->expected[i]);
        message(mlv_info, "suggested to expect: sibling server %s\n",
                          addrport2a(&sess->expected[i]));
    }

    return 1;
}

static int handle_echo_reply(struct session *sess,
                             unsigned char *buf, int dl,
                             struct addrport *seen)
{
    int r;
    unsigned char *ctext;

    r = decrypt_semiencrypted(sess, buf, dl, &ctext);
    if(!r)
        return 0;
    mem2addrport(ctext, seen);
        /* actually, this is all we need; we aren't going to establish
           a cryptographic association so we don't check the signature
           and don't use other fields from the datagram */

    return 1;
}


static int handle_stub_dgram(struct session *sess, unsigned char *buf)
{
    return 1;
}

static int send_stage1_dgrams(struct session *sess)
{
    int res, i;

    for(i = 0; i < 5; i++) {
        res = sess_send_dgram_stub(sess);
        if(!res)
            return 0;
    }
    res = sess_send_dgram_test(sess);
    if(!res)
        return 0;

    return 1;
}

static void check_local_port(struct session *sess)
{
    int res;
    struct sockaddr_in saddr;
    socklen_t addrlen;

    addrlen = sizeof(saddr);
    res = getsockname(sess->sockfd, (struct sockaddr*)&saddr, &addrlen);
    if(res == -1) {
        message_perror(mlv_normal, "WARNING", "getsockname");
    } else {
        sess->my_local.addr = ntohl(saddr.sin_addr.s_addr);
        sess->my_local.port = ntohs(saddr.sin_port);
    }
}


static void run_first_stage(struct session *sess, int tmms,
                            int *srv_cnt, int *siblings_cnt)
{
    int rc, r, len, cmd;
    struct addrport ap;
    unsigned char buf[2048];
    unsigned char *dgram = buf + cipher_nonce_offset;

    *siblings_cnt = 0;   /* as of now we don't expect any... change this! */
    *srv_cnt = 0;

    rc = send_stage1_dgrams(sess);
    if(!rc)
        return;
    check_local_port(sess);
    message(mlv_info, "my local socket address: %s\n",
                      addrport2a(&sess->my_local));

    for(;;) {
        len = sizeof(buf) - cipher_nonce_offset;
        rc = sess_receive_dgram(sess->sockfd, tmms, dgram, &len, &ap);
        if(rc < 0) {
            message_perror(mlv_normal, NULL, "receiving");
            return;
        }
        if(rc == 0) {
            message(mlv_info, "timed out\n");
            return;
        }
        cmd = get_plain_dgram_cmd(dgram);
        message(mlv_debug, "got %d bytes from %s (cmd %02x)\n",
                           rc, addrport2a(&ap), cmd);
    
        switch(cmd) {
        case -1:
            message(mlv_normal, "we didn't expect encrypted dgram!\n");
            break;
        case fedaprot_stub:
            r = handle_stub_dgram(sess, buf);
            if(r)
                (*siblings_cnt)++;
            break;
        case fedaprot_test_reply:
            r = handle_test_reply(sess, buf, rc);
            if(r)
                (*srv_cnt)++;
            break;
        default:
            message(mlv_normal, "cmd %02x; don't know how to handle\n");
            return;
        }
    }
}




static int sess_send_echoreq(struct session *sess, struct addrport *ap)
{
    int i;
    long long tm;
    unsigned char dgram[128];

    set_plain_dgram_head(dgram, fedaprot_echo_req);
    memcpy(dgram + 2, sess->comctx.kex_public, kex_public_size);
    i = kex_public_size + 2;
    tm = time(NULL);
    tm /= 60;
    place_timemark(tm, dgram + i);
    i += 4;
    fill_noise(dgram + i, sizeof(dgram) - i);
        /* it may look not good to use non-cryptografic random as the
           nonce, but please note the key we use is ephemeral, it will
           only be used as many times as there are peers (no more than 10),
           and absolutely no "serious things" will be done with the key;
           actually, all this exchange might be done in plain text.
         */

    return sess_send_dgram(sess, ap, dgram, sizeof(dgram));
}

static int addrport_eq(const struct addrport *a, const struct addrport *b)
{
    return a->addr == b->addr && a->port == b->port;
}

static void run_second_stage(struct session *sess, int tmms,
                             int *good_cnt, int *bad_cnt)
{
    int rc, r, len, cmd, i;
    struct addrport ap, seen;
    unsigned char buf[2048];
    unsigned char *dgram = buf + cipher_nonce_offset;

    *good_cnt = 0;
    *bad_cnt = 0;

    for(i = 0; i < sess->expected_count; i++) {
        message(mlv_debug, "sending feda echo req. to %s\n",
                           addrport2a(&sess->expected[i]));
        sess_send_echoreq(sess, &sess->expected[i]);
    }

    while(*good_cnt + *bad_cnt < sess->expected_count) {
        len = sizeof(buf) - cipher_nonce_offset;
        rc = sess_receive_dgram(sess->sockfd, tmms, dgram, &len, &ap);
        if(rc < 0) {
            message_perror(mlv_normal, NULL, "receiving");
            return;
        }
        if(rc == 0) {
            message(mlv_info, "timed out\n");
            return;
        }
        message(mlv_debug, "got %d bytes from %s\n", rc, addrport2a(&ap));
    
        cmd = get_plain_dgram_cmd(dgram);
        switch(cmd) {
        case -1:
            message(mlv_normal, "we didn't expect encrypted dgram!\n");
            break;
        case fedaprot_echo_reply:
            r = handle_echo_reply(sess, buf, rc, &seen);
            if(!r) {
                message(mlv_normal, "broken echo reply, ignored\n");
                break;
            }
            message(mlv_debug, "they see us as %s\n", addrport2a(&seen));
            if(addrport_eq(&seen, &sess->my_visible))
                ++*good_cnt;
            else
                ++*bad_cnt;
            break;
        default:
            message(mlv_normal, "cmd %02x; don't know how to handle\n");
            return;
        }
    }
}

static void print_result(const char *res, const char *cmt)
{
    message(mlv_info, "----------------------\n");
    message(mlv_alert, "Looks like %s type of NAT%s\n", res, cmt);
    message(mlv_normal, "Check the file NAT_TYPES for explanation\n");
}

static int first_stage_report(int serv_rep, int peers_rep)
{
    if(serv_rep <= 0) {
        if(peers_rep <= 0)
            message(mlv_alert, "timed out waiting for the server's reply\n");
        else {
            message(mlv_alert,
                "no reply from the server but got dgrams from others\n");
            message(mlv_info,
                "this look very strange, you're suggested to try again\n");
        }
        return 1;
    }
    if(peers_rep > 0) {
        message(mlv_info,
                "got the server's reply and %d dgrams from others\n",
                peers_rep);
        print_result("FULL CONE", " (or even no NAT at all). Excellent!");
        return 0;
    }
    message(mlv_info,
        "got the server's reply but no others, second stage is needed\n");
    return -1;
}

static int second_stage_report(int same_addr, int diff_addr)
{
    if(same_addr <= 0 && diff_addr <= 0) {    
        message(mlv_info, "got no replies at all, timed out; try again\n");
        return 1;
    }
    if(same_addr >= 0 && diff_addr <= 0) {
        message(mlv_info,
            "got %d replies, all peers see us at the same address\n",
            same_addr);
        print_result("RESTRICTED CONE", "; not too bad.");
        return 0;
    }
    if(same_addr <= 0 && diff_addr >= 0) {
        message(mlv_info,
            "got %d replies, they see us as DIFFERENT addresses; too bad\n",
            diff_addr);
        print_result("SYMMETRIC NAT", "; too bad.");
        return 0;
    }
    message(mlv_info,
            "%d vs. %d replies to see us at the same vs. different addrs\n",
            same_addr, diff_addr);
    message(mlv_alert, "strange results, try again\n");
    return 1;
}

int main(int argc, const char * const *argv)
{
    unsigned int saddr;
    struct cmdline_opts opts;
    struct addrport servaddr;
    struct session sess;
    int r, servrep, peersrep, badpeers;

    set_defaults(&opts);
    r = parse_cmdline(argc, argv, &opts);
    if(!r) {
        message(mlv_normal, "try -h for help\n");
        return 1;
    }
    if(opts.help) {
        do_help();
        return 0;
    }

    message_set_verbosity(opts.verbosity);
    message(mlv_normal, "normal messages enabled\n");
    message(mlv_info, "info messages enabled\n");
    message(mlv_debug, "debug messages enabled\n");


    r = get_the_ip(opts.hostname, &saddr);
    if(!r)  /* diags already printed */
        return 1;

    servaddr.addr = ntohl(saddr);
    servaddr.port = opts.port;
    message(mlv_info, "will contact %s\n", addrport2a(&servaddr));
    r = session_start(&sess, &servaddr, opts.localport);
    if(!r) {
        message(mlv_normal, "couldn't start, abort\n");
        return 1;
    }

    run_first_stage(&sess, opts.timeout_msec, &servrep, &peersrep);
    r = first_stage_report(servrep, peersrep);
    if(r >= 0)
        return r;
    run_second_stage(&sess, opts.timeout_msec, &peersrep, &badpeers);
    return second_stage_report(peersrep, badpeers);
}
