Skip to content

Commit 48176a8

Browse files
mp911dechristophstrobl
authored andcommitted
DATAMONGO-2393 - Fix BufferOverflow in GridFS upload.
AsyncInputStreamAdapter now properly splits and buffers incoming DataBuffers according the read requests of AsyncInputStream.read(…) calls. Previously, the adapter used the input buffer size to be used as the output buffer size. A larger DataBuffer than the transfer buffer handed in through read(…) caused a BufferOverflow. Original Pull Request: spring-projects#799
1 parent 0facdcf commit 48176a8

File tree

1 file changed

+103
-30
lines changed

1 file changed

+103
-30
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/gridfs/AsyncInputStreamAdapter.java

+103-30
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import lombok.RequiredArgsConstructor;
1919
import reactor.core.CoreSubscriber;
20+
import reactor.core.publisher.Flux;
21+
import reactor.core.publisher.FluxSink;
2022
import reactor.core.publisher.Mono;
2123
import reactor.core.publisher.Operators;
2224
import reactor.util.concurrent.Queues;
@@ -25,14 +27,15 @@
2527
import java.nio.ByteBuffer;
2628
import java.util.Queue;
2729
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
30+
import java.util.concurrent.atomic.AtomicLong;
2831
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
2932
import java.util.function.BiConsumer;
3033

3134
import org.reactivestreams.Publisher;
3235
import org.reactivestreams.Subscription;
36+
3337
import org.springframework.core.io.buffer.DataBuffer;
3438
import org.springframework.core.io.buffer.DataBufferUtils;
35-
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
3639

3740
import com.mongodb.reactivestreams.client.Success;
3841
import com.mongodb.reactivestreams.client.gridfs.AsyncInputStream;
@@ -66,15 +69,16 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
6669

6770
private final Publisher<? extends DataBuffer> buffers;
6871
private final Context subscriberContext;
69-
private final DefaultDataBufferFactory factory = new DefaultDataBufferFactory();
7072

7173
private volatile Subscription subscription;
7274
private volatile boolean cancelled;
73-
private volatile boolean complete;
75+
private volatile boolean allDataBuffersReceived;
7476
private volatile Throwable error;
7577
private final Queue<BiConsumer<DataBuffer, Integer>> readRequests = Queues.<BiConsumer<DataBuffer, Integer>> small()
7678
.get();
7779

80+
private final Queue<DataBuffer> bufferQueue = Queues.<DataBuffer> small().get();
81+
7882
// see DEMAND
7983
volatile long demand;
8084

