/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.dashscope.image;

import com.alibaba.cloud.ai.dashscope.api.DashScopeImageApi;
import com.alibaba.cloud.ai.dashscope.common.DashScopeApiConstants;
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions;
import com.alibaba.cloud.ai.dashscope.image.observation.DashScopeImageModelObservationConvention;
import com.alibaba.cloud.ai.dashscope.image.observation.DashScopeImagePromptContentObservationHandler;
import com.alibaba.cloud.ai.dashscope.spec.DashScopeApiSpec;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationConvention;
import io.micrometer.observation.ObservationHandler;
import io.micrometer.observation.ObservationRegistry;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.image.Image;
import org.springframework.ai.image.ImageGeneration;
import org.springframework.ai.image.ImageMessage;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.image.ImageResponseMetadata;
import org.springframework.ai.image.observation.DefaultImageModelObservationConvention;
import org.springframework.ai.image.observation.ImageModelObservationContext;
import org.springframework.ai.image.observation.ImageModelObservationConvention;
import org.springframework.ai.image.observation.ImageModelObservationDocumentation;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.retry.TransientAiException;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

public class DashScopeImageModel
implements ImageModel {
    private static final Logger logger = LoggerFactory.getLogger(DashScopeImageModel.class);
    private static final String DEFAULT_MODEL = "wanx-v1";
    private final DashScopeImageApi dashScopeImageApi;
    private final DashScopeImageOptions defaultOptions;
    private final RetryTemplate retryTemplate;
    private final ObservationRegistry observationRegistry;
    private ImageModelObservationConvention observationConvention = new DefaultImageModelObservationConvention();

    public DashScopeImageModel(DashScopeImageApi dashScopeImageApi, DashScopeImageOptions options, RetryTemplate retryTemplate) {
        this(dashScopeImageApi, options, retryTemplate, ObservationRegistry.NOOP);
    }

    public DashScopeImageModel(DashScopeImageApi dashScopeImageApi) {
        this(dashScopeImageApi, DashScopeImageOptions.builder().model(DashScopeImageApi.DEFAULT_IMAGE_MODEL).build(), RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP);
    }

    public DashScopeImageModel(DashScopeImageApi dashScopeImageApi, DashScopeImageOptions options) {
        this(dashScopeImageApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP);
    }

    public DashScopeImageModel(DashScopeImageApi dashScopeImageApi, ObservationRegistry observationRegistry) {
        this(dashScopeImageApi, DashScopeImageOptions.builder().model(DashScopeImageApi.DEFAULT_IMAGE_MODEL).build(), RetryUtils.DEFAULT_RETRY_TEMPLATE, observationRegistry);
    }

    public DashScopeImageModel(DashScopeImageApi dashScopeImageApi, DashScopeImageOptions options, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
        Assert.notNull((Object)dashScopeImageApi, (String)"DashScopeImageApi must not be null");
        Assert.notNull((Object)options, (String)"options must not be null");
        Assert.notNull((Object)retryTemplate, (String)"retryTemplate must not be null");
        Assert.notNull((Object)observationRegistry, (String)"observationRegistry must not be null");
        this.dashScopeImageApi = dashScopeImageApi;
        this.defaultOptions = options;
        this.retryTemplate = retryTemplate;
        this.observationRegistry = observationRegistry;
        this.observationRegistry.observationConfig().observationHandler((ObservationHandler)new DashScopeImagePromptContentObservationHandler());
        this.observationConvention = new DashScopeImageModelObservationConvention();
    }

    public static Builder builder() {
        return new Builder();
    }

    public ImageResponse call(ImagePrompt request) {
        Assert.notNull((Object)request, (String)"Prompt must not be null");
        Assert.isTrue((!CollectionUtils.isEmpty((Collection)request.getInstructions()) ? 1 : 0) != 0, (String)"Prompt messages must not be empty");
        String taskId = this.submitImageGenTask(request);
        if (taskId == null) {
            return new ImageResponse(List.of(), this.toMetadataEmpty());
        }
        ImageModelObservationContext observationContext = ImageModelObservationContext.builder().imagePrompt(request).provider(DashScopeApiConstants.PROVIDER_NAME).build();
        Observation observation = ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION.observation((ObservationConvention)this.observationConvention, (ObservationConvention)new DefaultImageModelObservationConvention(), () -> observationContext, this.observationRegistry);
        return Objects.requireNonNull((ImageResponse)observation.observe(() -> (ImageResponse)this.retryTemplate.execute(ctx -> {
            observation.lowCardinalityKeyValue("retry.attempt", String.valueOf(ctx.getRetryCount()));
            DashScopeApiSpec.DashScopeImageAsyncResponse resp = this.getImageGenTask(taskId);
            if (resp != null) {
                String status = resp.output().taskStatus();
                observation.lowCardinalityKeyValue("task.status", status);
                switch (status) {
                    case "SUCCEEDED": {
                        return this.toImageResponse(resp);
                    }
                    case "FAILED": 
                    case "UNKNOWN": {
                        return new ImageResponse(List.of(), this.toMetadata(resp));
                    }
                }
            }
            throw new TransientAiException("Image generation still pending");
        }, context -> {
            observation.lowCardinalityKeyValue("timeout", "true");
            return new ImageResponse(List.of(), this.toMetadataTimeout(taskId));
        })));
    }

    public String submitImageGenTask(ImagePrompt request) {
        DashScopeImageOptions imageOptions = this.toImageOptions(request.getOptions());
        logger.debug("Image options: {}", (Object)imageOptions);
        DashScopeApiSpec.DashScopeImageRequest dashScopeImageRequest = this.constructImageRequest(request, imageOptions);
        ResponseEntity<DashScopeApiSpec.DashScopeImageAsyncResponse> submitResponse = this.dashScopeImageApi.submitImageGenTask(dashScopeImageRequest);
        if (submitResponse == null || submitResponse.getBody() == null) {
            logger.warn("Submit imageGen error,request: {}", (Object)request);
            return null;
        }
        return ((DashScopeApiSpec.DashScopeImageAsyncResponse)submitResponse.getBody()).output().taskId();
    }

    private DashScopeImageOptions toImageOptions(ImageOptions runtimeOptions) {
        DashScopeImageOptions currentOptions = DashScopeImageOptions.builder().model(DEFAULT_MODEL).build();
        if (Objects.nonNull(runtimeOptions)) {
            currentOptions = (DashScopeImageOptions)ModelOptionsUtils.copyToTarget((Object)runtimeOptions, ImageOptions.class, DashScopeImageOptions.class);
        }
        currentOptions = (DashScopeImageOptions)ModelOptionsUtils.merge((Object)currentOptions, (Object)this.defaultOptions, DashScopeImageOptions.class);
        return currentOptions;
    }

    public DashScopeApiSpec.DashScopeImageAsyncResponse getImageGenTask(String taskId) {
        ResponseEntity<DashScopeApiSpec.DashScopeImageAsyncResponse> getImageGenResponse = this.dashScopeImageApi.getImageGenTaskResult(taskId);
        if (getImageGenResponse == null || getImageGenResponse.getBody() == null) {
            logger.warn("No image response returned for taskId: {}", (Object)taskId);
            return null;
        }
        return (DashScopeApiSpec.DashScopeImageAsyncResponse)getImageGenResponse.getBody();
    }

    public DashScopeImageOptions getOptions() {
        return this.defaultOptions;
    }

    private ImageResponse toImageResponse(DashScopeApiSpec.DashScopeImageAsyncResponse asyncResp) {
        DashScopeApiSpec.DashScopeImageAsyncResponse.DashScopeImageAsyncResponseOutput output = asyncResp.output();
        List<DashScopeApiSpec.DashScopeImageAsyncResponse.DashScopeImageAsyncResponseResult> results = output.results();
        ImageResponseMetadata md = this.toMetadata(asyncResp);
        List gens = results == null ? List.of() : results.stream().map(r -> new ImageGeneration(new Image(r.url(), null))).toList();
        return new ImageResponse(gens, md);
    }

    private DashScopeApiSpec.DashScopeImageRequest constructImageRequest(ImagePrompt imagePrompt, DashScopeImageOptions options) {
        return new DashScopeApiSpec.DashScopeImageRequest(options.getModel(), new DashScopeApiSpec.DashScopeImageRequest.DashScopeImageRequestInput(((ImageMessage)imagePrompt.getInstructions().get(0)).getText(), options.getNegativePrompt(), options.getRefImg(), options.getFunction(), options.getBaseImageUrl(), options.getMaskImageUrl(), options.getSketchImageUrl()), new DashScopeApiSpec.DashScopeImageRequest.DashScopeImageRequestParameter(options.getStyle(), options.getSize(), options.getN(), options.getSeed(), options.getRefStrength(), options.getRefMode(), options.getPromptExtend(), options.getWatermark(), options.getSketchWeight(), options.getSketchExtraction(), options.getSketchColor(), options.getMaskColor()));
    }

    private ImageResponseMetadata toMetadata(DashScopeApiSpec.DashScopeImageAsyncResponse re) {
        DashScopeApiSpec.DashScopeImageAsyncResponse.DashScopeImageAsyncResponseOutput out = re.output();
        DashScopeApiSpec.DashScopeImageAsyncResponse.DashScopeImageAsyncResponseTaskMetrics tm = out.taskMetrics();
        DashScopeApiSpec.DashScopeImageAsyncResponse.DashScopeImageAsyncResponseUsage usage = re.usage();
        ImageResponseMetadata md = new ImageResponseMetadata();
        Optional.ofNullable(usage).map(DashScopeApiSpec.DashScopeImageAsyncResponse.DashScopeImageAsyncResponseUsage::imageCount).ifPresent(count -> md.put("imageCount", count));
        Optional.ofNullable(tm).ifPresent(metrics -> {
            md.put("taskTotal", (Object)metrics.total());
            md.put("taskSucceeded", (Object)metrics.SUCCEEDED());
            md.put("taskFailed", (Object)metrics.FAILED());
        });
        md.put("requestId", (Object)re.requestId());
        md.put("taskStatus", (Object)out.taskStatus());
        Optional.ofNullable(out.code()).ifPresent(code -> md.put("code", code));
        Optional.ofNullable(out.message()).ifPresent(msg -> md.put("message", msg));
        return md;
    }

    private ImageResponseMetadata toMetadataEmpty() {
        ImageResponseMetadata md = new ImageResponseMetadata();
        md.put("taskStatus", (Object)"NO_TASK_ID");
        return md;
    }

    private ImageResponseMetadata toMetadataTimeout(String taskId) {
        ImageResponseMetadata md = new ImageResponseMetadata();
        md.put("taskId", (Object)taskId);
        md.put("taskStatus", (Object)"TIMED_OUT");
        return md;
    }

    public void setObservationConvention(ImageModelObservationConvention observationConvention) {
        Assert.notNull((Object)observationConvention, (String)"observationConvention cannot be null");
        this.observationConvention = observationConvention;
    }

    public static final class Builder {
        private DashScopeImageApi dashScopeImageApi;
        private DashScopeImageOptions defaultOptions = DashScopeImageOptions.builder().model("wanx-v1").n(1).build();
        private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
        private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
        private ImageModelObservationConvention observationConvention = new DashScopeImageModelObservationConvention();
        private ObservationHandler<ImageModelObservationContext> promptHandler = new DashScopeImagePromptContentObservationHandler();

        private Builder() {
        }

        public Builder dashScopeApi(DashScopeImageApi dashScopeImageApi) {
            this.dashScopeImageApi = dashScopeImageApi;
            return this;
        }

        public Builder defaultOptions(DashScopeImageOptions defaultOptions) {
            this.defaultOptions = defaultOptions;
            return this;
        }

        public Builder retryTemplate(RetryTemplate retryTemplate) {
            this.retryTemplate = retryTemplate;
            return this;
        }

        public Builder observationRegistry(ObservationRegistry observationRegistry) {
            this.observationRegistry = observationRegistry;
            return this;
        }

        public Builder observationConvention(ImageModelObservationConvention observationConvention) {
            this.observationConvention = observationConvention;
            return this;
        }

        public Builder promptHandler(ObservationHandler<ImageModelObservationContext> promptHandler) {
            this.promptHandler = promptHandler;
            return this;
        }

        public DashScopeImageModel build() {
            DashScopeImageModel model = new DashScopeImageModel(this.dashScopeImageApi, this.defaultOptions, this.retryTemplate, this.observationRegistry);
            model.setObservationConvention(this.observationConvention);
            this.observationRegistry.observationConfig().observationHandler(this.promptHandler);
            return model;
        }
    }
}

