/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterBlock;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class Linear
extends ParameterBlock {
    private static final byte VERSION = 1;
    private long outChannels;
    private Shape inputShape;
    private Shape inChannels;
    private Parameter weight;
    private Parameter bias;

    Linear(Builder builder) {
        this.outChannels = builder.outChannels;
        this.weight = new Parameter("weight", this, ParameterType.WEIGHT);
        if (builder.bias) {
            this.bias = new Parameter("bias", this, ParameterType.BIAS);
        }
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, PairList<String, Object> params) {
        inputs = this.opInputs(parameterStore, inputs);
        NDArrayEx ex = inputs.head().getNDArrayInternal();
        return ex.fullyConnected(inputs, this.outChannels, false, this.bias == null, params);
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
        return new Shape[]{new Shape(inputs[0].get(0), this.outChannels)};
    }

    @Override
    public List<Parameter> getDirectParameters() {
        if (this.bias != null) {
            return Arrays.asList(this.weight, this.bias);
        }
        return Collections.singletonList(this.weight);
    }

    @Override
    public PairList<String, Shape> describeInput() {
        return new PairList<String, Shape>(Collections.singletonList("linearInput"), Collections.singletonList(this.inputShape));
    }

    @Override
    public void beforeInitialize(Shape[] inputShapes) {
        this.inputShapes = inputShapes;
        Shape input = inputShapes[0];
        if (input.isLayoutKnown()) {
            this.inChannels = input.filterByLayoutType(t -> !t.equals((Object)LayoutType.BATCH));
            this.inputShape = input.map(pair -> new Pair(((LayoutType)((Object)((Object)pair.getValue()))).equals((Object)LayoutType.BATCH) ? Long.valueOf(-1L) : (Long)pair.getKey(), pair.getValue()));
        } else if (input.dimension() > 1) {
            this.inChannels = input.slice(1);
            this.inputShape = new Shape(new long[]{-1L}, new LayoutType[]{LayoutType.BATCH}).addAll(input.slice(1));
        } else {
            this.inChannels = input.slice(0);
            this.inputShape = input;
        }
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        switch (name) {
            case "weight": {
                return new Shape(this.outChannels).addAll(this.inChannels);
            }
            case "bias": {
                return new Shape(this.outChannels);
            }
        }
        throw new IllegalArgumentException("Invalid parameter name");
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(1);
        os.write(this.inChannels.getEncoded());
        os.write(this.inputShape.getEncoded());
        this.weight.save(os);
        if (this.bias != null) {
            this.bias.save(os);
        }
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte version = is.readByte();
        if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        this.inChannels = Shape.decode(is);
        this.inputShape = Shape.decode(is);
        this.weight.load(manager, is);
        if (this.bias != null) {
            this.bias.load(manager, is);
        }
    }

    private NDList opInputs(ParameterStore parameterStore, NDList inputs) {
        if (inputs.size() != 1) {
            throw new IllegalArgumentException("Linear requires exactly 1 NDArray");
        }
        Device device = inputs.head().getDevice();
        NDList result = new NDList(inputs);
        result.add(parameterStore.getValue(this.weight, device));
        if (this.bias != null) {
            result.add(parameterStore.getValue(this.bias, device));
        }
        return result;
    }

    public static final class Builder {
        private long outChannels;
        private boolean bias = true;

        public Builder setOutChannels(long outChannels) {
            this.outChannels = outChannels;
            return this;
        }

        public Builder optBias(boolean bias) {
            this.bias = bias;
            return this;
        }

        public Linear build() {
            if (this.outChannels == 0L) {
                throw new IllegalArgumentException("You must specify outChannels");
            }
            return new Linear(this);
        }
    }
}

