MultipartUploadTest.java

/*
 * Copyright (c) 2010-2012 Sonatype, Inc. All rights reserved.
 *
 * This program is licensed to you under the Apache License Version 2.0,
 * and you may not use this file except in compliance with the Apache License Version 2.0.
 * You may obtain a copy of the Apache License Version 2.0 at http://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the Apache License Version 2.0 is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the Apache License Version 2.0 for the specific language governing permissions and limitations there under.
 */
package org.asynchttpclient.request.body.multipart;

import io.github.artsok.RepeatedIfExceptionsTest;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.apache.commons.fileupload2.FileItemIterator;
import org.apache.commons.fileupload2.FileItemStream;
import org.apache.commons.fileupload2.FileUploadException;
import org.apache.commons.fileupload2.jaksrvlt.JakSrvltFileUpload;
import org.apache.commons.fileupload2.util.Streams;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.asynchttpclient.AbstractBasicTest;
import org.asynchttpclient.AsyncHttpClient;
import org.asynchttpclient.Request;
import org.asynchttpclient.Response;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder;
import org.junit.jupiter.api.BeforeEach;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Writer;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.zip.GZIPInputStream;

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.asynchttpclient.Dsl.asyncHttpClient;
import static org.asynchttpclient.Dsl.config;
import static org.asynchttpclient.Dsl.post;
import static org.asynchttpclient.test.TestUtils.addHttpConnector;
import static org.asynchttpclient.test.TestUtils.getClasspathFile;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

/**
 * @author dominict
 */
public class MultipartUploadTest extends AbstractBasicTest {

    @BeforeEach
    public void setUp() throws Exception {
        server = new Server();
        ServerConnector connector = addHttpConnector(server);
        ServletContextHandler context = new ServletContextHandler(ServletContextHandler.SESSIONS);
        context.addServlet(new ServletHolder(new MockMultipartUploadServlet()), "/upload");
        server.setHandler(context);
        server.start();
        port1 = connector.getLocalPort();
    }

    @RepeatedIfExceptionsTest(repeats = 5)
    public void testSendingSmallFilesAndByteArray() throws Exception {
        String expectedContents = "filecontent: hello";
        String expectedContents2 = "gzipcontent: hello";
        String expectedContents3 = "filecontent: hello2";
        String testResource1 = "textfile.txt";
        String testResource2 = "gzip.txt.gz";
        String testResource3 = "textfile2.txt";

        File testResource1File = getClasspathFile(testResource1);
        File testResource2File = getClasspathFile(testResource2);
        File testResource3File = getClasspathFile(testResource3);
        InputStream inputStreamFile1 = new BufferedInputStream(new FileInputStream(testResource1File));
        InputStream inputStreamFile2 = new BufferedInputStream(new FileInputStream(testResource2File));
        InputStream inputStreamFile3 = new BufferedInputStream(new FileInputStream(testResource3File));

        List<File> testFiles = new ArrayList<>();
        testFiles.add(testResource1File);
        testFiles.add(testResource2File);
        testFiles.add(testResource3File);
        testFiles.add(testResource3File);
        testFiles.add(testResource2File);
        testFiles.add(testResource1File);

        List<String> expected = new ArrayList<>();
        expected.add(expectedContents);
        expected.add(expectedContents2);
        expected.add(expectedContents3);
        expected.add(expectedContents3);
        expected.add(expectedContents2);
        expected.add(expectedContents);

        List<Boolean> gzipped = new ArrayList<>();
        gzipped.add(false);
        gzipped.add(true);
        gzipped.add(false);
        gzipped.add(false);
        gzipped.add(true);
        gzipped.add(false);

        File tmpFile = File.createTempFile("textbytearray", ".txt");
        try (OutputStream os = Files.newOutputStream(tmpFile.toPath())) {
            IOUtils.write(expectedContents.getBytes(UTF_8), os);

            testFiles.add(tmpFile);
            expected.add(expectedContents);
            gzipped.add(false);
        }

        try (AsyncHttpClient c = asyncHttpClient(config())) {
            Request r = post("http://localhost" + ':' + port1 + "/upload")
                    .addBodyPart(new FilePart("file1", testResource1File, "text/plain", UTF_8))
                    .addBodyPart(new FilePart("file2", testResource2File, "application/x-gzip", null))
                    .addBodyPart(new StringPart("Name", "Dominic"))
                    .addBodyPart(new FilePart("file3", testResource3File, "text/plain", UTF_8))
                    .addBodyPart(new StringPart("Age", "3")).addBodyPart(new StringPart("Height", "shrimplike"))
                    .addBodyPart(new InputStreamPart("inputStream3", inputStreamFile3, testResource3File.getName(), testResource3File.length(), "text/plain", UTF_8))
                    .addBodyPart(new InputStreamPart("inputStream2", inputStreamFile2, testResource2File.getName(), testResource2File.length(), "application/x-gzip", null))
                    .addBodyPart(new StringPart("Hair", "ridiculous")).addBodyPart(new ByteArrayPart("file4",
                            expectedContents.getBytes(UTF_8), "text/plain", UTF_8, "bytearray.txt"))
                    .addBodyPart(new InputStreamPart("inputStream1", inputStreamFile1, testResource1File.getName(), testResource1File.length(), "text/plain", UTF_8))
                    .build();

            Response res = c.executeRequest(r).get();

            assertEquals(200, res.getStatusCode());

            testSentFile(expected, testFiles, res, gzipped);
        }
    }

