Skip to content

Commit 0f04591

Browse files
authored
Merge pull request #224 from pontusmelke/1.0-tls-write-issue
Better handling of BUFFER_OVERFLOW on writes
2 parents 1ac0038 + feeee03 commit 0f04591

File tree

4 files changed

+303
-121
lines changed

4 files changed

+303
-121
lines changed

driver/src/main/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannel.java

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
2929
import javax.net.ssl.SSLEngineResult.Status;
3030

31-
import org.neo4j.driver.v1.Logger;
3231
import org.neo4j.driver.internal.util.BytePrinter;
3332
import org.neo4j.driver.v1.Config.TrustStrategy;
33+
import org.neo4j.driver.v1.Logger;
3434
import org.neo4j.driver.v1.exceptions.ClientException;
3535

3636
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.FINISHED;
@@ -163,8 +163,8 @@ private HandshakeStatus runDelegatedTasks()
163163
* To verify if deciphering is done successfully, we could check if any bytes has been read into {@code buffer},
164164
* as the deciphered bytes will only be saved into {@code buffer} when deciphering is carried out successfully.
165165
*
166-
* @param buffer
167-
* @return
166+
* @param buffer to read data into.
167+
* @return The status of the current handshake.
168168
* @throws IOException
169169
*/
170170
private HandshakeStatus unwrap( ByteBuffer buffer ) throws IOException
@@ -261,7 +261,7 @@ private HandshakeStatus unwrap( ByteBuffer buffer ) throws IOException
261261
* a loop
262262
*
263263
* @param buffer contains the bytes to send to channel
264-
* @return
264+
* @return The status of the current handshake
265265
* @throws IOException
266266
*/
267267
private HandshakeStatus wrap( ByteBuffer buffer ) throws IOException
@@ -277,7 +277,7 @@ private HandshakeStatus wrap( ByteBuffer buffer ) throws IOException
277277
case OK:
278278
handshakeStatus = runDelegatedTasks();
279279
cipherOut.flip();
280-
while(cipherOut.hasRemaining())
280+
while ( cipherOut.hasRemaining() )
281281
{
282282
channel.write( cipherOut );
283283
}
@@ -287,19 +287,30 @@ private HandshakeStatus wrap( ByteBuffer buffer ) throws IOException
287287
// Enlarge the buffer and return the old status
288288
int curNetSize = cipherOut.capacity();
289289
int netSize = sslEngine.getSession().getPacketBufferSize();
290-
if ( curNetSize >= netSize || buffer.capacity() > netSize )
290+
if ( netSize > curNetSize )
291291
{
292-
// TODO
293-
throw new ClientException(
294-
String.format( "Failed to enlarge network buffer from %s to %s. This is either because the " +
295-
"new size is however less than the old size, or because the application " +
296-
"buffer size %s is so big that the application data still cannot fit into the " +
297-
"new network buffer.", curNetSize, netSize, buffer.capacity() ) );
292+
// enlarge the peer application data buffer
293+
cipherOut = ByteBuffer.allocate( netSize );
294+
logger.debug( "Enlarged network output buffer from %s to %s. " +
295+
"This operation should be a rare operation.", curNetSize, netSize );
296+
}
297+
else
298+
{
299+
// flush as much data as possible
300+
cipherOut.flip();
301+
int written = channel.write( cipherOut );
302+
if (written == 0)
303+
{
304+
throw new ClientException(
305+
String.format(
306+
"Failed to enlarge network buffer from %s to %s. This is either because the " +
307+
"new size is however less than the old size, or because the application " +
308+
"buffer size %s is so big that the application data still cannot fit into the " +
309+
"new network buffer.", curNetSize, netSize, buffer.capacity() ) );
310+
}
311+
cipherOut.compact();
312+
logger.debug( "Network output buffer couldn't be enlarged, flushing data to the channel instead." );
298313
}
299-
300-
cipherOut = ByteBuffer.allocate( netSize );
301-
logger.debug( "Enlarged network output buffer from %s to %s. " +
302-
"This operation should be a rare operation.", curNetSize, netSize );
303314
break;
304315
default:
305316
throw new ClientException( "Got unexpected status " + status );
@@ -320,9 +331,9 @@ private HandshakeStatus wrap( ByteBuffer buffer ) throws IOException
320331
* After the method call, the new position of {@code from.pos} will be {@code from.pos + p}, and similarly,
321332
* the new position of {@code to.pos} will be {@code to.pos + p}
322333
*
323-
* @param from
324-
* @param to
325-
* @return
334+
* @param from buffer to copy from
335+
* @param to buffer to copy to
336+
* @return the number of transferred bytes
326337
*/
327338
static int bufferCopy( ByteBuffer from, ByteBuffer to )
328339
{
@@ -341,9 +352,9 @@ static int bufferCopy( ByteBuffer from, ByteBuffer to )
341352

342353
/**
343354
* Create SSLEngine with the SSLContext just created.
344-
* @param host
345-
* @param port
346-
* @param sslContext
355+
* @param host the host to connect to
356+
* @param port the port to connect to
357+
* @param sslContext the current ssl context
347358
*/
348359
private static SSLEngine createSSLEngine( String host, int port, SSLContext sslContext )
349360
{
@@ -431,10 +442,10 @@ public void close() throws IOException
431442
channel.close();
432443
logger.debug( "TLS connection closed" );
433444
}
434-
catch(IOException e)
445+
catch ( IOException e )
435446
{
436447
// Treat this as ok - the connection is closed, even if the TLS session did not exit cleanly.
437-
logger.warn( "TLS socket could not be closed cleanly: '"+e.getMessage()+"'", e );
448+
logger.warn( "TLS socket could not be closed cleanly: '" + e.getMessage() + "'", e );
438449
}
439450
}
440451

driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelFragmentationIT.java renamed to driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelFragmentation.java

Lines changed: 9 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,8 @@
2222
import org.junit.Test;
2323

2424
import java.io.IOException;
25-
import java.io.OutputStream;
26-
import java.net.InetSocketAddress;
27-
import java.net.ServerSocket;
28-
import java.net.Socket;
2925
import java.nio.ByteBuffer;
3026
import java.nio.channels.ByteChannel;
31-
import java.nio.channels.SocketChannel;
3227
import java.security.GeneralSecurityException;
3328
import java.security.KeyManagementException;
3429
import java.security.KeyStore;
@@ -39,35 +34,12 @@
3934
import java.security.cert.X509Certificate;
4035
import javax.net.ssl.KeyManagerFactory;
4136
import javax.net.ssl.SSLContext;
42-
import javax.net.ssl.SSLEngine;
43-
import javax.net.ssl.SSLServerSocketFactory;
4437
import javax.net.ssl.TrustManager;
4538
import javax.net.ssl.X509TrustManager;
4639

47-
import org.neo4j.driver.internal.connector.socket.TLSSocketChannel;
48-
import org.neo4j.driver.internal.logging.DevNullLogger;
49-
50-
import static org.hamcrest.core.IsEqual.equalTo;
51-
import static org.junit.Assert.assertThat;
52-
53-
/**
54-
* This tests that the TLSSocketChannel handles every combination of network buffer sizes that we
55-
* can reasonably expect to see in the wild. It exhaustively tests power-of-two sizes up to 2^16
56-
* for the following variables:
57-
*
58-
* - Network frame size
59-
* - Bolt message size
60-
* - Read buffer size
61-
*
62-
* It tests every possible combination, and it does this currently only for the read path, expanding
63-
* to the write path as well would be useful. For each size, it sets up a TLS server and tests the
64-
* handshake, transferring the data, and verifying the data is correct after decryption.
65-
*/
66-
public class TLSSocketChannelFragmentationIT
40+
public abstract class TLSSocketChannelFragmentation
6741
{
68-
private SSLContext sslCtx;
69-
private byte[] blobOfData;
70-
private ServerSocket server;
42+
protected SSLContext sslCtx;
7143

7244
@Before
7345
public void setup() throws Throwable
@@ -76,17 +48,6 @@ public void setup() throws Throwable
7648
createServer();
7749
}
7850

