001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019
020package org.apache.shiro.web.filter;
021
022import org.apache.shiro.lang.util.StringUtils;
023import org.apache.shiro.web.util.WebUtils;
024
025import javax.servlet.ServletRequest;
026import javax.servlet.ServletResponse;
027import javax.servlet.http.HttpServletRequest;
028import javax.servlet.http.HttpServletResponse;
029import java.util.Arrays;
030import java.util.Collections;
031import java.util.List;
032import java.util.stream.Stream;
033
034@SuppressWarnings("checkstyle:LineLength")
035/**
036 * A request filter that blocks malicious requests. Invalid request will respond with a 400 response code.
037 * <p>
038 * This filter checks and blocks the request if the following characters are found in the request URI:
039 * <ul>
040 *     <li>Semicolon - can be disabled by setting {@code blockSemicolon = false}</li>
041 *     <li>Backslash - can be disabled by setting {@code blockBackslash = false}</li>
042 *     <li>Non-ASCII characters - can be disabled by setting {@code blockNonAscii = false},
043 *          the ability to disable this check will be removed in future version.</li>
044 *     <li>Path traversals - can be disabled by setting {@code blockTraversal = false}</li>
045 * </ul>
046 *
047 * @see <a href="https://docs.spring.io/spring-security/site/docs/current/api/org/springframework/security/web/firewall/StrictHttpFirewall.html">
048 * This class was inspired by Spring Security StrictHttpFirewall</a>
049 * @since 1.6
050 */
051public class InvalidRequestFilter extends AccessControlFilter {
052
053    private static final List<String> SEMICOLON = Collections.unmodifiableList(Arrays.asList(";", "%3b", "%3B"));
054
055    private static final List<String> BACKSLASH = Collections.unmodifiableList(Arrays.asList("\\", "%5c", "%5C"));
056
057    private static final List<String> FORWARDSLASH = Collections.unmodifiableList(Arrays.asList("%2f", "%2F"));
058
059    private static final List<String> PERIOD = Collections.unmodifiableList(Arrays.asList("%2e", "%2E"));
060
061    private boolean blockSemicolon = true;
062
063    private boolean blockBackslash = !WebUtils.isAllowBackslash();
064
065    private boolean blockNonAscii = true;
066
067    private boolean blockTraversal = true;
068
069    private boolean blockEncodedPeriod = true;
070
071    private boolean blockEncodedForwardSlash = true;
072
073    private boolean blockRewriteTraversal = true;
074
075    @Override
076    protected boolean isAccessAllowed(ServletRequest req, ServletResponse response, Object mappedValue) throws Exception {
077        HttpServletRequest request = WebUtils.toHttp(req);
078        // check the original and decoded values
079        // user request string (not decoded)
080        return isValid(request.getRequestURI())
081                // decoded servlet part
082                && isValid(request.getServletPath())
083                // decoded path info (may be null)
084                && isValid(request.getPathInfo());
085    }
086
087    @SuppressWarnings("checkstyle:BooleanExpressionComplexity")
088    private boolean isValid(String uri) {
089        return !StringUtils.hasText(uri)
090               || (!containsSemicolon(uri)
091                 && !containsBackslash(uri)
092                 && !containsNonAsciiCharacters(uri)
093                 && !containsTraversal(uri)
094                 && !containsEncodedPeriods(uri)
095                 && !containsEncodedForwardSlash(uri));
096    }
097
098    @Override
099    protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws Exception {
100        WebUtils.toHttp(response).sendError(HttpServletResponse.SC_BAD_REQUEST, "Invalid request");
101        return false;
102    }
103
104    private boolean containsSemicolon(String uri) {
105        if (isBlockSemicolon()) {
106            return SEMICOLON.stream().anyMatch(uri::contains);
107        }
108        return false;
109    }
110
111    private boolean containsBackslash(String uri) {
112        if (isBlockBackslash()) {
113            return BACKSLASH.stream().anyMatch(uri::contains);
114        }
115        return false;
116    }
117
118    private boolean containsNonAsciiCharacters(String uri) {
119        if (isBlockNonAscii()) {
120            return !containsOnlyPrintableAsciiCharacters(uri);
121        }
122        return false;
123    }
124
125    private static boolean containsOnlyPrintableAsciiCharacters(String uri) {
126        int length = uri.length();
127        for (int i = 0; i < length; i++) {
128            char c = uri.charAt(i);
129            if (c < '\u0020' || c > '\u007e') {
130                return false;
131            }
132        }
133        return true;
134    }
135
136    private boolean containsTraversal(String uri) {
137        if (isBlockTraversal()) {
138            return !isNormalized(uri)
139                || (isBlockRewriteTraversal() && Stream.of("/..;", "/.;").anyMatch(uri::contains));
140        }
141        return false;
142    }
143
144    private boolean containsEncodedPeriods(String uri) {
145        if (isBlockEncodedPeriod()) {
146            return PERIOD.stream().anyMatch(uri::contains);
147        }
148        return false;
149    }
150
151    private boolean containsEncodedForwardSlash(String uri) {
152        if (isBlockEncodedForwardSlash()) {
153            return FORWARDSLASH.stream().anyMatch(uri::contains);
154        }
155        return false;
156    }
157
158    /**
159     * Checks whether a path is normalized (doesn't contain path traversal sequences like
160     * "./", "/../" or "/.")
161     *
162     * @param path the path to test
163     * @return true if the path doesn't contain any path-traversal character sequences.
164     */
165    private boolean isNormalized(String path) {
166        if (path == null) {
167            return true;
168        }
169        for (int i = path.length(); i > 0; ) {
170            int slashIndex = path.lastIndexOf('/', i - 1);
171            int gap = i - slashIndex;
172            if (gap == 2 && path.charAt(slashIndex + 1) == '.') {
173                // ".", "/./" or "/."
174                return false;
175            }
176            if (gap == 3 && path.charAt(slashIndex + 1) == '.' && path.charAt(slashIndex + 2) == '.') {
177                return false;
178            }
179            i = slashIndex;
180        }
181        return true;
182    }
183
184    public boolean isBlockSemicolon() {
185        return blockSemicolon;
186    }
187
188    public void setBlockSemicolon(boolean blockSemicolon) {
189        this.blockSemicolon = blockSemicolon;
190    }
191
192    public boolean isBlockBackslash() {
193        return blockBackslash;
194    }
195
196    public void setBlockBackslash(boolean blockBackslash) {
197        this.blockBackslash = blockBackslash;
198    }
199
200    public boolean isBlockNonAscii() {
201        return blockNonAscii;
202    }
203
204    public void setBlockNonAscii(boolean blockNonAscii) {
205        this.blockNonAscii = blockNonAscii;
206    }
207
208    public boolean isBlockTraversal() {
209        return blockTraversal;
210    }
211
212    public void setBlockTraversal(boolean blockTraversal) {
213        this.blockTraversal = blockTraversal;
214    }
215
216    public boolean isBlockEncodedPeriod() {
217        return blockEncodedPeriod;
218    }
219
220    public void setBlockEncodedPeriod(boolean blockEncodedPeriod) {
221        this.blockEncodedPeriod = blockEncodedPeriod;
222    }
223
224    public boolean isBlockEncodedForwardSlash() {
225        return blockEncodedForwardSlash;
226    }
227
228    public void setBlockEncodedForwardSlash(boolean blockEncodedForwardSlash) {
229        this.blockEncodedForwardSlash = blockEncodedForwardSlash;
230    }
231
232    public boolean isBlockRewriteTraversal() {
233        return blockRewriteTraversal;
234    }
235
236    public void setBlockRewriteTraversal(boolean blockRewriteTraversal) {
237        this.blockRewriteTraversal = blockRewriteTraversal;
238    }
239}