Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
database:
name: hn_embeddings

model:
name: text-embedding-ada-002

paths:
pickle_cache: data/embeddings_cache.pkl

openai:
api_key: YOUR_OPENAI_API_KEY

rag_pipeline:
top_k: 10
19 changes: 11 additions & 8 deletions src/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ def open_connection(dbname=None) -> duckdb.DuckDBPyConnection:
Returns:
`duckdb.DuckDBPyConnection`: A connection object to the local database.
"""
if dbname:
return duckdb.connect(f"{dbname}.db")
else:
return duckdb.connect(":memory:")

try:
if dbname:
return duckdb.connect(f"{dbname}.db")
else:
return duckdb.connect(":memory:")
except Exception as e:
print(f"Error connecting to the database: {e}")
raise

def load_extension(
con: duckdb.DuckDBPyConnection, extension: str
Expand All @@ -36,7 +39,7 @@ def load_extension(
try:
con.install_extension(extension)
con.load_extension(extension)
except:
print(f"Could not load extension {extension}")
pass
except Exception as e:
print(f"Could not load extension {extension}: {e}")
raise
return con
11 changes: 6 additions & 5 deletions src/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import src.operations as operations
import src.openai_client as openai_client
from .connection import DuckDBPyConnection

from .operations import EmbeddingKey


# Function to get embeddings, using the cache
Expand All @@ -24,7 +24,7 @@ def pickle_embeddings(
pickle_cache = operations.load_pickle_cache(pickle_path)

for text in texts:
key = (text, model)
key = EmbeddingKey(text, model)
if key not in pickle_cache:
pickle_cache[key] = openai_client.create_embedding(text, model=model)
embeddings.append(pickle_cache[key])
Expand All @@ -49,20 +49,21 @@ def duckdb_embeddings(
"""
embeddings = []
for text in texts:
key = EmbeddingKey(text, model)
# check to see if embedding is in duckdb table
result = operations.is_key_in_table(con, (text, model))
result = operations.is_key_in_table(con, key)
if result:
print("Embedding found in table")
# if so, get it
embedding = operations.get_embedding_from_table(con, text, model)
embedding = operations.get_embedding_from_table(con, key)
embeddings.append(embedding)
else:
print("Embedding not found in table")
print("Creating new embedding")
# if not, create it
embedding = openai_client.create_embedding(text, model)
# and write it to the table
operations.write_embedding_to_table(con, text, model, embedding)
operations.write_embedding_to_table(con, key, embedding)
embeddings.append(embedding)
return embeddings

Expand Down
66 changes: 66 additions & 0 deletions src/embedding_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import List, Tuple
from .connection import DuckDBPyConnection
from .operations import EmbeddingKey, load_pickle_cache, save_pickle_cache, is_key_in_table, get_embedding_from_table, write_embedding_to_table
import src.openai_client as openai_client

class EmbeddingOperations:
def __init__(self, con: DuckDBPyConnection, pickle_path: str):
self.con = con
self.pickle_path = pickle_path

def pickle_embeddings(self, texts: List[str], model: str) -> List[List[float]]:
embeddings = []
pickle_cache = load_pickle_cache(self.pickle_path)

for text in texts:
key = EmbeddingKey(text, model)
if key not in pickle_cache:
pickle_cache[key] = openai_client.create_embedding(text, model=model)
embeddings.append(pickle_cache[key])
save_pickle_cache(pickle_cache, self.pickle_path)
return embeddings

def duckdb_embeddings(self, texts: List[str], model: str) -> List[List[float]]:
embeddings = []
for text in texts:
key = EmbeddingKey(text, model)
result = is_key_in_table(self.con, key)
if result:
embedding = get_embedding_from_table(self.con, key)
embeddings.append(embedding)
else:
embedding = openai_client.create_embedding(text, model)
write_embedding_to_table(self.con, key, embedding)
embeddings.append(embedding)
return embeddings

def cosine_similarity(self, l1, l2) -> float:
return self.con.execute(f"SELECT list_cosine_similarity({l1}, {l2})").fetchall()[0][0]

