package com.atlassian.xwork;

import org.apache.struts2.ServletActionContext;
import org.apache.struts2.dispatcher.LocalizedMessage;
import org.apache.struts2.dispatcher.multipart.MultiPartRequestWrapper;

import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.File;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;

import static java.util.Collections.unmodifiableList;
import static java.util.stream.Collectors.joining;


public class FileUploadUtils {

    public static File getSingleFile() throws FileUploadException {
        UploadedFile uploadedFile = getSingleUploadedFile();
        return uploadedFile == null ? null : uploadedFile.getFile();
    }

    public static UploadedFile getSingleUploadedFile() throws FileUploadException {
        List<UploadedFile> uploadedFiles = getUploadedFiles();
        return uploadedFiles.isEmpty() ? null : uploadedFiles.get(0);
    }

    public static List<UploadedFile> getUploadedFiles() throws FileUploadException {
        return getUploadedFiles(unwrapMultiPartRequest(ServletActionContext.getRequest()), true);
    }

    public static UploadedFile[] handleFileUpload(MultiPartRequestWrapper multiWrapper, boolean clean)
            throws FileUploadException {
        return getUploadedFiles(multiWrapper, clean).toArray(new UploadedFile[0]);
    }

    public static List<UploadedFile> getUploadedFiles(MultiPartRequestWrapper multiWrapper) throws FileUploadException {
        return getUploadedFiles(multiWrapper, true);
    }

    public static List<UploadedFile> getUploadedFiles(MultiPartRequestWrapper multiWrapper, boolean clean) throws FileUploadException {
        checkMultiPartRequestForErrors(multiWrapper);

        Enumeration<String> e = multiWrapper.getFileParameterNames();
        List<UploadedFile> uploadedFiles = new ArrayList<>();

        while (e.hasMoreElements()) {
            // get the value of this input tag
            String inputValue = e.nextElement();

            // Get a File object for the uploaded File
            org.apache.struts2.dispatcher.multipart.UploadedFile[] files = multiWrapper.getFiles(inputValue);

            // support multiple upload controls with the same name
            for (int i = 0; i < files.length; i++) {
                org.apache.struts2.dispatcher.multipart.UploadedFile file = files[i];

                if (file == null) { // If it's null the upload failed
                    if (clean) {
                        continue;
                    } else {
                        FileUploadException fileUploadException = new FileUploadException();
                        fileUploadException.addError(new LocalizedMessage(FileUploadUtils.class,
                                "struts.messages.error.uploading",
                                "Error uploading " + multiWrapper.getFileSystemNames(inputValue)[i],
                                new Object[]{multiWrapper.getFileSystemNames(inputValue)[i]}));
                        throw fileUploadException;
                    }
                }

                UploadedFile uploadedFile = new UploadedFile(new File(file.getAbsolutePath()),
                        multiWrapper.getFileNames(inputValue)[i],
                        multiWrapper.getContentTypes(inputValue)[i]);
                uploadedFiles.add(uploadedFile);
            }
        }
        return uploadedFiles;
    }

    public static MultiPartRequestWrapper unwrapMultiPartRequest(HttpServletRequest request) {
        ServletRequest servletRequest = request;
        while (servletRequest instanceof HttpServletRequestWrapper) {
            if (servletRequest instanceof MultiPartRequestWrapper) {
                return (MultiPartRequestWrapper) servletRequest;
            } else {
                servletRequest = ((HttpServletRequestWrapper) servletRequest).getRequest();
            }
        }
        return null;
    }

    /**
     * The multipart request should always be checked for errors before processing is done on it.
     *
     * @throws FileUploadException
     */
    public static void checkMultiPartRequestForErrors(MultiPartRequestWrapper multiWrapper) throws FileUploadException {
        if (!multiWrapper.hasErrors()) {
            return;
        }
        FileUploadException fileUploadException = new FileUploadException();
        multiWrapper.getErrors().forEach(fileUploadException::addError);
        throw fileUploadException;
    }

    public static final class UploadedFile {
        private final File file;
        private final String fileName;
        private final String contentType;

        public UploadedFile(File file, String fileName, String contentType) {
            this.file = file;
            this.fileName = fileName;
            this.contentType = contentType;
        }

        public File getFile() {
            return file;
        }

        public String getFileName() {
            return fileName;
        }

        public String getContentType() {
            return contentType;
        }
    }

    public static final class FileUploadException extends Exception {
        private final List<LocalizedMessage> errors = new ArrayList<>();

        public void addError(LocalizedMessage error) {
            errors.add(error);
        }

        /**
         * @deprecated since 1.1.0, use {@link #getErrorMsgs()} instead.
         */
        @Deprecated
        public String[] getErrors() {
            return errors.stream().map(LocalizedMessage::getDefaultMessage).toArray(String[]::new);
        }

        /**
         * Use a {@link com.opensymphony.xwork2.TextProvider Struts TextProvider} to localize these errors.
         * Alternatively, you may implement your own localization mechanism, or retrieve the default error message.
         */
        public List<LocalizedMessage> getErrorMsgs() {
            return unmodifiableList(errors);
        }

        @Override
        public String getMessage() {
            return errors.stream().map(LocalizedMessage::getDefaultMessage).collect(joining(", "));
        }
    }
}
