/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.updater;

import java.util.Objects;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.nd4j.linalg.learning.config.IUpdater;

public class UpdaterUtils {
    public static boolean updaterConfigurationsEquals(Layer layer1, String param1, Layer layer2, String param2) {
        IUpdater u2;
        org.deeplearning4j.nn.conf.layers.Layer l1 = layer1.conf().getLayer();
        org.deeplearning4j.nn.conf.layers.Layer l2 = layer2.conf().getLayer();
        IUpdater u1 = l1.getIUpdaterByParam(param1);
        if (!u1.equals((Object)(u2 = l2.getIUpdaterByParam(param2)))) {
            return false;
        }
        if (!UpdaterUtils.lrSchedulesEqual(layer1, param1, layer2, param2)) {
            return false;
        }
        boolean isPretrainParam1 = layer1.conf().getLayer().isPretrainParam(param1);
        boolean isPretrainParam2 = layer2.conf().getLayer().isPretrainParam(param2);
        if (isPretrainParam1 || isPretrainParam2) {
            return layer1 == layer2 && isPretrainParam1 && isPretrainParam2;
        }
        return true;
    }

    public static boolean lrSchedulesEqual(Layer layer1, String param1, Layer layer2, String param2) {
        boolean lrConfigEqual;
        double lr2;
        LearningRatePolicy lp2;
        LearningRatePolicy lp1 = layer1.conf().getLearningRatePolicy();
        if (lp1 != (lp2 = layer2.conf().getLearningRatePolicy())) {
            return false;
        }
        double lr1 = layer1.conf().getLearningRateByParam(param1);
        if (lr1 != (lr2 = layer2.conf().getLearningRateByParam(param2))) {
            return false;
        }
        double dr1 = layer1.conf().getLrPolicyDecayRate();
        double dr2 = layer2.conf().getLrPolicyDecayRate();
        switch (lp1) {
            case None: {
                lrConfigEqual = true;
                break;
            }
            case Exponential: {
                lrConfigEqual = dr1 == dr2;
                break;
            }
            case Inverse: {
                lrConfigEqual = dr1 == dr2 && layer1.conf().getLrPolicyPower() == layer2.conf().getLrPolicyPower();
                break;
            }
            case Poly: {
                lrConfigEqual = layer1.conf().getLrPolicyPower() == layer2.conf().getLrPolicyPower();
                break;
            }
            case Sigmoid: {
                lrConfigEqual = dr1 == dr2 && layer1.conf().getLrPolicySteps() == layer2.conf().getLrPolicySteps();
                break;
            }
            case Step: {
                lrConfigEqual = dr1 == dr2 && layer1.conf().getLrPolicySteps() == layer2.conf().getLrPolicySteps();
                break;
            }
            case TorchStep: {
                lrConfigEqual = layer1.conf().getLrPolicyPower() == layer2.conf().getLrPolicyPower();
                break;
            }
            case Schedule: {
                BaseLayer bl1 = (BaseLayer)layer1.conf().getLayer();
                BaseLayer bl2 = (BaseLayer)layer2.conf().getLayer();
                lrConfigEqual = Objects.equals(bl1.getLearningRateSchedule(), bl2.getLearningRateSchedule());
                break;
            }
            case Score: {
                lrConfigEqual = false;
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unknown learning rate schedule: " + (Object)((Object)lp1));
            }
        }
        return lrConfigEqual;
    }
}

