/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.security.hpke;

import com.yahoo.security.ArrayUtils;
import com.yahoo.security.hpke.Aead;
import com.yahoo.security.hpke.Ciphersuite;
import com.yahoo.security.hpke.Constants;
import com.yahoo.security.hpke.Kdf;
import com.yahoo.security.hpke.Kem;
import com.yahoo.security.hpke.LabeledKdfUtils;
import java.security.interfaces.XECPrivateKey;
import java.security.interfaces.XECPublicKey;
import java.util.Arrays;

public final class Hpke {
    private final Kem kem;
    private final Kdf kdf;
    private final Aead aead;
    private final byte[] hpkeSuiteId;
    private static final byte MODE_BASE = 0;
    private static final byte MODE_PSK = 1;
    private static final byte MODE_AUTH = 2;
    private static final byte MODE_AUTH_PSK = 3;
    private static final int MAX_INPUT_LENGTH = 64;

    private Hpke(Ciphersuite ciphersuite) {
        this.kem = ciphersuite.kem();
        this.kdf = ciphersuite.kdf();
        this.aead = ciphersuite.aead();
        this.hpkeSuiteId = this.makeHpkeSuiteId();
    }

    public static Hpke of(Ciphersuite ciphersuite) {
        return new Hpke(ciphersuite);
    }

    private byte[] makeHpkeSuiteId() {
        byte[] hpkePrefix = new byte[]{72, 80, 75, 69};
        return ArrayUtils.concat(hpkePrefix, LabeledKdfUtils.i2osp2(this.kem.kemId()), LabeledKdfUtils.i2osp2(this.kdf.kdfId()), LabeledKdfUtils.i2osp2(this.aead.aeadId()));
    }

    byte[] labeledExtractHpke(byte[] salt, byte[] label, byte[] ikm) {
        return LabeledKdfUtils.labeledExtractForSuite(this.kdf, this.hpkeSuiteId, salt, label, ikm);
    }

    byte[] labeledExpandHpke(byte[] prk, byte[] label, byte[] info, int nBytesToExpand) {
        return LabeledKdfUtils.labeledExpandForSuite(this.kdf, prk, this.hpkeSuiteId, label, info, nBytesToExpand);
    }

    static void verifyPskInputs(byte mode, byte[] psk, byte[] pskId) {
        boolean gotPskId;
        boolean gotPsk = !Arrays.equals(psk, Constants.DEFAULT_PSK);
        boolean bl = gotPskId = !Arrays.equals(pskId, Constants.DEFAULT_PSK_ID);
        if (gotPsk != gotPskId) {
            throw new IllegalArgumentException("Inconsistent PSK inputs");
        }
        if (gotPsk && (mode == 0 || mode == 2)) {
            throw new IllegalArgumentException("PSK input provided when not needed");
        }
        if (!(gotPsk || mode != 1 && mode != 3)) {
            throw new IllegalArgumentException("Missing required PSK input");
        }
    }

    static void verifyInputLengthRestrictions(byte[] psk, byte[] pskId, byte[] info) {
        if (psk.length > 64) {
            throw new IllegalArgumentException("Input PSK length (%d) greater than max length (%d)".formatted(psk.length, 64));
        }
        if (pskId.length > 64) {
            throw new IllegalArgumentException("Input PSK ID length (%d) greater than max length (%d)".formatted(pskId.length, 64));
        }
        if (info.length > 64) {
            throw new IllegalArgumentException("Input info length (%d) greater than max length (%d)".formatted(info.length, 64));
        }
    }

