/*
 * Decompiled with CFR 0.152.
 */
package se.swedenconnect.opensaml.xmlsec.signature.support.provider.padding;

import java.security.MessageDigest;
import java.security.SecureRandom;
import org.bouncycastle.crypto.CryptoServicesRegistrar;
import se.swedenconnect.opensaml.xmlsec.signature.support.provider.padding.MGF;
import se.swedenconnect.opensaml.xmlsec.signature.support.provider.padding.MGF1;

public class SCPSSPadding {
    private static final byte DEFAULT_END_BYTE = -68;
    private final MessageDigest messageDigest;
    private final int messageDigestSize;
    private final MGF maskGenerationFunction;
    private final int emLength;
    private final int emBits;
    private byte[] salt;
    private int saltLength;
    private final SecureRandom rng = CryptoServicesRegistrar.getSecureRandom();

    public SCPSSPadding(MessageDigest messageDigest, int modulusBits) {
        if (messageDigest == null) {
            throw new NullPointerException("messageDigest must not be null");
        }
        this.messageDigest = messageDigest;
        this.messageDigestSize = messageDigest.getDigestLength();
        this.emBits = modulusBits - 1;
        this.emLength = (int)Math.ceil((double)this.emBits / 8.0);
        this.maskGenerationFunction = new MGF1(messageDigest);
        this.saltLength = messageDigest.getDigestLength();
        this.salt = new byte[this.saltLength];
        this.rng.nextBytes(this.salt);
    }

    public void setSalt(byte[] salt) {
        this.salt = salt;
        this.saltLength = salt.length;
    }

    public byte[] getPaddingFromMessage(byte[] message) {
        return this.getPadding(this.messageDigest.digest(message));
    }

    public byte[] getPadding(byte[] messageHash) throws IllegalArgumentException {
        if (this.emLength < this.messageDigestSize + this.saltLength + 2) {
            throw new IllegalArgumentException("Illegal key modulus length for RSA PSS");
        }
        byte[] mBlock = new byte[8 + this.messageDigestSize + this.saltLength];
        System.arraycopy(messageHash, 0, mBlock, 8, this.messageDigestSize);
        System.arraycopy(this.salt, 0, mBlock, 8 + this.messageDigestSize, this.saltLength);
        byte[] mBlockHash = this.messageDigest.digest(mBlock);
        byte[] dbBlock = new byte[this.emLength - this.messageDigestSize - 1];
        dbBlock[this.emLength - this.messageDigestSize - this.saltLength - 2] = 1;
        System.arraycopy(this.salt, 0, dbBlock, dbBlock.length - this.saltLength, this.saltLength);
        byte[] dbMask = this.maskGenerationFunction.getMask(mBlockHash, dbBlock.length);
        byte[] em = new byte[this.emLength];
        for (int i = 0; i < dbBlock.length; ++i) {
            em[i] = (byte)(dbBlock[i] ^ dbMask[i]);
        }
        em[0] = (byte)(em[0] & 255 >> this.emLength * 8 - this.emBits);
        System.arraycopy(mBlockHash, 0, em, dbBlock.length, this.messageDigestSize);
        em[this.emLength - 1] = -68;
        return em;
    }
}

