From ed12c5a0d2293743dcf12c26e0dced97fd5f057f Mon Sep 17 00:00:00 2001 From: jojii Date: Wed, 3 Sep 2025 16:54:04 +0200 Subject: [PATCH 1/8] Add support for connection pooling --- src/channel_pool.rs | 90 +++++++++++++++++++++++++++++-------- src/client/mod.rs | 1 + src/qdrant_client/config.rs | 4 ++ src/qdrant_client/mod.rs | 2 + 4 files changed, 79 insertions(+), 18 deletions(-) diff --git a/src/channel_pool.rs b/src/channel_pool.rs index 7722502d..23f9320e 100644 --- a/src/channel_pool.rs +++ b/src/channel_pool.rs @@ -1,4 +1,5 @@ use std::future::Future; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::RwLock; use std::time::Duration; @@ -6,11 +7,13 @@ use tonic::transport::{Channel, ClientTlsConfig, Uri}; use tonic::{Code, Status}; pub struct ChannelPool { - channel: RwLock>, + channels: RwLock>>, + channel_index: AtomicU64, uri: Uri, grpc_timeout: Duration, connection_timeout: Duration, keep_alive_while_idle: bool, + pool_size: usize, } impl ChannelPool { @@ -19,17 +22,24 @@ impl ChannelPool { grpc_timeout: Duration, connection_timeout: Duration, keep_alive_while_idle: bool, + mut pool_size: usize, ) -> Self { + // Ensure `pool_size` is always >= 1 + pool_size = std::cmp::max(pool_size, 1); + Self { - channel: RwLock::new(None), + channels: RwLock::new(vec![None; pool_size]), + channel_index: AtomicU64::new(0), uri, grpc_timeout, connection_timeout, keep_alive_while_idle, + pool_size, } } - async fn make_channel(&self) -> Result { + /// Creates a new channel at the given index. If one already exists, it will be dropped and replaced. + async fn make_channel(&self, channel_index: usize) -> Result { let tls = match self.uri.scheme_str() { None => false, Some(schema) => match schema { @@ -62,29 +72,40 @@ impl ChannelPool { endpoint }; - let channel = endpoint + let new_channel = endpoint .connect() .await .map_err(|e| Status::internal(format!("Failed to connect to {}: {:?}", self.uri, e)))?; - let mut self_channel = self.channel.write().unwrap(); - *self_channel = Some(channel.clone()); + let mut pool_channels = self.channels.write().unwrap(); + + pool_channels[channel_index] = Some(new_channel.clone()); - Ok(channel) + Ok(new_channel) } - async fn get_channel(&self) -> Result { - if let Some(channel) = &*self.channel.read().unwrap() { - return Ok(channel.clone()); + /// Returns a channel from the pool. If `pool_size` > 1, calls will return different channels in a round-robin way. + /// Otherwise, the same channel is returned each time. + async fn get_channel(&self) -> Result<(Channel, usize), Status> { + let channel_index = self.next_channel_index(); + + if let Some(channel) = self + .channels + .read() + .unwrap() + .get(channel_index) + .and_then(|i| i.as_ref()) + { + return Ok((channel.clone(), channel_index)); } - let channel = self.make_channel().await?; - Ok(channel) + Ok((self.make_channel(channel_index).await?, channel_index)) } - pub async fn drop_channel(&self) { - let mut channel = self.channel.write().unwrap(); - *channel = None; + /// Drops the channel at the given index. + fn drop_channel(&self, idx: usize) { + let mut channel = self.channels.write().unwrap(); + channel[idx] = None; } // Allow to retry request if channel is broken @@ -93,7 +114,7 @@ impl ChannelPool { f: impl Fn(Channel) -> O, allow_retry: bool, ) -> Result { - let channel = self.get_channel().await?; + let (channel, channel_index) = self.get_channel().await?; let result: Result = f(channel).await; @@ -102,11 +123,12 @@ impl ChannelPool { Ok(res) => Ok(res), Err(err) => match err.code() { Code::Internal | Code::Unavailable | Code::Cancelled | Code::Unknown => { - self.drop_channel().await; if allow_retry { - let channel = self.get_channel().await?; + // Recreate the channel at the given index. + let channel = self.make_channel(channel_index).await?; Ok(f(channel).await?) } else { + self.drop_channel(channel_index); Err(err) } } @@ -114,6 +136,11 @@ impl ChannelPool { }, } } + + /// Returns the index for the next channel to use. + fn next_channel_index(&self) -> usize { + self.channel_index.fetch_add(1, Ordering::Relaxed) as usize % self.pool_size + } } // The future returned by get_channel needs to be Send so that the client can be @@ -127,9 +154,36 @@ fn require_get_channel_fn_to_be_send() { Duration::from_millis(0), Duration::from_millis(0), false, + 2, ) .get_channel() .await .expect("get channel should not error"); }); } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_channel_counter() { + let channel = ChannelPool::new( + Uri::from_static("http://localhost:6444"), + Duration::default(), + Duration::default(), + false, + 5, + ); + + assert_eq!(channel.next_channel_index(), 0); + assert_eq!(channel.next_channel_index(), 1); + assert_eq!(channel.next_channel_index(), 2); + assert_eq!(channel.next_channel_index(), 3); + assert_eq!(channel.next_channel_index(), 4); + assert_eq!(channel.next_channel_index(), 0); + assert_eq!(channel.next_channel_index(), 1); + + assert_eq!(channel.channels.read().unwrap().len(), 5); + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs index 4425873f..eeb45b38 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -65,6 +65,7 @@ impl QdrantClient { cfg.timeout, cfg.connect_timeout, cfg.keep_alive_while_idle, + 1, ); let client = Self { channel, cfg }; diff --git a/src/qdrant_client/config.rs b/src/qdrant_client/config.rs index 9ed54810..a195e14d 100644 --- a/src/qdrant_client/config.rs +++ b/src/qdrant_client/config.rs @@ -38,6 +38,9 @@ pub struct QdrantConfig { /// Whether to check compatibility between the client and server versions pub check_compatibility: bool, + + /// Amount of concurrent connections. + pub pool_size: usize, } impl QdrantConfig { @@ -193,6 +196,7 @@ impl Default for QdrantConfig { api_key: None, compression: None, check_compatibility: true, + pool_size: 1, } } } diff --git a/src/qdrant_client/mod.rs b/src/qdrant_client/mod.rs index 25a20dce..8ceb04a4 100644 --- a/src/qdrant_client/mod.rs +++ b/src/qdrant_client/mod.rs @@ -107,6 +107,7 @@ impl Qdrant { config.timeout, config.connect_timeout, config.keep_alive_while_idle, + config.pool_size, ); let client = Self { channel: Arc::new(channel), @@ -151,6 +152,7 @@ impl Qdrant { config.timeout, config.connect_timeout, config.keep_alive_while_idle, + config.pool_size, ); let client = Self { From 8c9b6cd8f1f3c3c8f63f9d453b6f2eb7014186c4 Mon Sep 17 00:00:00 2001 From: jojii Date: Thu, 4 Sep 2025 08:05:28 +0200 Subject: [PATCH 2/8] Clippy --- src/error.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/error.rs b/src/error.rs index 4421b2ee..962f2b31 100644 --- a/src/error.rs +++ b/src/error.rs @@ -60,7 +60,7 @@ impl Debug for NotA { impl Display for NotA { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.write_str(concat!("not a bool")) + f.write_str("not a bool") } } @@ -87,7 +87,7 @@ impl Debug for NotA { impl Display for NotA { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.write_str(concat!("not an i64")) + f.write_str("not an i64") } } @@ -114,7 +114,7 @@ impl Debug for NotA { impl Display for NotA { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.write_str(concat!("not a f64")) + f.write_str("not a f64") } } @@ -141,7 +141,7 @@ impl Debug for NotA { impl Display for NotA { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.write_str(concat!("not a String")) + f.write_str("not a String") } } From 437c1b9a232dac0ab5acf3f503dccb64fa8d9de3 Mon Sep 17 00:00:00 2001 From: jojii Date: Fri, 5 Sep 2025 09:55:21 +0200 Subject: [PATCH 3/8] Optimizations --- Cargo.toml | 1 + src/channel_pool.rs | 29 ++++++++++++++++++----------- src/qdrant_client/mod.rs | 2 +- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bcdd9a8a..4f1eafe6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ derive_builder = { version = "0.20.2" } thiserror = "1.0.64" semver = "1.0.24" uuid = { version = "1.8.2", optional = true } +parking_lot = "0.12.4" [dev-dependencies] tonic-build = { version = "0.12.3", features = ["prost"] } diff --git a/src/channel_pool.rs b/src/channel_pool.rs index 23f9320e..c916e43b 100644 --- a/src/channel_pool.rs +++ b/src/channel_pool.rs @@ -1,14 +1,13 @@ use std::future::Future; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::RwLock; use std::time::Duration; +use parking_lot::{Mutex, RwLock}; use tonic::transport::{Channel, ClientTlsConfig, Uri}; use tonic::{Code, Status}; pub struct ChannelPool { channels: RwLock>>, - channel_index: AtomicU64, + channel_index_lock: Mutex, uri: Uri, grpc_timeout: Duration, connection_timeout: Duration, @@ -29,7 +28,7 @@ impl ChannelPool { Self { channels: RwLock::new(vec![None; pool_size]), - channel_index: AtomicU64::new(0), + channel_index_lock: Mutex::new(0), uri, grpc_timeout, connection_timeout, @@ -77,10 +76,8 @@ impl ChannelPool { .await .map_err(|e| Status::internal(format!("Failed to connect to {}: {:?}", self.uri, e)))?; - let mut pool_channels = self.channels.write().unwrap(); - + let mut pool_channels = self.channels.write(); pool_channels[channel_index] = Some(new_channel.clone()); - Ok(new_channel) } @@ -92,7 +89,6 @@ impl ChannelPool { if let Some(channel) = self .channels .read() - .unwrap() .get(channel_index) .and_then(|i| i.as_ref()) { @@ -104,7 +100,7 @@ impl ChannelPool { /// Drops the channel at the given index. fn drop_channel(&self, idx: usize) { - let mut channel = self.channels.write().unwrap(); + let mut channel = self.channels.write(); channel[idx] = None; } @@ -139,7 +135,18 @@ impl ChannelPool { /// Returns the index for the next channel to use. fn next_channel_index(&self) -> usize { - self.channel_index.fetch_add(1, Ordering::Relaxed) as usize % self.pool_size + // Avoid expensive atomic operation if pooling is disabled. + if self.pool_size == 0 { + return 0; + } + + // ChannelIndex always holds the index of the next client to return. + // Therefore we increase the counter and return the current index. + let mut channel_index = self.channel_index_lock.lock(); + let curr_idx = *channel_index; + let next = (curr_idx + 1) % self.pool_size; + *channel_index = next; + curr_idx } } @@ -184,6 +191,6 @@ mod test { assert_eq!(channel.next_channel_index(), 0); assert_eq!(channel.next_channel_index(), 1); - assert_eq!(channel.channels.read().unwrap().len(), 5); + assert_eq!(channel.channels.read().len(), 5); } } diff --git a/src/qdrant_client/mod.rs b/src/qdrant_client/mod.rs index 8ceb04a4..bae43ee0 100644 --- a/src/qdrant_client/mod.rs +++ b/src/qdrant_client/mod.rs @@ -107,7 +107,7 @@ impl Qdrant { config.timeout, config.connect_timeout, config.keep_alive_while_idle, - config.pool_size, + 1, // No need to create a pool for the compatibility check. ); let client = Self { channel: Arc::new(channel), From 129a821dd5d8591788cee87d62048a95859d0456 Mon Sep 17 00:00:00 2001 From: jojii Date: Fri, 5 Sep 2025 09:59:36 +0200 Subject: [PATCH 4/8] Update comments --- src/channel_pool.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/channel_pool.rs b/src/channel_pool.rs index c916e43b..5dd6e3ed 100644 --- a/src/channel_pool.rs +++ b/src/channel_pool.rs @@ -120,10 +120,12 @@ impl ChannelPool { Err(err) => match err.code() { Code::Internal | Code::Unavailable | Code::Cancelled | Code::Unknown => { if allow_retry { - // Recreate the channel at the given index. + // Recreate the channel at the same index when reconnecting. let channel = self.make_channel(channel_index).await?; Ok(f(channel).await?) } else { + // If retries aren't allowed, delete the channel so it will be recreated + // the next time it's used. self.drop_channel(channel_index); Err(err) } @@ -135,7 +137,7 @@ impl ChannelPool { /// Returns the index for the next channel to use. fn next_channel_index(&self) -> usize { - // Avoid expensive atomic operation if pooling is disabled. + // Avoid the expensive locking operation if pooling is disabled. if self.pool_size == 0 { return 0; } From c205e0bc821595f04ac0ce849411d7c924735273 Mon Sep 17 00:00:00 2001 From: jojii Date: Fri, 5 Sep 2025 10:01:26 +0200 Subject: [PATCH 5/8] Add config setter --- src/qdrant_client/config.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/qdrant_client/config.rs b/src/qdrant_client/config.rs index a195e14d..562dda91 100644 --- a/src/qdrant_client/config.rs +++ b/src/qdrant_client/config.rs @@ -181,6 +181,11 @@ impl QdrantConfig { self.check_compatibility = false; self } + + /// Set the pool size of concurrent connections. + pub fn set_pool_size(&mut self, pool_size: usize) { + self.pool_size = pool_size; + } } /// Default Qdrant client configuration. From 4fe7b23264ccca2e901ef76a5cbe37a4d63905c8 Mon Sep 17 00:00:00 2001 From: jojii Date: Fri, 5 Sep 2025 10:20:13 +0200 Subject: [PATCH 6/8] Specify how to disable connection pools --- src/qdrant_client/config.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/qdrant_client/config.rs b/src/qdrant_client/config.rs index 562dda91..41678503 100644 --- a/src/qdrant_client/config.rs +++ b/src/qdrant_client/config.rs @@ -40,6 +40,7 @@ pub struct QdrantConfig { pub check_compatibility: bool, /// Amount of concurrent connections. + /// If set to 0 or 1, connection pools will be disabled. pub pool_size: usize, } @@ -183,6 +184,7 @@ impl QdrantConfig { } /// Set the pool size of concurrent connections. + /// If set to 0 or 1, connection pools will be disabled. pub fn set_pool_size(&mut self, pool_size: usize) { self.pool_size = pool_size; } From 0f34cc2a540f337475904010973a0a599212d6e5 Mon Sep 17 00:00:00 2001 From: jojii Date: Fri, 5 Sep 2025 11:36:51 +0200 Subject: [PATCH 7/8] Fix pooling-enabled check --- src/channel_pool.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/channel_pool.rs b/src/channel_pool.rs index 5dd6e3ed..6f4156bf 100644 --- a/src/channel_pool.rs +++ b/src/channel_pool.rs @@ -135,10 +135,17 @@ impl ChannelPool { } } + /// Returns `true` if multiple connections being used. + fn is_connection_pooling_enabled(&self) -> bool { + // This value is never `0` becuase we enforce this in the constructor. + // 1 connection = No pooling + self.pool_size != 1 + } + /// Returns the index for the next channel to use. fn next_channel_index(&self) -> usize { // Avoid the expensive locking operation if pooling is disabled. - if self.pool_size == 0 { + if !self.is_connection_pooling_enabled() { return 0; } From 5cf66e39ba8ed069411edd38a28ecaa5b2094013 Mon Sep 17 00:00:00 2001 From: jojii Date: Fri, 5 Sep 2025 13:07:14 +0200 Subject: [PATCH 8/8] Set default pool size to 3 --- src/qdrant_client/config.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qdrant_client/config.rs b/src/qdrant_client/config.rs index 41678503..6843df32 100644 --- a/src/qdrant_client/config.rs +++ b/src/qdrant_client/config.rs @@ -203,7 +203,7 @@ impl Default for QdrantConfig { api_key: None, compression: None, check_compatibility: true, - pool_size: 1, + pool_size: 3, } } }