#define _CRT_SECURE_NO_WARNINGS
#include <stdio.h>
#include <stdlib.h>
#include <memory.h>
#include <math.h>
#include <stdint.h>
#include <vector>

using namespace std;

typedef short sample_t;
typedef int64_t err_t;

static vector<int16_t> samples;
static int rate;

#define DESIRED_RMS    20000    //can be tweaked slightly for better utilization of dynamic range depending on material
#define BPS    1
#define DELTA 10            //try table values +- this every pass
                            //higher values = more exhaustive search
#define PRESHAPE
//#define POSTSHAPE
//#define VERBOSE

static void normalize(sample_t *data, int n) {
    float rms = 0;
    int x;

    for (x = 0; x < n; x++)
        rms += data[x]*data[x];

    rms = sqrtf(rms / n);
    fprintf(stderr, "RMS amplitude prior to normalization: %f\n", rms);

    for (x = 0; x < n; x++) {
        int value = data[x] * DESIRED_RMS / rms;
        if (value < -32768) value = -32768;
        if (value > 32767)  value = 32767;
        data[x] = value;
    }
}

static void quantize(const sample_t *input, sample_t *output, int n) {
    int x, y, last_error = 0;
    int counts[31][31];
    memset(counts, 0, sizeof(counts));

    for (x = 0; x < n; x++) {
#ifdef PRESHAPE
        //shape using feedback. not sure how correct this is,
        //but quiet parts appear to receive less noise
        output[x] = ((input[x] + 3 * last_error / 4) / 2048) + 15;
#else
        output[x] = (input[x] / 2048) + 15;
#endif

        if (output[x] < 0)
            output[x] = 0;
        if (output[x] >= 31)
            output[x] = 30;

        last_error = (output[x] - 15) * 2048 - input[x];

        if (x > 0)
            counts[output[x-1]][output[x]]++;
    }

    for (y = 0; y < 31; y++) {
        for (x = 0; x < 31; x++)
            fprintf(stderr, "%3i ", counts[y][x]);
        fprintf(stderr, "\n");
    }
}

//table is k*31 entries, where k=2^N
static err_t table_adpcm_work(sample_t *data, int n, int *table, int k, int do_output, unsigned char *bits) {
    int x, last = 15, last_error = 0;
    err_t ret = 0;
    int hist[256] = {0};
    float factor;
    int byte = 0;

    for (x = 0; x < n; x++) {
        int y;
        int best, best_error, best_value;

        for (y = 0; y < k; y++) {
            int value = table[y + last*k];
            int error;

#ifdef POSTSHAPE
            //again, may not be entirely correct,
            //but the output sounds roughly right
            error = value - data[x] - last_error / 2;
#else
            error = value - data[x];
#endif
            error *= error;

            if (y == 0 || error < best_error) {
                best = y;
                best_error = error;
                best_value = value;
            }
        }

        last_error = best_value - data[x];
        last = best_value;
        ret += best_error;

        if ((x & 7) == 0)
            byte = 0;
        byte |= best << (x & 7);
        if ((x & 7) == 7)
            hist[byte]++;

        if (do_output) {
            data[x] = best_value;

            //NOTE: Only supports k == 2 (1-bit) properly atm
            if (bits && best)
                bits[x >> 3] |= 1 << (x & 7);
        }
    }

    /* expect x/8/256 of each 0x55 and 0aAA */
    factor = (hist[0x55] + hist[0xAA]) / (float)(n/8/128);
    factor -= 9;
    if (factor < 1)
        factor = 1;
    //fprintf(stderr, "factor = %.2f\tn = %i\n", factor, n);

    return ret * factor;
}

static void decode_to_samples(int *table, unsigned char *bits, int bytes) {
    int last = 15, x;

    for (x = 0; x < bytes*8; x++)
        samples.push_back(((last = table[(last << 1) | ((bits[x >> 3] >> (x & 7)) & 1)]) - 15) * 2048);
}

static void table_adpcm(sample_t *data, int n, int bps, int *best_table, unsigned char *bits, int entries, int k) {
    int x, y;
    int *table = (int*)malloc(entries*sizeof(int));
    err_t e, ebest;
    int pass, changes;

    //speed up processing by only looking at the first tenth of the file
    //this seems to work well enough
    int n2 = n;

    //initialize table with reasonable values
    for (y = 0; y < 31; y++)
        for (x = 0; x < k; x++) {
            if (bps == 1)
                table[x + y*k] = y + 7*(x - k/2) + 4;
            else
                table[x + y*k] = y + 3*(x - k/2) + 3;

            if (table[x + y*k] < 0)  table[x + y*k] = 0;
            if (table[x + y*k] > 30) table[x + y*k] = 30;
        }

    ebest = table_adpcm_work(data, n2, table, k, 0, NULL);
    memcpy(best_table, table, entries*sizeof(int));

    fprintf(stderr, "initial error: %li\n", ebest);
    fprintf(stderr, "initial rms: %.2f\n", sqrtf((float)ebest/n2));

    changes = 1;
    for (pass = 1; changes; pass++) {
        changes = 0;

        for (y = 0; y < entries; y++) {
            int delta = DELTA;
            int min = best_table[y] - delta, max = best_table[y] + delta;
            memcpy(table, best_table, entries*sizeof(int));

#ifdef VERBOSE
            fprintf(stderr, "table[% 4i/% 4i] = % 3i\n", y, entries, best_table[y]);
#endif

            if (min < 0)
                min = 0;
            if (max > 30)    //31 isn't a legal value
                max = 30;

            for (x = min; x < max; x++) {
                table[y] = x;
                e = table_adpcm_work(data, n2, table, k, 0, NULL);

                if (e < ebest) {
                    float rms = sqrtf((float)ebest/n2);
                    memcpy(best_table, table, entries*sizeof(int));
                    ebest = e;
                    changes++;

#ifdef VERBOSE
                    fprintf(stderr, "                -> % 3i -> %li ", x, ebest);
                    //for some reason printf()ing rms above doesn't work..
                    fprintf(stderr, "(%.2f)\n", rms);
#endif
                }
            }
        }

        fprintf(stderr, "pass %i: %i changes, rms = %.2f\n", pass, changes, sqrtf((float)ebest/n2));
    }

    e = table_adpcm_work(data, n, best_table, k, 1, bits);
    fprintf(stderr, "final rms: %.2f\n", sqrtf((float)e/n));

    free(table);
}

