diff --git a/.mia-template/README.md b/.mia-template/README.md index eefcf72..74c5aff 100644 --- a/.mia-template/README.md +++ b/.mia-template/README.md @@ -139,7 +139,3 @@ docker build . -t ai-rag-template ```sh docker run --env-file ./local.env -p 3000:3000 -d ai-rag-template ``` - -### Try the ai-rag-template - -You can also use the ai-rag-template with a CLI. Please follow the instruction in the [related README file](./scripts/chatbotcli/README.md). diff --git a/CHANGELOG.md b/CHANGELOG.md index e73ac16..cbb0c46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- add endpoints `POST /embeddings/generate` and `GET /embeddings/status` for embedding generation + ## 0.2.0 - 2024-08-21 ### Updated diff --git a/README.md b/README.md index 588c62b..17b622d 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,83 @@ curl 'http://localhost:3000/chat/completions' \ +### Generate Embedding Endpoint (`/embeddings/generate`) + +The `/embeddings/generate` endpoint takes as input a web URL and execute the following operation: + +- crawl the webpage +- check for links on the same domain of the webpage and store them in a list +- scrape the page for text +- generate the embeddings using the [configured embedding model](#configuration) +- start again from every link still in the list + +> **NOTE**: +> This method can be run only one at a time, as it uses a lock to prevent multiple requests from starting the process at the same time. +> +> No information are returned when the process ends, either as completed or stopped because of an error. + +***Eg***: + +
+Request + +```curl +curl 'http://localhost:3000/embedding/generation' \ + -H 'content-type: application/json' \ + --data-raw '{"url":"https://docs.mia-platform.eu/"}' +``` + +
+ +
+Response in case the runner is idle + +```json +200 OK +{ + "statusOk": "true" +} +``` +
+ +
+Response in case the runner is runnning + +```json +409 Conflict +{ + "detail": "A process to generate embeddings is already in progress." +} +``` +
+ +### Generation Embedding Status Endpoint (`/embeddings/generate`) + +This request returns to the user information regarding the [embeddings generation runner](#generate-embedding-endpoint-embeddingsgenerate). Could be either `idle` (no process currently running) or `running` (a process of generating embeddings is actually happenning). + +***Eg***: + +
+Request + +```curl +curl 'http://localhost:3000/embedding/status' \ + -H 'content-type: application/json' \ +``` + +
+ +
+Response + +```json +200 OK +{ + "status": "idle" +} +``` +
+ ### Metrics Endpoint (`/-/metrics`) The `/-/metrics` endpoint exposes the metrics collected by Prometheus. diff --git a/requirements.txt b/requirements.txt index 87f3560..ebf0122 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ argcomplete==3.3.0 astroid==3.1.0 async-timeout==4.0.3 attrs==23.2.0 +beautifulsoup4==4.12.3 black==24.4.0 certifi==2024.2.2 cfgv==3.4.0 @@ -93,6 +94,7 @@ rpds-py==0.18.0 setuptools==73.0.0 six==1.16.0 sniffio==1.3.1 +soupsieve==2.6 SQLAlchemy==2.0.29 starlette==0.37.2 tavily-python==0.3.3 diff --git a/src/api/controllers/embeddings/embeddings_handler.py b/src/api/controllers/embeddings/embeddings_handler.py new file mode 100644 index 0000000..9653973 --- /dev/null +++ b/src/api/controllers/embeddings/embeddings_handler.py @@ -0,0 +1,84 @@ +from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status + +from src.api.schemas.status_ok_schema import StatusOkResponseSchema +from src.application.embeddings.embedding_generator import EmbeddingGenerator +from src.api.schemas.embeddings_schemas import GenerateEmbeddingsInputSchema, GenerateStatusOutputSchema +from src.context import AppContext + +router = APIRouter() + +# This is a simplified mechanism to prevent multiple requests from generating embeddings at the same time. +# For this specific router, we designed it to only allow one request to generate embeddings at a time, therefore +# we use a lock to prevent multiple requests from starting the process at the same time. +# +# In case of need of multiple methods that requires locking or very long tasks, +# you might want to use a more sophisticated mechanism to handle this. +router.lock = False + +def generate_embeddings_from_url(url: str, app_context: AppContext): + """ + Generate embeddings for a given URL. + + This method is intended to be called as a background task. Includes managmeent of the lock mechanism + of this router, which is locked when the embedding generation process is running, and unlocked when it finishes. + + Args: + url (str): The URL to generate embeddings from. + app_context (AppContext): The application context. + """ + logger = app_context.logger + + try: + router.lock = True + embedding_generator = EmbeddingGenerator(app_context=app_context) + embedding_generator.generate(url) + # pylint: disable=W0718 + except Exception as e: + logger.error(f"Error in background task: {str(e)}") + finally: + router.lock = False + +@router.post( + "/embeddings/generate", + response_model=StatusOkResponseSchema, + status_code=status.HTTP_200_OK, + tags=["Embeddings"] +) +def generate_embeddings(request: Request, data: GenerateEmbeddingsInputSchema, background_tasks: BackgroundTasks): + """ + Generate embeddings for a given URL. + + This method can be run only one at a time, as it uses a lock to prevent multiple requests from starting the process at the same time. + If a process is already in progress, it will return a 409 status code (Conflict). + + Args: + request (Request): The request object. + data (GenerateEmbeddingsInputSchema): The input schema. + background_tasks (BackgroundTasks): The background tasks object. + """ + + request_context: AppContext = request.state.app_context + url = data.url + request_context.logger.info(f"Generate embeddings request received for url: {url}") + + if not router.lock: + background_tasks.add_task(generate_embeddings_from_url, url, request_context) + request_context.logger.info("Generation embeddings process started.") + return {"statusOk": True} + + raise HTTPException(status_code=409, detail="A process to generate embeddings is already in progress.") + +@router.get( + "/embeddings/status", + response_model=GenerateStatusOutputSchema, + status_code=status.HTTP_200_OK, + tags=["Embeddings"] +) +def embeddings_status(): + """ + Get the status of the embeddings generation process. + + Returns: + dict: A StatusOkResponseSchema responding _True_ if there are no process in progress. Otherwise, it will return _False_. + """ + return {"status": "running" if router.lock else "idle"} diff --git a/src/api/schemas/embeddings_schemas.py b/src/api/schemas/embeddings_schemas.py new file mode 100644 index 0000000..fef795f --- /dev/null +++ b/src/api/schemas/embeddings_schemas.py @@ -0,0 +1,13 @@ +from typing import Any, Dict, Literal +from pydantic import BaseModel + + +class GenerateEmbeddingsInputSchema(BaseModel): + url: str + +class GenerateEmbeddingsOutputSchema(BaseModel): + state: str + metadata: Dict[str, Any] + +class GenerateStatusOutputSchema(BaseModel): + status: Literal["running", "idle"] diff --git a/src/app.py b/src/app.py index 171bcdf..09ac5c8 100644 --- a/src/app.py +++ b/src/app.py @@ -6,6 +6,7 @@ from src.api.controllers.core.liveness import liveness_handler from src.api.controllers.core.readiness import readiness_handler from src.api.controllers.core.metrics import metrics_handler +from src.api.controllers.embeddings import embeddings_handler from src.api.middlewares.app_context_middleware import AppContextMiddleware from src.api.middlewares.logger_middleware import LoggerMiddleware from src.configurations.configuration import get_configuration @@ -32,6 +33,7 @@ def create_app(context: AppContext) -> FastAPI: app.include_router(metrics_handler.router) app.include_router(chat_completions_handler.router) + app.include_router(embeddings_handler.router) return app diff --git a/src/application/embeddings/document_chunker.py b/src/application/embeddings/document_chunker.py new file mode 100644 index 0000000..21ed830 --- /dev/null +++ b/src/application/embeddings/document_chunker.py @@ -0,0 +1,49 @@ +""" +Module to include the EmbeddingGenerator class, a class that generates embeddings for text data. +""" + +import hashlib +from typing import List +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_experimental.text_splitter import SemanticChunker + +DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small" + + +class DocumentChunker(): + """ + Initialize the DocumentChunker class. + """ + + def __init__(self, embedding: Embeddings) -> None: + self._chunker = SemanticChunker(embeddings=embedding, breakpoint_threshold_type='percentile') + + def _remove_consecutive_newlines(self, text: str) -> str: + """ + Remove duplicate newlines from the text. + """ + return "\n".join([line for line in text.split("\n\n") if line.strip()]) + + def _generate_sha(self, content: str) -> str: + """ + Generate a SHA hash for the given content. + """ + return hashlib.sha256(content.encode()).hexdigest() + + def split_text_into_chunks(self, text: str, url: str) -> List[Document]: + """ + Generate chunks via semantic separation from a given text + + Args: + text (str): The input text. + url (str): The URL of the text. + """ + content = self._remove_consecutive_newlines(text) + sha = self._generate_sha(content) + + document = Document(page_content=content, metadata={"sha": sha, "url": url}) + chunks = [Document(page_content=chunk) for chunk in self._chunker.split_text(document.page_content)] + # NOTE: "copy" method actually exists. + # pylint: disable=E1101 + return [Document(page_content=chunk.page_content, metadata=document.metadata.copy()) for chunk in chunks] diff --git a/src/application/embeddings/embedding_generator.py b/src/application/embeddings/embedding_generator.py new file mode 100644 index 0000000..26b1331 --- /dev/null +++ b/src/application/embeddings/embedding_generator.py @@ -0,0 +1,162 @@ +""" +This script crawls a website and saves the embeddings extracted from the text of each page to a text file. +""" +import re +import urllib.request +from collections import deque +from urllib.error import URLError +from urllib.parse import urlparse + +from langchain_openai import OpenAIEmbeddings +import requests +from bs4 import BeautifulSoup + +from langchain_community.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch +from src.context import AppContext +from src.application.embeddings.document_chunker import DocumentChunker +from src.application.embeddings.hyperlink_parser import HyperlinkParser + +# Regex pattern to match a URL +HTTP_URL_PATTERN = r"^http[s]*://.+" + +class EmbeddingGenerator(): + """ + Class to generate embeddings for text data. + """ + + def __init__(self, app_context: AppContext): + self.logger = app_context.logger + mongodb_cluster_uri = app_context.env_vars.MONGODB_CLUSTER_URI + embedding_api_key = app_context.env_vars.EMBEDDINGS_API_KEY + configuration = app_context.configurations + + embedding = OpenAIEmbeddings(openai_api_key=embedding_api_key, model=configuration.embeddings.name) + + self._document_chunker = DocumentChunker(embedding=embedding) + + self._embedding_vector_store = MongoDBAtlasVectorSearch.from_connection_string( + connection_string=mongodb_cluster_uri, + namespace=f"{configuration.vectorStore.dbName}.{configuration.vectorStore.collectionName}", + embedding=embedding, + index_name=configuration.vectorStore.indexName, + embedding_key=configuration.vectorStore.embeddingKey, + relevance_score_fn=configuration.vectorStore.relevanceScoreFn, + text_key=configuration.vectorStore.textKey + ) + + + def _get_hyperlinks(self, url): + """ + Function to get the hyperlinks from a URL + """ + + try: + # Open the URL and read the HTML + with urllib.request.urlopen(url) as response: + # If the response is not HTML, return an empty list + if not response.info().get("Content-Type").startswith("text/html"): + return [] + + html = response.read().decode("utf-8") + except URLError as e: + self.logger.error(e) + return [] + + # Create the HTML Parser and then Parse the HTML to get hyperlinks + parser = HyperlinkParser() + parser.feed(html) + + return parser.hyperlinks + + + def _get_domain_hyperlinks(self, local_domain: str, url: str): + """ + Function to get the hyperlinks from a URL that are within the same domain + + Args: + local_domain (str): The domain to compare the hyperlinks against. + url (str): The URL to extract hyperlinks from. + + Returns: + list: A list of hyperlinks that are within the same domain. + """ + clean_links = [] + for link in set(self._get_hyperlinks(url)): + clean_link = None + + # If the link is a URL, check if it is within the same domain + if re.search(HTTP_URL_PATTERN, link): + # Parse the URL and check if the domain is the same + url_obj = urlparse(link) + # Link should be within the same domain and should start with one of the paths + if url_obj.netloc == local_domain: + clean_link = link + + # If the link is not a URL, check if it is a relative link + else: + if link.startswith("/"): + link = link[1:] + elif link.startswith("#") or link.startswith("mailto:"): + continue + clean_link = "https://" + local_domain + "/" + link + + if clean_link is not None: + if clean_link.endswith("/"): + clean_link = clean_link[:-1] + clean_links.append(clean_link) + + return list(set(clean_links)) + + def generate(self, url: str): + """ + Crawls the given URL and saves the text content of each page to a text file. + + Args: + url (str): The URL to crawl. From this URL, the crawler will extract the text content + of said page and any other page connected via hyperlinks (anchor tags). + + Returns: + None + """ + + local_domain = urlparse(url).netloc + + queue = deque([url]) + seen = set([url]) + + # While the queue is not empty, continue crawling + while queue: + # Get the next URL from the queue + url = queue.pop() + self.logger.debug(f"Scraping page: {url}") # for debugging and to see the progress + + # Get the text from the URL using BeautifulSoup + soup = BeautifulSoup(requests.get(url, timeout=5).text, "html.parser") + + # Get the text but remove the tags + text = soup.get_text() + + # If the crawler gets to a page that requires JavaScript, it will stop the crawl + if "You need to enable JavaScript to run this app." in text: + self.logger.debug( + f"Unable to parse page {url} due to JavaScript being required" + ) + continue + + chunks = self._document_chunker.split_text_into_chunks(text=text, url=url) + self.logger.debug(f"Extracted {len(chunks)} chunks from the page. Generated embeddings for these...") # for debugging and to see the progress + self._embedding_vector_store.add_documents(chunks) + self.logger.debug("Embeddings generation completed. Extracting links...") # for debugging and to see the progress + + hyperlinks = self._get_domain_hyperlinks(local_domain, url) + if len(hyperlinks) == 0: + self.logger.debug("No links found, move on.") + + # Get the hyperlinks from the URL and add them to the queue + for link in hyperlinks: + if link not in seen: + self.logger.debug(f"Found new link: {link}") + queue.append(link) + seen.add(link) + + self.logger.debug("Scraping completed.") diff --git a/src/application/embeddings/hyperlink_parser.py b/src/application/embeddings/hyperlink_parser.py new file mode 100644 index 0000000..62c0acf --- /dev/null +++ b/src/application/embeddings/hyperlink_parser.py @@ -0,0 +1,30 @@ +""" +Module providing the HyperlinkParser class. +""" + +from html.parser import HTMLParser + + +class HyperlinkParser(HTMLParser): + """ + A class that parses HTML and extracts hyperlinks. + + Attributes: + hyperlinks (list): A list to store the extracted hyperlinks. + + Methods: + handle_starttag(tag, attrs): Overrides the HTMLParser's handle_starttag method to extract hyperlinks. + """ + + def __init__(self): + super().__init__() + # Create a list to store the hyperlinks + self.hyperlinks = [] + + # Override the HTMLParser's handle_starttag method to get the hyperlinks + def handle_starttag(self, tag, attrs): + attrs = dict(attrs) + + # If the tag is an anchor tag and it has an href attribute, add the href attribute to the list of hyperlinks + if tag == "a" and "href" in attrs: + self.hyperlinks.append(attrs["href"]) diff --git a/tests/src/api/controllers/embedding_handler_test.py b/tests/src/api/controllers/embedding_handler_test.py new file mode 100644 index 0000000..7cd501d --- /dev/null +++ b/tests/src/api/controllers/embedding_handler_test.py @@ -0,0 +1,41 @@ +from unittest.mock import patch + +from src.api.controllers.embeddings.embeddings_handler import router + + +def test_generate_embeddings_success(test_client): + url = "http://example.com" + data = {"url": url} + + with patch("src.api.controllers.embeddings.embeddings_handler.EmbeddingGenerator.generate") as mock_generate: + response = test_client.post("/embeddings/generate", json=data) + + assert response.status_code == 200 + assert response.json() == {"statusOk": True} + mock_generate.assert_called_once_with(url) + +def test_generate_embeddings_conflict(test_client): + url = "http://example.com" + data = {"url": url} + + router.lock = True # Simulate a process already in progress + response = test_client.post("/embeddings/generate", json=data) + + assert response.status_code == 409 + assert response.json() == {"detail": "A process to generate embeddings is already in progress."} + router.lock = False # Reset lock for other tests + +def test_embeddings_status_idle(test_client): + router.lock = False # Ensure no process is running + response = test_client.get("/embeddings/status") + + assert response.status_code == 200 + assert response.json() == {"status": "idle"} + +def test_embeddings_status_running(test_client): + router.lock = True # Simulate a process running + response = test_client.get("/embeddings/status") + + assert response.status_code == 200 + assert response.json() == {"status": "running"} + router.lock = False # Reset lock for other tests diff --git a/tests/src/application/embeddings/assets/example.html b/tests/src/application/embeddings/assets/example.html new file mode 100644 index 0000000..5da27b3 --- /dev/null +++ b/tests/src/application/embeddings/assets/example.html @@ -0,0 +1,10 @@ + + + + Lorem Ipsum Example + + +

Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.

