/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor;

import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorParser;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Argmax;
import com.yahoo.tensor.functions.Argmin;
import com.yahoo.tensor.functions.Concat;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Diag;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.L1Normalize;
import com.yahoo.tensor.functions.L2Normalize;
import com.yahoo.tensor.functions.Map;
import com.yahoo.tensor.functions.Matmul;
import com.yahoo.tensor.functions.Random;
import com.yahoo.tensor.functions.Range;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.Softmax;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.XwPlusB;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;

public interface Tensor {
    public TensorType type();

    default public boolean isEmpty() {
        return this.size() == 0L;
    }

    public long size();

    public double get(TensorAddress var1);

    public Iterator<Cell> cellIterator();

    public Iterator<Double> valueIterator();

    public java.util.Map<TensorAddress, Double> cells();

    default public double asDouble() {
        if (this.type().dimensions().size() > 0) {
            throw new IllegalStateException("This tensor is not dimensionless. Dimensions: " + this.type().dimensions().size());
        }
        if (this.size() == 0L) {
            return Double.NaN;
        }
        return this.valueIterator().next();
    }

    public Tensor withType(TensorType var1);

    default public Tensor modify(DoubleBinaryOperator op, java.util.Map<TensorAddress, Double> cells) {
        Builder builder = Builder.of(this.type());
        Iterator<Cell> i = this.cellIterator();
        while (i.hasNext()) {
            Cell cell = i.next();
            TensorAddress address = cell.getKey();
            double value = cell.getValue();
            builder.cell(address, cells.containsKey(address) ? op.applyAsDouble(value, cells.get(address)) : value);
        }
        return builder.build();
    }

    default public Tensor map(DoubleUnaryOperator mapper) {
        return new Map(new ConstantTensor(this), mapper).evaluate();
    }

    default public Tensor reduce(Reduce.Aggregator aggregator, String ... dimensions) {
        return new Reduce((TensorFunction)new ConstantTensor(this), aggregator, Arrays.asList(dimensions)).evaluate();
    }

    default public Tensor reduce(Reduce.Aggregator aggregator, List<String> dimensions) {
        return new Reduce((TensorFunction)new ConstantTensor(this), aggregator, dimensions).evaluate();
    }

    default public Tensor join(Tensor argument, DoubleBinaryOperator combinator) {
        return new Join(new ConstantTensor(this), new ConstantTensor(argument), combinator).evaluate();
    }

    default public Tensor rename(String fromDimension, String toDimension) {
        return new Rename((TensorFunction)new ConstantTensor(this), Collections.singletonList(fromDimension), Collections.singletonList(toDimension)).evaluate();
    }

    default public Tensor concat(double argument, String dimension) {
        return this.concat(Builder.of(TensorType.empty).cell(argument, new long[0]).build(), dimension);
    }