@@ -88,41 +92,75 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
8892
@Override
8993
public Publisher<Integer> read(ByteBuffer dst) {
9094

91-
return Mono.create(sink -> {
95+
return Flux.create(sink -> {
9296

97+
AtomicLong written = new AtomicLong();
9398
readRequests.offer((db, bytecount) -> {
9499

95100
try {
96101

97102
if (error != null) {
98-
99-
sink.error(error);
103+
onError(sink, error);
100104
return;
101105
}
102106

103107
if (bytecount == -1) {
104108

105-
sink.success(-1);
109+
onComplete(sink, written.get() > 0 ? written.intValue() : -1);
106110
return;
107111
}
108112

109113
ByteBuffer byteBuffer = db.asByteBuffer();
110-
int toWrite = byteBuffer.remaining();
114+
int remaining = byteBuffer.remaining();
115+
int writeCapacity = Math.min(dst.remaining(), remaining);
116+
int limit = Math.min(byteBuffer.position() + writeCapacity, byteBuffer.capacity());
117+
int toWrite = limit - byteBuffer.position();
118+
119+
if (toWrite == 0) {
111120

121+
onComplete(sink, written.intValue());
122+
return;
123+
}
124+
125+
int oldPosition = byteBuffer.position();
126+
127+
byteBuffer.limit(toWrite);
112128
dst.put(byteBuffer);
113-
sink.success(toWrite);
129+
byteBuffer.limit(byteBuffer.capacity());
130+
byteBuffer.position(oldPosition);
131+
db.readPosition(db.readPosition() + toWrite);
132+
written.addAndGet(toWrite);
114133

115134
} catch (Exception e) {
116-
sink.error(e);
135+
onError(sink, e);
117136
} finally {
118-
DataBufferUtils.release(db);
137+
138+
if (db != null && db.readableByteCount() == 0) {
139+
DataBufferUtils.release(db);
140+
}
119141
}
120142
});
121143

122-
request(1);
144+
sink.onCancel(this::terminatePendingReads);
145+
sink.onDispose(this::terminatePendingReads);
146+
sink.onRequest(this::request);
123147
});
124148
}
125149

150+
void onError(FluxSink<Integer> sink, Throwable e) {
151+
152+
readRequests.poll();
153+
sink.error(e);
154+
}
155+
156+
void onComplete(FluxSink<Integer> sink, int writtenBytes) {
157+
158+
readRequests.poll();
159+
DEMAND.decrementAndGet(this);
160+
sink.next(writtenBytes);
161+
sink.complete();
162+
}
163+
126164
/*
127165
* (non-Javadoc)
128166
* @see com.mongodb.reactivestreams.client.gridfs.AsyncInputStream#skip(long)
@@ -144,17 +182,19 @@ public Publisher<Success> close() {
144182
cancelled = true;
145183

146184
if (error != null) {
185+
terminatePendingReads();
147186
sink.error(error);
148187
return;
149188
}
150189

190+
terminatePendingReads();
151191
sink.success(Success.SUCCESS);
152192
});
153193
}
154194

155-
protected void request(int n) {
195+
protected void request(long n) {
156196

157-
if (complete) {
197+
if (allDataBuffersReceived && bufferQueue.isEmpty()) {
158198

159199
terminatePendingReads();
160200
return;
@@ -176,18 +216,51 @@ protected void request(int n) {
176216
requestFromSubscription(subscription);
177217
}
178218
}
219+
179220
}
180221

181222
void requestFromSubscription(Subscription subscription) {
182223

183-
long demand = DEMAND.get(AsyncInputStreamAdapter.this);
184-
185224
if (cancelled) {
186225
subscription.cancel();
187226
}
188227

189-
if (demand > 0 && DEMAND.compareAndSet(AsyncInputStreamAdapter.this, demand, demand - 1)) {
190-
subscription.request(1);
228+
drainLoop();
229+
}
230+
231+
void drainLoop() {
232+
233+
while (DEMAND.get(AsyncInputStreamAdapter.this) > 0) {
234+
235+
DataBuffer wip = bufferQueue.peek();
236+
237+
if (wip == null) {
238+
break;
239+
}
240+
241+
if (wip.readableByteCount() == 0) {
242+
bufferQueue.poll();
243+
continue;
244+
}
245+
246+
BiConsumer<DataBuffer, Integer> consumer = AsyncInputStreamAdapter.this.readRequests.peek();
247+
if (consumer == null) {
248+
break;
249+
}
250+
251+
consumer.accept(wip, wip.readableByteCount());
252+
}
253+
254+
if (bufferQueue.isEmpty()) {
255+
256+
if (allDataBuffersReceived) {
257+
terminatePendingReads();
258+
return;
259+
}
260+
261+
if (demand > 0) {
262+
subscription.request(1);
263+
}
191264
}
192265
}
193266

@@ -199,7 +272,7 @@ void terminatePendingReads() {
199272
BiConsumer<DataBuffer, Integer> readers;
200273

201274
while ((readers = readRequests.poll()) != null) {
202-
readers.accept(factory.wrap(new byte[0]), -1);
275+
readers.accept(null, -1);
203276
}
204277
}
205278

@@ -214,53 +287,53 @@ public Context currentContext() {
214287
public void onSubscribe(Subscription s) {
215288

216289
AsyncInputStreamAdapter.this.subscription = s;
217-
218-
Operators.addCap(DEMAND, AsyncInputStreamAdapter.this, -1);
219290
s.request(1);
220291
}
221292

222293
@Override
223294
public void onNext(DataBuffer dataBuffer) {
224295

225-
if (cancelled || complete) {
296+
if (cancelled || allDataBuffersReceived) {
226297
DataBufferUtils.release(dataBuffer);
227298
Operators.onNextDropped(dataBuffer, AsyncInputStreamAdapter.this.subscriberContext);
228299
return;
229300
}
230301

231-
BiConsumer<DataBuffer, Integer> poll = AsyncInputStreamAdapter.this.readRequests.poll();
302+
BiConsumer<DataBuffer, Integer> readRequest = AsyncInputStreamAdapter.this.readRequests.peek();
232303

233-
if (poll == null) {
304+
if (readRequest == null) {
234305

235306
DataBufferUtils.release(dataBuffer);
236307
Operators.onNextDropped(dataBuffer, AsyncInputStreamAdapter.this.subscriberContext);
237308
subscription.cancel();
238309
return;
239310
}
240311

241-
poll.accept(dataBuffer, dataBuffer.readableByteCount());
312+
bufferQueue.offer(dataBuffer);
242313

243-
requestFromSubscription(subscription);
314+
drainLoop();
244315
}
245316

246317
@Override
247318
public void onError(Throwable t) {
248319

249-
if (AsyncInputStreamAdapter.this.cancelled || AsyncInputStreamAdapter.this.complete) {
320+
if (AsyncInputStreamAdapter.this.cancelled || AsyncInputStreamAdapter.this.allDataBuffersReceived) {
250321
Operators.onErrorDropped(t, AsyncInputStreamAdapter.this.subscriberContext);
251322
return;
252323
}
253324

254325
AsyncInputStreamAdapter.this.error = t;
255-
AsyncInputStreamAdapter.this.complete = true;
326+
AsyncInputStreamAdapter.this.allDataBuffersReceived = true;
256327
terminatePendingReads();
257328
}
258329

259330
@Override
260331
public void onComplete() {
261332

262-
AsyncInputStreamAdapter.this.complete = true;
263-
terminatePendingReads();
333+
AsyncInputStreamAdapter.this.allDataBuffersReceived = true;
334+
if (bufferQueue.isEmpty()) {
335+
terminatePendingReads();
336+
}
264337
}
265338
}
266339
}

0 commit comments

Comments
 (0)