def get_similarity(self, text: str, model: str) -> list[tuple[str, float]]:
sql = """
WITH q1 AS (
SELECT
? as text,
?::DOUBLE[] AS embedding
),

q2 AS (
select
distinct text,
embedding::DOUBLE[] as embedding
from embeddings
)

SELECT
b.text,
list_cosine_similarity(a.embedding::DOUBLE[], b.embedding::DOUBLE[]) AS similarity
FROM q1 a
join q2 b on a.text != b.text
ORDER BY similarity DESC
LIMIT 10
"""

embedding = self.duckdb_embeddings([text], model)[0]
result = self.con.execute(sql, [text, embedding]).fetchall()
return result
12 changes: 9 additions & 3 deletions src/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def get_openai_client() -> OpenAI:
:return: An instance of the OpenAI client.
"""
key = os.getenv("OPENAI_API_KEY")
if not key:
raise ValueError("OpenAI API key not found in environment variables.")
client = OpenAI(api_key=key)
return client

Expand All @@ -34,9 +36,13 @@ def create_embedding(
try:
client = get_openai_client()
except Exception as e:
print(e)
print(f"Error initializing OpenAI client: {e}")
return []

text = text.replace("\n", " ")
response = client.embeddings.create(input=[text], model=model, **kwargs)
return response.data[0].embedding
try:
response = client.embeddings.create(input=[text], model=model, **kwargs)
return response.data[0].embedding
except Exception as e:
print(f"Error creating embedding: {e}")
return []
61 changes: 41 additions & 20 deletions src/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,44 @@
PickleCache = Dict[Tuple[str, str], List[float]]


class EmbeddingKey:
def __init__(self, text: str, model: str):
self._text = text
self._model = model

@property
def text(self):
return self._text

@property
def model(self):
return self._model

def __eq__(self, other):
if isinstance(other, EmbeddingKey):
return self.text == other.text and self.model == other.model
return False

def __hash__(self):
return hash((self.text, self.model))


def write_embedding_to_table(
con: DuckDBPyConnection, text: str, model: str, embedding: List[float]
con: DuckDBPyConnection, key: EmbeddingKey, embedding: List[float]
) -> DuckDBPyConnection:
"""
Writes the given embedding to the `embeddings` table in the database.

Args:
con (DuckDBPyConnection): The connection to the DuckDB database.
text (str): The text associated with the embedding.
model (str): The model used to generate the embedding.
key (EmbeddingKey): The key associated with the embedding.
embedding (List[float]): The embedding vector.

Returns:
DuckDBPyConnection: The connection to the DuckDB database after the insertion.
"""
create_table_if_not_exists(con)
con.execute("INSERT INTO embeddings VALUES (?, ?, ?)", [text, model, embedding])
con.execute("INSERT INTO embeddings VALUES (?, ?, ?)", [key.text, key.model, embedding])
return con


Expand All @@ -44,39 +65,39 @@ def create_table_if_not_exists(con) -> None:
)


def is_key_in_table(con: DuckDBPyConnection, key: Tuple[str, str]) -> bool:
def is_key_in_table(con: DuckDBPyConnection, key: EmbeddingKey) -> bool:
"""
Check if a key exists in the embeddings table.

Args:
con (DuckDBPyConnection): The connection to the DuckDB database.
key (Tuple[str, str]): The key to check in the format (text, model).
key (EmbeddingKey): The key to check.

Returns:
bool: True if the key exists in the table, False otherwise.
"""
create_table_if_not_exists(con)
result = con.execute(
"SELECT EXISTS(SELECT * FROM embeddings WHERE text=? AND model=?)",
[key[0], key[1]],
[key.text, key.model],
).fetchone()
if result:
return result[0]
return False


def list_keys_in_table(
con: DuckDBPyConnection, keys: List[Tuple[str, str]]
) -> list[tuple[str, str]]:
con: DuckDBPyConnection, keys: List[EmbeddingKey]
) -> list[EmbeddingKey]:
"""
Returns a list of keys that exist in the specified table.

Args:
con (DuckDBPyConnection): The connection to the DuckDB database.
keys (List[Tuple[str, str]]): The keys to check in the table.
keys (List[EmbeddingKey]): The keys to check in the table.

