import { stack, tidy, browser as tfBrowser, layers, loadLayersModel as tfLoadLayersModel, memory, tensor2d, tensor1d, sequential, train as tfTrain, losses } from '@tensorflow/tfjs';
import { log, logTime, loadImage, getObjectValue } from "./Utils";

export const optimalImageObjectDetection = async (model, images, modelName, showLogs) => {
    let predictions = await Promise.all(images.map(async (image, index) => {
        let prediction = await detectObjectsInImages(model, [image], modelName, index, showLogs);
        return prediction[0];
    }));
    // console.log(predictions);

    return predictions;
};

export const loadLayersModel = async (modelUrl) => {
    let model = await tfLoadLayersModel(modelUrl);
    return model
};

const detectObjectsInImages = async (model, images, modelName, instantiationIndex, showLogs) => {
    let humanReadableIndex = instantiationIndex + 1;
    let hideLogs = !showLogs;
    logTime(`${modelName} ${images.length} Image Batch Runtime (${humanReadableIndex})`, false, hideLogs);

    let imgs = await Promise.all(images.map(async (img) => {
        let image = await loadImage(img);
        return await loadBase64Image(image);
    }));
    
    // tf.tidy for automatic memory cleanup
    const predictions = tidy(() => {
        const tfImages = imgs.map((img, index) => {
            return tfBrowser.fromPixels(img).resizeNearestNeighbor([224, 224]).toFloat();
        });
        const tfImagesBatch = stack(tfImages);
        const normalizationLayer = new layers.rescaling({ "scale": (1. / 127.5), "offset": -1 });
        const normalizedImage = normalizationLayer.call(tfImagesBatch);
        const result = model.predict(normalizedImage);

        const pred = result.dataSync();

        let memoryUsage = memory();
        if (getObjectValue(memoryUsage, "numBytesInGPU")) {
            let megaBytesUsage = memoryUsage.numBytesInGPU / 1000000;
            let megaBytesTotal = memoryUsage.numBytesInGPUAllocated / 1000000;
            log(`${megaBytesUsage.toFixed(0)} MB / ${megaBytesTotal.toFixed(0)} MB = ${((megaBytesUsage / megaBytesTotal) * 100).toFixed(0)}% of GPU`, null, hideLogs);
        }

        return pred;
    });
    log(predictions, `${modelName} ${images.length} Image Batch Predictions (${humanReadableIndex}):`, hideLogs);
    logTime(`${modelName} ${images.length} Image Batch Runtime (${humanReadableIndex})`, true, hideLogs);
    return predictions;
};

const loadBase64Image = (url) => {
    return new Promise((resolve, reject) => {
        const im = new Image();
        im.crossOrigin = 'anonymous'
        im.src = url;
        im.onload = () => resolve(im);
        im.onerror = reject;
    });
};

export const split_sequences = (sequence, n_steps) => {
    return new Promise((resolve) => {
        let seq_x = [];
        let seq_y = [];
    
        for (let i = 0; i < sequence.length; i++) {
            let end_ix = i + n_steps;
            console.debug(` i: ${i} end_ix:${end_ix}`);
            if (end_ix > sequence.length) break;
            seq_x.push(sequence.slice(i, end_ix));
            seq_y.push(sequence[end_ix - 1]);
        }
    
        resolve({ seq_x, seq_y });
    });
};
  
export const train = async (n_steps, epochs, n_features, timeID, depVar, seq_x, seq_y, onEpochEnd) => {
    let ten_X = tensor2d(seq_x, [seq_x.length, n_steps]);
    let ten_y = tensor1d(seq_y);
  
    let model = sequential();
    model.add(
        layers.lstm({
            units: 50,
            activation: "relu",
            inputShape: [n_steps, n_features]
        })
    );
    model.add(layers.dense({ units: 1 }));
  
    model.compile({
        optimizer: tfTrain.adam(),
        loss: losses.meanSquaredError,
        metrics: ["mse"]
    });
    ten_X = ten_X.reshape([ten_X.shape[0], n_steps, n_features]);
  
    await model.fit(ten_X, ten_y, { epochs: epochs, callbacks: { onEpochEnd: onEpochEnd }   });
  
    return new Promise((resolve) => {
        resolve(model);
    });
};
  
export const predict = (timeID, n_steps, n_features, model, seq_x) => {
    return new Promise((resolve) => {
        let ten_X = tensor2d(seq_x, [seq_x.length, n_steps]);
        ten_X = ten_X.reshape([ten_X.shape[0], n_steps, n_features]);
        let predictions = [];
        let count = 0;
    
        // predictions = df
        //   .select(timeID)
        //   .slice(n_steps - 1)
        //   .toArray()
        predictions = seq_x.map(date => {
            let predicted = model.predict(ten_X.slice(count, 1));
            count++;
    
            return predicted.dataSync()[0];
        });
    
        resolve({ predictions: predictions });
    });
};