/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.crypto.ntru;

import org.bouncycastle.pqc.crypto.ntru.NTRUSampling;
import org.bouncycastle.pqc.crypto.ntru.OWCPADecryptResult;
import org.bouncycastle.pqc.crypto.ntru.OWCPAKeyPair;
import org.bouncycastle.pqc.crypto.ntru.PolynomialPair;
import org.bouncycastle.pqc.math.ntru.HPSPolynomial;
import org.bouncycastle.pqc.math.ntru.Polynomial;
import org.bouncycastle.pqc.math.ntru.parameters.NTRUHPSParameterSet;
import org.bouncycastle.pqc.math.ntru.parameters.NTRUHRSSParameterSet;
import org.bouncycastle.pqc.math.ntru.parameters.NTRUParameterSet;
import org.bouncycastle.util.Arrays;

class NTRUOWCPA {
    private final NTRUParameterSet params;
    private final NTRUSampling sampling;

    public NTRUOWCPA(NTRUParameterSet params) {
        this.params = params;
        this.sampling = new NTRUSampling(params);
    }

    public OWCPAKeyPair keypair(byte[] seed) {
        byte[] privateKey = new byte[this.params.owcpaSecretKeyBytes()];
        int n = this.params.n();
        int q = this.params.q();
        Polynomial x3 = this.params.createPolynomial();
        Polynomial x4 = this.params.createPolynomial();
        Polynomial x5 = this.params.createPolynomial();
        Polynomial invfMod3 = x3;
        Polynomial gf = x3;
        Polynomial invgf = x4;
        Polynomial tmp = x5;
        Polynomial invh = x3;
        Polynomial h = x3;
        PolynomialPair pair = this.sampling.sampleFg(seed);
        Polynomial f = pair.f();
        Polynomial g = pair.g();
        invfMod3.s3Inv(f);
        f.s3ToBytes(privateKey, 0);
        invfMod3.s3ToBytes(privateKey, this.params.packTrinaryBytes());
        f.z3ToZq();
        g.z3ToZq();
        if (this.params instanceof NTRUHRSSParameterSet) {
            for (int i = n - 1; i > 0; --i) {
                g.coeffs[i] = (short)(3 * (g.coeffs[i - 1] - g.coeffs[i]));
            }
            g.coeffs[0] = (short)(-(3 * g.coeffs[0]));
        } else {
            for (int i = 0; i < n; ++i) {
                g.coeffs[i] = (short)(3 * g.coeffs[i]);
            }
        }
        gf.rqMul(g, f);
        invgf.rqInv(gf);
        tmp.rqMul(invgf, f);
        invh.sqMul(tmp, f);
        byte[] sqRes = invh.sqToBytes(privateKey.length - 2 * this.params.packTrinaryBytes());
        System.arraycopy(sqRes, 0, privateKey, 2 * this.params.packTrinaryBytes(), sqRes.length);
        tmp.rqMul(invgf, g);
        h.rqMul(tmp, g);
        byte[] publicKey = h.rqSumZeroToBytes(this.params.owcpaPublicKeyBytes());
        return new OWCPAKeyPair(publicKey, privateKey);
    }

    public byte[] encrypt(Polynomial r, Polynomial m, byte[] publicKey) {
        Polynomial x1 = this.params.createPolynomial();
        Polynomial x2 = this.params.createPolynomial();
        Polynomial h = x1;
        Polynomial liftm = x1;
        Polynomial ct = x2;
        h.rqSumZeroFromBytes(publicKey);
        ct.rqMul(r, h);
        liftm.lift(m);
        for (int i = 0; i < this.params.n(); ++i) {
            int n = i;
            ct.coeffs[n] = (short)(ct.coeffs[n] + liftm.coeffs[i]);
        }
        return ct.rqSumZeroToBytes(this.params.ntruCiphertextBytes());
    }

    public OWCPADecryptResult decrypt(byte[] ciphertext, byte[] privateKey) {
        byte[] sk = privateKey;
        byte[] rm = new byte[this.params.owcpaMsgBytes()];
        Polynomial x1 = this.params.createPolynomial();
        Polynomial x2 = this.params.createPolynomial();
        Polynomial x3 = this.params.createPolynomial();
        Polynomial x4 = this.params.createPolynomial();
        Polynomial c = x1;
        Polynomial f = x2;
        Polynomial cf = x3;
        Polynomial mf = x2;
        Polynomial finv3 = x3;
        Polynomial m = x4;
        Polynomial liftm = x2;
        Polynomial invh = x3;
        Polynomial r = x4;
        Polynomial b = x1;
        c.rqSumZeroFromBytes(ciphertext);
        f.s3FromBytes(sk);
        f.z3ToZq();
        cf.rqMul(c, f);
        mf.rqToS3(cf);
        finv3.s3FromBytes(Arrays.copyOfRange(sk, this.params.packTrinaryBytes(), sk.length));
        m.s3Mul(mf, finv3);
        m.s3ToBytes(rm, this.params.packTrinaryBytes());
        int fail = 0;
        fail |= this.checkCiphertext(ciphertext);
        if (this.params instanceof NTRUHPSParameterSet) {
            fail |= this.checkM((HPSPolynomial)m);
        }
        liftm.lift(m);
        for (int i = 0; i < this.params.n(); ++i) {
            b.coeffs[i] = (short)(c.coeffs[i] - liftm.coeffs[i]);
        }
        invh.sqFromBytes(Arrays.copyOfRange(sk, 2 * this.params.packTrinaryBytes(), sk.length));
        r.sqMul(b, invh);
        r.trinaryZqToZ3();
        r.s3ToBytes(rm, 0);
        return new OWCPADecryptResult(rm, fail |= this.checkR(r));
    }

    private int checkCiphertext(byte[] ciphertext) {
        short t = ciphertext[this.params.ntruCiphertextBytes() - 1];
        t = (short)(t & 255 << 8 - (7 & this.params.logQ() * this.params.packDegree()));
        return 1 & ~t + 1 >>> 15;
    }

    private int checkR(Polynomial r) {
        int t = 0;
        for (int i = 0; i < this.params.n() - 1; ++i) {
            short c = r.coeffs[i];
            t |= c + 1 & this.params.q() - 4;
            t |= c + 2 & 4;
        }
        return 1 & ~(t |= r.coeffs[this.params.n() - 1]) + 1 >>> 31;
    }

    private int checkM(HPSPolynomial m) {
        int t = 0;
        int ps = 0;
        int ms = 0;
        for (int i = 0; i < this.params.n() - 1; ++i) {
            ps = (short)(ps + (m.coeffs[i] & 1));
            ms = (short)(ms + (m.coeffs[i] & 2));
        }
        t |= ps ^ ms >>> 1;
        return 1 & ~(t |= ms ^ ((NTRUHPSParameterSet)this.params).weight()) + 1 >>> 31;
    }
}

