/*
 * Decompiled with CFR 0.152.
 */
package hex.deeplearning;

import hex.ContributionsWithBackgroundFrameTask;
import hex.DataInfo;
import hex.deeplearning.DeepLearningModel;
import hex.deeplearning.Storage;
import java.util.Arrays;
import water.H2O;
import water.Key;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.ArrayUtils;
import water.util.fp.Function;

class DeepSHAPContributionsWithBackground
extends ContributionsWithBackgroundFrameTask<DeepSHAPContributionsWithBackground> {
    private final DeepLearningModel deepLearningModel;
    transient Function<Double, Double> _activation;
    transient Function<Double, Double> _activationDiff;
    final int[] _origIndices;
    int _hiddenLayerMultiplier;
    final boolean _outputSpace;

    public DeepSHAPContributionsWithBackground(DeepLearningModel deepLearningModel, Key<Frame> frKey, Key<Frame> backgroundFrameKey, boolean perReference, int[] origIndices, boolean outputSpace) {
        super(frKey, backgroundFrameKey, perReference);
        this.deepLearningModel = deepLearningModel;
        this._origIndices = origIndices;
        this._outputSpace = outputSpace;
    }

    @Override
    protected void setupLocal() {
        super.setupLocal();
        switch (((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._activation) {
            case Tanh: 
            case TanhWithDropout: {
                this._activation = this::tanhActivation;
                this._activationDiff = this::tanhActivationDiff;
                this._hiddenLayerMultiplier = 1;
                break;
            }
            case Rectifier: 
            case RectifierWithDropout: {
                this._activation = this::rectifierActivation;
                this._activationDiff = this::rectifierActivationDiff;
                this._hiddenLayerMultiplier = 1;
                break;
            }
            case Maxout: 
            case MaxoutWithDropout: {
                this._activation = this::identity;
                this._activationDiff = this::identity;
                this._hiddenLayerMultiplier = 2;
                break;
            }
            default: {
                H2O.unimpl("Activation " + (Object)((Object)((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._activation) + " is not supported in DeepSHAP.");
            }
        }
    }

    protected double identity(double v) {
        return v;
    }

    protected double tanhActivation(double v) {
        return 1.0 - 2.0 / (1.0 + Math.exp(2.0 * v));
    }

    protected double tanhActivationDiff(double v) {
        return 1.0 - Math.pow(1.0 - 2.0 / (1.0 + Math.exp(2.0 * v)), 2.0);
    }

    protected double rectifierActivation(double v) {
        return 0.5 * (v + Math.abs(v));
    }

    protected double rectifierActivationDiff(double v) {
        return v > 0.0 ? 1.0 : 0.0;
    }

    protected double div(double a, double b) {
        if (Math.abs(b) < 1.0E-10) {
            return 0.0;
        }
        return a / b;
    }

    protected double linearPred(Storage.DenseRowMatrix weights, Storage.DenseVector bias, double[] input, int index, boolean outputLayer) {
        double tmp = bias.get(index);
        if (outputLayer) {
            for (int i = 0; i < input.length; ++i) {
                tmp += (double)weights.get(index, i) * input[i];
            }
        } else {
            for (int i = 0; i < input.length; ++i) {
                tmp += (double)this.getWeight(weights, index, i) * input[i];
            }
        }
        return tmp;
    }

    protected void softMax(double[] x) {
        int i;
        double max = ArrayUtils.maxValue(x);
        double scaling = 0.0;
        for (i = 0; i < x.length; ++i) {
            x[i] = Math.exp(x[i] - max);
            scaling += x[i];
        }
        i = 0;
        while (i < x.length) {
            int n = i++;
            x[n] = x[n] / scaling;
        }
    }

    protected float getWeight(Storage.DenseRowMatrix w, int row, int col) {
        if (this._hiddenLayerMultiplier != 1) {
            assert (this._hiddenLayerMultiplier == 2);
            return w.raw()[2 * (row / 2 * w.cols() + col) + row % 2];
        }
        return w.get(row, col);
    }

    protected void forwardPass(DataInfo.Row row, double[][] forwardPassActivations) {
        int l;
        int i;
        int l2;
        for (int i2 = 0; i2 < forwardPassActivations.length; ++i2) {
            Arrays.fill(forwardPassActivations[i2], 0.0);
        }
        Storage.DenseRowMatrix w = this.deepLearningModel.model_info().get_weights(0);
        Storage.DenseVector b = this.deepLearningModel.model_info().get_biases(0);
        for (l2 = 0; l2 < w.rows(); ++l2) {
            for (int m = 0; m < w.cols(); ++m) {
                double[] dArray = forwardPassActivations[0];
                int n = l2;
                dArray[n] = dArray[n] + row.get(m) * (double)this.getWeight(w, l2, m);
            }
            double[] dArray = forwardPassActivations[0];
            int n = l2;
            dArray[n] = dArray[n] + b.get(l2);
        }
        for (l2 = 0; l2 < forwardPassActivations[1].length; ++l2) {
            forwardPassActivations[1][l2] = this._hiddenLayerMultiplier == 1 ? this._activation.apply(forwardPassActivations[0][l2]) : Math.max(forwardPassActivations[0][2 * l2], forwardPassActivations[0][2 * l2 + 1]);
            if (null == this.deepLearningModel.model_info().get_params()._hidden_dropout_ratios) continue;
            double[] dArray = forwardPassActivations[1];
            int n = l2;
            dArray[n] = dArray[n] * (1.0 - this.deepLearningModel.model_info().get_params()._hidden_dropout_ratios[0]);
        }
        for (i = 1; i < ((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden.length; ++i) {
            w = this.deepLearningModel.model_info().get_weights(i);
            b = this.deepLearningModel.model_info().get_biases(i);
            for (l = 0; l < w.rows(); ++l) {
                forwardPassActivations[2 * i][l] = this.linearPred(w, b, forwardPassActivations[2 * i - 1], l, false);
            }
            for (l = 0; l < forwardPassActivations[2 * i + 1].length; ++l) {
                forwardPassActivations[2 * i + 1][l] = this._hiddenLayerMultiplier == 1 ? this._activation.apply(forwardPassActivations[2 * i][l]) : Math.max(forwardPassActivations[2 * i][2 * l], forwardPassActivations[2 * i][2 * l + 1]);
                if (null == this.deepLearningModel.model_info().get_params()._hidden_dropout_ratios) continue;
                double[] dArray = forwardPassActivations[2 * i + 1];
                int n = l;
                dArray[n] = dArray[n] * (1.0 - this.deepLearningModel.model_info().get_params()._hidden_dropout_ratios[i]);
            }
        }
        i = ((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden.length;
        w = this.deepLearningModel.model_info().get_weights(i);
        b = this.deepLearningModel.model_info().get_biases(i);
        for (l = 0; l < w.rows(); ++l) {
            forwardPassActivations[2 * i][l] = this.linearPred(w, b, forwardPassActivations[2 * i - 1], l, true);
            forwardPassActivations[2 * i + 1][l] = forwardPassActivations[2 * i][l];
            if (w.rows() != 1) continue;
            if (this.deepLearningModel.model_info().data_info()._normRespMul != null) {
                forwardPassActivations[2 * i + 1][l] = forwardPassActivations[2 * i + 1][l] / this.deepLearningModel.model_info().data_info()._normRespMul[0] + this.deepLearningModel.model_info().data_info()._normRespSub[0];
            }
            forwardPassActivations[2 * i + 1][l] = this.deepLearningModel._dist.linkInv(forwardPassActivations[2 * i + 1][l]);
        }
        if (w.rows() == 2) {
            this.softMax(forwardPassActivations[2 * i + 1]);
        }
    }

    protected void maxSHAP(double[] x, double[] bg, float[] contributions, int i, int j) {
        double maxBB = Math.max(bg[i], bg[j]);
        double maxBX = Math.max(bg[i], x[j]);
        double maxXB = Math.max(x[i], bg[j]);
        double maxXX = Math.max(x[i], x[j]);
        double maxXXmBB = maxXX - maxBB;
        double maxXBmBX = maxXB - maxBX;
        contributions[0] = (float)(0.5 * (maxXXmBB + maxXBmBX));
        contributions[1] = (float)(0.5 * (maxXXmBB - maxXBmBX));
    }

    protected void linearSHAP(Storage.DenseRowMatrix weights, double[] contributions, int index) {
        for (int i = 0; i < contributions.length; ++i) {
            contributions[i] = weights.get(index, i);
        }
    }

    protected void nonLinearActivationSHAP(Storage.DenseRowMatrix weights, double[][] forwardPass, double[][] forwardBgPass, int currLayer, Storage.DenseRowMatrix contributions) {
        if (this._hiddenLayerMultiplier > 1 && forwardPass.length > 2 * currLayer + 2) {
            double dropoutRatio = null == this.deepLearningModel.model_info().get_params()._hidden_dropout_ratios ? 1.0 : 1.0 - this.deepLearningModel.model_info().get_params()._hidden_dropout_ratios[currLayer];
            for (int row = 0; row < contributions.rows(); ++row) {
                float[] deltaIn = new float[]{(float)(forwardPass[2 * currLayer][2 * row] - forwardBgPass[2 * currLayer][2 * row]), (float)(forwardPass[2 * currLayer][2 * row + 1] - forwardBgPass[2 * currLayer][2 * row + 1])};
                float[] maxOutContr = new float[2];
                this.maxSHAP(forwardPass[2 * currLayer], forwardBgPass[2 * currLayer], maxOutContr, 2 * row, 2 * row + 1);
                for (int col = 0; col < contributions.cols(); ++col) {
                    contributions.set(row, col, (float)(dropoutRatio * (this.div(this.getWeight(weights, 2 * row, col) * maxOutContr[0], deltaIn[0]) + this.div(this.getWeight(weights, 2 * row + 1, col) * maxOutContr[1], deltaIn[1]))));
                }
            }
        } else {
            for (int row = 0; row < contributions.rows(); ++row) {
                double deltaOut = forwardPass[2 * currLayer + 1][row] - forwardBgPass[2 * currLayer + 1][row];
                double deltaIn = forwardPass[2 * currLayer][row] - forwardBgPass[2 * currLayer][row];
                float ratio = (float)(Math.abs(deltaIn) > 1.0E-6 ? this.div(deltaOut, deltaIn) : this._activationDiff.apply(forwardPass[2 * currLayer][row]).doubleValue());
                for (int col = 0; col < contributions.cols(); ++col) {
                    contributions.set(row, col, weights.get(row, col) * ratio);
                }
            }
        }
    }

    protected void combineMultiplicators(Storage.DenseRowMatrix m, double[][] contributions, int currentLayer) {
        int prevLayer = currentLayer + 1;
        Arrays.fill(contributions[currentLayer], 0.0);
        for (int i = 0; i < m.rows(); ++i) {
            for (int j = 0; j < m.cols(); ++j) {
                double[] dArray = contributions[currentLayer];
                int n = j;
                dArray[n] = dArray[n] + (double)m.get(i, j) * contributions[prevLayer][i];
            }
        }
    }

    protected void backwardPass(double[][] forwardPass, double[][] forwardBgPass, double[][] backwardPass, DataInfo.Row row, DataInfo.Row bgRow) {
        Storage.DenseRowMatrix m;
        int j;
        int i;
        for (i = 0; i < backwardPass.length; ++i) {
            Arrays.fill(backwardPass[i], 0.0);
        }
        i = backwardPass.length - 1;
        int backwardPassOffset = this._origIndices == null ? 0 : 1;
        int outputNeuron = this.deepLearningModel.model_info().get_weights(backwardPass.length - 1 - backwardPassOffset).rows() - 1;
        if (outputNeuron == 0) {
            float[] outWeight = new float[backwardPass[i].length];
            for (j = 0; j < outWeight.length; ++j) {
                outWeight[j] = this.deepLearningModel.model_info().data_info._normRespMul != null ? (float)((double)this.deepLearningModel.model_info().get_weights(i - backwardPassOffset).get(outputNeuron, j) / this.deepLearningModel.model_info().data_info._normRespMul[outputNeuron]) : this.deepLearningModel.model_info().get_weights(i - backwardPassOffset).get(outputNeuron, j);
            }
            this.linearSHAP(new Storage.DenseRowMatrix(outWeight, 1, backwardPass[i].length), backwardPass[i], 0);
        } else {
            m = new Storage.DenseRowMatrix(2, backwardPass[i].length);
            this.nonLinearActivationSHAP(this.deepLearningModel.model_info().get_weights(i - backwardPassOffset), forwardPass, forwardBgPass, i - backwardPassOffset, m);
            for (j = 0; j < m.cols(); ++j) {
                backwardPass[i][j] = m.get(outputNeuron, j);
            }
        }
        for (i = backwardPass.length - 2; i >= backwardPassOffset; --i) {
            m = new Storage.DenseRowMatrix(backwardPass[i + 1].length, backwardPass[i].length);
            this.nonLinearActivationSHAP(this.deepLearningModel.model_info().get_weights(i - backwardPassOffset), forwardPass, forwardBgPass, i - backwardPassOffset, m);
            this.combineMultiplicators(m, backwardPass, i);
        }
        if (null != this._origIndices) {
            Arrays.fill(backwardPass[0], 0.0);
            for (i = 0; i < this._origIndices.length; ++i) {
                double[] dArray = backwardPass[0];
                int n = this._origIndices[i];
                dArray[n] = dArray[n] + backwardPass[1][i] * (row.get(i) - bgRow.get(i));
            }
        } else {
            for (i = 0; i < backwardPass[0].length; ++i) {
                double[] dArray = backwardPass[0];
                int n = i;
                dArray[n] = dArray[n] * (row.get(i) - bgRow.get(i));
            }
        }
    }

    @Override
    protected void map(Chunk[] cs, Chunk[] bgCs, NewChunk[] ncs) {
        double[][] forwardPass = new double[2 * (((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden.length + 1)][];
        double[][] forwardBgPass = new double[2 * (((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden.length + 1)][];
        int backwardPassOffset = this._origIndices == null ? 1 : 2;
        double[][] backwardPass = new double[((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden.length + backwardPassOffset][];
        backwardPass[0] = MemoryManager.malloc8d(ncs.length - 1);
        if (backwardPassOffset > 1) {
            backwardPass[1] = MemoryManager.malloc8d(this.deepLearningModel.model_info().get_weights(0).cols());
        }
        for (int i = 0; i < ((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden.length; ++i) {
            forwardPass[2 * i] = MemoryManager.malloc8d(this._hiddenLayerMultiplier * ((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden[i]);
            forwardBgPass[2 * i] = MemoryManager.malloc8d(this._hiddenLayerMultiplier * ((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden[i]);
            forwardPass[2 * i + 1] = MemoryManager.malloc8d(((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden[i]);
            forwardBgPass[2 * i + 1] = MemoryManager.malloc8d(((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden[i]);
            backwardPass[i + backwardPassOffset] = MemoryManager.malloc8d(((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden[i]);
        }
        forwardPass[2 * ((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden.length] = new double[this.deepLearningModel.model_info().get_weights(((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden.length).rows()];
        forwardBgPass[2 * ((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden.length] = new double[this.deepLearningModel.model_info().get_weights(((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden.length).rows()];
        forwardPass[2 * ((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden.length + 1] = new double[this.deepLearningModel.model_info().get_weights(((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden.length).rows()];
        forwardBgPass[2 * ((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden.length + 1] = new double[this.deepLearningModel.model_info().get_weights(((DeepLearningModel.DeepLearningParameters)this.deepLearningModel._parms)._hidden.length).rows()];
        DataInfo.Row row = this.deepLearningModel.model_info().data_info.newDenseRow();
        DataInfo.Row bgRow = this.deepLearningModel.model_info().data_info.newDenseRow();
        for (int j = 0; j < cs[0]._len; ++j) {
            this.deepLearningModel.model_info().data_info.extractDenseRow(cs, j, row);
            this.forwardPass(row, forwardPass);
            for (int k = 0; k < bgCs[0]._len; ++k) {
                this.deepLearningModel.model_info().data_info.extractDenseRow(bgCs, k, bgRow);
                this.forwardPass(bgRow, forwardBgPass);
                ncs[ncs.length - 1].addNum(forwardBgPass[forwardBgPass.length - 1][forwardBgPass[forwardBgPass.length - 1].length - 1]);
                this.backwardPass(forwardPass, forwardBgPass, backwardPass, row, bgRow);
                double multiplier = this._outputSpace && forwardPass[forwardPass.length - 1].length == 1 ? this.div(forwardPass[forwardPass.length - 1][0] - forwardBgPass[forwardBgPass.length - 1][0], Arrays.stream(backwardPass[0]).sum()) : 1.0;
                for (int i = 0; i < backwardPass[0].length; ++i) {
                    ncs[i].addNum(multiplier * backwardPass[0][i]);
                }
            }
        }
    }
}

