001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.transport.nio;
019
020import java.io.DataInputStream;
021import java.io.DataOutputStream;
022import java.io.EOFException;
023import java.io.IOException;
024import java.net.Socket;
025import java.net.SocketTimeoutException;
026import java.net.URI;
027import java.net.UnknownHostException;
028import java.nio.ByteBuffer;
029import java.nio.channels.SelectionKey;
030import java.nio.channels.Selector;
031import java.security.cert.X509Certificate;
032import java.util.concurrent.CountDownLatch;
033
034import javax.net.SocketFactory;
035import javax.net.ssl.SSLContext;
036import javax.net.ssl.SSLEngine;
037import javax.net.ssl.SSLEngineResult;
038import javax.net.ssl.SSLEngineResult.HandshakeStatus;
039import javax.net.ssl.SSLParameters;
040import javax.net.ssl.SSLPeerUnverifiedException;
041import javax.net.ssl.SSLSession;
042
043import org.apache.activemq.command.ConnectionInfo;
044import org.apache.activemq.openwire.OpenWireFormat;
045import org.apache.activemq.thread.TaskRunnerFactory;
046import org.apache.activemq.util.IOExceptionSupport;
047import org.apache.activemq.util.ServiceStopper;
048import org.apache.activemq.wireformat.WireFormat;
049import org.slf4j.Logger;
050import org.slf4j.LoggerFactory;
051
052public class NIOSSLTransport extends NIOTransport {
053
054    private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class);
055
056    protected boolean needClientAuth;
057    protected boolean wantClientAuth;
058    protected String[] enabledCipherSuites;
059    protected String[] enabledProtocols;
060    protected boolean verifyHostName = false;
061
062    protected SSLContext sslContext;
063    protected SSLEngine sslEngine;
064    protected SSLSession sslSession;
065
066    protected volatile boolean handshakeInProgress = false;
067    protected SSLEngineResult.Status status = null;
068    protected SSLEngineResult.HandshakeStatus handshakeStatus = null;
069    protected TaskRunnerFactory taskRunnerFactory;
070
071    public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
072        super(wireFormat, socketFactory, remoteLocation, localLocation);
073    }
074
075    public NIOSSLTransport(WireFormat wireFormat, Socket socket, SSLEngine engine, InitBuffer initBuffer,
076            ByteBuffer inputBuffer) throws IOException {
077        super(wireFormat, socket, initBuffer);
078        this.sslEngine = engine;
079        if (engine != null) {
080            this.sslSession = engine.getSession();
081        }
082        this.inputBuffer = inputBuffer;
083    }
084
085    public void setSslContext(SSLContext sslContext) {
086        this.sslContext = sslContext;
087    }
088
089    volatile boolean hasSslEngine = false;
090
091    @Override
092    protected void initializeStreams() throws IOException {
093        if (sslEngine != null) {
094            hasSslEngine = true;
095        }
096        NIOOutputStream outputStream = null;
097        try {
098            channel = socket.getChannel();
099            channel.configureBlocking(false);
100
101            if (sslContext == null) {
102                sslContext = SSLContext.getDefault();
103            }
104
105            String remoteHost = null;
106            int remotePort = -1;
107
108            try {
109                URI remoteAddress = new URI(this.getRemoteAddress());
110                remoteHost = remoteAddress.getHost();
111                remotePort = remoteAddress.getPort();
112            } catch (Exception e) {
113            }
114
115            // initialize engine, the initial sslSession we get will need to be
116            // updated once the ssl handshake process is completed.
117            if (!hasSslEngine) {
118                if (remoteHost != null && remotePort != -1) {
119                    sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
120                } else {
121                    sslEngine = sslContext.createSSLEngine();
122                }
123
124                if (verifyHostName) {
125                    SSLParameters sslParams = new SSLParameters();
126                    sslParams.setEndpointIdentificationAlgorithm("HTTPS");
127                    sslEngine.setSSLParameters(sslParams);
128                }
129
130                sslEngine.setUseClientMode(false);
131                if (enabledCipherSuites != null) {
132                    sslEngine.setEnabledCipherSuites(enabledCipherSuites);
133                }
134
135                if (enabledProtocols != null) {
136                    sslEngine.setEnabledProtocols(enabledProtocols);
137                }
138
139                if (wantClientAuth) {
140                    sslEngine.setWantClientAuth(wantClientAuth);
141                }
142
143                if (needClientAuth) {
144                    sslEngine.setNeedClientAuth(needClientAuth);
145                }
146
147                sslSession = sslEngine.getSession();
148
149                inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
150                inputBuffer.clear();
151            }
152
153            outputStream = new NIOOutputStream(channel);
154            outputStream.setEngine(sslEngine);
155            this.dataOut = new DataOutputStream(outputStream);
156            this.buffOut = outputStream;
157
158            //If the sslEngine was not passed in, then handshake
159            if (!hasSslEngine) {
160                sslEngine.beginHandshake();
161            }
162            handshakeStatus = sslEngine.getHandshakeStatus();
163            if (!hasSslEngine) {
164                doHandshake();
165            }
166
167            selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
168                @Override
169                public void onSelect(SelectorSelection selection) {
170                    try {
171                        initialized.await();
172                    } catch (InterruptedException error) {
173                        onException(IOExceptionSupport.create(error));
174                    }
175                    serviceRead();
176                }
177
178                @Override
179                public void onError(SelectorSelection selection, Throwable error) {
180                    if (error instanceof IOException) {
181                        onException((IOException) error);
182                    } else {
183                        onException(IOExceptionSupport.create(error));
184                    }
185                }
186            });
187            doInit();
188
189        } catch (Exception e) {
190            try {
191                if(outputStream != null) {
192                    outputStream.close();
193                }
194                super.closeStreams();
195            } catch (Exception ex) {}
196            throw new IOException(e);
197        }
198    }
199
200    final protected CountDownLatch initialized = new CountDownLatch(1);
201
202    protected void doInit() throws Exception {
203        taskRunnerFactory.execute(new Runnable() {
204
205            @Override
206            public void run() {
207                //Need to start in new thread to let startup finish first
208                //We can trigger a read because we know the channel is ready since the SSL handshake
209                //already happened
210                serviceRead();
211                initialized.countDown();
212            }
213        });
214    }
215
216    //Only used for the auto transport to abort the openwire init method early if already initialized
217    boolean openWireInititialized = false;
218
219    protected void doOpenWireInit() throws Exception {
220        //Do this later to let wire format negotiation happen
221        if (initBuffer != null && !openWireInititialized && this.wireFormat instanceof OpenWireFormat) {
222            initBuffer.buffer.flip();
223            if (initBuffer.buffer.hasRemaining()) {
224                nextFrameSize = -1;
225                receiveCounter += initBuffer.readSize;
226                processCommand(initBuffer.buffer);
227                processCommand(initBuffer.buffer);
228                initBuffer.buffer.clear();
229                openWireInititialized = true;
230            }
231        }
232    }
233
234    protected void finishHandshake() throws Exception {
235        if (handshakeInProgress) {
236            handshakeInProgress = false;
237            nextFrameSize = -1;
238
239            // Once handshake completes we need to ask for the now real sslSession
240            // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
241            // cipher suite.
242            sslSession = sslEngine.getSession();
243        }
244    }
245
246    //Prevent concurrent access to SSLEngine
247    @Override
248    public synchronized void serviceRead() {
249        try {
250            if (handshakeInProgress) {
251                doHandshake();
252            }
253
254            doOpenWireInit();
255
256            ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
257            plain.position(plain.limit());
258
259            while (true) {
260                if (!plain.hasRemaining()) {
261
262                    int readCount = secureRead(plain);
263
264                    if (readCount == 0) {
265                        break;
266                    }
267
268                    // channel is closed, cleanup
269                    if (readCount == -1) {
270                        onException(new EOFException());
271                        selection.close();
272                        break;
273                    }
274
275                    receiveCounter += readCount;
276                }
277
278                if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
279                    processCommand(plain);
280                }
281            }
282        } catch (IOException e) {
283            onException(e);
284        } catch (Throwable e) {
285            onException(IOExceptionSupport.create(e));
286        }
287    }
288
289    protected void processCommand(ByteBuffer plain) throws Exception {
290
291        // Are we waiting for the next Command or are we building on the current one
292        if (nextFrameSize == -1) {
293
294            // We can get small packets that don't give us enough for the frame size
295            // so allocate enough for the initial size value and
296            if (plain.remaining() < Integer.SIZE) {
297                if (currentBuffer == null) {
298                    currentBuffer = ByteBuffer.allocate(4);
299                }
300
301                // Go until we fill the integer sized current buffer.
302                while (currentBuffer.hasRemaining() && plain.hasRemaining()) {
303                    currentBuffer.put(plain.get());
304                }
305
306                // Didn't we get enough yet to figure out next frame size.
307                if (currentBuffer.hasRemaining()) {
308                    return;
309                } else {
310                    currentBuffer.flip();
311                    nextFrameSize = currentBuffer.getInt();
312                }
313
314            } else {
315
316                // Either we are completing a previous read of the next frame size or its
317                // fully contained in plain already.
318                if (currentBuffer != null) {
319
320                    // Finish the frame size integer read and get from the current buffer.
321                    while (currentBuffer.hasRemaining()) {
322                        currentBuffer.put(plain.get());
323                    }
324
325                    currentBuffer.flip();
326                    nextFrameSize = currentBuffer.getInt();
327
328                } else {
329                    nextFrameSize = plain.getInt();
330                }
331            }
332
333            if (wireFormat instanceof OpenWireFormat) {
334                long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
335                if (nextFrameSize > maxFrameSize) {
336                    throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) +
337                                          " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
338                }
339            }
340
341            // now we got the data, lets reallocate and store the size for the marshaler.
342            // if there's more data in plain, then the next call will start processing it.
343            currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
344            currentBuffer.putInt(nextFrameSize);
345
346        } else {
347            // If its all in one read then we can just take it all, otherwise take only
348            // the current frame size and the next iteration starts a new command.
349            if (currentBuffer != null) {
350                if (currentBuffer.remaining() >= plain.remaining()) {
351                    currentBuffer.put(plain);
352                } else {
353                    byte[] fill = new byte[currentBuffer.remaining()];
354                    plain.get(fill);
355                    currentBuffer.put(fill);
356                }
357
358                // Either we have enough data for a new command or we have to wait for some more.
359                if (currentBuffer.hasRemaining()) {
360                    return;
361                } else {
362                    currentBuffer.flip();
363                    Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer)));
364                    doConsume(command);
365                    nextFrameSize = -1;
366                    currentBuffer = null;
367               }
368            }
369        }
370    }
371
372    protected int secureRead(ByteBuffer plain) throws Exception {
373
374        if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
375            int bytesRead = channel.read(inputBuffer);
376
377            if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) {
378                return 0;
379            }
380
381            if (bytesRead == -1) {
382                sslEngine.closeInbound();
383                if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
384                    return -1;
385                }
386            }
387        }
388
389        plain.clear();
390
391        inputBuffer.flip();
392        SSLEngineResult res;
393        do {
394            res = sslEngine.unwrap(inputBuffer, plain);
395        } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP
396                && res.bytesProduced() == 0);
397
398        if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
399            finishHandshake();
400        }
401
402        status = res.getStatus();
403        handshakeStatus = res.getHandshakeStatus();
404
405        // TODO deal with BUFFER_OVERFLOW
406
407        if (status == SSLEngineResult.Status.CLOSED) {
408            sslEngine.closeInbound();
409            return -1;
410        }
411
412        inputBuffer.compact();
413        plain.flip();
414
415        return plain.remaining();
416    }
417
418    protected void doHandshake() throws Exception {
419        handshakeInProgress = true;
420        Selector selector = null;
421        SelectionKey key = null;
422        boolean readable = true;
423        try {
424            while (true) {
425                HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
426                switch (handshakeStatus) {
427                    case NEED_UNWRAP:
428                        if (readable) {
429                            secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
430                        }
431                        if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
432                            long now = System.currentTimeMillis();
433                            if (selector == null) {
434                                selector = Selector.open();
435                                key = channel.register(selector, SelectionKey.OP_READ);
436                            } else {
437                                key.interestOps(SelectionKey.OP_READ);
438                            }
439                            int keyCount = selector.select(this.getSoTimeout());
440                            if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) {
441                                throw new SocketTimeoutException("Timeout during handshake");
442                            }
443                            readable = key.isReadable();
444                        }
445                        break;
446                    case NEED_TASK:
447                        Runnable task;
448                        while ((task = sslEngine.getDelegatedTask()) != null) {
449                            task.run();
450                        }
451                        break;
452                    case NEED_WRAP:
453                        ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0));
454                        break;
455                    case FINISHED:
456                    case NOT_HANDSHAKING:
457                        finishHandshake();
458                        return;
459                }
460            }
461        } finally {
462            if (key!=null) try {key.cancel();} catch (Exception ignore) {}
463            if (selector!=null) try {selector.close();} catch (Exception ignore) {}
464        }
465    }
466
467    @Override
468    protected void doStart() throws Exception {
469        taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
470        // no need to init as we can delay that until demand (eg in doHandshake)
471        super.doStart();
472    }
473
474    @Override
475    protected void doStop(ServiceStopper stopper) throws Exception {
476        initialized.countDown();
477
478        if (taskRunnerFactory != null) {
479            taskRunnerFactory.shutdownNow();
480            taskRunnerFactory = null;
481        }
482        if (channel != null) {
483            channel.close();
484            channel = null;
485        }
486        super.doStop(stopper);
487    }
488
489    /**
490     * Overriding in order to add the client's certificates to ConnectionInfo Commands.
491     *
492     * @param command
493     *            The Command coming in.
494     */
495    @Override
496    public void doConsume(Object command) {
497        if (command instanceof ConnectionInfo) {
498            ConnectionInfo connectionInfo = (ConnectionInfo) command;
499            connectionInfo.setTransportContext(getPeerCertificates());
500        }
501        super.doConsume(command);
502    }
503
504    /**
505     * @return peer certificate chain associated with the ssl socket
506     */
507    @Override
508    public X509Certificate[] getPeerCertificates() {
509
510        X509Certificate[] clientCertChain = null;
511        try {
512            if (sslEngine.getSession() != null) {
513                clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates();
514            }
515        } catch (SSLPeerUnverifiedException e) {
516            if (LOG.isTraceEnabled()) {
517                LOG.trace("Failed to get peer certificates.", e);
518            }
519        }
520
521        return clientCertChain;
522    }
523
524    public boolean isNeedClientAuth() {
525        return needClientAuth;
526    }
527
528    public void setNeedClientAuth(boolean needClientAuth) {
529        this.needClientAuth = needClientAuth;
530    }
531
532    public boolean isWantClientAuth() {
533        return wantClientAuth;
534    }
535
536    public void setWantClientAuth(boolean wantClientAuth) {
537        this.wantClientAuth = wantClientAuth;
538    }
539
540    public String[] getEnabledCipherSuites() {
541        return enabledCipherSuites;
542    }
543
544    public void setEnabledCipherSuites(String[] enabledCipherSuites) {
545        this.enabledCipherSuites = enabledCipherSuites;
546    }
547
548    public String[] getEnabledProtocols() {
549        return enabledProtocols;
550    }
551
552    public void setEnabledProtocols(String[] enabledProtocols) {
553        this.enabledProtocols = enabledProtocols;
554    }
555
556    public boolean isVerifyHostName() {
557        return verifyHostName;
558    }
559
560    public void setVerifyHostName(boolean verifyHostName) {
561        this.verifyHostName = verifyHostName;
562    }
563}