/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.diff;

import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.Arrays;

public class PartialDerivative {
    public static final PartialDerivative EMPTY = new PartialDerivative(null);
    private final DoubleTensor partial;

    public PartialDerivative(DoubleTensor partial) {
        this.partial = partial;
    }

    public boolean isPresent() {
        return this.partial != null;
    }

    public DoubleTensor get() {
        return this.partial;
    }

    public long[] getOfShape(long[] wrtShape) {
        return Arrays.copyOfRange(this.partial.getShape(), 0, this.partial.getRank() - wrtShape.length);
    }

    public long[] getWrtShape(long[] ofShape) {
        return Arrays.copyOfRange(this.partial.getShape(), ofShape.length, this.partial.getRank());
    }

    public PartialDerivative add(PartialDerivative addition) {
        if (this.isPresent() && addition.isPresent()) {
            return new PartialDerivative(this.partial.plus(addition.partial));
        }
        if (this.isPresent() && !addition.isPresent()) {
            return new PartialDerivative(this.get());
        }
        if (!this.isPresent() && addition.isPresent()) {
            return new PartialDerivative(addition.partial);
        }
        return EMPTY;
    }

    public PartialDerivative subtract(PartialDerivative subtraction) {
        if (this.isPresent() && subtraction.isPresent()) {
            return new PartialDerivative(this.partial.minus(subtraction.partial));
        }
        if (this.isPresent() && !subtraction.isPresent()) {
            return new PartialDerivative(this.get());
        }
        if (!this.isPresent() && subtraction.isPresent()) {
            return new PartialDerivative((DoubleTensor)subtraction.partial.unaryMinus());
        }
        return EMPTY;
    }

    public PartialDerivative multiplyBy(double multiplier) {
        if (!this.isPresent()) {
            return this;
        }
        return new PartialDerivative(this.partial.times(multiplier));
    }

    public PartialDerivative multiplyAlongOfDimensions(DoubleTensor multiplier) {
        return this.multiplyAlongOfDimensions(multiplier, multiplier.getRank());
    }

    public PartialDerivative multiplyAlongOfDimensions(DoubleTensor multiplier, int partialOfRank) {
        if (!this.isPresent()) {
            return this;
        }
        DoubleTensor multiplierAlignedAlongOf = PartialDerivative.alignAlongOf(multiplier, this.partial.getShape(), partialOfRank);
        DoubleTensor result = this.partial.times(multiplierAlignedAlongOf);
        return new PartialDerivative(result);
    }

    public PartialDerivative divideByAlongOfDimensions(DoubleTensor divisor) {
        return this.divideByAlongOfDimensions(divisor, divisor.getRank());
    }

    public PartialDerivative divideByAlongOfDimensions(DoubleTensor divisor, int partialOfRank) {
        if (!this.isPresent()) {
            return this;
        }
        DoubleTensor divisorAlignedAlongOf = PartialDerivative.alignAlongOf(divisor, this.partial.getShape(), partialOfRank);
        DoubleTensor result = this.partial.div(divisorAlignedAlongOf);
        return new PartialDerivative(result);
    }

    public PartialDerivative multiplyAlongWrtDimensions(DoubleTensor multiplier) {
        if (!this.isPresent()) {
            return this;
        }
        DoubleTensor multiplierAlignedAlongWrt = PartialDerivative.alignAlongWrt(multiplier, this.partial.getRank());
        DoubleTensor result = this.partial.times(multiplierAlignedAlongWrt);
        return new PartialDerivative(result);
    }

    public static PartialDerivative matrixMultiplyAlongOfDimensions(PartialDerivative partial, DoubleTensor multiplier, boolean partialIsLeft) {
        DoubleTensor result;
        if (!partial.isPresent()) {
            return partial;
        }
        DoubleTensor partialValue = partial.get();
        int partialRank = partialValue.getRank();
        if (partialIsLeft) {
            int[] rearrange = TensorShape.dimensionRange(-1, partialRank - 1);
            rearrange[0] = 0;
            rearrange[1] = partialRank - 1;
            result = (DoubleTensor)partialValue.tensorMultiply(multiplier, new int[]{1}, new int[]{0}).permute(rearrange);
        } else {
            result = multiplier.tensorMultiply(partialValue, new int[]{1}, new int[]{0});
        }
        return new PartialDerivative(result);
    }

    public static PartialDerivative matrixMultiplyAlongWrtDimensions(PartialDerivative partial, DoubleTensor multiplier, boolean partialIsLeft) {
        DoubleTensor result;
        if (!partial.isPresent()) {
            return partial;
        }
        DoubleTensor partialValue = partial.get();
        int partialRank = partialValue.getRank();
        int wrtRightDimension = partialRank - 1;
        if (partialIsLeft) {
            result = partialValue.tensorMultiply(multiplier, new int[]{wrtRightDimension}, new int[]{1});
        } else {
            int wrtLeftDimension = partialRank - 2;
            int[] transposeWrt = TensorShape.dimensionRange(0, partialRank);
            transposeWrt[wrtRightDimension] = wrtLeftDimension;
            transposeWrt[wrtLeftDimension] = wrtRightDimension;
            result = (DoubleTensor)partialValue.tensorMultiply(multiplier, new int[]{wrtLeftDimension}, new int[]{0}).permute(transposeWrt);
        }
        return new PartialDerivative(result);
    }

    private static DoubleTensor alignAlongOf(DoubleTensor tensor, long[] partialShape, int partialOfRank) {
        long[] alongOfShape = new long[partialShape.length];
        Arrays.fill(alongOfShape, 1L);
        int tensorRank = tensor.getRank();
        System.arraycopy(tensor.getShape(), 0, alongOfShape, partialOfRank - tensorRank, tensorRank);
        return (DoubleTensor)tensor.reshape(alongOfShape);
    }

    private static DoubleTensor alignAlongWrt(DoubleTensor tensor, int partialRank) {
        long[] alongWrtShape = TensorShape.shapeToDesiredRankByPrependingOnes(tensor.getShape(), partialRank);
        return (DoubleTensor)tensor.reshape(alongWrtShape);
    }
}