79-
private void blobOfDataSize( int dataBlobSize )
80-
{
81-
blobOfData = new byte[dataBlobSize];
82-
// If the blob is all zeros, we'd miss data corruption problems in assertions, so
83-
// fill the data blob with different values.
84-
for ( int i = 0; i < blobOfData.length; i++ )
85-
{
86-
blobOfData[i] = (byte) (i % 128);
87-
}
88-
}
89-
9051
@Test
9152
public void shouldHandleFuzziness() throws Throwable
9253
{
@@ -109,34 +70,9 @@ public void shouldHandleFuzziness() throws Throwable
10970
}
11071
}
11172

112-
private void testForBufferSizes( int blobOfDataSize, int networkFrameSize, int userBufferSize ) throws IOException, GeneralSecurityException
113-
{
114-
blobOfDataSize(blobOfDataSize);
115-
SSLEngine engine = sslCtx.createSSLEngine();
116-
engine.setUseClientMode( true );
117-
ByteChannel ch = SocketChannel.open( new InetSocketAddress( server.getInetAddress(), server.getLocalPort() ) );
118-
ch = new LittleAtATimeChannel( ch, networkFrameSize );
119-
120-
TLSSocketChannel channel = new TLSSocketChannel(ch, new DevNullLogger(), engine);
121-
try
122-
{
123-
ByteBuffer readBuffer = ByteBuffer.allocate( blobOfData.length );
124-
while ( readBuffer.position() < readBuffer.capacity() )
125-
{
126-
readBuffer.limit(Math.min( readBuffer.capacity(), readBuffer.position() + userBufferSize ));
127-
channel.read( readBuffer );
128-
}
129-
130-
assertThat(readBuffer.array(), equalTo(blobOfData));
131-
}
132-
finally
133-
{
134-
channel.close();
135-
}
136-
}
137-
138-
private void createSSLContext()
139-
throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException, UnrecoverableKeyException, KeyManagementException
73+
protected void createSSLContext()
74+
throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException,
75+
UnrecoverableKeyException, KeyManagementException
14076
{
14177
KeyStore ks = KeyStore.getInstance("JKS");
14278
char[] password = "password".toCharArray();
@@ -159,40 +95,16 @@ public X509Certificate[] getAcceptedIssuers() {
15995
}}, null );
16096
}
16197

162-
private void createServer() throws IOException
163-
{
164-
SSLServerSocketFactory ssf = sslCtx.getServerSocketFactory();
165-
server = ssf.createServerSocket(0);
98+
protected abstract void testForBufferSizes( int blobOfDataSize, int networkFrameSize, int userBufferSize ) throws IOException,
99+
GeneralSecurityException;
166100

167-
new Thread(new Runnable()
168-
{
169-
@Override
170-
public void run()
171-
{
172-
try
173-
{
174-
while(true)
175-
{
176-
Socket client = server.accept();
177-
OutputStream outputStream = client.getOutputStream();
178-
outputStream.write( blobOfData );
179-
outputStream.flush();
180-
// client.close(); // TODO: Uncomment this, fix resulting error handling CLOSED event
181-
}
182-
}
183-
catch ( IOException e )
184-
{
185-
e.printStackTrace();
186-
}
187-
}
188-
}).start();
189-
}
101+
protected abstract void createServer() throws IOException;
190102

