Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ derive_builder = { version = "0.20.2" }
thiserror = "1.0.64"
semver = "1.0.24"
uuid = { version = "1.8.2", optional = true }
parking_lot = "0.12.4"

[dev-dependencies]
tonic-build = { version = "0.12.3", features = ["prost"] }
Expand Down
110 changes: 90 additions & 20 deletions src/channel_pool.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use std::future::Future;
use std::sync::RwLock;
use std::time::Duration;

use parking_lot::{Mutex, RwLock};
use tonic::transport::{Channel, ClientTlsConfig, Uri};
use tonic::{Code, Status};

pub struct ChannelPool {
channel: RwLock<Option<Channel>>,
channels: RwLock<Vec<Option<Channel>>>,
channel_index_lock: Mutex<usize>,
uri: Uri,
grpc_timeout: Duration,
connection_timeout: Duration,
keep_alive_while_idle: bool,
pool_size: usize,
}

impl ChannelPool {
Expand All @@ -19,17 +21,24 @@ impl ChannelPool {
grpc_timeout: Duration,
connection_timeout: Duration,
keep_alive_while_idle: bool,
mut pool_size: usize,
) -> Self {
// Ensure `pool_size` is always >= 1
pool_size = std::cmp::max(pool_size, 1);

Self {
channel: RwLock::new(None),
channels: RwLock::new(vec![None; pool_size]),
channel_index_lock: Mutex::new(0),
uri,
grpc_timeout,
connection_timeout,
keep_alive_while_idle,
pool_size,
}
}

async fn make_channel(&self) -> Result<Channel, Status> {
/// Creates a new channel at the given index. If one already exists, it will be dropped and replaced.
async fn make_channel(&self, channel_index: usize) -> Result<Channel, Status> {
let tls = match self.uri.scheme_str() {
None => false,
Some(schema) => match schema {
Expand Down Expand Up @@ -62,29 +71,37 @@ impl ChannelPool {
endpoint
};

let channel = endpoint
let new_channel = endpoint
.connect()
.await
.map_err(|e| Status::internal(format!("Failed to connect to {}: {:?}", self.uri, e)))?;
let mut self_channel = self.channel.write().unwrap();

*self_channel = Some(channel.clone());

Ok(channel)
let mut pool_channels = self.channels.write();
pool_channels[channel_index] = Some(new_channel.clone());
Ok(new_channel)
}

async fn get_channel(&self) -> Result<Channel, Status> {
if let Some(channel) = &*self.channel.read().unwrap() {
return Ok(channel.clone());
/// Returns a channel from the pool. If `pool_size` > 1, calls will return different channels in a round-robin way.
/// Otherwise, the same channel is returned each time.
async fn get_channel(&self) -> Result<(Channel, usize), Status> {
let channel_index = self.next_channel_index();

if let Some(channel) = self
.channels
.read()
.get(channel_index)
.and_then(|i| i.as_ref())
{
return Ok((channel.clone(), channel_index));
}

let channel = self.make_channel().await?;
Ok(channel)
Ok((self.make_channel(channel_index).await?, channel_index))
}

pub async fn drop_channel(&self) {
let mut channel = self.channel.write().unwrap();
*channel = None;
/// Drops the channel at the given index.
fn drop_channel(&self, idx: usize) {
let mut channel = self.channels.write();
channel[idx] = None;
}

// Allow to retry request if channel is broken
Expand All @@ -93,7 +110,7 @@ impl ChannelPool {
f: impl Fn(Channel) -> O,
allow_retry: bool,
) -> Result<T, Status> {
let channel = self.get_channel().await?;
let (channel, channel_index) = self.get_channel().await?;

let result: Result<T, Status> = f(channel).await;

Expand All @@ -102,18 +119,44 @@ impl ChannelPool {
Ok(res) => Ok(res),
Err(err) => match err.code() {
Code::Internal | Code::Unavailable | Code::Cancelled | Code::Unknown => {
self.drop_channel().await;
if allow_retry {
let channel = self.get_channel().await?;
// Recreate the channel at the same index when reconnecting.
let channel = self.make_channel(channel_index).await?;
Ok(f(channel).await?)
} else {
// If retries aren't allowed, delete the channel so it will be recreated
// the next time it's used.
self.drop_channel(channel_index);
Err(err)
}
}
_ => Err(err)?,
},
}
}

/// Returns `true` if multiple connections being used.
fn is_connection_pooling_enabled(&self) -> bool {
// This value is never `0` becuase we enforce this in the constructor.
// 1 connection = No pooling
self.pool_size != 1
}

/// Returns the index for the next channel to use.
fn next_channel_index(&self) -> usize {
// Avoid the expensive locking operation if pooling is disabled.
if !self.is_connection_pooling_enabled() {
return 0;
}

// ChannelIndex always holds the index of the next client to return.
// Therefore we increase the counter and return the current index.
let mut channel_index = self.channel_index_lock.lock();
let curr_idx = *channel_index;
let next = (curr_idx + 1) % self.pool_size;
*channel_index = next;
curr_idx
}
}

// The future returned by get_channel needs to be Send so that the client can be
Expand All @@ -127,9 +170,36 @@ fn require_get_channel_fn_to_be_send() {
Duration::from_millis(0),
Duration::from_millis(0),
false,
2,
)
.get_channel()
.await
.expect("get channel should not error");
});
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_channel_counter() {
let channel = ChannelPool::new(
Uri::from_static("http://localhost:6444"),
Duration::default(),
Duration::default(),
false,
5,
);

assert_eq!(channel.next_channel_index(), 0);
assert_eq!(channel.next_channel_index(), 1);
assert_eq!(channel.next_channel_index(), 2);
assert_eq!(channel.next_channel_index(), 3);
assert_eq!(channel.next_channel_index(), 4);
assert_eq!(channel.next_channel_index(), 0);
assert_eq!(channel.next_channel_index(), 1);

assert_eq!(channel.channels.read().len(), 5);
}
}
1 change: 1 addition & 0 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ impl QdrantClient {
cfg.timeout,
cfg.connect_timeout,
cfg.keep_alive_while_idle,
1,
);

let client = Self { channel, cfg };
Expand Down
8 changes: 4 additions & 4 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl Debug for NotA<bool> {

impl Display for NotA<bool> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(concat!("not a bool"))
f.write_str("not a bool")
}
}

Expand All @@ -87,7 +87,7 @@ impl Debug for NotA<i64> {

impl Display for NotA<i64> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(concat!("not an i64"))
f.write_str("not an i64")
}
}

Expand All @@ -114,7 +114,7 @@ impl Debug for NotA<f64> {

impl Display for NotA<f64> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(concat!("not a f64"))
f.write_str("not a f64")
}
}