#define BOOT    26
#define HARMONY 900 //space needed for Harmony's F4 driver. would be available on a real F4 cart
#define EFFECTS 100
#define IMAGE   1536

//<= 4 KiB per page, we need space for the player
static int pagesizes[8] = {
    3920 - BOOT - HARMONY - EFFECTS - IMAGE,
    3920 - EFFECTS,
    3920 - EFFECTS,
    3920 - EFFECTS,
    3920 - EFFECTS,
    3920 - EFFECTS,
    3920 - EFFECTS,
    3919 - EFFECTS,
};

static void write_l32(FILE *f, uint32_t a) {
    putc(a, f);
    putc(a>>8, f);
    putc(a>>16, f);
    putc(a>>24, f);
}

static void write_wav() {
    fprintf(stderr, "Writing %li samples to output.wav\n", samples.size());

    FILE *wav = fopen("output.wav", "wb");
    fprintf(wav, "RIFF");
    write_l32(wav, samples.size()*2 + 36);
    fprintf(wav, "WAVEfmt ");
    write_l32(wav, 16);
    write_l32(wav, 0x00010001);
    write_l32(wav, rate);
    write_l32(wav, rate*2);
    write_l32(wav, 0x00100002);
    fprintf(wav, "data");
    write_l32(wav, samples.size()*2);
    fwrite(&samples[0], samples.size()*2, 1, wav);
    fclose(wav);
}

int main(int argc, char **argv) {
    FILE *f;
    unsigned char header[44];
    int size, samples;
    sample_t *input, *output;
    int bps;
    uint8_t pagedata[8][4096];
    int x, ofs, bytes, k, entries, iter;
    int *best_table;
    unsigned char *bits;

    if (!(f = fopen(argv[1], "rb")))
        return 1;

    fread(header, 44, 1, f);
    rate = *(int*)&header[24];
    size = *(int*)&header[40];

    /* HACK: don't change output just because pagesizes change */
    samples = 237136;

    fprintf(stderr, "length: %.2f seconds\n", (float)samples / rate);

    if (size / 2 > samples)
        samples = size / 2;

    fprintf(stderr, "rate: %i\n", rate);
    fprintf(stderr, "size: %i = %i samples = %f seconds\n", size, samples, (float)samples/rate);

    input  = (sample_t*)malloc(samples*sizeof(sample_t));
    memset(input, 0, samples*2);
    output = (sample_t*)malloc(samples*sizeof(sample_t));
    memset(output, 0, samples*2);

    fread(input, size, 1, f);

    bps = BPS;
    k = 1 << bps;
    entries = 31 * k;
    bytes = (samples + 7) / 8;
    best_table = (int*)malloc(entries*sizeof(int));
    bits = (unsigned char*)malloc(bytes);
    memset(bits, 0, bytes);

    normalize(input, samples);
    quantize(input, output, samples);

    table_adpcm(output, samples, bps, best_table, bits, entries, k);

    for (iter = ofs = 0; iter < 8; iter++) {
        int bytes = pagesizes[iter];

        printf("\tMAC TABLE%i\n", iter);
        printf("\t;%i entries\n", entries);
        printf("\t;%i bits per sample\n", bps);
        printf("ADPCMTable%i\n", iter);
        for (x = 0; x < entries; x++)
            printf("\t.byte %i\n", best_table[x]);
        printf("\tENDM\n");

        printf("\tMAC SAMPLES%i\n", iter);
        printf("SampleData%i\t;%i bytes\n", iter, bytes);

        for (x = ofs; x < ofs+bytes; x++)
            printf("\t.byte %i\n", bits[x]);

        printf("SampleEnd%i\n", iter);
        printf("\tENDM\n");

        ofs += bytes;
    }

    decode_to_samples(best_table, bits, bytes);

    fprintf(stderr, "encoded size: %.2f KiB (%i bps)\n", (samples * bps) / 8192.f, bps);

    free(input);
    free(output);
    fclose(f);
    write_wav();
    free(best_table);
    free(bits);

    return 0;
}
