/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.parameterserver.distributed.training;

import java.util.HashMap;
import java.util.Map;
import java.util.ServiceLoader;
import lombok.NonNull;
import org.nd4j.common.config.ND4JClassLoading;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.logic.Storage;
import org.nd4j.parameterserver.distributed.logic.completion.Clipboard;
import org.nd4j.parameterserver.distributed.messages.TrainingMessage;
import org.nd4j.parameterserver.distributed.training.TrainingDriver;
import org.nd4j.parameterserver.distributed.transport.Transport;

@Deprecated
public class TrainerProvider {
    private static final TrainerProvider INSTANCE = new TrainerProvider();
    protected Map<String, TrainingDriver<?>> trainers = new HashMap();
    protected VoidConfiguration voidConfiguration;
    protected Transport transport;
    protected Clipboard clipboard;
    protected Storage storage;

    private TrainerProvider() {
        this.loadProviders();
    }

    public static TrainerProvider getInstance() {
        return INSTANCE;
    }

    protected void loadProviders() {
        ServiceLoader serviceLoader = ND4JClassLoading.loadService(TrainingDriver.class);
        for (TrainingDriver trainingDriver : serviceLoader) {
            this.trainers.put(trainingDriver.targetMessageClass(), trainingDriver);
        }
        if (this.trainers.isEmpty()) {
            throw new ND4JIllegalStateException("No TrainingDrivers were found via ServiceLoader mechanism");
        }
    }

    public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport, @NonNull Storage storage, @NonNull Clipboard clipboard) {
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration is marked non-null but is null");
        }
        if (transport == null) {
            throw new NullPointerException("transport is marked non-null but is null");
        }
        if (storage == null) {
            throw new NullPointerException("storage is marked non-null but is null");
        }
        if (clipboard == null) {
            throw new NullPointerException("clipboard is marked non-null but is null");
        }
        this.voidConfiguration = voidConfiguration;
        this.transport = transport;
        this.clipboard = clipboard;
        this.storage = storage;
        for (TrainingDriver<?> trainer : this.trainers.values()) {
            trainer.init(voidConfiguration, transport, storage, clipboard);
        }
    }

    protected <T extends TrainingMessage> TrainingDriver<T> getTrainer(T message) {
        TrainingDriver<?> driver = this.trainers.get(message.getClass().getSimpleName());
        if (driver == null) {
            throw new ND4JIllegalStateException("Can't find trainer for [" + message.getClass().getSimpleName() + "]");
        }
        return driver;
    }

    public <T extends TrainingMessage> void doTraining(T message) {
        TrainingDriver<T> trainer = this.getTrainer(message);
        trainer.startTraining(message);
    }
}

