package io.improbable.keanu.tensor.dbl;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:io/improbable/keanu/tensor/dbl/Nd4jDoubleTensorFactory.class */
public class Nd4jDoubleTensorFactory implements DoubleTensorFactory {
    @Override // io.improbable.keanu.tensor.dbl.DoubleTensorFactory
    public DoubleTensor create(double d, long[] jArr) {
        return Nd4jDoubleTensor.create(d, jArr);
    }

    @Override // io.improbable.keanu.tensor.dbl.DoubleTensorFactory
    public DoubleTensor create(double[] dArr, long[] jArr) {
        return Nd4jDoubleTensor.create(dArr, jArr);
    }

    @Override // io.improbable.keanu.tensor.dbl.DoubleTensorFactory
    public DoubleTensor create(double[] dArr) {
        return Nd4jDoubleTensor.create(dArr);
    }

    @Override // io.improbable.keanu.tensor.dbl.DoubleTensorFactory
    public DoubleTensor ones(long[] jArr) {
        return Nd4jDoubleTensor.ones(jArr);
    }

    @Override // io.improbable.keanu.tensor.dbl.DoubleTensorFactory
    public DoubleTensor zeros(long[] jArr) {
        return Nd4jDoubleTensor.zeros(jArr);
    }

    @Override // io.improbable.keanu.tensor.dbl.DoubleTensorFactory
    public DoubleTensor eye(long j) {
        return Nd4jDoubleTensor.eye(j);
    }

    @Override // io.improbable.keanu.tensor.dbl.DoubleTensorFactory
    public DoubleTensor linspace(double d, double d2, int i) {
        return Nd4jDoubleTensor.linspace(d, d2, i);
    }

    @Override // io.improbable.keanu.tensor.dbl.DoubleTensorFactory
    public DoubleTensor arange(double d, double d2) {
        return Nd4jDoubleTensor.arange(d, d2);
    }

    @Override // io.improbable.keanu.tensor.dbl.DoubleTensorFactory
    public DoubleTensor arange(double d, double d2, double d3) {
        return Nd4jDoubleTensor.arange(d, d2, d3);
    }

    @Override // io.improbable.keanu.tensor.dbl.DoubleTensorFactory
    public DoubleTensor scalar(double d) {
        return Nd4jDoubleTensor.scalar(d);
    }

    @Override // io.improbable.keanu.tensor.dbl.DoubleTensorFactory
    public DoubleTensor concat(int i, DoubleTensor... doubleTensorArr) {
        INDArray[] iNDArrayArr = new INDArray[doubleTensorArr.length];
        for (int i2 = 0; i2 < doubleTensorArr.length; i2++) {
            iNDArrayArr[i2] = Nd4jDoubleTensor.getAsINDArray(doubleTensorArr[i2]).dup();
            if (iNDArrayArr[i2].shape().length == 0) {
                iNDArrayArr[i2] = iNDArrayArr[i2].reshape(new long[]{1});
            }
        }
        return new Nd4jDoubleTensor(Nd4j.concat(i, iNDArrayArr));
    }
}
