package com.testautomationguru.ocular.comparator;

import com.testautomationguru.ocular.exception.OcularException;
import org.arquillian.rusheye.oneoff.ImageUtils;
import org.openqa.selenium.Point;
import org.openqa.selenium.*;

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.*;
import java.util.List;
import java.util.*;

public class ImageUtil {

    private final static AlphaComposite COMPOSITE = AlphaComposite.getInstance(AlphaComposite.CLEAR);
    private final static Color TRANSPARENT = new Color(0, 0, 0, 0);

    public static BufferedImage getPageSnapshot(WebDriver driver) {
        File screen = ((TakesScreenshot) driver).getScreenshotAs(OutputType.FILE);
        BufferedImage page;
        try {
            page = ImageIO.read(screen);
        } catch (Exception e) {
            throw new OcularException("Unable to get page snapshot", e);
        }
        return page;
    }

    public static BufferedImage getElementSnapshot(WebDriver driver, WebElement element) {
        Point location;
        try {
            Map bounding = (Map) ((JavascriptExecutor) driver).executeScript(
                    "    let currentWindow = window;\n" +
                            "    let positions = [];\n" +

                            "    while (currentWindow !== window.top) {\n" +
                            "      let currentParentWindow = currentWindow.parent;\n" +
                            "      let iframes = currentParentWindow.document\n" +
                            "        .getElementsByTagName('iframe');\n" +
                            "      positions.push(\n" +
                            "        Array.from(iframes)\n" +
                            "          .find(frameElement => frameElement.contentWindow === currentWindow)\n" +
                            "          .getBoundingClientRect()\n" +
                            "      );\n" +
                            "      currentWindow = currentParentWindow;\n" +
                            "    }\n" +

                            "    let rect = arguments[0].getBoundingClientRect();\n" +
                            "    let frame = positions.reduce((accumulator, currentValue) => {\n" +
                            "      return {\n" +
                            "        x: accumulator.x + currentValue.x,\n" +
                            "        y: accumulator.y + currentValue.y\n" +
                            "      };\n" +
                            "    }, {x: 0, y: 0});\n" +

                            "    return {x: rect.x + frame.x, y: rect.y + frame.y};"
                    , element);
            int x = ((Number) bounding.get("x")).intValue();
            int y = ((Number) bounding.get("y")).intValue();
            location = new Point(x, y);
        } catch (ClassCastException e) {
            //falling back to WebElement::getLocation
            location = element.getLocation();
        }

        int width = element.getSize().getWidth();
        int height = element.getSize().getHeight();
        return getPageSnapshot(driver).getSubimage(location.getX(), location.getY(), width, height);
    }

    public static BufferedImage maskElement(BufferedImage img, WebElement element) {
        return maskArea(img, element);
    }

    public static BufferedImage maskElements(BufferedImage img, List<WebElement> elements) {
        for (WebElement element : elements) {
            img = maskArea(img, element);
        }
        return img;
    }

    public static void saveImage(BufferedImage result, File file) {
        try {
            ImageUtils.writeImage(result, file.getParentFile(), file.getName());
        } catch (IOException e) {
            throw new OcularException("Unable to write the difference", e);
        }
    }

    private static BufferedImage maskArea(BufferedImage img, WebElement element) {
        Graphics2D g2d = (Graphics2D) img.getGraphics();
        g2d.setComposite(COMPOSITE);
        g2d.setColor(TRANSPARENT);

        Point p = element.getLocation();
        int width = element.getSize().getWidth();
        int height = element.getSize().getHeight();
        g2d.fillRect(p.getX(), p.getY(), width, height);

        return img;
    }
}
