/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.vespa.indexinglanguage.expressions;

import com.yahoo.document.DataType;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.tensor.Tensor;
import com.yahoo.vespa.indexinglanguage.expressions.ExecutionContext;
import com.yahoo.vespa.indexinglanguage.expressions.Expression;
import com.yahoo.vespa.indexinglanguage.expressions.VerificationContext;
import java.util.Objects;
import java.util.Optional;

public class BinarizeExpression
extends Expression {
    private final double threshold;
    private DataType type;

    public BinarizeExpression(double threshold) {
        super((DataType)TensorDataType.any());
        this.threshold = threshold;
    }

    @Override
    protected void doExecute(ExecutionContext context) {
        Optional tensor = ((TensorFieldValue)context.getValue()).getTensor();
        if (tensor.isEmpty()) {
            return;
        }
        context.setValue((FieldValue)new TensorFieldValue(((Tensor)tensor.get()).map(v -> v > this.threshold ? 1.0 : 0.0)));
    }

    @Override
    protected void doVerify(VerificationContext context) {
        this.type = context.getValueType();
        if (!(this.type instanceof TensorDataType)) {
            throw new IllegalArgumentException("The 'binarize' function requires a tensor, but got " + this.type);
        }
    }

    @Override
    public DataType createdOutputType() {
        return this.type;
    }

    public String toString() {
        return "binarize" + (String)(this.threshold == 0.0 ? "" : " " + this.threshold);
    }

    public int hashCode() {
        return Objects.hash(this.threshold, this.toString().hashCode());
    }

    public boolean equals(Object o) {
        return o instanceof BinarizeExpression;
    }
}

