/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterList;
import ai.djl.training.ParameterStore;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;

public abstract class AbstractBaseBlock
implements Block {
    protected byte version;
    protected Shape[] inputShapes;
    protected List<String> inputNames = Collections.emptyList();

    public AbstractBaseBlock() {
        this(1);
    }

    public AbstractBaseBlock(byte version) {
        this.version = version;
    }

    @Override
    public final NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDManager paramsManager = parameterStore.getManager();
        if (training && !this.isInitialized()) {
            this.initialize(paramsManager, DataType.FLOAT32, inputs.getShapes());
        }
        return this.forwardInternal(parameterStore, inputs, training, params);
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList data, NDList labels, PairList<String, Object> params) {
        NDManager paramsManager = parameterStore.getManager();
        if (!this.isInitialized()) {
            this.initialize(paramsManager, DataType.FLOAT32, data.getShapes());
        }
        return this.forwardInternal(parameterStore, data, labels, params);
    }

    protected abstract NDList forwardInternal(ParameterStore var1, NDList var2, boolean var3, PairList<String, Object> var4);

    protected NDList forwardInternal(ParameterStore parameterStore, NDList data, NDList labels, PairList<String, Object> params) {
        return this.forwardInternal(parameterStore, data, true, params);
    }

    @Override
    public PairList<String, Shape> describeInput() {
        if (!this.isInitialized()) {
            throw new IllegalStateException("Parameter of this block are not initialised,please call model.newTrainer and trainer.initialize");
        }
        return new PairList<String, Shape>(this.inputNames, Arrays.asList(this.inputShapes));
    }

    @Override
    public void setInitializer(Initializer initializer, Parameter.Type params) {
        Predicate<Parameter> predicate = parameter -> parameter.getType().equals((Object)params);
        this.setInitializer(initializer, predicate);
    }

    @Override
    public void setInitializer(Initializer initializer, String paramName) {
        Parameter parameter = this.getDirectParameters().values().stream().filter(p -> p.getName().equals(paramName)).findFirst().orElseThrow(() -> new IllegalArgumentException("Could not find parameter " + paramName));
        parameter.setInitializer(initializer);
    }

    @Override
    public void setInitializer(Initializer initializer, Predicate<Parameter> predicate) {
        List params = this.getParameters().values();
        for (Parameter param : params) {
            if (!predicate.test(param)) continue;
            param.setInitializer(initializer);
        }
    }

    @Override
    public void initialize(NDManager manager, DataType dataType, Shape ... inputShapes) {
        this.beforeInitialize(inputShapes);
        if (!this.isInitialized()) {
            this.prepare(inputShapes);
        }
        for (Parameter parameter : this.getDirectParameters().values()) {
            parameter.initialize(manager, dataType);
        }
        this.initializeChildBlocks(manager, dataType, inputShapes);
    }

    protected void beforeInitialize(Shape ... inputShapes) {
        if (this.inputNames.isEmpty()) {
            this.inputNames = new ArrayList<String>();
            for (int i = 0; i < inputShapes.length; ++i) {
                this.inputNames.add("data" + i);
            }
        }
        this.inputShapes = inputShapes;
    }

    protected void initializeChildBlocks(NDManager manager, DataType dataType, Shape ... inputShapes) {
        if (!this.getChildren().isEmpty()) {
            throw new IllegalStateException(this.getClass().getSimpleName() + " has child blocks but initializeChildBlocks is not overwritten.");
        }
    }

    protected void prepare(Shape[] inputShapes) {
    }

    @Override
    public ParameterList getParameters() {
        ParameterList allParams = this.getDirectParameters();
        for (Pair childPair : this.getChildren()) {
            for (Pair paramPair : ((Block)childPair.getValue()).getParameters()) {
                allParams.add((String)childPair.getKey() + "_" + (String)paramPair.getKey(), (Parameter)paramPair.getValue());
            }
        }
        return allParams;
    }

    @Override
    public boolean isInitialized() {
        if (this.inputShapes == null) {
            return false;
        }
        for (Parameter param : this.getParameters().values()) {
            if (param.isInitialized()) continue;
            return false;
        }
        return true;
    }

    @Override
    public void clear() {
        this.getParameters().forEach(param -> ((Parameter)param.getValue()).close());
    }

    @Override
    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.write(this.version);
        this.saveMetadata(os);
        for (Parameter parameter : this.getDirectParameters().values()) {
            parameter.save(os);
        }
        for (Block child : this.getChildren().values()) {
            child.saveParameters(os);
        }
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte loadVersion = is.readByte();
        this.loadMetadata(loadVersion, is);
        for (Parameter parameter : this.getDirectParameters().values()) {
            parameter.load(manager, is);
        }
        for (Block child : this.getChildren().values()) {
            child.loadParameters(manager, is);
        }
    }

    protected void saveMetadata(DataOutputStream os) throws IOException {
        this.saveInputShapes(os);
    }

    protected void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException {
        if (loadVersion != this.version) {
            throw new MalformedModelException("Cannot load parameters for " + this.getClass().getCanonicalName() + ", expected version " + this.version + ", got " + loadVersion + ".");
        }
        this.readInputShapes(is);
    }

    protected void saveInputShapes(DataOutputStream os) throws IOException {
        os.writeInt(this.inputShapes.length);
        for (Shape shape : this.inputShapes) {
            os.write(shape.getEncoded());
        }
    }

    protected void readInputShapes(DataInputStream is) throws IOException {
        int len = is.readInt();
        Shape[] shapes = new Shape[len];
        for (int i = 0; i < len; ++i) {
            shapes[i] = Shape.decode(is);
        }
        if (this.inputShapes == null) {
            this.inputShapes = shapes;
        }
    }

    public String toString() {
        return Blocks.describe(this, null, 0);
    }

    @Override
    public Shape[] getInputShapes() {
        if (!this.isInitialized()) {
            throw new IllegalStateException("getInputShapes() can only be called after the initialization process");
        }
        return this.inputShapes;
    }
}

