/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Collections;

public class Linear
extends AbstractBlock {
    private static final byte VERSION = 3;
    private long outChannels;
    private long inputDimension;
    private boolean flatten;
    private Shape inputShape;
    private Parameter weight;
    private Parameter bias;

    Linear(Builder builder) {
        super((byte)3);
        this.outChannels = builder.outChannels;
        this.flatten = builder.flatten;
        this.weight = this.addParameter(new Parameter("weight", this, ParameterType.WEIGHT), (Shape[] inputShapes) -> new Shape(this.outChannels, this.inputDimension));
        if (builder.bias) {
            this.bias = this.addParameter(new Parameter("bias", this, ParameterType.BIAS), new Shape(this.outChannels));
        }
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        inputs = this.opInputs(parameterStore, inputs);
        NDArrayEx ex = inputs.head().getNDArrayInternal();
        return ex.fullyConnected(inputs, this.outChannels, this.flatten, this.bias == null, params);
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
        if (this.flatten) {
            return new Shape[]{new Shape(inputs[0].get(0), this.outChannels)};
        }
        return new Shape[]{this.inputShape.addAll(new Shape(this.outChannels))};
    }

    @Override
    public PairList<String, Shape> describeInput() {
        return new PairList<String, Shape>(Collections.singletonList("linearInput"), Collections.singletonList(this.inputShape));
    }

    @Override
    public void beforeInitialize(Shape[] inputShapes) {
        this.inputShapes = inputShapes;
        Shape input = inputShapes[0];
        if (this.flatten) {
            Shape inChannels;
            if (input.isLayoutKnown()) {
                inChannels = input.filterByLayoutType(t -> !t.equals((Object)LayoutType.BATCH));
                this.inputShape = input.map(pair -> new Pair(((LayoutType)((Object)((Object)pair.getValue()))).equals((Object)LayoutType.BATCH) ? Long.valueOf(-1L) : (Long)pair.getKey(), pair.getValue()));
            } else if (input.dimension() > 1) {
                inChannels = input.slice(1);
                this.inputShape = new Shape(new long[]{-1L}, new LayoutType[]{LayoutType.BATCH}).addAll(input.slice(1));
            } else {
                inChannels = input;
                this.inputShape = input;
            }
            this.inputDimension = inChannels.size();
        } else {
            this.inputDimension = input.get(input.dimension() - 1);
            this.inputShape = input.slice(0, input.dimension() - 1);
        }
    }

    @Override
    protected void saveMetadata(DataOutputStream os) throws IOException {
        os.writeLong(this.outChannels);
        os.writeBoolean(this.flatten);
        os.writeLong(this.inputDimension);
        os.write(this.inputShape.getEncoded());
    }

    @Override
    public void loadMetadata(byte version, DataInputStream is) throws IOException, MalformedModelException {
        if (version < 1 || version > 3) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        if (version == 3) {
            this.outChannels = is.readLong();
            this.flatten = is.readBoolean();
            this.inputDimension = is.readLong();
        } else if (version == 2) {
            this.flatten = is.readBoolean();
            this.inputDimension = is.readLong();
        } else {
            this.flatten = false;
            this.inputDimension = Shape.decode(is).size();
        }
        this.inputShape = Shape.decode(is);
    }

    private NDList opInputs(ParameterStore parameterStore, NDList inputs) {
        if (inputs.size() != 1) {
            throw new IllegalArgumentException("Linear requires exactly 1 NDArray");
        }
        Device device = inputs.head().getDevice();
        NDList result = new NDList(inputs);
        result.add(parameterStore.getValue(this.weight, device));
        if (this.bias != null) {
            result.add(parameterStore.getValue(this.bias, device));
        }
        return result;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        private long outChannels;
        private boolean bias = true;
        private boolean flatten;

        Builder() {
        }

        public Builder setOutChannels(long outChannels) {
            this.outChannels = outChannels;
            return this;
        }

        public Builder optBias(boolean bias) {
            this.bias = bias;
            return this;
        }

        public Builder optFlatten(boolean flatten) {
            this.flatten = flatten;
            return this;
        }

        public Linear build() {
            if (this.outChannels <= 0L) {
                throw new IllegalArgumentException("You must specify outChannels");
            }
            return new Linear(this);
        }
    }
}

