package io.airlift.security.jwks;

import com.google.common.io.Closer;
import io.airlift.concurrent.Threads;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.Request;
import io.airlift.http.client.StringResponseHandler;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
import java.security.PublicKey;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;

/* loaded from: input_file:io/airlift/security/jwks/JwksService.class */
public final class JwksService {
    private static final Logger log = Logger.get(JwksService.class);
    private final URI address;
    private final HttpClient httpClient;
    private final Duration refreshDelay;
    private final AtomicReference<Map<String, PublicKey>> keys = new AtomicReference<>(fetchKeys());
    private Closer closer;

    public JwksService(URI uri, HttpClient httpClient, Duration duration) {
        this.address = (URI) Objects.requireNonNull(uri, "address is null");
        this.httpClient = (HttpClient) Objects.requireNonNull(httpClient, "httpClient is null");
        this.refreshDelay = (Duration) Objects.requireNonNull(duration, "refreshDelay is null");
    }

    @PostConstruct
    public synchronized void start() {
        if (this.closer != null) {
            return;
        }
        this.closer = Closer.create();
        ScheduledExecutorService newSingleThreadScheduledExecutor = Executors.newSingleThreadScheduledExecutor(Threads.daemonThreadsNamed("JWKS loader"));
        Closer closer = this.closer;
        newSingleThreadScheduledExecutor.getClass();
        closer.register(newSingleThreadScheduledExecutor::shutdownNow);
        ScheduledFuture<?> scheduleWithFixedDelay = newSingleThreadScheduledExecutor.scheduleWithFixedDelay(() -> {
            try {
                refreshKeys();
            } catch (Throwable th) {
                log.error(th, "Error fetching JWKS keys");
            }
        }, this.refreshDelay.toMillis(), this.refreshDelay.toMillis(), TimeUnit.MILLISECONDS);
        this.closer.register(() -> {
            scheduleWithFixedDelay.cancel(true);
        });
    }

    @PreDestroy
    public synchronized void stop() {
        if (this.closer == null) {
            return;
        }
        try {
            try {
                this.closer.close();
                this.closer = null;
            } catch (IOException e) {
                throw new UncheckedIOException("Error stopping JWKS service", e);
            }
        } catch (Throwable th) {
            this.closer = null;
            throw th;
        }
    }

    public Map<String, PublicKey> getKeys() {
        return this.keys.get();
    }

    public Optional<PublicKey> getKey(String str) {
        return Optional.ofNullable(this.keys.get().get(str));
    }

    public void refreshKeys() {
        this.keys.set(fetchKeys());
    }

    private Map<String, PublicKey> fetchKeys() {
        try {
            StringResponseHandler.StringResponse stringResponse = (StringResponseHandler.StringResponse) this.httpClient.execute(Request.Builder.prepareGet().setUri(this.address).build(), StringResponseHandler.createStringResponseHandler());
            if (stringResponse.getStatusCode() != 200) {
                throw new RuntimeException("Unexpected response code " + stringResponse.getStatusCode() + " from JWKS service at " + this.address);
            }
            try {
                return JwksDecoder.decodeKeys(stringResponse.getBody());
            } catch (RuntimeException e) {
                throw new RuntimeException("Unable to decode JWKS response from " + this.address, e);
            }
        } catch (RuntimeException e2) {
            throw new RuntimeException("Error reading JWKS keys from " + this.address, e2);
        }
    }
}
