#include <stdlib.h>
#include <string.h>

#include "treebyte.h"


enum { arrsize = 0x10 };

struct treebyte_node {
    void *a[arrsize];
};

static struct treebyte_node *mknode()
{
    struct treebyte_node *p = malloc(sizeof(*p));
    memset(p, 0, sizeof(*p));
    return p;
}

static int node_children(const struct treebyte_node *p)
{
    int c, i;
    c = 0;
    for(i = 0; i < arrsize; i++)
        if(p->a[i])
            c++;
    return c;
}

static void dispose_node(struct treebyte_node *p)
{
    int i;
    for(i = 0; i < arrsize; i++)
        if(p->a[i])
            dispose_node(p->a[i]);
    free(p);
}

static void **traverse(struct treebyte_node *p, const unsigned char *key,
                       int depth, int half, int mk)
{
    int idx;
    struct treebyte_node **tmp;

    idx = (half ? (*key) : (*key >> 4)) & 0x0f;

    if(depth <= 1)   /* actually must be exactly 1 */
        return &(p->a[idx]);

        /* on all other depth level, p->a is an array of node pointers */

    tmp = (struct treebyte_node **)&(p->a[idx]);
    if(!*tmp && !mk)
        return NULL;
    if(!*tmp)
        *tmp = mknode();

    return traverse(*tmp, half ? key+1 : key, depth - 1, !half, mk);
}

static int traverse_del(struct treebyte_node *p, const unsigned char *key,
                        int depth, int half)
{
    int idx, res, cc;
    struct treebyte_node *tmp;

    idx = (half ? (*key) : (*key >> 4)) & 0x0f;

    if(depth <= 1) {
        if(p->a[idx]) {
            p->a[idx] = NULL;
            return 1;
        }
        return 0;
    }
        /* on all other depth level, p->a is an array of node pointers */

    tmp = p->a[idx];
    if(!tmp)
        return 0;
    res = traverse_del(tmp, half ? key+1 : key, depth - 1, !half);
    if(!res)
        return 0;   /* nothing changed there, nothing to do here */

    cc = node_children(tmp);
    if(cc == 0) {
        dispose_node(tmp);
        p->a[idx] = NULL;
    }

    return 1;
}

/* ---- public functions ------------------------------------------ */

void treebyte_init(struct treebyte *p, int levels)
{
    p->root = NULL;
    p->levels = levels;
}

void treebyte_clear(struct treebyte *p)
{
    if(p->root)
        dispose_node(p->root);
    p->root = NULL;
}

void *treebyte_get(const struct treebyte *p, const void *key)
{
    void **res;
    if(!p->root)
        return NULL;
    res = traverse(p->root, key, 2 * p->levels, 0, 0);
    return res ? *res : NULL;
}

void **treebyte_provide(struct treebyte *p, const void *key)
{
    if(!p->root)
        p->root = mknode();
    return traverse(p->root, key, 2 * p->levels, 0, 1);
}

int treebyte_delete(struct treebyte *p, const void *key)
{
    int res;
    if(!p->root)
        return 0;
    res = traverse_del(p->root, key, 2 * p->levels, 0);
    if(!res)
        return 0;
    res = node_children(p->root);
    if(res == 0) {
        dispose_node(p->root);
        p->root = NULL;
    }
    return 1;
}


/* -------------- the iterator ----------------- */

struct treebyte_iter
{
    struct treebyte_node *iter_root;  /* the node determined by the keypref */
    int restlevels;                   /* 2*levels - kplen */
    char *pos;    /* has restlevels+1 elems, the last is -1 (marks the end) */
    void **next_result;             /* everything else must be NULL/0     */
};

    /* returns either the last level's ptr address, or NULL */
static void **
step_iter(struct treebyte_node *node, char *pos, int leftlevs, int force)
{
    void **vpp;
    int i;

    if(leftlevs < 1 || *pos < -1 || *pos >= arrsize)   /* must not happen */
        return 0;
    if(*pos == -1)        /* we can't do anything, no more positions here */
        return 0;
    vpp = &(node->a[(int)*pos]);
    if(leftlevs == 1) {   /* the recursion base case */
        if(*vpp && !force)  /* just stay here */
            return vpp;
        (*pos)++;
        while(*pos < arrsize && !node->a[(int)*pos])
            (*pos)++;
        if(*pos >= arrsize) {  /* no more objects here */
            *pos = -1;
            return NULL;
        }
        /* found the object */
        return &(node->a[(int)*pos]);
    }
    /* if we are here, it means we're not on the final level */
    /* if we can stay here on the current level, just check the rest
       of the levels if they can stay (if we're allowed to stay) or
       if they move successfully (if we're forced to move)
     */
    if(*vpp) {
        vpp = step_iter(*vpp, pos + 1, leftlevs - 1, force);
        if(vpp)
            return vpp;
    }
    /* no way, we have to move on the current level */
    (*pos)++;
    while(*pos < arrsize) {
        vpp = &(node->a[(int)*pos]);
        if(*vpp) {   /* let's try this position */
            /* first of all, reset the rest because we've just moved */
            for(i = 1; i < leftlevs; i++)
                pos[i] = 0;
            /* we've just moved already, no forcing the rest to move */
            vpp = step_iter(*vpp, pos + 1, leftlevs - 1, 0);
            if(vpp)
                return vpp;
        }
        /* the position isn't suitable, this way or the other */
        (*pos)++;
    }
    /* if we're here, it means *pos >= arrsize so we failed on this level */
    *pos = -1;
    return NULL;
}