+

Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.

+ + \ No newline at end of file diff --git a/tests/src/application/embeddings/document_chunker_test.py b/tests/src/application/embeddings/document_chunker_test.py new file mode 100644 index 0000000..8178583 --- /dev/null +++ b/tests/src/application/embeddings/document_chunker_test.py @@ -0,0 +1,16 @@ +from unittest.mock import patch +from langchain_openai import OpenAIEmbeddings +from src.application.embeddings.document_chunker import DocumentChunker + + +def test_split_text_into_chunks(): + with patch("langchain_experimental.text_splitter.SemanticChunker.split_text") as mock_split_text: + mock_split_text.return_value = ["This is a test.", "this is another test."] + embedding = OpenAIEmbeddings(model="text-embedding-3-small", openai_api_key="embeddings_api_key") + document_chunker = DocumentChunker(embedding) + text = "This is a test. This is another test." + url = "http://example.com" + chunks = document_chunker.split_text_into_chunks(text, url) + + assert mock_split_text.call_count == 1 + assert len(chunks) == 2 diff --git a/tests/src/application/embeddings/embedding_generator_test.py b/tests/src/application/embeddings/embedding_generator_test.py new file mode 100644 index 0000000..7fbfdbc --- /dev/null +++ b/tests/src/application/embeddings/embedding_generator_test.py @@ -0,0 +1,26 @@ +from pathlib import Path +from unittest.mock import patch + +from src.application.embeddings.embedding_generator import EmbeddingGenerator + + +def test_generate(app_context): + current_dir = Path(__file__).parent + file_path = current_dir / "assets" / "example.html" + with open(file_path, 'r', encoding='utf-8') as f: + html_content = f.read() + + with patch('requests.api.get') as mock_requests_get, \ + patch("langchain_experimental.text_splitter.SemanticChunker.split_text") as mock_split_text, \ + patch('langchain_community.vectorstores.mongodb_atlas.MongoDBAtlasVectorSearch.add_documents') as mock_add_documents: + + embedding_generator = EmbeddingGenerator(app_context) + + mock_requests_get.return_value.text = html_content + mock_split_text.return_value = ["chunk1", "chunk2"] + + embedding_generator.generate("http://example.com") + + mock_split_text.assert_called_once() + mock_add_documents.assert_called_once() + embedding_generator.logger.debug.assert_called() diff --git a/tests/src/application/embeddings/hyperlink_parser_test.py b/tests/src/application/embeddings/hyperlink_parser_test.py new file mode 100644 index 0000000..29c7587 --- /dev/null +++ b/tests/src/application/embeddings/hyperlink_parser_test.py @@ -0,0 +1,27 @@ +from src.application.embeddings.hyperlink_parser import HyperlinkParser + + +def test_initial_hyperlinks_empty(): + parser = HyperlinkParser() + assert not parser.hyperlinks + +def test_handle_starttag_adds_hyperlink(): + parser = HyperlinkParser() + parser.handle_starttag('a', [('href', 'http://example.com')]) + assert parser.hyperlinks == ['http://example.com'] + +def test_handle_starttag_ignores_non_anchor_tags(): + parser = HyperlinkParser() + parser.handle_starttag('div', [('href', 'http://example.com')]) + assert not parser.hyperlinks + +def test_handle_starttag_ignores_anchor_without_href(): + parser = HyperlinkParser() + parser.handle_starttag('a', [('class', 'link')]) + assert not parser.hyperlinks + +def test_handle_starttag_multiple_hyperlinks(): + parser = HyperlinkParser() + parser.handle_starttag('a', [('href', 'http://example1.com')]) + parser.handle_starttag('a', [('href', 'http://example2.com')]) + assert parser.hyperlinks == ['http://example1.com', 'http://example2.com']