/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.crypto.kems;

import java.math.BigInteger;
import org.bouncycastle.crypto.Digest;
import org.bouncycastle.crypto.EncapsulatedSecretExtractor;
import org.bouncycastle.crypto.kems.SAKKEKEMSGenerator;
import org.bouncycastle.crypto.params.SAKKEPrivateKeyParameters;
import org.bouncycastle.crypto.params.SAKKEPublicKeyParameters;
import org.bouncycastle.math.ec.ECAlgorithms;
import org.bouncycastle.math.ec.ECCurve;
import org.bouncycastle.math.ec.ECPoint;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.BigIntegers;

public class SAKKEKEMExtractor
implements EncapsulatedSecretExtractor {
    private final ECCurve curve;
    private final BigInteger p;
    private final BigInteger q;
    private final ECPoint P;
    private final ECPoint Z_S;
    private final ECPoint K_bs;
    private final int n;
    private final BigInteger identifier;
    private final Digest digest;

    public SAKKEKEMExtractor(SAKKEPrivateKeyParameters privateKey) {
        SAKKEPublicKeyParameters publicKey = privateKey.getPublicParams();
        this.curve = publicKey.getCurve();
        this.q = publicKey.getQ();
        this.P = publicKey.getPoint();
        this.p = publicKey.getPrime();
        this.Z_S = publicKey.getZ();
        this.identifier = publicKey.getIdentifier();
        this.K_bs = this.P.multiply(this.identifier.add(privateKey.getMasterSecret()).modInverse(this.q)).normalize();
        this.n = publicKey.getN();
        this.digest = publicKey.getDigest();
    }

    public byte[] extractSecret(byte[] encapsulation) {
        ECPoint Test2;
        ECPoint R_bS = this.curve.decodePoint(Arrays.copyOfRange(encapsulation, 0, 257));
        BigInteger H = BigIntegers.fromUnsignedByteArray(encapsulation, 257, 16);
        BigInteger w = SAKKEKEMExtractor.computePairing(R_bS, this.K_bs, this.p, this.q);
        BigInteger twoToN = BigInteger.ONE.shiftLeft(this.n);
        BigInteger mask = SAKKEKEMSGenerator.hashToIntegerRange(w.toByteArray(), twoToN, this.digest);
        BigInteger ssv = H.xor(mask).mod(this.p);
        BigInteger b = this.identifier;
        BigInteger r = SAKKEKEMSGenerator.hashToIntegerRange(Arrays.concatenate(ssv.toByteArray(), b.toByteArray()), this.q, this.digest);
        BigInteger order = this.curve.getOrder();
        if (order == null) {
            Test2 = this.P.multiply(b).add(this.Z_S).multiply(r);
        } else {
            BigInteger a = b.multiply(r).mod(order);
            Test2 = ECAlgorithms.sumOfTwoMultiplies(this.P, a, this.Z_S, r);
        }
        Test2 = Test2.subtract(R_bS);
        if (!Test2.isInfinity()) {
            throw new IllegalStateException("Validation of R_bS failed");
        }
        return BigIntegers.asUnsignedByteArray(this.n / 8, ssv);
    }

    public int getEncapsulationLength() {
        return 273;
    }

    static BigInteger computePairing(ECPoint R, ECPoint Q, BigInteger p, BigInteger q) {
        BigInteger[] v = new BigInteger[]{BigInteger.ONE, BigInteger.ZERO};
        ECPoint C = R;
        BigInteger qMinusOne = q.subtract(BigInteger.ONE);
        int numBits = qMinusOne.bitLength();
        BigInteger Qx = Q.getAffineXCoord().toBigInteger();
        BigInteger Qy = Q.getAffineYCoord().toBigInteger();
        BigInteger Rx = R.getAffineXCoord().toBigInteger();
        BigInteger Ry = R.getAffineYCoord().toBigInteger();
        BigInteger three = BigInteger.valueOf(3L);
        for (int i = numBits - 2; i >= 0; --i) {
            BigInteger Cx = C.getAffineXCoord().toBigInteger();
            BigInteger Cy = C.getAffineYCoord().toBigInteger();
            BigInteger l = Cx.multiply(Cx).mod(p).subtract(BigInteger.ONE).multiply(three).multiply(BigIntegers.modOddInverse(p, Cy.shiftLeft(1))).mod(p);
            v = SAKKEKEMExtractor.fp2PointSquare(v[0], v[1], p);
            v = SAKKEKEMExtractor.fp2Multiply(v[0], v[1], l.multiply(Qx.add(Cx)).subtract(Cy).mod(p), Qy, p);
            C = C.twice().normalize();
            if (!qMinusOne.testBit(i)) continue;
            Cx = C.getAffineXCoord().toBigInteger();
            Cy = C.getAffineYCoord().toBigInteger();
            l = Cy.subtract(Ry).multiply(BigIntegers.modOddInverse(p, Cx.subtract(Rx))).mod(p);
            v = SAKKEKEMExtractor.fp2Multiply(v[0], v[1], l.multiply(Qx.add(Cx)).subtract(Cy).mod(p), Qy, p);
            if (i <= 0) continue;
            C = C.add(R).normalize();
        }
        v = SAKKEKEMExtractor.fp2PointSquare(v[0], v[1], p);
        v = SAKKEKEMExtractor.fp2PointSquare(v[0], v[1], p);
        BigInteger v0Inv = BigIntegers.modOddInverse(p, v[0]);
        return v[1].multiply(v0Inv).mod(p);
    }

    static BigInteger[] fp2Multiply(BigInteger a0, BigInteger b0, BigInteger a1, BigInteger b1, BigInteger p) {
        return new BigInteger[]{a0.multiply(a1).subtract(b0.multiply(b1)).mod(p), a0.multiply(b1).add(b0.multiply(a1)).mod(p)};
    }

    static BigInteger[] fp2PointSquare(BigInteger a, BigInteger b, BigInteger p) {
        return new BigInteger[]{a.add(b).multiply(a.subtract(b)).mod(p), a.multiply(b).shiftLeft(1).mod(p)};
    }
}

