/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;

public class Generate<NAMETYPE extends Name>
extends PrimitiveTensorFunction<NAMETYPE> {
    private final TensorType type;
    private final Function<List<Long>, Double> freeGenerator;
    private final ScalarFunction<NAMETYPE> boundGenerator;

    public Generate(TensorType type, Function<List<Long>, Double> generator) {
        this(type, Objects.requireNonNull(generator), null);
    }

    public static <NAMETYPE extends Name> Generate<NAMETYPE> free(TensorType type, Function<List<Long>, Double> generator) {
        return new Generate<NAMETYPE>(type, Objects.requireNonNull(generator), null);
    }

    public static <NAMETYPE extends Name> Generate<NAMETYPE> bound(TensorType type, ScalarFunction<NAMETYPE> generator) {
        return new Generate<NAMETYPE>(type, null, Objects.requireNonNull(generator));
    }

    private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction<NAMETYPE> boundGenerator) {
        Objects.requireNonNull(type, "The argument tensor type cannot be null");
        this.validateType(type);
        this.type = type;
        this.freeGenerator = freeGenerator;
        this.boundGenerator = boundGenerator;
    }

    private void validateType(TensorType type) {
        for (TensorType.Dimension dimension : type.dimensions()) {
            if (dimension.type() == TensorType.Dimension.Type.indexedBound) continue;
            throw new IllegalArgumentException("A generated tensor can only have indexed bound dimensions");
        }
    }

    @Override
    public List<TensorFunction<NAMETYPE>> arguments() {
        return this.boundGenerator != null && this.boundGenerator.asTensorFunction().isPresent() ? List.of(this.boundGenerator.asTensorFunction().get()) : List.of();
    }

    @Override
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
        if (arguments.size() > 1) {
            throw new IllegalArgumentException("Generate must have 0 or 1 arguments, got " + arguments.size());
        }
        if (arguments.isEmpty()) {
            return this;
        }
        if (arguments.get(0).asScalarFunction().isEmpty()) {
            throw new IllegalArgumentException("The argument to generate must be convertible to a tensor function, but got " + arguments.get(0));
        }
        return new Generate<NAMETYPE>(this.type, null, arguments.get(0).asScalarFunction().get());
    }

    @Override
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return this;
    }

    @Override
    public TensorType type(TypeContext<NAMETYPE> context) {
        return this.type;
    }

    @Override
    public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        Tensor.Builder builder = Tensor.Builder.of(this.type);
        IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(this.dimensionSizes(this.type));
        GenerateEvaluationContext generateContext = new GenerateEvaluationContext(this.type, context);
        int i = 0;
        while ((long)i < indexes.size()) {
            indexes.next();
            builder.cell(generateContext.apply(indexes), indexes.indexesForReading());
            ++i;
        }
        return builder.build();
    }

    private DimensionSizes dimensionSizes(TensorType type) {
        DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
        for (int i = 0; i < b.dimensions(); ++i) {
            b.set(i, type.dimensions().get(i).size().get());
        }
        return b.build();
    }

    @Override
    public String toString(ToStringContext<NAMETYPE> context) {
        return this.type + "(" + this.generatorToString(context) + ")";
    }

    private String generatorToString(ToStringContext<NAMETYPE> context) {
        if (this.freeGenerator != null) {
            return this.freeGenerator.toString();
        }
        return this.boundGenerator.toString(new GenerateToStringContext(context));
    }

    @Override
    public int hashCode() {
        return Objects.hash("generate", this.type, this.freeGenerator, this.boundGenerator);
    }

    private class GenerateToStringContext
    implements ToStringContext<NAMETYPE> {
        private final ToStringContext<NAMETYPE> context;

        public GenerateToStringContext(ToStringContext<NAMETYPE> context) {
            this.context = context;
        }

        @Override
        public String getBinding(String identifier) {
            if (Generate.this.type.dimension(identifier).isPresent()) {
                return identifier;
            }
            return this.context.getBinding(identifier);
        }

        @Override
        public ToStringContext<NAMETYPE> parent() {
            return this.context;
        }
    }

    private class GenerateEvaluationContext
    implements EvaluationContext<NAMETYPE> {
        private final TensorType type;
        private final EvaluationContext<NAMETYPE> context;
        private IndexedTensor.Indexes indexes;

        GenerateEvaluationContext(TensorType type, EvaluationContext<NAMETYPE> context) {
            this.type = type;
            this.context = context;
        }

        double apply(IndexedTensor.Indexes indexes) {
            if (Generate.this.freeGenerator != null) {
                return Generate.this.freeGenerator.apply(indexes.toList());
            }
            this.indexes = indexes;
            return Generate.this.boundGenerator.apply(this);
        }

        @Override
        public Tensor getTensor(String name) {
            Optional<Integer> index = this.type.indexOfDimension(name);
            if (index.isPresent()) {
                return Tensor.from(this.indexes.indexesForReading()[index.get()]);
            }
            return this.context.getTensor(name);
        }

        @Override
        public TensorType getType(NAMETYPE name) {
            Optional<Integer> index = this.type.indexOfDimension(((Name)name).name());
            if (index.isPresent()) {
                return TensorType.empty;
            }
            return this.context.getType(name);
        }

        @Override
        public TensorType getType(String name) {
            Optional<Integer> index = this.type.indexOfDimension(name);
            if (index.isPresent()) {
                return TensorType.empty;
            }
            return this.context.getType(name);
        }
    }
}

