/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.rule;

import com.yahoo.api.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

@Beta
public class UnpackBitsNode
extends CompositeNode {
    private static final String operationName = "unpack_bits";
    final ExpressionNode input;
    final TensorType.Value targetCellType;
    final EndianNess endian;

    public UnpackBitsNode(ExpressionNode input, TensorType.Value targetCellType, String endianNess) {
        this.input = input;
        this.targetCellType = targetCellType;
        this.endian = EndianNess.fromId(endianNess);
    }

    @Override
    public List<ExpressionNode> children() {
        return Collections.singletonList(this.input);
    }

    @Override
    public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
        Optional<TypeContext<Reference>> optTC = context.typeContext();
        if (optTC.isPresent()) {
            TensorType inputType = this.input.type(optTC.get());
            Meta meta = this.analyze(inputType);
            string.append("map_subspaces").append("(");
            this.input.toString(string, context, path, this);
            string.append(", f(denseSubspaceInput)(");
            string.append(meta.outputDenseType()).append("(");
            string.append("bit(denseSubspaceInput{");
            for (TensorType.Dimension dim : meta.outputDenseType().dimensions()) {
                String dName = dim.name();
                boolean last = dName.equals(meta.unpackDimension);
                string.append(dName);
                string.append(":(");
                string.append(dName);
                if (last) {
                    string.append("/8");
                }
                string.append(")");
                if (last) continue;
                string.append(", ");
            }
            if (this.endian.equals((Object)EndianNess.BIG_ENDIAN)) {
                string.append("}, 7-(");
            } else {
                string.append("}, (");
            }
            string.append(meta.unpackDimension);
            string.append(" % 8)");
            string.append("))))");
        } else {
            string.append(operationName);
            string.append("(");
            this.input.toString(string, context, path, this);
            string.append(",");
            string.append(this.targetCellType);
            string.append(",");
            string.append((Object)this.endian);
            string.append(")");
        }
        return string;
    }

    @Override
    public Value evaluate(Context context) {
        Tensor inputTensor = this.input.evaluate(context).asTensor();
        TensorType inputType = inputTensor.type();
        Meta meta = this.analyze(inputType);
        Tensor.Builder builder = Tensor.Builder.of((TensorType)meta.outputType());
        Iterator iter = inputTensor.cellIterator();
        while (iter.hasNext()) {
            Tensor.Cell cell = (Tensor.Cell)iter.next();
            TensorAddress oldAddr = cell.getKey();
            for (int bitIdx = 0; bitIdx < 8; ++bitIdx) {
                float newCellValue;
                TensorAddress.Builder addrBuilder = new TensorAddress.Builder(meta.outputType());
                for (int i = 0; i < inputType.dimensions().size(); ++i) {
                    TensorType.Dimension dim = (TensorType.Dimension)inputType.dimensions().get(i);
                    if (dim.name().equals(meta.unpackDimension())) {
                        long newIdx = oldAddr.numericLabel(i) * 8L + (long)bitIdx;
                        addrBuilder.add(dim.name(), newIdx);
                        continue;
                    }
                    addrBuilder.add(dim.name(), oldAddr.numericLabel(i));
                }
                TensorAddress newAddr = addrBuilder.build();
                int oldValue = (int)cell.getValue().doubleValue();
                if (this.endian.equals((Object)EndianNess.BIG_ENDIAN)) {
                    newCellValue = 1 & oldValue >>> 7 - bitIdx;
                    builder.cell(newAddr, newCellValue);
                    continue;
                }
                newCellValue = 1 & oldValue >>> bitIdx;
                builder.cell(newAddr, newCellValue);
            }
        }
        return new TensorValue(builder.build());
    }

    private Meta analyze(TensorType inputType) {
        if (inputType.valueType() != TensorType.Value.INT8) {
            throw new IllegalArgumentException("bad unpack_bits; input must have cell-type int8, but it was: " + inputType.valueType());
        }
        TensorType inputDenseType = inputType.indexedSubtype();
        if (inputDenseType.rank() == 0) {
            throw new IllegalArgumentException("bad unpack_bits; input must have indexed dimension, but type was: " + inputType);
        }
        TensorType.Dimension lastDim = (TensorType.Dimension)inputDenseType.dimensions().get(inputDenseType.rank() - 1);
        if (lastDim.size().isEmpty()) {
            throw new IllegalArgumentException("bad unpack_bits; last indexed dimension must be bound, but type was: " + inputType);
        }
        TensorType.Builder ttBuilder = new TensorType.Builder(this.targetCellType);
        for (TensorType.Dimension dim : inputType.dimensions()) {
            if (dim.name().equals(lastDim.name())) {
                long sz = (Long)dim.size().get();
                ttBuilder.indexed(dim.name(), sz * 8L);
                continue;
            }
            ttBuilder.set(dim);
        }
        TensorType outputType = ttBuilder.build();
        return new Meta(outputType, outputType.indexedSubtype(), lastDim.name());
    }

    @Override
    public TensorType type(TypeContext<Reference> context) {
        TensorType inputType = this.input.type(context);
        Meta meta = this.analyze(inputType);
        return meta.outputType();
    }

    @Override
    public CompositeNode setChildren(List<ExpressionNode> newChildren) {
        if (newChildren.size() != 1) {
            throw new IllegalArgumentException("Expected 1 child but got " + newChildren.size());
        }
        return new UnpackBitsNode(newChildren.get(0), this.targetCellType, this.endian.toString());
    }

    @Override
    public int hashCode() {
        return Objects.hash(operationName, this.input, this.targetCellType);
    }

    private static enum EndianNess {
        BIG_ENDIAN("big"),
        LITTLE_ENDIAN("little");

        private final String id;

        private EndianNess(String id) {
            this.id = id;
        }

        public String toString() {
            return this.id;
        }

        public static EndianNess fromId(String id) {
            for (EndianNess value : EndianNess.values()) {
                if (!value.id.equals(id)) continue;
                return value;
            }
            throw new IllegalArgumentException("EndianNess must be either 'big' or 'little', but was '" + id + "'");
        }
    }

    private record Meta(TensorType outputType, TensorType outputDenseType, String unpackDimension) {
    }
}