    private void sendEmptyFile0(boolean disableZeroCopy) throws Exception {
        File file = getClasspathFile("empty.txt");
        try (AsyncHttpClient client = asyncHttpClient(config().setDisableZeroCopy(disableZeroCopy))) {
            Request r = post("http://localhost" + ':' + port1 + "/upload")
                    .addBodyPart(new FilePart("file", file, "text/plain", UTF_8)).build();

            Response res = client.executeRequest(r).get();
            assertEquals(res.getStatusCode(), 200);
        }
    }

    @RepeatedIfExceptionsTest(repeats = 5)
    public void sendEmptyFile() throws Exception {
        sendEmptyFile0(true);
    }

    @RepeatedIfExceptionsTest(repeats = 5)
    public void sendEmptyFileZeroCopy() throws Exception {
        sendEmptyFile0(false);
    }

    private void sendEmptyFileInputStream(boolean disableZeroCopy) throws Exception {
        File file = getClasspathFile("empty.txt");
        try (AsyncHttpClient client = asyncHttpClient(config().setDisableZeroCopy(disableZeroCopy));
             InputStream inputStream = new BufferedInputStream(new FileInputStream(file))) {
            Request r = post("http://localhost" + ':' + port1 + "/upload")
                    .addBodyPart(new InputStreamPart("file", inputStream, file.getName(), file.length(), "text/plain", UTF_8)).build();

            Response res = client.executeRequest(r).get();
            assertEquals(200, res.getStatusCode());
        }
    }

    @RepeatedIfExceptionsTest(repeats = 5)
    public void testSendEmptyFileInputStream() throws Exception {
        sendEmptyFileInputStream(true);
    }

    @RepeatedIfExceptionsTest(repeats = 5)
    public void testSendEmptyFileInputStreamZeroCopy() throws Exception {
        sendEmptyFileInputStream(false);
    }

    private void sendFileInputStream(boolean useContentLength, boolean disableZeroCopy) throws Exception {
        File file = getClasspathFile("textfile.txt");
        try (AsyncHttpClient c = asyncHttpClient(config().setDisableZeroCopy(disableZeroCopy));
             InputStream inputStream = new BufferedInputStream(new FileInputStream(file))) {

            InputStreamPart part;
            if (useContentLength) {
                part = new InputStreamPart("file", inputStream, file.getName(), file.length());
            } else {
                part = new InputStreamPart("file", inputStream, file.getName());
            }
            Request r = post("http://localhost" + ':' + port1 + "/upload").addBodyPart(part).build();

            Response res = c.executeRequest(r).get();
            assertEquals(200, res.getStatusCode());
        } catch (ExecutionException ex) {
            ex.getCause().printStackTrace();
            throw ex;
        }
    }

    @RepeatedIfExceptionsTest(repeats = 5)
    public void testSendFileInputStreamUnknownContentLength() throws Exception {
        sendFileInputStream(false, true);
    }

    @RepeatedIfExceptionsTest(repeats = 5)
    public void testSendFileInputStreamZeroCopyUnknownContentLength() throws Exception {
        sendFileInputStream(false, false);
    }

    @RepeatedIfExceptionsTest(repeats = 5)
    public void testSendFileInputStreamKnownContentLength() throws Exception {
        sendFileInputStream(true, true);
    }

    @RepeatedIfExceptionsTest(repeats = 5)
    public void testSendFileInputStreamZeroCopyKnownContentLength() throws Exception {
        sendFileInputStream(true, false);
    }

