Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: embedding generation from web scraping #16

Merged
merged 8 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .mia-template/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,3 @@ docker build . -t ai-rag-template
```sh
docker run --env-file ./local.env -p 3000:3000 -d ai-rag-template
```

### Try the ai-rag-template

You can also use the ai-rag-template with a CLI. Please follow the instruction in the [related README file](./scripts/chatbotcli/README.md).
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

- add endpoints `POST /embeddings/generate` and `GET /embeddings/status` for embedding generation

## 0.2.0 - 2024-08-21

### Updated
Expand Down
77 changes: 77 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,83 @@ curl 'http://localhost:3000/chat/completions' \

</details>

### Generate Embedding Endpoint (`/embeddings/generate`)

The `/embeddings/generate` endpoint takes as input a web URL and execute the following operation:

- crawl the webpage
- check for links on the same domain of the webpage and store them in a list
- scrape the page for text
- generate the embeddings using the [configured embedding model](#configuration)
- start again from every link still in the list

> **NOTE**:
> This method can be run only one at a time, as it uses a lock to prevent multiple requests from starting the process at the same time.
>
> No information are returned when the process ends, either as completed or stopped because of an error.

***Eg***:

<details>
<summary>Request</summary>

```curl
curl 'http://localhost:3000/embedding/generation' \
-H 'content-type: application/json' \
--data-raw '{"url":"https://docs.mia-platform.eu/"}'
```

</details>

<details>
<summary>Response in case the runner is idle</summary>

```json
200 OK
{
"statusOk": "true"
}
```
</details>

<details>
<summary>Response in case the runner is runnning</summary>

```json
409 Conflict
{
"detail": "A process to generate embeddings is already in progress."
}
```
</details>

### Generation Embedding Status Endpoint (`/embeddings/generate`)

This request returns to the user information regarding the [embeddings generation runner](#generate-embedding-endpoint-embeddingsgenerate). Could be either `idle` (no process currently running) or `running` (a process of generating embeddings is actually happenning).

***Eg***:

<details>
<summary>Request</summary>

```curl
curl 'http://localhost:3000/embedding/status' \
-H 'content-type: application/json' \
```

</details>

<details>
<summary>Response</summary>

```json
200 OK
{
"status": "idle"
}
```
</details>

### Metrics Endpoint (`/-/metrics`)

The `/-/metrics` endpoint exposes the metrics collected by Prometheus.
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ argcomplete==3.3.0
astroid==3.1.0
async-timeout==4.0.3
attrs==23.2.0
beautifulsoup4==4.12.3
black==24.4.0
certifi==2024.2.2
cfgv==3.4.0
Expand Down Expand Up @@ -93,6 +94,7 @@ rpds-py==0.18.0
setuptools==73.0.0
six==1.16.0
sniffio==1.3.1
soupsieve==2.6
SQLAlchemy==2.0.29
starlette==0.37.2
tavily-python==0.3.3
Expand Down
84 changes: 84 additions & 0 deletions src/api/controllers/embeddings/embeddings_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status

from src.api.schemas.status_ok_schema import StatusOkResponseSchema
from src.application.embeddings.embedding_generator import EmbeddingGenerator
from src.api.schemas.embeddings_schemas import GenerateEmbeddingsInputSchema, GenerateStatusOutputSchema
from src.context import AppContext

router = APIRouter()

# This is a simplified mechanism to prevent multiple requests from generating embeddings at the same time.
# For this specific router, we designed it to only allow one request to generate embeddings at a time, therefore
# we use a lock to prevent multiple requests from starting the process at the same time.
#
# In case of need of multiple methods that requires locking or very long tasks,
# you might want to use a more sophisticated mechanism to handle this.
router.lock = False

def generate_embeddings_from_url(url: str, app_context: AppContext):
"""
Generate embeddings for a given URL.

This method is intended to be called as a background task. Includes managmeent of the lock mechanism
of this router, which is locked when the embedding generation process is running, and unlocked when it finishes.

Args:
url (str): The URL to generate embeddings from.
app_context (AppContext): The application context.
"""
logger = app_context.logger

try:
router.lock = True
embedding_generator = EmbeddingGenerator(app_context=app_context)
embedding_generator.generate(url)
# pylint: disable=W0718
except Exception as e:
logger.error(f"Error in background task: {str(e)}")
finally:
router.lock = False

@router.post(
"/embeddings/generate",
response_model=StatusOkResponseSchema,
status_code=status.HTTP_200_OK,
tags=["Embeddings"]
)
def generate_embeddings(request: Request, data: GenerateEmbeddingsInputSchema, background_tasks: BackgroundTasks):
"""
Generate embeddings for a given URL.

This method can be run only one at a time, as it uses a lock to prevent multiple requests from starting the process at the same time.
If a process is already in progress, it will return a 409 status code (Conflict).

Args:
request (Request): The request object.
data (GenerateEmbeddingsInputSchema): The input schema.
background_tasks (BackgroundTasks): The background tasks object.
"""

request_context: AppContext = request.state.app_context
url = data.url
request_context.logger.info(f"Generate embeddings request received for url: {url}")

if not router.lock:
background_tasks.add_task(generate_embeddings_from_url, url, request_context)
request_context.logger.info("Generation embeddings process started.")
return {"statusOk": True}

raise HTTPException(status_code=409, detail="A process to generate embeddings is already in progress.")

@router.get(
"/embeddings/status",
response_model=GenerateStatusOutputSchema,
status_code=status.HTTP_200_OK,
tags=["Embeddings"]
)
def embeddings_status():
"""
Get the status of the embeddings generation process.

Returns:
dict: A StatusOkResponseSchema responding _True_ if there are no process in progress. Otherwise, it will return _False_.
"""
return {"status": "running" if router.lock else "idle"}
13 changes: 13 additions & 0 deletions src/api/schemas/embeddings_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Any, Dict, Literal
from pydantic import BaseModel


class GenerateEmbeddingsInputSchema(BaseModel):
url: str

class GenerateEmbeddingsOutputSchema(BaseModel):
state: str
metadata: Dict[str, Any]

class GenerateStatusOutputSchema(BaseModel):
status: Literal["running", "idle"]
2 changes: 2 additions & 0 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from src.api.controllers.core.liveness import liveness_handler
from src.api.controllers.core.readiness import readiness_handler
from src.api.controllers.core.metrics import metrics_handler
from src.api.controllers.embeddings import embeddings_handler
from src.api.middlewares.app_context_middleware import AppContextMiddleware
from src.api.middlewares.logger_middleware import LoggerMiddleware
from src.configurations.configuration import get_configuration
Expand All @@ -32,6 +33,7 @@ def create_app(context: AppContext) -> FastAPI:
app.include_router(metrics_handler.router)

app.include_router(chat_completions_handler.router)
app.include_router(embeddings_handler.router)

return app

Expand Down
49 changes: 49 additions & 0 deletions src/application/embeddings/document_chunker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Module to include the EmbeddingGenerator class, a class that generates embeddings for text data.
"""

import hashlib
from typing import List
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_experimental.text_splitter import SemanticChunker

DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"


class DocumentChunker():
"""
Initialize the DocumentChunker class.
"""

def __init__(self, embedding: Embeddings) -> None:
self._chunker = SemanticChunker(embeddings=embedding, breakpoint_threshold_type='percentile')

def _remove_consecutive_newlines(self, text: str) -> str:
"""
Remove duplicate newlines from the text.
"""
return "\n".join([line for line in text.split("\n\n") if line.strip()])

def _generate_sha(self, content: str) -> str:
"""
Generate a SHA hash for the given content.
"""
return hashlib.sha256(content.encode()).hexdigest()

def split_text_into_chunks(self, text: str, url: str) -> List[Document]:
"""
Generate chunks via semantic separation from a given text

Args:
text (str): The input text.
url (str): The URL of the text.
"""
content = self._remove_consecutive_newlines(text)
sha = self._generate_sha(content)

document = Document(page_content=content, metadata={"sha": sha, "url": url})
chunks = [Document(page_content=chunk) for chunk in self._chunker.split_text(document.page_content)]
# NOTE: "copy" method actually exists.
# pylint: disable=E1101
return [Document(page_content=chunk.page_content, metadata=document.metadata.copy()) for chunk in chunks]
Loading