struct treebyte_iter *
treebyte_make_iter(struct treebyte *p, const void *keypref, int kplen)
{
    void **vpp;
    struct treebyte_iter *res;
    int i;

    res = malloc(sizeof(*res));

    res->iter_root = NULL;
    res->pos = NULL;
    res->next_result = NULL;

    if(!p || !p->root || node_children(p->root) == 0)
        return res;
    if(kplen > 2 * p->levels)
        kplen = 2 * p->levels;
    vpp = kplen > 0 ?
        traverse(p->root, keypref, kplen, 0, 0) : (void**)&(p->root);
    if(!vpp || !*vpp)
        return res;

    res->restlevels = 2 * p->levels - kplen;
    if(res->restlevels) {
        res->iter_root = (struct treebyte_node *)*vpp;
        res->pos = malloc(sizeof(*res->pos) * res->restlevels); 
        for(i = 0; i < res->restlevels; i++)
            res->pos[i] = 0;
        res->next_result =
            step_iter(res->iter_root, res->pos, res->restlevels, 0);
    } else {
        res->iter_root = NULL;
        res->pos = NULL;
        res->next_result = vpp;
    }

    return res;
}

void treebyte_dispose_iter(struct treebyte_iter *p)
{
    if(p->pos)
        free(p->pos);
    free(p);
}

void **treebyte_iter_next(struct treebyte_iter *p)
{
    void **res;
    res = p->next_result;
    if(!p->pos) { /* special case, either single result or no more results */
        p->next_result = NULL;
        return res;
    }
    p->next_result = step_iter(p->iter_root, p->pos, p->restlevels, 1);
    return res;
}

const char *treebyte_iter_getpos(const struct treebyte_iter *p)
{
    return p->pos;
}



/* -------------- tests ------------------------ */

#ifdef TREEBYTE_TESTS

#include <stdio.h>

const char * const members[] = {
    "Alpha", "Beta", "Gamma", "Delta", "Epsilon", "Zeta", "Eta===",
    "Theta", "Iota", "Kappa", "Lambda", "Mu===", "Nu===", "Xi===",
    "Omicron", "Pi===", "Rho===", "Sigma", "Tau===", "Upsilon",
    "Phi===", "Chi===", "Psi===", "Omega",
    NULL
};

const char * const nonmembers[] = {
    "alpha", "beta", "gamma", "delta", "epsilon", "zeta", "eta===",
    "theta", "iota", "kappa", "lambda", "mu===", "nu===", "xi===",
    "omicron", "pi===", "rho===", "sigma", "tau===", "upsilon",
    "phi===", "chi===", "psi===", "omega",
    NULL
};

typedef unsigned char uchr;

void test2(struct treebyte *tb, void *key, int lev)
{
    void **vpp;
    struct treebyte_iter *iter;
    iter = treebyte_make_iter(tb, key, lev);
    while((vpp = treebyte_iter_next(iter)) != NULL) {
        if(*vpp)
            printf("    ->   %s\n", (char*)*vpp);
        else
            printf("    -    NULL\n");
    }
}

int main()
{
    int i;
    struct treebyte tb;

    treebyte_init(&tb, 4);

    for(i = 0; members[i]; i++)
        *treebyte_provide(&tb, (const uchr*)members[i]) = (void*)(members[i]);

    for(i = 0; members[i] && nonmembers[i]; i++) {
        void *res;
        res = treebyte_get(&tb, (const uchr*)members[i]);
        printf("%s: %s\n", members[i], res == members[i] ? "ok" : "fail");
        res = treebyte_get(&tb, (const uchr*)nonmembers[i]);
        printf("%s: %s\n", nonmembers[i], !res ? "ok" : "fail");
    }

    printf("[T] (expect Theta, Tau===)\n");
    test2(&tb, "T", 2);

    printf("[Sigm] (expect Sigma)\n");
    test2(&tb, "Sigm", 8);

    printf("[] (expect 24 results)\n");
    test2(&tb, "", 0);

    printf("[T]/1 (expect some results)\n");
    test2(&tb, "T", 1);

    printf("[E]/1 (expect some results)\n");
    test2(&tb, "E", 1);


    return 0;
}


#endif
