/*
 * Decompiled with CFR 0.152.
 */
package org.apache.nifi.security.util.crypto;

import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.concurrent.TimeUnit;
import org.apache.nifi.security.util.crypto.SecureHasher;
import org.bouncycastle.crypto.generators.Argon2BytesGenerator;
import org.bouncycastle.crypto.params.Argon2Parameters;
import org.bouncycastle.util.encoders.Base64;
import org.bouncycastle.util.encoders.Hex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Argon2SecureHasher
implements SecureHasher {
    private static final Logger logger = LoggerFactory.getLogger(Argon2SecureHasher.class);
    private static final int DEFAULT_HASH_LENGTH = 32;
    private static final int DEFAULT_PARALLELISM = 1;
    private static final int DEFAULT_MEMORY = 4096;
    private static final int DEFAULT_ITERATIONS = 3;
    private static final int DEFAULT_SALT_LENGTH = 16;
    private static final int MIN_MEMORY_SIZE_KB = 8;
    private static final int MIN_PARALLELISM = 1;
    private static final long MAX_PARALLELISM = Math.round(Math.pow(2.0, 24.0)) - 1L;
    private static final int MIN_HASH_LENGTH = 4;
    private static final int MIN_ITERATIONS = 1;
    private static final int MIN_SALT_LENGTH = 8;
    private final Integer hashLength;
    private final Integer memory;
    private final int parallelism;
    private final Integer iterations;
    private final Integer saltLength;
    private boolean usingStaticSalt;
    private static final byte[] STATIC_SALT = "NiFi Static Salt".getBytes(StandardCharsets.UTF_8);
    private static final long UPPER_BOUNDARY = Math.round(Math.pow(2.0, 32.0)) - 1L;

    public Argon2SecureHasher() {
        this(32, 4096, 1, 3, 0);
    }

    public Argon2SecureHasher(Integer hashLength, Integer memory, int parallelism, Integer iterations) {
        this(hashLength, memory, parallelism, iterations, 0);
    }

    public Argon2SecureHasher(Integer hashLength, Integer memory, int parallelism, Integer iterations, Integer saltLength) {
        this.validateParameters(hashLength, memory, parallelism, iterations, saltLength);
        this.hashLength = hashLength;
        this.memory = memory;
        this.parallelism = parallelism;
        this.iterations = iterations;
        this.saltLength = saltLength;
    }

    private void validateParameters(Integer hashLength, Integer memory, int parallelism, Integer iterations, Integer saltLength) {
        if (!Argon2SecureHasher.isHashLengthValid(hashLength)) {
            logger.error("The provided hash length {} is outside the boundary of 4 to 2^32 - 1.", (Object)hashLength);
            throw new IllegalArgumentException("Invalid hash length is not within the hashLength boundary.");
        }
        if (!Argon2SecureHasher.isMemorySizeValid(memory)) {
            logger.error("The provided memory size {} KiB is outside the boundary of 8p to 2^32 - 1.", (Object)memory);
            throw new IllegalArgumentException("Invalid memory size is not within the memory boundary.");
        }
        if (!Argon2SecureHasher.isParallelismValid(parallelism)) {
            logger.error("The provided parallelization factor {} is outside the boundary of 1 to 2^24 - 1.", (Object)parallelism);
            throw new IllegalArgumentException("Invalid parallelization factor exceeds the parallelism boundary.");
        }
        if (!Argon2SecureHasher.isIterationsValid(iterations)) {
            logger.error("The iteration count {} is outside the boundary of 1 to 2^32 - 1.", (Object)iterations);
            throw new IllegalArgumentException("Invalid iteration count exceeds the iterations boundary.");
        }
        if (saltLength > 0) {
            if (!Argon2SecureHasher.isSaltLengthValid(saltLength)) {
                logger.error("The salt length {} is outside the boundary of 8 to 2^32 - 1.", (Object)saltLength);
                throw new IllegalArgumentException("Invalid salt length exceeds the saltLength boundary.");
            }
            this.usingStaticSalt = false;
        } else {
            this.usingStaticSalt = true;
            logger.debug("Configured to use static salt");
        }
    }

    public boolean isUsingStaticSalt() {
        return this.usingStaticSalt;
    }

    byte[] getSalt() {
        if (this.isUsingStaticSalt()) {
            return STATIC_SALT;
        }
        SecureRandom sr = new SecureRandom();
        byte[] salt = new byte[this.saltLength.intValue()];
        sr.nextBytes(salt);
        return salt;
    }

    public static boolean isHashLengthValid(Integer hashLength) {
        if (hashLength < 32) {
            logger.warn("The provided hash length {} is below the recommended minimum {}.", (Object)hashLength, (Object)32);
        }
        return hashLength >= 4 && (long)hashLength.intValue() <= UPPER_BOUNDARY;
    }

    public static boolean isMemorySizeValid(Integer memory) {
        if (memory < 4096) {
            logger.warn("The provided memory size {} KiB is below the recommended minimum {} KiB.", (Object)memory, (Object)4096);
        }
        return memory >= 8 && (long)memory.intValue() <= UPPER_BOUNDARY;
    }

    public static boolean isParallelismValid(int parallelism) {
        if (parallelism < 1) {
            logger.warn("The provided parallelization factor {} is below the recommended minimum {}.", (Object)parallelism, (Object)1);
        }
        return parallelism >= 1 && (long)parallelism <= MAX_PARALLELISM;
    }

    public static boolean isIterationsValid(Integer iterations) {
        if (iterations < 3) {
            logger.warn("The provided iteration count {} is below the recommended minimum {}.", (Object)iterations, (Object)3);
        }
        return iterations >= 1 && (long)iterations.intValue() <= UPPER_BOUNDARY;
    }

    public static boolean isSaltLengthValid(Integer saltLength) {
        if (saltLength == 0) {
            logger.debug("The provided salt length 0 indicates a static salt of {} bytes", (Object)16);
            return true;
        }
        if (saltLength < 16) {
            logger.warn("The provided dynamic salt length {} is below the recommended minimum {}", (Object)saltLength, (Object)16);
        }
        return saltLength >= 8 && (long)saltLength.intValue() <= UPPER_BOUNDARY;
    }

    @Override
    public String hashHex(String input) {
        if (input == null) {
            logger.warn("Attempting to generate an Argon2 hash of null input; using empty input");
            input = "";
        }
        return Hex.toHexString((byte[])this.hash(input.getBytes(StandardCharsets.UTF_8)));
    }

    @Override
    public String hashBase64(String input) {
        if (input == null) {
            logger.warn("Attempting to generate an Argon2 hash of null input; using empty input");
            input = "";
        }
        return Base64.toBase64String((byte[])this.hash(input.getBytes(StandardCharsets.UTF_8)));
    }

    @Override
    public byte[] hashRaw(byte[] input) {
        return this.hash(input);
    }

    private byte[] hash(byte[] input) {
        byte[] salt = this.getSalt();
        byte[] hash = new byte[this.hashLength.intValue()];
        logger.debug("Creating {} byte Argon2 hash with salt [{}]", (Object)this.hashLength, (Object)Hex.toHexString((byte[])salt));
        long startNanos = System.nanoTime();
        Argon2Parameters params = new Argon2Parameters.Builder(2).withSalt(salt).withParallelism(this.parallelism).withMemoryAsKB(this.memory.intValue()).withIterations(this.iterations.intValue()).build();
        Argon2BytesGenerator generator = new Argon2BytesGenerator();
        generator.init(params);
        long initNanos = System.nanoTime();
        generator.generateBytes(input, hash);
        long generateNanos = System.nanoTime();
        long initDurationMicros = TimeUnit.NANOSECONDS.toMicros(initNanos - startNanos);
        long generateDurationMicros = TimeUnit.NANOSECONDS.toMicros(generateNanos - initNanos);
        long totalDurationMillis = TimeUnit.MICROSECONDS.toMillis(initDurationMicros + generateDurationMicros);
        logger.debug("Generated Argon2 hash in {} ms (init: {} \u00b5s, generate: {} \u00b5s)", new Object[]{totalDurationMillis, initDurationMicros, generateDurationMicros});
        return hash;
    }
}

