/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.ndarray;

import ai.djl.Device;
import ai.djl.ndarray.BytesSupplier;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDResource;
import ai.djl.ndarray.NDSerializer;
import ai.djl.ndarray.types.Shape;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PushbackInputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;

public class NDList
extends ArrayList<NDArray>
implements NDResource,
BytesSupplier {
    private static final long serialVersionUID = 1L;

    public NDList() {
    }

    public NDList(int initialCapacity) {
        super(initialCapacity);
    }

    public NDList(NDArray ... arrays) {
        super(Arrays.asList(arrays));
    }

    public NDList(Collection<NDArray> other) {
        super(other);
    }

    public static NDList decode(NDManager manager, byte[] byteArray) {
        return NDList.decode(manager, new ByteArrayInputStream(byteArray));
    }

    public static NDList decode(NDManager manager, InputStream is) {
        try {
            DataInputStream dis = new DataInputStream(is);
            byte[] magic = new byte[4];
            dis.readFully(magic);
            PushbackInputStream pis = new PushbackInputStream(is, 4);
            pis.unread(magic);
            if (magic[0] == 80 && magic[1] == 75) {
                return NDList.decodeNumpy(manager, pis);
            }
            if (magic[0] == 57 && magic[1] == 78 && magic[2] == 85 && magic[3] == 77) {
                return new NDList(NDSerializer.decode(manager, pis));
            }
            dis = new DataInputStream(pis);
            int size = dis.readInt();
            if (size < 0) {
                throw new IllegalArgumentException("Invalid NDList size: " + size);
            }
            NDList list = new NDList();
            for (int i = 0; i < size; ++i) {
                list.add(i, manager.decode(dis));
            }
            return list;
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Malformed data", e);
        }
    }

    private static NDList decodeNumpy(NDManager manager, InputStream is) throws IOException {
        ZipEntry entry;
        NDList list = new NDList();
        ZipInputStream zis = new ZipInputStream(is);
        while ((entry = zis.getNextEntry()) != null) {
            String name = entry.getName();
            NDArray array = NDSerializer.decodeNumpy(manager, zis);
            if (!name.startsWith("arr_") && name.endsWith(".npy")) {
                array.setName(name.substring(0, name.length() - 4));
            }
            list.add(array);
        }
        return list;
    }

    public NDArray get(String name) {
        for (NDArray array : this) {
            if (!name.equals(array.getName())) continue;
            return array;
        }
        return null;
    }

    public NDArray remove(String name) {
        int index = 0;
        for (NDArray array : this) {
            if (name.equals(array.getName())) {
                this.remove(index);
                return array;
            }
            ++index;
        }
        return null;
    }

    public boolean contains(String name) {
        for (NDArray array : this) {
            if (!name.equals(array.getName())) continue;
            return true;
        }
        return false;
    }

    public NDArray head() {
        return (NDArray)this.get(0);
    }

    public NDArray singletonOrThrow() {
        if (this.size() != 1) {
            throw new IndexOutOfBoundsException("Incorrect number of elements in NDList.singletonOrThrow: Expected 1 and was " + this.size());
        }
        return (NDArray)this.get(0);
    }

    public NDList addAll(NDList other) {
        for (NDArray array : other) {
            this.add(array);
        }
        return this;
    }

    public NDList subNDList(int fromIndex) {
        return new NDList((Collection<NDArray>)this.subList(fromIndex, this.size()));
    }

    public NDList toDevice(Device device, boolean copy) {
        if (!copy && this.stream().allMatch(array -> array.getDevice() == device)) {
            return this;
        }
        NDList newNDList = new NDList(this.size());
        this.forEach((? super E a) -> newNDList.add(a.toDevice(device, copy)));
        return newNDList;
    }

    @Override
    public NDManager getManager() {
        return this.head().getManager();
    }

    @Override
    public void attach(NDManager manager) {
        this.forEach((? super E array) -> array.attach(manager));
    }

    @Override
    public void tempAttach(NDManager manager) {
        this.forEach((? super E array) -> array.tempAttach(manager));
    }

    @Override
    public void detach() {
        this.forEach(NDResource::detach);
    }

    public byte[] encode() {
        byte[] byArray;
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        try {
            this.encode(baos);
            byArray = baos.toByteArray();
        }
        catch (Throwable throwable) {
            try {
                try {
                    baos.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (IOException e) {
                throw new AssertionError("NDList is not writable", e);
            }
        }
        baos.close();
        return byArray;
    }

    public void encode(OutputStream os) throws IOException {
        this.encode(os, false);
    }

    public void encode(OutputStream os, boolean numpy) throws IOException {
        if (numpy) {
            ZipOutputStream zos = new ZipOutputStream(os);
            int i = 0;
            for (NDArray nd : this) {
                String name = nd.getName();
                if (name == null) {
                    zos.putNextEntry(new ZipEntry("arr_" + i + ".npy"));
                    ++i;
                } else {
                    zos.putNextEntry(new ZipEntry(name + ".npy"));
                }
                NDSerializer.encodeAsNumpy(nd, zos);
            }
            zos.finish();
            zos.flush();
            return;
        }
        DataOutputStream dos = new DataOutputStream(os);
        dos.writeInt(this.size());
        for (NDArray nd : this) {
            dos.write(nd.encode());
        }
        dos.flush();
    }

    @Override
    public byte[] getAsBytes() {
        return this.encode();
    }

    @Override
    public ByteBuffer toByteBuffer() {
        return ByteBuffer.wrap(this.encode());
    }

    public Shape[] getShapes() {
        return (Shape[])this.stream().map(NDArray::getShape).toArray(Shape[]::new);
    }

    @Override
    public void close() {
        this.forEach(NDArray::close);
        this.clear();
    }

    @Override
    public String toString() {
        StringBuilder builder = new StringBuilder(200);
        builder.append("NDList size: ").append(this.size()).append('\n');
        int index = 0;
        for (NDArray array : this) {
            String name = array.getName();
            builder.append(index++).append(' ');
            if (name != null) {
                builder.append(name);
            }
            builder.append(": ").append(array.getShape()).append(' ').append((Object)array.getDataType()).append('\n');
        }
        return builder.toString();
    }
}