    /**
     * Test that the files were sent, based on the response from the servlet
     */
    private static void testSentFile(List<String> expectedContents, List<File> sourceFiles, Response r,
                                     List<Boolean> deflate) throws IOException {
        String content = r.getResponseBody();
        assertNotNull(content);
        logger.debug(content);

        String[] contentArray = content.split("\\|\\|");
        // TODO: this fail on win32
        assertEquals(contentArray.length, 2);

        String tmpFiles = contentArray[1];
        assertNotNull(tmpFiles);
        assertTrue(tmpFiles.trim().length() > 2);
        tmpFiles = tmpFiles.substring(1, tmpFiles.length() - 1);

        String[] responseFiles = tmpFiles.split(",");
        assertNotNull(responseFiles);
        assertEquals(responseFiles.length, sourceFiles.size());

        logger.debug(Arrays.toString(responseFiles));

        int i = 0;
        for (File sourceFile : sourceFiles) {

            File tmp = null;
            try {

                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                byte[] sourceBytes;
                try (InputStream instream = Files.newInputStream(sourceFile.toPath())) {
                    byte[] buf = new byte[8092];
                    int len;
                    while ((len = instream.read(buf)) > 0) {
                        baos.write(buf, 0, len);
                    }
                    logger.debug("================");
                    logger.debug("Length of file: " + baos.toByteArray().length);
                    logger.debug("Contents: " + Arrays.toString(baos.toByteArray()));
                    logger.debug("================");
                    System.out.flush();
                    sourceBytes = baos.toByteArray();
                }

                tmp = new File(responseFiles[i].trim());
                logger.debug("==============================");
                logger.debug(tmp.getAbsolutePath());
                logger.debug("==============================");
                System.out.flush();
                assertTrue(tmp.exists());

                byte[] bytes;
                try (InputStream instream = Files.newInputStream(tmp.toPath())) {
                    ByteArrayOutputStream baos2 = new ByteArrayOutputStream();
                    byte[] buf = new byte[8092];
                    int len;
                    while ((len = instream.read(buf)) > 0) {
                        baos2.write(buf, 0, len);
                    }
                    bytes = baos2.toByteArray();
                    assertArrayEquals(bytes, sourceBytes);
                }

                if (!deflate.get(i)) {
                    String helloString = new String(bytes);
                    assertEquals(helloString, expectedContents.get(i));
                } else {
                    try (InputStream instream = Files.newInputStream(tmp.toPath())) {
                        ByteArrayOutputStream baos3 = new ByteArrayOutputStream();
                        GZIPInputStream deflater = new GZIPInputStream(instream);
                        try {
                            byte[] buf3 = new byte[8092];
                            int len3;
                            while ((len3 = deflater.read(buf3)) > 0) {
                                baos3.write(buf3, 0, len3);
                            }
                        } finally {
                            deflater.close();
                        }

                        String helloString = baos3.toString();

                        assertEquals(expectedContents.get(i), helloString);
                    }
                }
            } catch (Exception e) {
                throw e;
            } finally {
                if (tmp != null) {
                    FileUtils.deleteQuietly(tmp);
                }
                i++;
            }
        }
    }


    public static class MockMultipartUploadServlet extends HttpServlet {

        private static final Logger LOGGER = LoggerFactory.getLogger(MockMultipartUploadServlet.class);

        private static final long serialVersionUID = 1L;
        private int filesProcessed;
        private int stringsProcessed;

        MockMultipartUploadServlet() {
            stringsProcessed = 0;
        }

        synchronized void resetFilesProcessed() {
            filesProcessed = 0;
        }

        private synchronized int incrementFilesProcessed() {
            return ++filesProcessed;
        }

        int getFilesProcessed() {
            return filesProcessed;
        }

        synchronized void resetStringsProcessed() {
            stringsProcessed = 0;
        }

        private synchronized int incrementStringsProcessed() {
            return ++stringsProcessed;

        }

        public int getStringsProcessed() {
            return stringsProcessed;
        }

        @Override
        public void service(HttpServletRequest request, HttpServletResponse response) throws IOException {
            // Check that we have a file upload request
            boolean isMultipart = JakSrvltFileUpload.isMultipartContent(request);
            if (isMultipart) {
                List<String> files = new ArrayList<>();
                JakSrvltFileUpload upload = new JakSrvltFileUpload();
                // Parse the request
                FileItemIterator iter;
                try {
                    iter = upload.getItemIterator(request);
                    while (iter.hasNext()) {
                        FileItemStream item = iter.next();
                        String name = item.getFieldName();
                        try (InputStream stream = item.openStream()) {

                            if (item.isFormField()) {
                                LOGGER.debug("Form field " + name + " with value " + Streams.asString(stream) + " detected.");
                                incrementStringsProcessed();
                            } else {
                                LOGGER.debug("File field " + name + " with file name " + item.getName() + " detected.");
                                // Process the input stream
                                File tmpFile = File.createTempFile(UUID.randomUUID() + "_MockUploadServlet",
                                        ".tmp");
                                tmpFile.deleteOnExit();
                                try (OutputStream os = Files.newOutputStream(tmpFile.toPath())) {
                                    byte[] buffer = new byte[4096];
                                    int bytesRead;
                                    while ((bytesRead = stream.read(buffer)) != -1) {
                                        os.write(buffer, 0, bytesRead);
                                    }
                                    incrementFilesProcessed();
                                    files.add(tmpFile.getAbsolutePath());
                                }
                            }
                        }
                    }
                } catch (FileUploadException e) {
                    //
                }
                try (Writer w = response.getWriter()) {
                    w.write(Integer.toString(getFilesProcessed()));
                    resetFilesProcessed();
                    resetStringsProcessed();
                    w.write("||");
                    w.write(files.toString());
                }
            } else {
                try (Writer w = response.getWriter()) {
                    w.write(Integer.toString(getFilesProcessed()));
                    resetFilesProcessed();
                    resetStringsProcessed();
                    w.write("||");
                }
            }
        }
    }
}