    default public Tensor concat(Tensor argument, String dimension) {
        return new Concat(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate();
    }

    default public Tensor rename(List<String> fromDimensions, List<String> toDimensions) {
        return new Rename((TensorFunction)new ConstantTensor(this), fromDimensions, toDimensions).evaluate();
    }

    public static Tensor generate(TensorType type, Function<List<Long>, Double> valueSupplier) {
        return new Generate(type, valueSupplier).evaluate();
    }

    default public Tensor l1Normalize(String dimension) {
        return new L1Normalize(new ConstantTensor(this), dimension).evaluate();
    }

    default public Tensor l2Normalize(String dimension) {
        return new L2Normalize(new ConstantTensor(this), dimension).evaluate();
    }

    default public Tensor matmul(Tensor argument, String dimension) {
        return new Matmul(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate();
    }

    default public Tensor softmax(String dimension) {
        return new Softmax(new ConstantTensor(this), dimension).evaluate();
    }

    default public Tensor xwPlusB(Tensor w, Tensor b, String dimension) {
        return new XwPlusB(new ConstantTensor(this), new ConstantTensor(w), new ConstantTensor(b), dimension).evaluate();
    }

    default public Tensor argmax(String dimension) {
        return new Argmax(new ConstantTensor(this), dimension).evaluate();
    }

    default public Tensor argmin(String dimension) {
        return new Argmin(new ConstantTensor(this), dimension).evaluate();
    }

    public static Tensor diag(TensorType type) {
        return new Diag(type).evaluate();
    }

    public static Tensor random(TensorType type) {
        return new Random(type).evaluate();
    }

    public static Tensor range(TensorType type) {
        return new Range(type).evaluate();
    }

    default public Tensor multiply(Tensor argument) {
        return this.join(argument, (a, b) -> a * b);
    }

    default public Tensor add(Tensor argument) {
        return this.join(argument, (a, b) -> a + b);
    }

    default public Tensor divide(Tensor argument) {
        return this.join(argument, (a, b) -> a / b);
    }

    default public Tensor subtract(Tensor argument) {
        return this.join(argument, (a, b) -> a - b);
    }

    default public Tensor max(Tensor argument) {
        return this.join(argument, (a, b) -> a > b ? a : b);
    }

    default public Tensor min(Tensor argument) {
        return this.join(argument, (a, b) -> a < b ? a : b);
    }

    default public Tensor atan2(Tensor argument) {
        return this.join(argument, Math::atan2);
    }

    default public Tensor pow(Tensor argument) {
        return this.join(argument, Math::pow);
    }

    default public Tensor fmod(Tensor argument) {
        return this.join(argument, (a, b) -> a % b);
    }

    default public Tensor ldexp(Tensor argument) {
        return this.join(argument, (a, b) -> a * Math.pow(2.0, (int)b));
    }

    default public Tensor larger(Tensor argument) {
        return this.join(argument, (a, b) -> a > b ? 1.0 : 0.0);
    }

    default public Tensor largerOrEqual(Tensor argument) {
        return this.join(argument, (a, b) -> a >= b ? 1.0 : 0.0);
    }

    default public Tensor smaller(Tensor argument) {
        return this.join(argument, (a, b) -> a < b ? 1.0 : 0.0);
    }

    default public Tensor smallerOrEqual(Tensor argument) {
        return this.join(argument, (a, b) -> a <= b ? 1.0 : 0.0);
    }

    default public Tensor equal(Tensor argument) {
        return this.join(argument, (a, b) -> a == b ? 1.0 : 0.0);
    }

    default public Tensor notEqual(Tensor argument) {
        return this.join(argument, (a, b) -> a != b ? 1.0 : 0.0);
    }

    default public Tensor approxEqual(Tensor argument) {
        return this.join(argument, (a, b) -> Tensor.approxEquals(a, b) ? 1.0 : 0.0);
    }

    default public Tensor avg() {
        return this.avg(Collections.emptyList());
    }

    default public Tensor avg(String dimension) {
        return this.avg(Collections.singletonList(dimension));
    }

    default public Tensor avg(List<String> dimensions) {
        return this.reduce(Reduce.Aggregator.avg, dimensions);
    }

    default public Tensor count() {
        return this.count(Collections.emptyList());
    }

    default public Tensor count(String dimension) {
        return this.count(Collections.singletonList(dimension));
    }

    default public Tensor count(List<String> dimensions) {
        return this.reduce(Reduce.Aggregator.count, dimensions);
    }

    default public Tensor max() {
        return this.max(Collections.emptyList());
    }

    default public Tensor max(String dimension) {
        return this.max(Collections.singletonList(dimension));
    }

    default public Tensor max(List<String> dimensions) {
        return this.reduce(Reduce.Aggregator.max, dimensions);
    }

    default public Tensor min() {
        return this.min(Collections.emptyList());
    }

    default public Tensor min(String dimension) {
        return this.min(Collections.singletonList(dimension));
    }

    default public Tensor min(List<String> dimensions) {
        return this.reduce(Reduce.Aggregator.min, dimensions);
    }

    default public Tensor prod() {
        return this.prod(Collections.emptyList());
    }

    default public Tensor prod(String dimension) {
        return this.prod(Collections.singletonList(dimension));
    }

    default public Tensor prod(List<String> dimensions) {
        return this.reduce(Reduce.Aggregator.prod, dimensions);
    }

    default public Tensor sum() {
        return this.sum(Collections.emptyList());
    }

    default public Tensor sum(String dimension) {
        return this.sum(Collections.singletonList(dimension));
    }

    default public Tensor sum(List<String> dimensions) {
        return this.reduce(Reduce.Aggregator.sum, dimensions);
    }

    public String toString();

    public static String toStandardString(Tensor tensor) {
        return tensor.type() + ":" + Tensor.contentToString(tensor);
    }

    public static String contentToString(Tensor tensor) {
        ArrayList<Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<Map.Entry<TensorAddress, Double>>(tensor.cells().entrySet());
        if (tensor.type().dimensions().isEmpty()) {
            if (cellEntries.isEmpty()) {
                return "{}";
            }
            return "{" + ((Map.Entry)cellEntries.get(0)).getValue() + "}";
        }
        Collections.sort(cellEntries, Map.Entry.comparingByKey());
        StringBuilder b = new StringBuilder("{");
        for (Map.Entry entry : cellEntries) {
            b.append(((TensorAddress)entry.getKey()).toString(tensor.type())).append(":").append(entry.getValue());
            b.append(",");
        }
        if (b.length() > 1) {
            b.setLength(b.length() - 1);
        }
        b.append("}");
        return b.toString();
    }

    public boolean equals(Object var1);

    public static boolean equals(Tensor a, Tensor b) {
        if (a == b) {
            return true;
        }
        if (!a.type().mathematicallyEquals(b.type())) {
            return false;
        }
        if (a.size() != b.size()) {
            return false;
        }
        Iterator<Cell> aIterator = a.cellIterator();
        while (aIterator.hasNext()) {
            double bValue;
            Cell aCell = aIterator.next();
            double aValue = aCell.getValue();
            if (Tensor.approxEquals(aValue, bValue = b.get(aCell.getKey()), 1.0E-5)) continue;
            return false;
        }
        return true;
    }

    public static boolean approxEquals(double x, double y, double tolerance) {
        return Math.abs(x - y) < tolerance;
    }

    public static boolean approxEquals(double x, double y) {
        if (y < -1.0 || y > 1.0) {
            x = Math.nextAfter(x / y, 1.0);
            y = 1.0;
        } else {
            x = Math.nextAfter(x, y);
        }
        return x == y;
    }

    public static Tensor from(TensorType type, String tensorString) {
        return TensorParser.tensorFrom(tensorString, Optional.of(type));
    }

    public static Tensor from(String tensorType, String tensorString) {
        return TensorParser.tensorFrom(tensorString, Optional.of(TensorType.fromSpec(tensorType)));
    }

    public static Tensor from(String tensorString) {
        return TensorParser.tensorFrom(tensorString, Optional.empty());
    }

    public static interface Builder {
        public static Builder of(TensorType type) {
            boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed());
            boolean containsMapped = type.dimensions().stream().anyMatch(d -> !d.isIndexed());
            if (containsIndexed && containsMapped) {
                return MixedTensor.Builder.of(type);
            }
            if (containsMapped) {
                return MappedTensor.Builder.of(type);
            }
            return IndexedTensor.Builder.of(type);
        }

        public static Builder of(TensorType type, DimensionSizes dimensionSizes) {
            boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed());
            boolean containsMapped = type.dimensions().stream().anyMatch(d -> !d.isIndexed());
            if (containsIndexed && containsMapped) {
                return MixedTensor.Builder.of(type);
            }
            if (containsMapped) {
                return MappedTensor.Builder.of(type);
            }
            return IndexedTensor.Builder.of(type, dimensionSizes);
        }

        public TensorType type();

        public CellBuilder cell();

        public Builder cell(TensorAddress var1, double var2);

        public Builder cell(double var1, long ... var3);

        default public Builder cell(Cell cell, double value) {
            return this.cell(cell.getKey(), value);
        }

        public Tensor build();

        public static class CellBuilder {
            private final TensorAddress.Builder addressBuilder;
            private final Builder tensorBuilder;

            CellBuilder(TensorType type, Builder tensorBuilder) {
                this.addressBuilder = new TensorAddress.Builder(type);
                this.tensorBuilder = tensorBuilder;
            }

            public CellBuilder label(String dimension, String label) {
                this.addressBuilder.add(dimension, label);
                return this;
            }

            public CellBuilder label(String dimension, long label) {
                return this.label(dimension, String.valueOf(label));
            }

            public Builder value(double cellValue) {
                return this.tensorBuilder.cell(this.addressBuilder.build(), cellValue);
            }
        }
    }

    public static class Cell
    implements Map.Entry<TensorAddress, Double> {
        private final TensorAddress address;
        private final Double value;

        Cell(TensorAddress address, Double value) {
            this.address = address;
            this.value = value;
        }

        @Override
        public TensorAddress getKey() {
            return this.address;
        }

        long getDirectIndex() {
            return -1L;
        }

        @Override
        public Double getValue() {
            return this.value;
        }

        @Override
        public Double setValue(Double value) {
            throw new UnsupportedOperationException("A tensor cannot be modified");
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Map.Entry)) {
                return false;
            }
            Map.Entry other = (Map.Entry)o;
            if (!this.getValue().equals(other.getValue())) {
                return false;
            }
            return this.getKey().equals(other.getKey());
        }

        @Override
        public int hashCode() {
            return this.getKey().hashCode() ^ this.getValue().hashCode();
        }
    }
}

