Skip to content

Commit

Permalink
Enable support for custom filesystem (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagosalvatore authored Sep 10, 2024
1 parent bca5492 commit 7cb6d06
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
__pycache__/
*.pyc
.DS_Store
.idea
44 changes: 29 additions & 15 deletions llama_parse/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import os
import asyncio
from io import TextIOWrapper

import httpx
import mimetypes
import time
from pathlib import Path
from typing import AsyncGenerator, List, Optional, Union
from pathlib import Path, PurePath, PurePosixPath
from typing import AsyncGenerator, Any, Dict, List, Optional, Union
from contextlib import asynccontextmanager
from io import BufferedIOBase

from fsspec import AbstractFileSystem
from fsspec.spec import AbstractBufferedFile
from llama_index.core.async_utils import run_jobs
from llama_index.core.bridge.pydantic import Field, field_validator
from llama_index.core.constants import DEFAULT_BASE_URL
from llama_index.core.readers.base import BasePydanticReader
from llama_index.core.readers.file.base import get_default_fs
from llama_index.core.schema import Document
from llama_parse.utils import (
nest_asyncio_err,
Expand Down Expand Up @@ -178,7 +183,10 @@ async def client_context(self) -> AsyncGenerator[httpx.AsyncClient, None]:

# upload a document and get back a job_id
async def _create_job(
self, file_input: FileInput, extra_info: Optional[dict] = None
self,
file_input: FileInput,
extra_info: Optional[dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> str:
headers = {"Authorization": f"Bearer {self.api_key}"}
url = f"{self.base_url}/api/parsing/upload"
Expand All @@ -193,7 +201,7 @@ async def _create_job(
file_name = extra_info["file_name"]
mime_type = mimetypes.guess_type(file_name)[0]
files = {"file": (file_name, file_input, mime_type)}
elif isinstance(file_input, (str, Path)):
elif isinstance(file_input, (str, Path, PurePosixPath, PurePath)):
file_path = str(file_input)
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext not in SUPPORTED_FILE_TYPES:
Expand All @@ -203,7 +211,9 @@ async def _create_job(
)
mime_type = mimetypes.guess_type(file_path)[0]
# Open the file here for the duration of the async context
file_handle = open(file_path, "rb")
# load data, set the mime type
fs = fs or get_default_fs()
file_handle = fs.open(file_input, "rb")
files = {"file": (os.path.basename(file_path), file_handle, mime_type)}
else:
raise ValueError(
Expand Down Expand Up @@ -259,9 +269,15 @@ async def _create_job(
if file_handle is not None:
file_handle.close()

@staticmethod
def __get_filename(f: Union[TextIOWrapper, AbstractBufferedFile]) -> str:
if isinstance(f, TextIOWrapper):
return f.name
return f.full_name

async def _get_job_result(
self, job_id: str, result_type: str, verbose: bool = False
) -> dict:
) -> Dict[str, Any]:
result_url = f"{self.base_url}/api/parsing/job/{job_id}/result/{result_type}"
status_url = f"{self.base_url}/api/parsing/job/{job_id}"
headers = {"Authorization": f"Bearer {self.api_key}"}
Expand Down Expand Up @@ -300,21 +316,16 @@ async def _get_job_result(

await asyncio.sleep(self.check_interval)

continue
else:
raise Exception(
f"Failed to parse the file: {job_id}, status: {status}"
)

async def _aload_data(
self,
file_path: FileInput,
extra_info: Optional[dict] = None,
fs: Optional[AbstractFileSystem] = None,
verbose: bool = False,
) -> List[Document]:
"""Load data from the input path."""
try:
job_id = await self._create_job(file_path, extra_info=extra_info)
job_id = await self._create_job(file_path, extra_info=extra_info, fs=fs)
if verbose:
print("Started parsing the file under job_id %s" % job_id)

Expand Down Expand Up @@ -345,17 +356,19 @@ async def aload_data(
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> List[Document]:
"""Load data from the input path."""
if isinstance(file_path, (str, Path, bytes, BufferedIOBase)):
return await self._aload_data(
file_path, extra_info=extra_info, verbose=self.verbose
file_path, extra_info=extra_info, fs=fs, verbose=self.verbose
)
elif isinstance(file_path, list):
jobs = [
self._aload_data(
f,
extra_info=extra_info,
fs=fs,
verbose=self.verbose and not self.show_progress,
)
for f in file_path
Expand Down Expand Up @@ -384,10 +397,11 @@ def load_data(
self,
file_path: Union[List[FileInput], FileInput],
extra_info: Optional[dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> List[Document]:
"""Load data from the input path."""
try:
return asyncio.run(self.aload_data(file_path, extra_info))
return asyncio.run(self.aload_data(file_path, extra_info, fs=fs))
except RuntimeError as e:
if nest_asyncio_err in str(e):
raise RuntimeError(nest_asyncio_msg)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "llama-parse"
version = "0.5.3"
version = "0.5.4"
description = "Parse files into RAG-Optimized formats."
authors = ["Logan Markewich <[email protected]>"]
license = "MIT"
Expand Down
16 changes: 16 additions & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import pytest
from fsspec.implementations.local import LocalFileSystem
from httpx import AsyncClient

from llama_parse import LlamaParse


Expand Down Expand Up @@ -70,6 +72,20 @@ def test_simple_page_markdown_buffer(markdown_parser: LlamaParse) -> None:
assert len(result[0].text) > 0


@pytest.mark.skipif(
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
reason="LLAMA_CLOUD_API_KEY not set",
)
def test_simple_page_with_custom_fs() -> None:
parser = LlamaParse(result_type="markdown")
fs = LocalFileSystem()
filepath = os.path.join(
os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf"
)
result = parser.load_data(filepath, fs=fs)
assert len(result) == 1


@pytest.mark.skipif(
os.environ.get("LLAMA_CLOUD_API_KEY", "") == "",
reason="LLAMA_CLOUD_API_KEY not set",
Expand Down

0 comments on commit 7cb6d06

Please sign in to comment.