From 512b4222b3361dd5bad4fbf609c6e49ae2452343 Mon Sep 17 00:00:00 2001 From: ThisIsDemetrio Date: Thu, 22 Aug 2024 17:10:02 +0200 Subject: [PATCH 1/8] wip: generate embeddings from web scraping with api --- requirements.txt | 2 + .../embeddings/embeddings_handler.py | 59 ++++++ src/api/schemas/embeddings_schemas.py | 10 + src/app.py | 2 + .../embeddings/document_chunker.py | 47 +++++ .../embeddings/embedding_generator.py | 171 ++++++++++++++++++ .../embeddings/hyperlink_parser.py | 30 +++ 7 files changed, 321 insertions(+) create mode 100644 src/api/controllers/embeddings/embeddings_handler.py create mode 100644 src/api/schemas/embeddings_schemas.py create mode 100644 src/application/embeddings/document_chunker.py create mode 100644 src/application/embeddings/embedding_generator.py create mode 100644 src/application/embeddings/hyperlink_parser.py diff --git a/requirements.txt b/requirements.txt index 87f3560..ebf0122 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ argcomplete==3.3.0 astroid==3.1.0 async-timeout==4.0.3 attrs==23.2.0 +beautifulsoup4==4.12.3 black==24.4.0 certifi==2024.2.2 cfgv==3.4.0 @@ -93,6 +94,7 @@ rpds-py==0.18.0 setuptools==73.0.0 six==1.16.0 sniffio==1.3.1 +soupsieve==2.6 SQLAlchemy==2.0.29 starlette==0.37.2 tavily-python==0.3.3 diff --git a/src/api/controllers/embeddings/embeddings_handler.py b/src/api/controllers/embeddings/embeddings_handler.py new file mode 100644 index 0000000..47a40f3 --- /dev/null +++ b/src/api/controllers/embeddings/embeddings_handler.py @@ -0,0 +1,59 @@ +import asyncio +from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status + +from src.api.schemas.status_ok_schema import StatusOkResponseSchema +from src.application.embeddings.embedding_generator import EmbeddingGenerator +from src.api.schemas.embeddings_schemas import GenerateEmbeddingsInputSchema +from src.context import AppContext + +router = APIRouter() + +lock = asyncio.Lock() + +async def generate_embeddings_from_url(url: str, app_context: AppContext): + """ + Generate embeddings for a given URL. + """ + embedding_generator = EmbeddingGenerator(app_context=app_context) + await embedding_generator.crawl(url) + +@router.post( + "/embeddings/generate", + response_model=StatusOkResponseSchema, + status_code=status.HTTP_200_OK, + tags=["Embeddings"] +) +async def generate_embeddings(request: Request, data: GenerateEmbeddingsInputSchema, background_tasks: BackgroundTasks): + """ + Generate embeddings for a given URL. + """ + + request_context: AppContext = request.state.app_context + url = data.url + request_context.logger.info(f"Generate embeddings request received for url: {url}") + + if lock.locked(): + request_context.logger.info("Generation embeddings process already in progress.") + + raise HTTPException(status_code=409, detail="A process to generate embeddings is already in progress.") + + async with lock: + background_tasks.add_task(generate_embeddings_from_url, url, request_context) + + request_context.logger.info("Generation embeddings process started.") + return {"statusOk": True} + +@router.get( + "/embeddings/status", + response_model=StatusOkResponseSchema, + status_code=status.HTTP_200_OK, + tags=["Embeddings"] +) +async def embeddings_status(): + """ + Get the status of the embeddings generation process. + """ + if lock.locked(): + raise HTTPException(status_code=409, detail="A process to generate embeddings is already in progress.") + + return {"statusOk": True} diff --git a/src/api/schemas/embeddings_schemas.py b/src/api/schemas/embeddings_schemas.py new file mode 100644 index 0000000..78c838c --- /dev/null +++ b/src/api/schemas/embeddings_schemas.py @@ -0,0 +1,10 @@ +from typing import Any, Dict +from pydantic import BaseModel + + +class GenerateEmbeddingsInputSchema(BaseModel): + url: str + +class GenerateEmbeddingsOutputSchema(BaseModel): + state: str + metadata: Dict[str, Any] diff --git a/src/app.py b/src/app.py index 171bcdf..09ac5c8 100644 --- a/src/app.py +++ b/src/app.py @@ -6,6 +6,7 @@ from src.api.controllers.core.liveness import liveness_handler from src.api.controllers.core.readiness import readiness_handler from src.api.controllers.core.metrics import metrics_handler +from src.api.controllers.embeddings import embeddings_handler from src.api.middlewares.app_context_middleware import AppContextMiddleware from src.api.middlewares.logger_middleware import LoggerMiddleware from src.configurations.configuration import get_configuration @@ -32,6 +33,7 @@ def create_app(context: AppContext) -> FastAPI: app.include_router(metrics_handler.router) app.include_router(chat_completions_handler.router) + app.include_router(embeddings_handler.router) return app diff --git a/src/application/embeddings/document_chunker.py b/src/application/embeddings/document_chunker.py new file mode 100644 index 0000000..88b960c --- /dev/null +++ b/src/application/embeddings/document_chunker.py @@ -0,0 +1,47 @@ +""" +Module to include the EmbeddingGenerator class, a class that generates embeddings for text data. +""" + +import hashlib +from typing import List +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_experimental.text_splitter import SemanticChunker + +DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small" + + +class DocumentChunker(): + """ + Initialize the DocumentChunker class. + """ + + def __init__(self, embedding: Embeddings) -> None: + self._chunker = SemanticChunker(embeddings=embedding, breakpoint_threshold_type='percentile') + + def _remove_consecutive_newlines(self, text: str) -> str: + """ + Remove duplicate newlines from the text. + """ + return "\n".join([line for line in text.split("\n\n") if line.strip()]) + + def _generate_sha(self, content: str) -> str: + """ + Generate a SHA hash for the given content. + """ + return hashlib.sha256(content.encode()).hexdigest() + + def split_text_into_chunks(self, text: str, url: str) -> List[Document]: + """ + Generate chunks via semantic separation from a given text + + Args: + text (str): The input text. + url (str): The URL of the text. + """ + content = self._remove_consecutive_newlines(text) + sha = self._generate_sha(content) + + document = Document(page_content=content, metadata={"sha": sha, "url": url}) + chunks = [Document(page_content=chunk) for chunk in self._chunker.split_text(document.page_content)] + return [Document(page_content=chunk.page_content, metadata=document.metadata.copy()) for chunk in chunks] diff --git a/src/application/embeddings/embedding_generator.py b/src/application/embeddings/embedding_generator.py new file mode 100644 index 0000000..64f906b --- /dev/null +++ b/src/application/embeddings/embedding_generator.py @@ -0,0 +1,171 @@ +""" +This script crawls a website and saves the embeddings extracted from the text of each page to a text file. +""" +import re +import urllib.request +from collections import deque +from urllib.error import URLError +from urllib.parse import urlparse + +from langchain_openai import OpenAIEmbeddings +import requests +from bs4 import BeautifulSoup + +from langchain_community.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch +from src.context import AppContext +from src.application.embeddings.document_chunker import DocumentChunker +from src.application.embeddings.hyperlink_parser import HyperlinkParser + +# Regex pattern to match a URL +HTTP_URL_PATTERN = r"^http[s]*://.+" + +class SingletonMeta(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + return cls._instances[cls] + + +class EmbeddingGenerator(metaclass=SingletonMeta): + """ + Class to generate embeddings for text data. + """ + + def __init__(self, app_context: AppContext): + self.logger = app_context.logger + mongodb_cluster_uri = app_context.env_vars.MONGODB_CLUSTER_URI + embedding_api_key = app_context.env_vars.EMBEDDINGS_API_KEY + configuration = app_context.configurations + + embedding = OpenAIEmbeddings(openai_api_key=embedding_api_key, model=configuration.embeddings.name) + + self._document_chunker = DocumentChunker(embedding=embedding) + + self._embedding_vector_store = MongoDBAtlasVectorSearch.from_connection_string( + connection_string=mongodb_cluster_uri, + namespace=f"{configuration.vectorStore.dbName}.{configuration.vectorStore.collectionName}", + embedding=embedding, + index_name=configuration.vectorStore.indexName, + embedding_key=configuration.vectorStore.embeddingKey, + relevance_score_fn=configuration.vectorStore.relevanceScoreFn, + text_key=configuration.vectorStore.textKey + ) + + + def _get_hyperlinks(self, url): + """ + Function to get the hyperlinks from a URL + """ + + try: + # Open the URL and read the HTML + with urllib.request.urlopen(url) as response: + # If the response is not HTML, return an empty list + if not response.info().get("Content-Type").startswith("text/html"): + return [] + + html = response.read().decode("utf-8") + except URLError as e: + self.logger.error(e) + return [] + + # Create the HTML Parser and then Parse the HTML to get hyperlinks + parser = HyperlinkParser() + parser.feed(html) + + return parser.hyperlinks + + + def _get_domain_hyperlinks(self, local_domain: str, url: str): + """ + Function to get the hyperlinks from a URL that are within the same domain + + Args: + local_domain (str): The domain to compare the hyperlinks against. + url (str): The URL to extract hyperlinks from. + + Returns: + list: A list of hyperlinks that are within the same domain. + """ + clean_links = [] + for link in set(self._get_hyperlinks(url)): + clean_link = None + + # If the link is a URL, check if it is within the same domain + if re.search(HTTP_URL_PATTERN, link): + # Parse the URL and check if the domain is the same + url_obj = urlparse(link) + # Link should be within the same domain and should start with one of the paths + if url_obj.netloc == local_domain: + clean_link = link + + # If the link is not a URL, check if it is a relative link + else: + if link.startswith("/"): + link = link[1:] + elif link.startswith("#") or link.startswith("mailto:"): + continue + clean_link = "https://" + local_domain + "/" + link + + if clean_link is not None: + if clean_link.endswith("/"): + clean_link = clean_link[:-1] + clean_links.append(clean_link) + + return list(set(clean_links)) + + def crawl(self, url: str): + """ + Crawls the given URL and saves the text content of each page to a text file. + + Args: + url (str): The URL to crawl. + + Returns: + None + """ + + local_domain = urlparse(url).netloc + + queue = deque([url]) + seen = set([url]) + + # While the queue is not empty, continue crawling + while queue: + # Get the next URL from the queue + url = queue.pop() + self.logger.debug(f"Crawling page: {url}") # for debugging and to see the progress + + # Get the text from the URL using BeautifulSoup + soup = BeautifulSoup(requests.get(url, timeout=5).text, "html.parser") + + # Get the text but remove the tags + text = soup.get_text() + + # If the crawler gets to a page that requires JavaScript, it will stop the crawl + if "You need to enable JavaScript to run this app." in text: + self.logger.debug( + "Unable to parse page " + url + " due to JavaScript being required" + ) + continue + + chunks = self._document_chunker.split_text_into_chunks(text=text, url=url) + self.logger.debug(f"Extracted {len(chunks)} chunks from the page. Generated embeddings for these...") # for debugging and to see the progress + self._embedding_vector_store.add_documents(chunks) + self.logger.debug("Embeddings generation completed. Extracting links...") # for debugging and to see the progress + + hyperlinks = self._get_domain_hyperlinks(local_domain, url) + if len(hyperlinks) == 0: + self.logger.debug("No links found, move on.") + + # Get the hyperlinks from the URL and add them to the queue + for link in hyperlinks: + if link not in seen: + self.logger.debug(f"Found new link: {link}") + queue.append(link) + seen.add(link) + + self.logger.debug("Crawling completed.") diff --git a/src/application/embeddings/hyperlink_parser.py b/src/application/embeddings/hyperlink_parser.py new file mode 100644 index 0000000..62c0acf --- /dev/null +++ b/src/application/embeddings/hyperlink_parser.py @@ -0,0 +1,30 @@ +""" +Module providing the HyperlinkParser class. +""" + +from html.parser import HTMLParser + + +class HyperlinkParser(HTMLParser): + """ + A class that parses HTML and extracts hyperlinks. + + Attributes: + hyperlinks (list): A list to store the extracted hyperlinks. + + Methods: + handle_starttag(tag, attrs): Overrides the HTMLParser's handle_starttag method to extract hyperlinks. + """ + + def __init__(self): + super().__init__() + # Create a list to store the hyperlinks + self.hyperlinks = [] + + # Override the HTMLParser's handle_starttag method to get the hyperlinks + def handle_starttag(self, tag, attrs): + attrs = dict(attrs) + + # If the tag is an anchor tag and it has an href attribute, add the href attribute to the list of hyperlinks + if tag == "a" and "href" in attrs: + self.hyperlinks.append(attrs["href"]) From bfab19605549d65398caeba42a4a490c66f8cf68 Mon Sep 17 00:00:00 2001 From: ThisIsDemetrio Date: Thu, 22 Aug 2024 18:05:39 +0200 Subject: [PATCH 2/8] feat: using locks for generating embeddings --- .../embeddings/embeddings_handler.py | 57 +++++++++++++------ .../embeddings/embedding_generator.py | 23 +++----- 2 files changed, 47 insertions(+), 33 deletions(-) diff --git a/src/api/controllers/embeddings/embeddings_handler.py b/src/api/controllers/embeddings/embeddings_handler.py index 47a40f3..286f119 100644 --- a/src/api/controllers/embeddings/embeddings_handler.py +++ b/src/api/controllers/embeddings/embeddings_handler.py @@ -1,4 +1,3 @@ -import asyncio from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status from src.api.schemas.status_ok_schema import StatusOkResponseSchema @@ -8,14 +7,28 @@ router = APIRouter() -lock = asyncio.Lock() +# This is a simplified mechanism to prevent multiple requests from generating embeddings at the same time. +# For this specific router, we designed it to only allow one request to generate embeddings at a time, therefore +# we use a lock to prevent multiple requests from starting the process at the same time. +# +# In case of need of multiple methods that requires locking or very long tasks, +# you might want to use a more sophisticated mechanism to handle this. +router.lock = False -async def generate_embeddings_from_url(url: str, app_context: AppContext): +def generate_embeddings_from_url(url: str, app_context: AppContext): """ Generate embeddings for a given URL. """ - embedding_generator = EmbeddingGenerator(app_context=app_context) - await embedding_generator.crawl(url) + logger = app_context.logger + + try: + router.lock = True + embedding_generator = EmbeddingGenerator(app_context=app_context) + embedding_generator.generate(url) + except Exception as e: + logger.error(f"Error in background task: {str(e)}") + finally: + router.lock = False @router.post( "/embeddings/generate", @@ -23,25 +36,29 @@ async def generate_embeddings_from_url(url: str, app_context: AppContext): status_code=status.HTTP_200_OK, tags=["Embeddings"] ) -async def generate_embeddings(request: Request, data: GenerateEmbeddingsInputSchema, background_tasks: BackgroundTasks): +def generate_embeddings(request: Request, data: GenerateEmbeddingsInputSchema, background_tasks: BackgroundTasks): """ Generate embeddings for a given URL. + + This method can be run only one at a time, as it uses a lock to prevent multiple requests from starting the process at the same time. + If a process is already in progress, it will return a 409 status code (Conflict). + + Args: + request (Request): The request object. + data (GenerateEmbeddingsInputSchema): The input schema. + background_tasks (BackgroundTasks): The background tasks object. """ request_context: AppContext = request.state.app_context url = data.url request_context.logger.info(f"Generate embeddings request received for url: {url}") - if lock.locked(): - request_context.logger.info("Generation embeddings process already in progress.") - - raise HTTPException(status_code=409, detail="A process to generate embeddings is already in progress.") - - async with lock: + if not router.lock: background_tasks.add_task(generate_embeddings_from_url, url, request_context) - - request_context.logger.info("Generation embeddings process started.") - return {"statusOk": True} + request_context.logger.info("Generation embeddings process started.") + return {"statusOk": True} + + raise HTTPException(status_code=409, detail="A process to generate embeddings is already in progress.") @router.get( "/embeddings/status", @@ -49,11 +66,17 @@ async def generate_embeddings(request: Request, data: GenerateEmbeddingsInputSch status_code=status.HTTP_200_OK, tags=["Embeddings"] ) -async def embeddings_status(): +def embeddings_status(): """ Get the status of the embeddings generation process. + + Returns: + dict: A StatusOkResponseSchema responding _True_. + + Raises: + HTTPException: If a process to generate embeddings is already in progress (thus: the lock is on _True_). """ - if lock.locked(): + if router.lock: raise HTTPException(status_code=409, detail="A process to generate embeddings is already in progress.") return {"statusOk": True} diff --git a/src/application/embeddings/embedding_generator.py b/src/application/embeddings/embedding_generator.py index 64f906b..26b1331 100644 --- a/src/application/embeddings/embedding_generator.py +++ b/src/application/embeddings/embedding_generator.py @@ -19,17 +19,7 @@ # Regex pattern to match a URL HTTP_URL_PATTERN = r"^http[s]*://.+" -class SingletonMeta(type): - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - instance = super().__call__(*args, **kwargs) - cls._instances[cls] = instance - return cls._instances[cls] - - -class EmbeddingGenerator(metaclass=SingletonMeta): +class EmbeddingGenerator(): """ Class to generate embeddings for text data. """ @@ -117,12 +107,13 @@ def _get_domain_hyperlinks(self, local_domain: str, url: str): return list(set(clean_links)) - def crawl(self, url: str): + def generate(self, url: str): """ Crawls the given URL and saves the text content of each page to a text file. Args: - url (str): The URL to crawl. + url (str): The URL to crawl. From this URL, the crawler will extract the text content + of said page and any other page connected via hyperlinks (anchor tags). Returns: None @@ -137,7 +128,7 @@ def crawl(self, url: str): while queue: # Get the next URL from the queue url = queue.pop() - self.logger.debug(f"Crawling page: {url}") # for debugging and to see the progress + self.logger.debug(f"Scraping page: {url}") # for debugging and to see the progress # Get the text from the URL using BeautifulSoup soup = BeautifulSoup(requests.get(url, timeout=5).text, "html.parser") @@ -148,7 +139,7 @@ def crawl(self, url: str): # If the crawler gets to a page that requires JavaScript, it will stop the crawl if "You need to enable JavaScript to run this app." in text: self.logger.debug( - "Unable to parse page " + url + " due to JavaScript being required" + f"Unable to parse page {url} due to JavaScript being required" ) continue @@ -168,4 +159,4 @@ def crawl(self, url: str): queue.append(link) seen.add(link) - self.logger.debug("Crawling completed.") + self.logger.debug("Scraping completed.") From 995f85d0f4e77ec5f9f7ca0c1522fefcf2571fff Mon Sep 17 00:00:00 2001 From: ThisIsDemetrio Date: Fri, 23 Aug 2024 11:29:49 +0200 Subject: [PATCH 3/8] fix: embedding_status to return False if a process is in progress --- src/api/controllers/embeddings/embeddings_handler.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/api/controllers/embeddings/embeddings_handler.py b/src/api/controllers/embeddings/embeddings_handler.py index 286f119..26a1f19 100644 --- a/src/api/controllers/embeddings/embeddings_handler.py +++ b/src/api/controllers/embeddings/embeddings_handler.py @@ -71,12 +71,6 @@ def embeddings_status(): Get the status of the embeddings generation process. Returns: - dict: A StatusOkResponseSchema responding _True_. - - Raises: - HTTPException: If a process to generate embeddings is already in progress (thus: the lock is on _True_). + dict: A StatusOkResponseSchema responding _True_ if there are no process in progress. Otherwise, it will return _False_. """ - if router.lock: - raise HTTPException(status_code=409, detail="A process to generate embeddings is already in progress.") - - return {"statusOk": True} + return {"statusOk": not router.lock} From 127562d8646d14f0a63d5b77832b7500f3314488 Mon Sep 17 00:00:00 2001 From: ThisIsDemetrio Date: Fri, 23 Aug 2024 15:53:51 +0200 Subject: [PATCH 4/8] feat: minor updates, tests --- .../embeddings/embeddings_handler.py | 15 +++++-- src/api/schemas/embeddings_schemas.py | 5 ++- .../api/controllers/embedding_handler_test.py | 41 +++++++++++++++++++ .../embeddings/assets/example.html | 10 +++++ .../embeddings/document_chunker_test.py | 16 ++++++++ .../embeddings/embedding_generator_test.py | 26 ++++++++++++ .../embeddings/hyperlink_parser_test.py | 27 ++++++++++++ 7 files changed, 135 insertions(+), 5 deletions(-) create mode 100644 tests/src/api/controllers/embedding_handler_test.py create mode 100644 tests/src/application/embeddings/assets/example.html create mode 100644 tests/src/application/embeddings/document_chunker_test.py create mode 100644 tests/src/application/embeddings/embedding_generator_test.py create mode 100644 tests/src/application/embeddings/hyperlink_parser_test.py diff --git a/src/api/controllers/embeddings/embeddings_handler.py b/src/api/controllers/embeddings/embeddings_handler.py index 26a1f19..2f983a6 100644 --- a/src/api/controllers/embeddings/embeddings_handler.py +++ b/src/api/controllers/embeddings/embeddings_handler.py @@ -2,7 +2,7 @@ from src.api.schemas.status_ok_schema import StatusOkResponseSchema from src.application.embeddings.embedding_generator import EmbeddingGenerator -from src.api.schemas.embeddings_schemas import GenerateEmbeddingsInputSchema +from src.api.schemas.embeddings_schemas import GenerateEmbeddingsInputSchema, GenerateStatusOutputSchema from src.context import AppContext router = APIRouter() @@ -17,7 +17,14 @@ def generate_embeddings_from_url(url: str, app_context: AppContext): """ - Generate embeddings for a given URL. + Generate embeddings for a given URL. + + This method is intended to be called as a background task. Includes managmeent of the lock mechanism + of this router, which is locked when the embedding generation process is running, and unlocked when it finishes. + + Args: + url (str): The URL to generate embeddings from. + app_context (AppContext): The application context. """ logger = app_context.logger @@ -62,7 +69,7 @@ def generate_embeddings(request: Request, data: GenerateEmbeddingsInputSchema, b @router.get( "/embeddings/status", - response_model=StatusOkResponseSchema, + response_model=GenerateStatusOutputSchema, status_code=status.HTTP_200_OK, tags=["Embeddings"] ) @@ -73,4 +80,4 @@ def embeddings_status(): Returns: dict: A StatusOkResponseSchema responding _True_ if there are no process in progress. Otherwise, it will return _False_. """ - return {"statusOk": not router.lock} + return {"status": "running" if router.lock else "idle"} diff --git a/src/api/schemas/embeddings_schemas.py b/src/api/schemas/embeddings_schemas.py index 78c838c..08c6d06 100644 --- a/src/api/schemas/embeddings_schemas.py +++ b/src/api/schemas/embeddings_schemas.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Literal from pydantic import BaseModel @@ -8,3 +8,6 @@ class GenerateEmbeddingsInputSchema(BaseModel): class GenerateEmbeddingsOutputSchema(BaseModel): state: str metadata: Dict[str, Any] + +class GenerateStatusOutputSchema(BaseModel): + status: Literal["running", "idle"] \ No newline at end of file diff --git a/tests/src/api/controllers/embedding_handler_test.py b/tests/src/api/controllers/embedding_handler_test.py new file mode 100644 index 0000000..9e63ed0 --- /dev/null +++ b/tests/src/api/controllers/embedding_handler_test.py @@ -0,0 +1,41 @@ +from unittest.mock import patch + +from src.api.controllers.embeddings.embeddings_handler import router + + +def test_generate_embeddings_success(test_client): + url = "http://example.com" + data = {"url": url} + + with patch("src.api.controllers.embeddings.embeddings_handler.EmbeddingGenerator.generate") as mock_generate: + response = test_client.post("/embeddings/generate", json=data) + + assert response.status_code == 200 + assert response.json() == {"statusOk": True} + mock_generate.assert_called_once_with(url) + +def test_generate_embeddings_conflict(test_client): + url = "http://example.com" + data = {"url": url} + + router.lock = True # Simulate a process already in progress + response = test_client.post("/embeddings/generate", json=data) + + assert response.status_code == 409 + assert response.json() == {"detail": "A process to generate embeddings is already in progress."} + router.lock = False # Reset lock for other tests + +def test_embeddings_status_idle(test_client): + router.lock = False # Ensure no process is running + response = test_client.get("/embeddings/status") + + assert response.status_code == 200 + assert response.json() == {"status": "idle"} + +def test_embeddings_status_running(test_client): + router.lock = True # Simulate a process running + response = test_client.get("/embeddings/status") + + assert response.status_code == 200 + assert response.json() == {"status": "running"} + router.lock = False # Reset lock for other tests \ No newline at end of file diff --git a/tests/src/application/embeddings/assets/example.html b/tests/src/application/embeddings/assets/example.html new file mode 100644 index 0000000..5da27b3 --- /dev/null +++ b/tests/src/application/embeddings/assets/example.html @@ -0,0 +1,10 @@ + + + + Lorem Ipsum Example + + +

Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.

