/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.listeners.checkpoint;

import com.google.common.io.Files;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.TimeUnit;
import lombok.NonNull;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.deeplearning4j.optimize.listeners.checkpoint.Checkpoint;
import org.deeplearning4j.util.ModelSerializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CheckpointListener
extends BaseTrainingListener {
    private static final Logger log = LoggerFactory.getLogger(CheckpointListener.class);
    private static final String[] MODEL_TYPES = new String[]{"MultiLayerNetwork", "ComputationGraph", "Model"};
    private File rootDir;
    private KeepMode keepMode;
    private int keepLast;
    private int keepEvery;
    private boolean logSaving;
    private Integer saveEveryNEpochs;
    private Integer saveEveryNIterations;
    private boolean saveEveryNIterSinceLast;
    private Long saveEveryAmount;
    private TimeUnit saveEveryUnit;
    private Long saveEveryMs;
    private boolean saveEverySinceLast;
    private int lastCheckpointNum = -1;
    private File checkpointRecordFile;
    private Checkpoint lastCheckpoint;
    private long startTime = -1L;
    private int startIter = -1;
    private Long lastSaveEveryMsNoSinceLast;

    private CheckpointListener(Builder builder) {
        this.rootDir = builder.rootDir;
        this.keepMode = builder.keepMode;
        this.keepLast = builder.keepLast;
        this.keepEvery = builder.keepEvery;
        this.logSaving = builder.logSaving;
        this.saveEveryNEpochs = builder.saveEveryNEpochs;
        this.saveEveryNIterations = builder.saveEveryNIterations;
        this.saveEveryNIterSinceLast = builder.saveEveryNIterSinceLast;
        this.saveEveryAmount = builder.saveEveryAmount;
        this.saveEveryUnit = builder.saveEveryUnit;
        this.saveEverySinceLast = builder.saveEverySinceLast;
        if (this.saveEveryAmount != null) {
            this.saveEveryMs = TimeUnit.MILLISECONDS.convert(this.saveEveryAmount, this.saveEveryUnit);
        }
        this.checkpointRecordFile = new File(this.rootDir, "checkpointInfo.txt");
    }

    @Override
    public void onEpochEnd(Model model) {
        int epochsDone = CheckpointListener.getEpoch(model) + 1;
        if (this.saveEveryNEpochs != null && epochsDone > 0 && epochsDone % this.saveEveryNEpochs == 0) {
            this.saveCheckpoint(model);
        }
    }

    @Override
    public void iterationDone(Model model, int iteration, int epoch) {
        if (this.startTime < 0L) {
            this.startTime = System.currentTimeMillis();
            this.startIter = iteration;
            return;
        }
        if (this.saveEveryNIterations != null) {
            if (this.saveEveryNIterSinceLast) {
                long lastSaveIter = this.lastCheckpoint != null ? this.lastCheckpoint.getIteration() : this.startIter;
                if ((long)iteration - lastSaveIter >= (long)this.saveEveryNIterations.intValue()) {
                    this.saveCheckpoint(model);
                    return;
                }
            } else if (iteration > 0 && iteration % this.saveEveryNIterations == 0) {
                this.saveCheckpoint(model);
                return;
            }
        }
        long time = System.currentTimeMillis();
        if (this.saveEveryUnit != null) {
            if (this.saveEverySinceLast) {
                long lastSaveTime;
                long l = lastSaveTime = this.lastCheckpoint != null ? this.lastCheckpoint.getTimestamp() : this.startTime;
                if (time - lastSaveTime >= this.saveEveryMs) {
                    this.saveCheckpoint(model);
                    return;
                }
            } else {
                long lastSave;
                long l = lastSave = this.lastSaveEveryMsNoSinceLast != null ? this.lastSaveEveryMsNoSinceLast : this.startTime;
                if (time - lastSave > this.saveEveryMs) {
                    this.saveCheckpoint(model);
                    this.lastSaveEveryMsNoSinceLast = time;
                    return;
                }
            }
        }
    }

    private void saveCheckpoint(Model model) {
        try {
            this.saveCheckpointHelper(model);
        }
        catch (Exception e) {
            throw new RuntimeException("Error saving checkpoint", e);
        }
    }

    private void saveCheckpointHelper(Model model) throws Exception {
        if (!this.checkpointRecordFile.exists()) {
            this.checkpointRecordFile.createNewFile();
            CheckpointListener.write(Checkpoint.getFileHeader() + "\n", this.checkpointRecordFile);
        }
        Checkpoint c = new Checkpoint(++this.lastCheckpointNum, System.currentTimeMillis(), CheckpointListener.getIter(model), CheckpointListener.getEpoch(model), CheckpointListener.getModelType(model), null);
        CheckpointListener.setFileName(c);
        ModelSerializer.writeModel(model, new File(this.rootDir, c.getFilename()), true);
        String s = c.toFileString();
        CheckpointListener.write(s + "\n", this.checkpointRecordFile);
        if (this.logSaving) {
            log.info("Model checkpoint saved: epoch {}, iteration {}, path: {}", new Object[]{c.getEpoch(), c.getIteration(), new File(this.rootDir, c.getFilename()).getPath()});
        }
        this.lastCheckpoint = c;
        if (this.keepMode == null || this.keepMode == KeepMode.ALL) {
            return;
        }
        if (this.keepMode == KeepMode.LAST) {
            List<Checkpoint> checkpoints = this.availableCheckpoints();
            Iterator<Checkpoint> iter = checkpoints.iterator();
            while (checkpoints.size() > this.keepLast) {
                Checkpoint toRemove = iter.next();
                File f = this.getFileForCheckpoint(toRemove);
                f.delete();
                iter.remove();
            }
        } else {
            for (Checkpoint cp : this.availableCheckpoints()) {
                if (cp.getCheckpointNum() > 0 && (cp.getCheckpointNum() + 1) % this.keepEvery == 0 || cp.getCheckpointNum() > this.lastCheckpointNum - this.keepLast) continue;
                File f = this.getFileForCheckpoint(cp);
                f.delete();
            }
        }
    }

    private static void setFileName(Checkpoint c) {
        String filename = CheckpointListener.getFileName(c.getCheckpointNum(), c.getModelType());
        c.setFilename(filename);
    }

    private static String getFileName(int checkpointNum, String modelType) {
        return "checkpoint_" + checkpointNum + "_" + modelType + ".zip";
    }

    private static String write(String str, File f) {
        try {
            if (!f.exists()) {
                f.createNewFile();
            }
            Files.append((CharSequence)str, (File)f, (Charset)Charset.defaultCharset());
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        return str;
    }

    protected static int getIter(Model model) {
        if (model instanceof MultiLayerNetwork) {
            return ((MultiLayerNetwork)model).getLayerWiseConfigurations().getIterationCount();
        }
        if (model instanceof ComputationGraph) {
            return ((ComputationGraph)model).getConfiguration().getIterationCount();
        }
        return model.conf().getIterationCount();
    }

    protected static int getEpoch(Model model) {
        if (model instanceof MultiLayerNetwork) {
            return ((MultiLayerNetwork)model).getLayerWiseConfigurations().getEpochCount();
        }
        if (model instanceof ComputationGraph) {
            return ((ComputationGraph)model).getConfiguration().getEpochCount();
        }
        return model.conf().getEpochCount();
    }

    protected static String getModelType(Model model) {
        if (model.getClass() == MultiLayerNetwork.class) {
            return "MultiLayerNetwork";
        }
        if (model.getClass() == ComputationGraph.class) {
            return "ComputationGraph";
        }
        return "Model";
    }

    public List<Checkpoint> availableCheckpoints() {
        List lines;
        if (!this.checkpointRecordFile.exists()) {
            return Collections.emptyList();
        }
        try (BufferedInputStream is = new BufferedInputStream(new FileInputStream(this.checkpointRecordFile));){
            lines = IOUtils.readLines((InputStream)is);
        }
        catch (IOException e) {
            throw new RuntimeException("Error loading checkpoint data from file: " + this.checkpointRecordFile.getAbsolutePath(), e);
        }
        ArrayList<Checkpoint> out = new ArrayList<Checkpoint>(lines.size() - 1);
        for (int i = 1; i < lines.size(); ++i) {
            Checkpoint c = Checkpoint.fromFileString((String)lines.get(i));
            if (!new File(this.rootDir, c.getFilename()).exists()) continue;
            out.add(c);
        }
        return out;
    }

    public Checkpoint lastCheckpoint() {
        List<Checkpoint> all = this.availableCheckpoints();
        if (all.isEmpty()) {
            return null;
        }
        return all.get(all.size() - 1);
    }

    public File getFileForCheckpoint(Checkpoint checkpoint) {
        return this.getFileForCheckpoint(checkpoint.getCheckpointNum());
    }

    public File getFileForCheckpoint(int checkpointNum) {
        if (checkpointNum < 0) {
            throw new IllegalArgumentException("Invalid checkpoint number: " + checkpointNum);
        }
        File f = null;
        for (String s : MODEL_TYPES) {
            f = new File(this.rootDir, CheckpointListener.getFileName(checkpointNum, s));
            if (!f.exists()) continue;
            return f;
        }
        throw new IllegalStateException("Model file for checkpoint " + checkpointNum + " does not exist");
    }

    public MultiLayerNetwork loadCheckpointMLN(Checkpoint checkpoint) {
        return this.loadCheckpointMLN(checkpoint.getCheckpointNum());
    }

    public MultiLayerNetwork loadCheckpointMLN(int checkpointNum) {
        File f = this.getFileForCheckpoint(checkpointNum);
        try {
            return ModelSerializer.restoreMultiLayerNetwork(f, true);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public ComputationGraph loadCheckpointCG(Checkpoint checkpoint) {
        return this.loadCheckpointCG(checkpoint.getCheckpointNum());
    }

    public ComputationGraph loadCheckpointCG(int checkpointNum) {
        File f = this.getFileForCheckpoint(checkpointNum);
        try {
            return ModelSerializer.restoreComputationGraph(f, true);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static class Builder {
        private File rootDir;
        private KeepMode keepMode;
        private int keepLast;
        private int keepEvery;
        private boolean logSaving = true;
        private Integer saveEveryNEpochs;
        private Integer saveEveryNIterations;
        private boolean saveEveryNIterSinceLast;
        private Long saveEveryAmount;
        private TimeUnit saveEveryUnit;
        private boolean saveEverySinceLast;

        public Builder(@NonNull String rootDir) {
            this(new File(rootDir));
            if (rootDir == null) {
                throw new NullPointerException("rootDir");
            }
        }

        public Builder(@NonNull File rootDir) {
            if (rootDir == null) {
                throw new NullPointerException("rootDir");
            }
            this.rootDir = rootDir;
        }

        public Builder saveEveryEpoch() {
            return this.saveEveryNEpochs(1);
        }

        public Builder saveEveryNEpochs(int n) {
            this.saveEveryNEpochs = n;
            return this;
        }

        public Builder saveEveryNIterations(int n) {
            return this.saveEveryNIterations(n, false);
        }

        public Builder saveEveryNIterations(int n, boolean sinceLast) {
            this.saveEveryNIterations = n;
            this.saveEveryNIterSinceLast = sinceLast;
            return this;
        }

        public Builder saveEvery(long amount, TimeUnit timeUnit) {
            return this.saveEvery(amount, timeUnit, false);
        }

        public Builder saveEvery(long amount, TimeUnit timeUnit, boolean sinceLast) {
            this.saveEveryAmount = amount;
            this.saveEveryUnit = timeUnit;
            this.saveEverySinceLast = sinceLast;
            return this;
        }

        public Builder keepAll() {
            this.keepMode = KeepMode.ALL;
            return this;
        }

        public Builder keepLast(int n) {
            if (n <= 0) {
                throw new IllegalArgumentException("Number of model files to keep should be > 0 (got: " + n + ")");
            }
            this.keepMode = KeepMode.LAST;
            this.keepLast = n;
            return this;
        }

        public Builder keepLastAndEvery(int nLast, int everyN) {
            if (nLast <= 0) {
                throw new IllegalArgumentException("Most recent number of model files to keep should be > 0 (got: " + nLast + ")");
            }
            if (everyN <= 0) {
                throw new IllegalArgumentException("Every n model files to keep should be > 0 (got: " + everyN + ")");
            }
            this.keepMode = KeepMode.LAST_AND_EVERY;
            this.keepLast = nLast;
            this.keepEvery = everyN;
            return this;
        }

        public Builder logSaving(boolean logSaving) {
            this.logSaving = logSaving;
            return this;
        }

        public CheckpointListener build() {
            if (this.saveEveryNEpochs == null && this.saveEveryAmount == null && this.saveEveryNIterations == null) {
                throw new IllegalStateException("Cannot construct listener: no models will be saved (must use at least one of: save every N epochs, every N iterations, or every T time periods)");
            }
            return new CheckpointListener(this);
        }
    }

    private static enum KeepMode {
        ALL,
        LAST,
        LAST_AND_EVERY;

    }
}