Expand All @@ -141,7 +141,7 @@ impl Debug for NotA<String> {

impl Display for NotA<String> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str(concat!("not a String"))
f.write_str("not a String")
}
}

Expand Down
11 changes: 11 additions & 0 deletions src/qdrant_client/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ pub struct QdrantConfig {

/// Whether to check compatibility between the client and server versions
pub check_compatibility: bool,

/// Amount of concurrent connections.
/// If set to 0 or 1, connection pools will be disabled.
pub pool_size: usize,
}

impl QdrantConfig {
Expand Down Expand Up @@ -178,6 +182,12 @@ impl QdrantConfig {
self.check_compatibility = false;
self
}

/// Set the pool size of concurrent connections.
/// If set to 0 or 1, connection pools will be disabled.
pub fn set_pool_size(&mut self, pool_size: usize) {
self.pool_size = pool_size;
}
}

/// Default Qdrant client configuration.
Expand All @@ -193,6 +203,7 @@ impl Default for QdrantConfig {
api_key: None,
compression: None,
check_compatibility: true,
pool_size: 3,
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/qdrant_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ impl Qdrant {
config.timeout,
config.connect_timeout,
config.keep_alive_while_idle,
1, // No need to create a pool for the compatibility check.
);
let client = Self {
channel: Arc::new(channel),
Expand Down Expand Up @@ -151,6 +152,7 @@ impl Qdrant {
config.timeout,
config.connect_timeout,
config.keep_alive_while_idle,
config.pool_size,
);

let client = Self {
Expand Down