+

Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.

+ + \ No newline at end of file diff --git a/tests/src/application/embeddings/document_chunker_test.py b/tests/src/application/embeddings/document_chunker_test.py new file mode 100644 index 0000000..8178583 --- /dev/null +++ b/tests/src/application/embeddings/document_chunker_test.py @@ -0,0 +1,16 @@ +from unittest.mock import patch +from langchain_openai import OpenAIEmbeddings +from src.application.embeddings.document_chunker import DocumentChunker + + +def test_split_text_into_chunks(): + with patch("langchain_experimental.text_splitter.SemanticChunker.split_text") as mock_split_text: + mock_split_text.return_value = ["This is a test.", "this is another test."] + embedding = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key="embeddings_api_key") + document_chunker = DocumentChunker(embedding) + text = "This is a test. This is another test." + url = "http://example.com" + chunks = document_chunker.split_text_into_chunks(text, url) + + assert mock_split_text.call_count == 1 + assert len(chunks) == 2 diff --git a/tests/src/application/embeddings/embedding_generator_test.py b/tests/src/application/embeddings/embedding_generator_test.py new file mode 100644 index 0000000..7fbfdbc --- /dev/null +++ b/tests/src/application/embeddings/embedding_generator_test.py @@ -0,0 +1,26 @@ +from pathlib import Path +from unittest.mock import patch + +from src.application.embeddings.embedding_generator import EmbeddingGenerator + + +def test_generate(app_context): + current_dir = Path(__file__).parent + file_path = current_dir / "assets" / "example.html" + with open(file_path, 'r', encoding='utf-8') as f: + html_content = f.read() + + with patch('requests.api.get') as mock_requests_get, \ + patch("langchain_experimental.text_splitter.SemanticChunker.split_text") as mock_split_text, \ + patch('langchain_community.vectorstores.mongodb_atlas.MongoDBAtlasVectorSearch.add_documents') as mock_add_documents: + + embedding_generator = EmbeddingGenerator(app_context) + + mock_requests_get.return_value.text = html_content + mock_split_text.return_value = ["chunk1", "chunk2"] + + embedding_generator.generate("http://example.com") + + mock_split_text.assert_called_once() + mock_add_documents.assert_called_once() + embedding_generator.logger.debug.assert_called() diff --git a/tests/src/application/embeddings/hyperlink_parser_test.py b/tests/src/application/embeddings/hyperlink_parser_test.py new file mode 100644 index 0000000..80b37b6 --- /dev/null +++ b/tests/src/application/embeddings/hyperlink_parser_test.py @@ -0,0 +1,27 @@ +from src.application.embeddings.hyperlink_parser import HyperlinkParser + + +def test_initial_hyperlinks_empty(): + parser = HyperlinkParser() + assert not parser.hyperlinks + +def test_handle_starttag_adds_hyperlink(): + parser = HyperlinkParser() + parser.handle_starttag('a', [('href', 'http://example.com')]) + assert parser.hyperlinks == ['http://example.com'] + +def test_handle_starttag_ignores_non_anchor_tags(): + parser = HyperlinkParser() + parser.handle_starttag('div', [('href', 'http://example.com')]) + assert not parser.hyperlinks + +def test_handle_starttag_ignores_anchor_without_href(): + parser = HyperlinkParser() + parser.handle_starttag('a', [('class', 'link')]) + assert not parser.hyperlinks + +def test_handle_starttag_multiple_hyperlinks(): + parser = HyperlinkParser() + parser.handle_starttag('a', [('href', 'http://example1.com')]) + parser.handle_starttag('a', [('href', 'http://example2.com')]) + assert parser.hyperlinks == ['http://example1.com', 'http://example2.com'] \ No newline at end of file From c08d5e73a5e4bce0e6b9ff990a0188357b84d5b0 Mon Sep 17 00:00:00 2001 From: ThisIsDemetrio Date: Fri, 23 Aug 2024 16:29:32 +0200 Subject: [PATCH 5/8] fix: lint --- src/api/controllers/embeddings/embeddings_handler.py | 1 + src/api/schemas/embeddings_schemas.py | 2 +- src/application/embeddings/document_chunker.py | 2 ++ 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/api/controllers/embeddings/embeddings_handler.py b/src/api/controllers/embeddings/embeddings_handler.py index 2f983a6..9653973 100644 --- a/src/api/controllers/embeddings/embeddings_handler.py +++ b/src/api/controllers/embeddings/embeddings_handler.py @@ -32,6 +32,7 @@ def generate_embeddings_from_url(url: str, app_context: AppContext): router.lock = True embedding_generator = EmbeddingGenerator(app_context=app_context) embedding_generator.generate(url) + # pylint: disable=W0718 except Exception as e: logger.error(f"Error in background task: {str(e)}") finally: diff --git a/src/api/schemas/embeddings_schemas.py b/src/api/schemas/embeddings_schemas.py index 08c6d06..fef795f 100644 --- a/src/api/schemas/embeddings_schemas.py +++ b/src/api/schemas/embeddings_schemas.py @@ -10,4 +10,4 @@ class GenerateEmbeddingsOutputSchema(BaseModel): metadata: Dict[str, Any] class GenerateStatusOutputSchema(BaseModel): - status: Literal["running", "idle"] \ No newline at end of file + status: Literal["running", "idle"] diff --git a/src/application/embeddings/document_chunker.py b/src/application/embeddings/document_chunker.py index 88b960c..21ed830 100644 --- a/src/application/embeddings/document_chunker.py +++ b/src/application/embeddings/document_chunker.py @@ -44,4 +44,6 @@ def split_text_into_chunks(self, text: str, url: str) -> List[Document]: document = Document(page_content=content, metadata={"sha": sha, "url": url}) chunks = [Document(page_content=chunk) for chunk in self._chunker.split_text(document.page_content)] + # NOTE: "copy" method actually exists. + # pylint: disable=E1101 return [Document(page_content=chunk.page_content, metadata=document.metadata.copy()) for chunk in chunks] From c58892bbf58d0f6c1c8800182b22c677e0a2c180 Mon Sep 17 00:00:00 2001 From: ThisIsDemetrio Date: Fri, 23 Aug 2024 16:32:42 +0200 Subject: [PATCH 6/8] chore: changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e73ac16..cbb0c46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- add endpoints `POST /embeddings/generate` and `GET /embeddings/status` for embedding generation + ## 0.2.0 - 2024-08-21 ### Updated From 8f3c7960c2dfa2d46a82680b120b0796c93a8671 Mon Sep 17 00:00:00 2001 From: ThisIsDemetrio Date: Fri, 23 Aug 2024 16:50:53 +0200 Subject: [PATCH 7/8] chore: documentation --- .mia-template/README.md | 4 --- README.md | 77 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 4 deletions(-) diff --git a/.mia-template/README.md b/.mia-template/README.md index eefcf72..74c5aff 100644 --- a/.mia-template/README.md +++ b/.mia-template/README.md @@ -139,7 +139,3 @@ docker build . -t ai-rag-template ```sh docker run --env-file ./local.env -p 3000:3000 -d ai-rag-template ``` - -### Try the ai-rag-template - -You can also use the ai-rag-template with a CLI. Please follow the instruction in the [related README file](./scripts/chatbotcli/README.md). diff --git a/README.md b/README.md index 588c62b..17b622d 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,83 @@ curl 'http://localhost:3000/chat/completions' \ +### Generate Embedding Endpoint (`/embeddings/generate`) + +The `/embeddings/generate` endpoint takes as input a web URL and execute the following operation: + +- crawl the webpage +- check for links on the same domain of the webpage and store them in a list +- scrape the page for text +- generate the embeddings using the [configured embedding model](#configuration) +- start again from every link still in the list + +> **NOTE**: +> This method can be run only one at a time, as it uses a lock to prevent multiple requests from starting the process at the same time. +> +> No information are returned when the process ends, either as completed or stopped because of an error. + +***Eg***: + +
+Request + +```curl +curl 'http://localhost:3000/embedding/generation' \ + -H 'content-type: application/json' \ + --data-raw '{"url":"https://docs.mia-platform.eu/"}' +``` + +
+ +
+Response in case the runner is idle + +```json +200 OK +{ + "statusOk": "true" +} +``` +
+ +
+Response in case the runner is runnning + +```json +409 Conflict +{ + "detail": "A process to generate embeddings is already in progress." +} +``` +
+ +### Generation Embedding Status Endpoint (`/embeddings/generate`) + +This request returns to the user information regarding the [embeddings generation runner](#generate-embedding-endpoint-embeddingsgenerate). Could be either `idle` (no process currently running) or `running` (a process of generating embeddings is actually happenning). + +***Eg***: + +
+Request + +```curl +curl 'http://localhost:3000/embedding/status' \ + -H 'content-type: application/json' \ +``` + +
+ +
+Response + +```json +200 OK +{ + "status": "idle" +} +``` +
+ ### Metrics Endpoint (`/-/metrics`) The `/-/metrics` endpoint exposes the metrics collected by Prometheus. From 76e05a2c329b4ba330469ef1391ee6f6e7f3c0e5 Mon Sep 17 00:00:00 2001 From: ThisIsDemetrio Date: Fri, 23 Aug 2024 16:54:21 +0200 Subject: [PATCH 8/8] fix: lint --- tests/src/api/controllers/embedding_handler_test.py | 2 +- tests/src/application/embeddings/hyperlink_parser_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/src/api/controllers/embedding_handler_test.py b/tests/src/api/controllers/embedding_handler_test.py index 9e63ed0..7cd501d 100644 --- a/tests/src/api/controllers/embedding_handler_test.py +++ b/tests/src/api/controllers/embedding_handler_test.py @@ -38,4 +38,4 @@ def test_embeddings_status_running(test_client): assert response.status_code == 200 assert response.json() == {"status": "running"} - router.lock = False # Reset lock for other tests \ No newline at end of file + router.lock = False # Reset lock for other tests diff --git a/tests/src/application/embeddings/hyperlink_parser_test.py b/tests/src/application/embeddings/hyperlink_parser_test.py index 80b37b6..29c7587 100644 --- a/tests/src/application/embeddings/hyperlink_parser_test.py +++ b/tests/src/application/embeddings/hyperlink_parser_test.py @@ -24,4 +24,4 @@ def test_handle_starttag_multiple_hyperlinks(): parser = HyperlinkParser() parser.handle_starttag('a', [('href', 'http://example1.com')]) parser.handle_starttag('a', [('href', 'http://example2.com')]) - assert parser.hyperlinks == ['http://example1.com', 'http://example2.com'] \ No newline at end of file + assert parser.hyperlinks == ['http://example1.com', 'http://example2.com']