001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.transport.nio;
019
020import java.io.DataInputStream;
021import java.io.DataOutputStream;
022import java.io.EOFException;
023import java.io.IOException;
024import java.net.Socket;
025import java.net.SocketTimeoutException;
026import java.net.URI;
027import java.net.UnknownHostException;
028import java.nio.ByteBuffer;
029import java.nio.channels.SelectionKey;
030import java.nio.channels.Selector;
031import java.security.cert.X509Certificate;
032import java.util.concurrent.CountDownLatch;
033
034import javax.net.SocketFactory;
035import javax.net.ssl.SSLContext;
036import javax.net.ssl.SSLEngine;
037import javax.net.ssl.SSLEngineResult;
038import javax.net.ssl.SSLEngineResult.HandshakeStatus;
039import javax.net.ssl.SSLParameters;
040import javax.net.ssl.SSLPeerUnverifiedException;
041import javax.net.ssl.SSLSession;
042
043import org.apache.activemq.command.ConnectionInfo;
044import org.apache.activemq.openwire.OpenWireFormat;
045import org.apache.activemq.thread.TaskRunnerFactory;
046import org.apache.activemq.util.IOExceptionSupport;
047import org.apache.activemq.util.ServiceStopper;
048import org.apache.activemq.wireformat.WireFormat;
049import org.slf4j.Logger;
050import org.slf4j.LoggerFactory;
051
052public class NIOSSLTransport extends NIOTransport {
053
054    private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class);
055
056    protected boolean needClientAuth;
057    protected boolean wantClientAuth;
058    protected String[] enabledCipherSuites;
059    protected String[] enabledProtocols;
060    protected boolean verifyHostName = false;
061
062    protected SSLContext sslContext;
063    protected SSLEngine sslEngine;
064    protected SSLSession sslSession;
065
066    protected volatile boolean handshakeInProgress = false;
067    protected SSLEngineResult.Status status = null;
068    protected SSLEngineResult.HandshakeStatus handshakeStatus = null;
069    protected TaskRunnerFactory taskRunnerFactory;
070
071    public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
072        super(wireFormat, socketFactory, remoteLocation, localLocation);
073    }
074
075    public NIOSSLTransport(WireFormat wireFormat, Socket socket, SSLEngine engine, InitBuffer initBuffer,
076            ByteBuffer inputBuffer) throws IOException {
077        super(wireFormat, socket, initBuffer);
078        this.sslEngine = engine;
079        if (engine != null) {
080            this.sslSession = engine.getSession();
081        }
082        this.inputBuffer = inputBuffer;
083    }
084
085    public void setSslContext(SSLContext sslContext) {
086        this.sslContext = sslContext;
087    }
088
089    volatile boolean hasSslEngine = false;
090
091    @Override
092    protected void initializeStreams() throws IOException {
093        if (sslEngine != null) {
094            hasSslEngine = true;
095        }
096        NIOOutputStream outputStream = null;
097        try {
098            channel = socket.getChannel();
099            channel.configureBlocking(false);
100
101            if (sslContext == null) {
102                sslContext = SSLContext.getDefault();
103            }
104
105            String remoteHost = null;
106            int remotePort = -1;
107
108            try {
109                URI remoteAddress = new URI(this.getRemoteAddress());
110                remoteHost = remoteAddress.getHost();
111                remotePort = remoteAddress.getPort();
112            } catch (Exception e) {
113            }
114
115            // initialize engine, the initial sslSession we get will need to be
116            // updated once the ssl handshake process is completed.
117            if (!hasSslEngine) {
118                if (remoteHost != null && remotePort != -1) {
119                    sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
120                } else {
121                    sslEngine = sslContext.createSSLEngine();
122                }
123
124                if (verifyHostName) {
125                    SSLParameters sslParams = new SSLParameters();
126                    sslParams.setEndpointIdentificationAlgorithm("HTTPS");
127                    sslEngine.setSSLParameters(sslParams);
128                }
129
130                sslEngine.setUseClientMode(false);
131                if (enabledCipherSuites != null) {
132                    sslEngine.setEnabledCipherSuites(enabledCipherSuites);
133                }
134
135                if (enabledProtocols != null) {
136                    sslEngine.setEnabledProtocols(enabledProtocols);
137                }
138
139                if (wantClientAuth) {
140                    sslEngine.setWantClientAuth(wantClientAuth);
141                }
142
143                if (needClientAuth) {
144                    sslEngine.setNeedClientAuth(needClientAuth);
145                }
146
147                sslSession = sslEngine.getSession();
148
149                inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
150                inputBuffer.clear();
151            }
152
153            outputStream = new NIOOutputStream(channel);
154            outputStream.setEngine(sslEngine);
155            this.dataOut = new DataOutputStream(outputStream);
156            this.buffOut = outputStream;
157
158            //If the sslEngine was not passed in, then handshake
159            if (!hasSslEngine) {
160                sslEngine.beginHandshake();
161            }
162            handshakeStatus = sslEngine.getHandshakeStatus();
163            if (!hasSslEngine) {
164                doHandshake();
165            }
166
167            selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
168                @Override
169                public void onSelect(SelectorSelection selection) {
170                    try {
171                        initialized.await();
172                    } catch (InterruptedException error) {
173                        onException(IOExceptionSupport.create(error));
174                    }
175                    serviceRead();
176                }
177
178                @Override
179                public void onError(SelectorSelection selection, Throwable error) {
180                    if (error instanceof IOException) {
181                        onException((IOException) error);
182                    } else {
183                        onException(IOExceptionSupport.create(error));
184                    }
185                }
186            });
187            doInit();
188
189        } catch (Exception e) {
190            try {
191                if(outputStream != null) {
192                    outputStream.close();
193                }
194                super.closeStreams();
195            } catch (Exception ex) {}
196            throw new IOException(e);
197        }
198    }
199
200    final protected CountDownLatch initialized = new CountDownLatch(1);
201
202    protected void doInit() throws Exception {
203        taskRunnerFactory.execute(new Runnable() {
204
205            @Override
206            public void run() {
207                //Need to start in new thread to let startup finish first
208                //We can trigger a read because we know the channel is ready since the SSL handshake
209                //already happened
210                serviceRead();
211                initialized.countDown();
212            }
213        });
214    }
215
216    //Only used for the auto transport to abort the openwire init method early if already initialized
217    boolean openWireInititialized = false;
218
219    protected void doOpenWireInit() throws Exception {
220        //Do this later to let wire format negotiation happen
221        if (initBuffer != null && !openWireInititialized && this.wireFormat instanceof OpenWireFormat) {
222            initBuffer.buffer.flip();
223            if (initBuffer.buffer.hasRemaining()) {
224                nextFrameSize = -1;
225                receiveCounter += initBuffer.readSize;
226                processCommand(initBuffer.buffer);
227                processCommand(initBuffer.buffer);
228                initBuffer.buffer.clear();
229                openWireInititialized = true;
230            }
231        }
232    }
233
234    protected void finishHandshake() throws Exception {
235        if (handshakeInProgress) {
236            handshakeInProgress = false;
237            nextFrameSize = -1;
238
239            // Once handshake completes we need to ask for the now real sslSession
240            // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
241            // cipher suite.
242            sslSession = sslEngine.getSession();
243        }
244    }
245
246    @Override
247    public void serviceRead() {
248        try {
249            if (handshakeInProgress) {
250                doHandshake();
251            }
252
253            doOpenWireInit();
254
255            ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
256            plain.position(plain.limit());
257
258            while (true) {
259                //If the transport was already stopped then break
260                if (this.isStopped()) {
261                    return;
262                }
263
264                if (!plain.hasRemaining()) {
265
266                    int readCount = secureRead(plain);
267
268                    if (readCount == 0) {
269                        break;
270                    }
271
272                    // channel is closed, cleanup
273                    if (readCount == -1) {
274                        onException(new EOFException());
275                        selection.close();
276                        break;
277                    }
278
279                    receiveCounter += readCount;
280                }
281
282                if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
283                    processCommand(plain);
284                }
285            }
286        } catch (IOException e) {
287            onException(e);
288        } catch (Throwable e) {
289            onException(IOExceptionSupport.create(e));
290        }
291    }
292
293    protected void processCommand(ByteBuffer plain) throws Exception {
294
295        // Are we waiting for the next Command or are we building on the current one
296        if (nextFrameSize == -1) {
297
298            // We can get small packets that don't give us enough for the frame size
299            // so allocate enough for the initial size value and
300            if (plain.remaining() < Integer.SIZE) {
301                if (currentBuffer == null) {
302                    currentBuffer = ByteBuffer.allocate(4);
303                }
304
305                // Go until we fill the integer sized current buffer.
306                while (currentBuffer.hasRemaining() && plain.hasRemaining()) {
307                    currentBuffer.put(plain.get());
308                }
309
310                // Didn't we get enough yet to figure out next frame size.
311                if (currentBuffer.hasRemaining()) {
312                    return;
313                } else {
314                    currentBuffer.flip();
315                    nextFrameSize = currentBuffer.getInt();
316                }
317
318            } else {
319
320                // Either we are completing a previous read of the next frame size or its
321                // fully contained in plain already.
322                if (currentBuffer != null) {
323
324                    // Finish the frame size integer read and get from the current buffer.
325                    while (currentBuffer.hasRemaining()) {
326                        currentBuffer.put(plain.get());
327                    }
328
329                    currentBuffer.flip();
330                    nextFrameSize = currentBuffer.getInt();
331
332                } else {
333                    nextFrameSize = plain.getInt();
334                }
335            }
336
337            if (wireFormat instanceof OpenWireFormat) {
338                long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
339                if (nextFrameSize > maxFrameSize) {
340                    throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) +
341                                          " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
342                }
343            }
344
345            // now we got the data, lets reallocate and store the size for the marshaler.
346            // if there's more data in plain, then the next call will start processing it.
347            currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
348            currentBuffer.putInt(nextFrameSize);
349
350        } else {
351            // If its all in one read then we can just take it all, otherwise take only
352            // the current frame size and the next iteration starts a new command.
353            if (currentBuffer != null) {
354                if (currentBuffer.remaining() >= plain.remaining()) {
355                    currentBuffer.put(plain);
356                } else {
357                    byte[] fill = new byte[currentBuffer.remaining()];
358                    plain.get(fill);
359                    currentBuffer.put(fill);
360                }
361
362                // Either we have enough data for a new command or we have to wait for some more.
363                if (currentBuffer.hasRemaining()) {
364                    return;
365                } else {
366                    currentBuffer.flip();
367                    Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer)));
368                    doConsume(command);
369                    nextFrameSize = -1;
370                    currentBuffer = null;
371               }
372            }
373        }
374    }
375
376    //Prevent concurrent access while reading from the channel
377    protected synchronized int secureRead(ByteBuffer plain) throws Exception {
378
379        if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
380            int bytesRead = channel.read(inputBuffer);
381
382            if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) {
383                return 0;
384            }
385
386            if (bytesRead == -1) {
387                sslEngine.closeInbound();
388                if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
389                    return -1;
390                }
391            }
392        }
393
394        plain.clear();
395
396        inputBuffer.flip();
397        SSLEngineResult res;
398        do {
399            res = sslEngine.unwrap(inputBuffer, plain);
400        } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP
401                && res.bytesProduced() == 0);
402
403        if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
404            finishHandshake();
405        }
406
407        status = res.getStatus();
408        handshakeStatus = res.getHandshakeStatus();
409
410        // TODO deal with BUFFER_OVERFLOW
411
412        if (status == SSLEngineResult.Status.CLOSED) {
413            sslEngine.closeInbound();
414            return -1;
415        }
416
417        inputBuffer.compact();
418        plain.flip();
419
420        return plain.remaining();
421    }
422
423    protected void doHandshake() throws Exception {
424        handshakeInProgress = true;
425        Selector selector = null;
426        SelectionKey key = null;
427        boolean readable = true;
428        try {
429            while (true) {
430                HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
431                switch (handshakeStatus) {
432                    case NEED_UNWRAP:
433                        if (readable) {
434                            secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
435                        }
436                        if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
437                            long now = System.currentTimeMillis();
438                            if (selector == null) {
439                                selector = Selector.open();
440                                key = channel.register(selector, SelectionKey.OP_READ);
441                            } else {
442                                key.interestOps(SelectionKey.OP_READ);
443                            }
444                            int keyCount = selector.select(this.getSoTimeout());
445                            if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) {
446                                throw new SocketTimeoutException("Timeout during handshake");
447                            }
448                            readable = key.isReadable();
449                        }
450                        break;
451                    case NEED_TASK:
452                        Runnable task;
453                        while ((task = sslEngine.getDelegatedTask()) != null) {
454                            task.run();
455                        }
456                        break;
457                    case NEED_WRAP:
458                        ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0));
459                        break;
460                    case FINISHED:
461                    case NOT_HANDSHAKING:
462                        finishHandshake();
463                        return;
464                }
465            }
466        } finally {
467            if (key!=null) try {key.cancel();} catch (Exception ignore) {}
468            if (selector!=null) try {selector.close();} catch (Exception ignore) {}
469        }
470    }
471
472    @Override
473    protected void doStart() throws Exception {
474        taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
475        // no need to init as we can delay that until demand (eg in doHandshake)
476        super.doStart();
477    }
478
479    @Override
480    protected void doStop(ServiceStopper stopper) throws Exception {
481        initialized.countDown();
482
483        if (taskRunnerFactory != null) {
484            taskRunnerFactory.shutdownNow();
485            taskRunnerFactory = null;
486        }
487        if (channel != null) {
488            channel.close();
489            channel = null;
490        }
491        super.doStop(stopper);
492    }
493
494    /**
495     * Overriding in order to add the client's certificates to ConnectionInfo Commands.
496     *
497     * @param command
498     *            The Command coming in.
499     */
500    @Override
501    public void doConsume(Object command) {
502        if (command instanceof ConnectionInfo) {
503            ConnectionInfo connectionInfo = (ConnectionInfo) command;
504            connectionInfo.setTransportContext(getPeerCertificates());
505        }
506        super.doConsume(command);
507    }
508
509    /**
510     * @return peer certificate chain associated with the ssl socket
511     */
512    @Override
513    public X509Certificate[] getPeerCertificates() {
514
515        X509Certificate[] clientCertChain = null;
516        try {
517            if (sslEngine.getSession() != null) {
518                clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates();
519            }
520        } catch (SSLPeerUnverifiedException e) {
521            if (LOG.isTraceEnabled()) {
522                LOG.trace("Failed to get peer certificates.", e);
523            }
524        }
525
526        return clientCertChain;
527    }
528
529    public boolean isNeedClientAuth() {
530        return needClientAuth;
531    }
532
533    public void setNeedClientAuth(boolean needClientAuth) {
534        this.needClientAuth = needClientAuth;
535    }
536
537    public boolean isWantClientAuth() {
538        return wantClientAuth;
539    }
540
541    public void setWantClientAuth(boolean wantClientAuth) {
542        this.wantClientAuth = wantClientAuth;
543    }
544
545    public String[] getEnabledCipherSuites() {
546        return enabledCipherSuites;
547    }
548
549    public void setEnabledCipherSuites(String[] enabledCipherSuites) {
550        this.enabledCipherSuites = enabledCipherSuites;
551    }
552
553    public String[] getEnabledProtocols() {
554        return enabledProtocols;
555    }
556
557    public void setEnabledProtocols(String[] enabledProtocols) {
558        this.enabledProtocols = enabledProtocols;
559    }
560
561    public boolean isVerifyHostName() {
562        return verifyHostName;
563    }
564
565    public void setVerifyHostName(boolean verifyHostName) {
566        this.verifyHostName = verifyHostName;
567    }
568}