Skip to content

sentenceTransformar embedding model download locally to use #1361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,22 @@ RUN apt-get update && \
tesseract-ocr && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

# Set LD_LIBRARY_PATH
ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
# Copy requirements file and install Python dependencies
COPY requirements.txt constraints.txt /code/
# --no-cache-dir --upgrade
RUN pip install --upgrade pip
RUN pip install -r requirements.txt -c constraints.txt

RUN python -c "from transformers import AutoTokenizer, AutoModel; \
name='sentence-transformers/all-MiniLM-L6-v2'; \
tok=AutoTokenizer.from_pretrained(name); \
mod=AutoModel.from_pretrained(name); \
tok.save_pretrained('./local_model'); \
mod.save_pretrained('./local_model')"

# Copy application code
COPY . /code
# Set command
Expand Down
2 changes: 1 addition & 1 deletion backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ wrapt==1.17.2
yarl==1.20.1
youtube-transcript-api==1.1.0
zipp==3.23.0
sentence-transformers==4.1.0
sentence-transformers==5.0.0
google-cloud-logging==3.12.1
pypandoc==1.15
graphdatascience==1.15.1
Expand Down
4 changes: 2 additions & 2 deletions backend/src/QA_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
load_dotenv()

EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL')
EMBEDDING_FUNCTION , _ = load_embedding_model(EMBEDDING_MODEL)

class SessionChatHistory:
history_dict = {}
Expand Down Expand Up @@ -304,6 +303,7 @@ def create_document_retriever_chain(llm, retriever):
output_parser = StrOutputParser()

splitter = TokenTextSplitter(chunk_size=CHAT_DOC_SPLIT_SIZE, chunk_overlap=0)
EMBEDDING_FUNCTION , _ = load_embedding_model(EMBEDDING_MODEL)
embeddings_filter = EmbeddingsFilter(
embeddings=EMBEDDING_FUNCTION,
similarity_threshold=CHAT_EMBEDDING_FILTER_SCORE_THRESHOLD
Expand Down Expand Up @@ -344,7 +344,7 @@ def initialize_neo4j_vector(graph, chat_mode_settings):

if not retrieval_query or not index_name:
raise ValueError("Required settings 'retrieval_query' or 'index_name' are missing.")

EMBEDDING_FUNCTION , _ = load_embedding_model(EMBEDDING_MODEL)
if keyword_index:
neo_db = Neo4jVector.from_existing_graph(
embedding=EMBEDDING_FUNCTION,
Expand Down
4 changes: 2 additions & 2 deletions backend/src/make_relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
logging.basicConfig(format='%(asctime)s - %(message)s',level='INFO')

EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL')
EMBEDDING_FUNCTION , EMBEDDING_DIMENSION = load_embedding_model(EMBEDDING_MODEL)

def merge_relationship_between_chunk_and_entites(graph: Neo4jGraph, graph_documents_chunk_chunk_Id : list):
batch_data = []
Expand Down Expand Up @@ -41,7 +40,7 @@ def merge_relationship_between_chunk_and_entites(graph: Neo4jGraph, graph_docume
def create_chunk_embeddings(graph, chunkId_chunkDoc_list, file_name):
isEmbedding = os.getenv('IS_EMBEDDING')

embeddings, dimension = EMBEDDING_FUNCTION , EMBEDDING_DIMENSION
embeddings, dimension = load_embedding_model(EMBEDDING_MODEL)
logging.info(f'embedding model:{embeddings} and dimesion:{dimension}')
data_for_query = []
logging.info(f"update embedding and vector index for chunks")
Expand Down Expand Up @@ -161,6 +160,7 @@ def create_chunk_vector_index(graph):
vector_index_query = "SHOW INDEXES YIELD name, type, labelsOrTypes, properties WHERE name = 'vector' AND type = 'VECTOR' AND 'Chunk' IN labelsOrTypes AND 'embedding' IN properties RETURN name"
vector_index = execute_graph_query(graph,vector_index_query)
if not vector_index:
EMBEDDING_FUNCTION , EMBEDDING_DIMENSION = load_embedding_model(EMBEDDING_MODEL)
vector_store = Neo4jVector(embedding=EMBEDDING_FUNCTION,
graph=graph,
node_label="Chunk",
Expand Down
44 changes: 41 additions & 3 deletions backend/src/shared/common_fn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import hashlib
import os
from transformers import AutoTokenizer, AutoModel
from langchain_huggingface import HuggingFaceEmbeddings
from threading import Lock
from threading import Lock
import logging
from src.document_sources.youtube import create_youtube_url
from langchain_huggingface import HuggingFaceEmbeddings
Expand All @@ -16,6 +21,40 @@
import boto3
from langchain_community.embeddings import BedrockEmbeddings

MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
MODEL_PATH = "./local_model"
_lock = Lock()
_embedding_instance = None

def ensure_sentence_transformer_model_downloaded():
if os.path.isdir(MODEL_PATH):
print("Model already downloaded at:", MODEL_PATH)
return
else:
print("Downloading model to:", MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
tokenizer.save_pretrained(MODEL_PATH)
model.save_pretrained(MODEL_PATH)
print("Model downloaded and saved.")

def get_local_sentence_transformer_embedding():
"""
Lazy, threadsafe singleton. Caller does not need to worry about
import-time initialization or download race.
"""
global _embedding_instance
if _embedding_instance is not None:
return _embedding_instance
with _lock:
if _embedding_instance is not None:
return _embedding_instance
# Ensure model is present before instantiating
ensure_sentence_transformer_model_downloaded()
_embedding_instance = HuggingFaceEmbeddings(model_name=MODEL_PATH)
print("Embedding model initialized.")
return _embedding_instance

def check_url_source(source_type, yt_url:str=None, wiki_query:str=None):
language=''
try:
Expand Down Expand Up @@ -85,9 +124,8 @@ def load_embedding_model(embedding_model_name: str):
dimension = 1536
logging.info(f"Embedding: Using bedrock titan Embeddings , Dimension:{dimension}")
else:
embeddings = HuggingFaceEmbeddings(
model_name="all-MiniLM-L6-v2"#, cache_folder="/embedding_model"
)
# embeddings = HuggingFaceEmbeddings(model_name="./local_model")
embeddings = get_local_sentence_transformer_embedding()
dimension = 384
logging.info(f"Embedding: Using Langchain HuggingFaceEmbeddings , Dimension:{dimension}")
return embeddings, dimension
Expand Down