191103
/**
192104
* Delegates to underlying channel, but only reads up to the set amount at a time, used to emulate
193105
* different network frame sizes in this test.
194106
*/
195-
private static class LittleAtATimeChannel implements ByteChannel
107+
protected static class LittleAtATimeChannel implements ByteChannel
196108
{
197109
private final ByteChannel delegate;
198110
private final int maxFrameSize;
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/**
2+
* Copyright (c) 2002-2016 "Neo Technology,"
3+
* Network Engine for Objects in Lund AB [http://neotechnology.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
*/
19+
package org.neo4j.driver.v1.integration;
20+
21+
import java.io.IOException;
22+
import java.io.OutputStream;
23+
import java.net.InetSocketAddress;
24+
import java.net.ServerSocket;
25+
import java.net.Socket;
26+
import java.nio.ByteBuffer;
27+
import java.nio.channels.ByteChannel;
28+
import java.nio.channels.SocketChannel;
29+
import java.security.GeneralSecurityException;
30+
import javax.net.ssl.SSLEngine;
31+
import javax.net.ssl.SSLServerSocketFactory;
32+
33+
import org.neo4j.driver.internal.connector.socket.TLSSocketChannel;
34+
import org.neo4j.driver.internal.logging.DevNullLogger;
35+
36+
import static org.hamcrest.core.IsEqual.equalTo;
37+
import static org.junit.Assert.assertThat;
38+
39+
/**
40+
* This tests that the TLSSocketChannel handles every combination of network buffer sizes that we
41+
* can reasonably expect to see in the wild. It exhaustively tests power-of-two sizes up to 2^16
42+
* for the following variables:
43+
*
44+
* - Network frame size
45+
* - Bolt message size
46+
* - Read buffer size
47+
*
48+
* It tests every possible combination, and it does this currently only for the read path, expanding
49+
* to the write path as well would be useful. For each size, it sets up a TLS server and tests the
50+
* handshake, transferring the data, and verifying the data is correct after decryption.
51+
*/
52+
public class TLSSocketChannelReadFragmentationIT extends TLSSocketChannelFragmentation
53+
{
54+
private byte[] blobOfData;
55+
private ServerSocket server;
56+
57+
58+
59+
private void blobOfDataSize( int dataBlobSize )
60+
{
61+
blobOfData = new byte[dataBlobSize];
62+
// If the blob is all zeros, we'd miss data corruption problems in assertions, so
63+
// fill the data blob with different values.
64+
for ( int i = 0; i < blobOfData.length; i++ )
65+
{
66+
blobOfData[i] = (byte) (i % 128);
67+
}
68+
}
69+
70+
protected void testForBufferSizes( int blobOfDataSize, int networkFrameSize, int userBufferSize ) throws IOException, GeneralSecurityException
71+
{
72+
blobOfDataSize(blobOfDataSize);
73+
SSLEngine engine = sslCtx.createSSLEngine();
74+
engine.setUseClientMode( true );
75+
ByteChannel ch = SocketChannel.open( new InetSocketAddress( server.getInetAddress(), server.getLocalPort() ) );
76+
ch = new LittleAtATimeChannel( ch, networkFrameSize );
77+
78+
try ( TLSSocketChannel channel = new TLSSocketChannel( ch, new DevNullLogger(), engine ) )
79+
{
80+
ByteBuffer readBuffer = ByteBuffer.allocate( blobOfData.length );
81+
while ( readBuffer.position() < readBuffer.capacity() )
82+
{
83+
readBuffer.limit( Math.min( readBuffer.capacity(), readBuffer.position() + userBufferSize ) );
84+
channel.read( readBuffer );
85+
}
86+
87+
assertThat( readBuffer.array(), equalTo( blobOfData ) );
88+
}
89+
}
90+
91+
protected void createServer() throws IOException
92+
{
93+
SSLServerSocketFactory ssf = sslCtx.getServerSocketFactory();
94+
server = ssf.createServerSocket(0);
95+
96+
new Thread(new Runnable()
97+
{
98+
@Override
99+
public void run()
100+
{
101+
try
102+
{
103+
//noinspection InfiniteLoopStatement
104+
while(true)
105+
{
106+
Socket client = server.accept();
107+
OutputStream outputStream = client.getOutputStream();
108+
outputStream.write( blobOfData );
109+
outputStream.flush();
110+
// client.close(); // TODO: Uncomment this, fix resulting error handling CLOSED event
111+
}
112+
}
113+
catch ( IOException e )
114+
{
115+
e.printStackTrace();
116+
}
117+
}
118+
}).start();
119+
}
120+
121+
}

0 commit comments

Comments
 (0)