    ContextBase keySchedule(byte mode, byte[] sharedSecret, byte[] info, byte[] psk, byte[] pskId) {
        Hpke.verifyPskInputs(mode, psk, pskId);
        Hpke.verifyInputLengthRestrictions(psk, pskId, info);
        byte[] pskIdHash = this.labeledExtractHpke(Constants.EMPTY_LABEL, Constants.PSK_ID_HASH_LABEL, pskId);
        byte[] infoHash = this.labeledExtractHpke(Constants.EMPTY_LABEL, Constants.INFO_HASH_LABEL, info);
        byte[] keyScheduleContext = ArrayUtils.concat({mode}, pskIdHash, infoHash);
        byte[] secret = this.labeledExtractHpke(sharedSecret, Constants.SECRET_LABEL, psk);
        byte[] key = this.labeledExpandHpke(secret, Constants.KEY_LABEL, keyScheduleContext, this.aead.nK());
        byte[] baseNonce = this.labeledExpandHpke(secret, Constants.BASE_NONCE_LABEL, keyScheduleContext, this.aead.nN());
        byte[] exporterSecret = this.labeledExpandHpke(secret, Constants.EXP_LABEL, keyScheduleContext, this.kdf.nH());
        return new ContextBase(key, baseNonce, 0L, exporterSecret);
    }

    ContextS setupBaseS(XECPublicKey pkR, byte[] info) {
        Kem.EncapResult encapped = this.kem.encap(pkR);
        return new ContextS(encapped.enc(), this.keySchedule((byte)0, encapped.sharedSecret(), info, Constants.DEFAULT_PSK, Constants.DEFAULT_PSK_ID));
    }

    ContextR setupBaseR(byte[] enc, XECPrivateKey skR, byte[] info) {
        byte[] sharedSecret = this.kem.decap(enc, skR);
        return new ContextR(this.keySchedule((byte)0, sharedSecret, info, Constants.DEFAULT_PSK, Constants.DEFAULT_PSK_ID));
    }

    ContextS setupAuthS(XECPublicKey pkR, byte[] info, XECPrivateKey skS) {
        Kem.EncapResult encapped = this.kem.authEncap(pkR, skS);
        return new ContextS(encapped.enc(), this.keySchedule((byte)2, encapped.sharedSecret(), info, Constants.DEFAULT_PSK, Constants.DEFAULT_PSK_ID));
    }

    ContextR setupAuthR(byte[] enc, XECPrivateKey skR, byte[] info, XECPublicKey pkS) {
        byte[] sharedSecret = this.kem.authDecap(enc, skR, pkS);
        return new ContextR(this.keySchedule((byte)2, sharedSecret, info, Constants.DEFAULT_PSK, Constants.DEFAULT_PSK_ID));
    }

    public Sealed sealBase(XECPublicKey pkR, byte[] info, byte[] aad, byte[] pt) {
        ContextS encAndCtx = this.setupBaseS(pkR, info);
        ContextBase base = encAndCtx.base;
        byte[] ct = this.aead.seal(base.key(), base.nonce(), aad, pt);
        return new Sealed(encAndCtx.enc, ct);
    }

    public Sealed sealAuth(XECPublicKey pkR, byte[] info, byte[] aad, byte[] pt, XECPrivateKey skS) {
        ContextS encAndCtx = this.setupAuthS(pkR, info, skS);
        ContextBase base = encAndCtx.base;
        byte[] ct = this.aead.seal(base.key(), base.nonce(), aad, pt);
        return new Sealed(encAndCtx.enc, ct);
    }

    public byte[] openBase(byte[] enc, XECPrivateKey skR, byte[] info, byte[] aad, byte[] ct) {
        ContextR ctx = this.setupBaseR(enc, skR, info);
        ContextBase base = ctx.base;
        return this.aead.open(base.key(), base.nonce(), aad, ct);
    }

    public byte[] openAuth(byte[] enc, XECPrivateKey skR, byte[] info, byte[] aad, byte[] ct, XECPublicKey pkS) {
        ContextR ctx = this.setupAuthR(enc, skR, info, pkS);
        ContextBase base = ctx.base;
        return this.aead.open(base.key(), base.nonce(), aad, ct);
    }

    private record ContextBase(byte[] key, byte[] nonce, long seqNum, byte[] exporterSecret) {
    }

    private record ContextS(byte[] enc, ContextBase base) {
    }

    private record ContextR(ContextBase base) {
    }

    public record Sealed(byte[] enc, byte[] ciphertext) {
    }
}