Returns:
List[Tuple[str, str]]: A list of keys that exist in the table.
List[EmbeddingKey]: A list of keys that exist in the table.
"""
keys_in_table = []

Expand Down Expand Up @@ -117,7 +138,8 @@ def write_pickle_cache_to_duckdb(con: DuckDBPyConnection, pickle_path: str) -> N
cache = load_pickle_cache(pickle_path)
create_table_if_not_exists(con)
for key, value in cache.items():
write_embedding_to_table(con, key[0], key[1], value)
embedding_key = EmbeddingKey(key[0], key[1])
write_embedding_to_table(con, embedding_key, value)


# Function to save the cache to a file
Expand All @@ -136,24 +158,23 @@ def save_pickle_cache(cache: PickleCache, cache_path: str) -> None:
pickle.dump(cache, file)


def get_embedding_from_table(con: DuckDBPyConnection, text: str, model: str) -> List[float]:
def get_embedding_from_table(con: DuckDBPyConnection, key: EmbeddingKey) -> List[float]:
"""
Retrieves the embedding from the 'embeddings' table based on the given text and model.
Retrieves the embedding from the 'embeddings' table based on the given key.

Args:
con (DuckDBPyConnection): The connection to the DuckDB database.
text (str): The text to search for in the 'text' column of the table.
model (str): The model to search for in the 'model' column of the table.
key (EmbeddingKey): The key to search for in the table.

Returns:
List[float]: The embedding associated with the given text and model.
List[float]: The embedding associated with the given key.

Raises:
ValueError: If the embedding for the given text and model is not found in the table.
ValueError: If the embedding for the given key is not found in the table.
"""
result = con.execute(
"SELECT embedding FROM embeddings WHERE text=? AND model=?", [text, model]
"SELECT embedding FROM embeddings WHERE text=? AND model=?", [key.text, key.model]
).fetchone()
if result:
return result[0]
raise ValueError(f"Embedding for {text} with model {model} not found in table")
raise ValueError(f"Embedding for {key.text} with model {key.model} not found in table")
27 changes: 27 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import pytest
import yaml
from dotenv import load_dotenv

load_dotenv()

@pytest.fixture
def config_data():
with open("config.yaml", "r") as file:
return yaml.safe_load(file)

def test_database_name(config_data):
assert config_data["database"]["name"] == "hn_embeddings"

def test_model_name(config_data):
assert config_data["model"]["name"] == "text-embedding-ada-002"

def test_pickle_cache_path(config_data):
assert config_data["paths"]["pickle_cache"] == "data/embeddings_cache.pkl"

def test_openai_api_key():
api_key = os.getenv("OPENAI_API_KEY")
assert api_key == "YOUR_OPENAI_API_KEY"

def test_rag_pipeline_top_k(config_data):
assert config_data["rag_pipeline"]["top_k"] == 10
44 changes: 44 additions & 0 deletions tests/test_embedding_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
from src.embedding_operations import EmbeddingOperations
from src.connection import open_connection
from src.operations import EmbeddingKey

@pytest.fixture
def setup_db():
con = open_connection(":memory:")
con.execute("CREATE TABLE embeddings (text VARCHAR, model VARCHAR, embedding DOUBLE[])")
yield con
con.close()

def test_pickle_embeddings(setup_db):
con = setup_db
embedding_ops = EmbeddingOperations(con, "test_cache.pkl")
texts = ["test text 1", "test text 2"]
model = "test-model"
embeddings = embedding_ops.pickle_embeddings(texts, model)
assert len(embeddings) == 2

def test_duckdb_embeddings(setup_db):
con = setup_db
embedding_ops = EmbeddingOperations(con, "test_cache.pkl")
texts = ["test text 1", "test text 2"]
model = "test-model"
embeddings = embedding_ops.duckdb_embeddings(texts, model)
assert len(embeddings) == 2

def test_cosine_similarity(setup_db):
con = setup_db
embedding_ops = EmbeddingOperations(con, "test_cache.pkl")
l1 = [1.0, 2.0, 3.0]
l2 = [1.0, 2.0, 3.0]
similarity = embedding_ops.cosine_similarity(l1, l2)
assert similarity == 1.0

def test_get_similarity(setup_db):
con = setup_db
embedding_ops = EmbeddingOperations(con, "test_cache.pkl")
text = "test text"
model = "test-model"
con.execute("INSERT INTO embeddings VALUES (?, ?, ?)", [text, model, [1.0, 2.0, 3.0]])
result = embedding_ops.get_similarity(text, model)
assert len(result) == 0