Skip to content

Commit

Permalink
immediate display transcribed text added (#472)
Browse files Browse the repository at this point in the history
  • Loading branch information
ishanExtreme authored Feb 20, 2024
1 parent 5df8789 commit fba55aa
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 4 deletions.
1 change: 1 addition & 0 deletions ayushma/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ChatMessage(BaseModel):
original_message = models.TextField(blank=True, null=True)
language = models.CharField(max_length=10, blank=False, default="en")
reference_documents = models.ManyToManyField(Document, blank=True)
# generated ayushma voice audio via TTS
audio = models.FileField(blank=True, null=True)
meta = models.JSONField(blank=True, null=True)
temperature = models.FloatField(blank=True, null=True)
Expand Down
15 changes: 12 additions & 3 deletions ayushma/serializers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ class ConverseSerializer(serializers.Serializer):
stream = serializers.BooleanField(default=True)
generate_audio = serializers.BooleanField(default=True)
noonce = serializers.CharField(required=False)
transcript_start_time = serializers.FloatField(required=False)
transcript_end_time = serializers.FloatField(required=False)


class ChatDetailSerializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -146,9 +148,11 @@ def get_chats(self, obj):
)
return [
{
"messageType": ChatMessageType.USER
if thread_message.role == "user"
else ChatMessageType.AYUSHMA,
"messageType": (
ChatMessageType.USER
if thread_message.role == "user"
else ChatMessageType.AYUSHMA
),
"message": thread_message.content[0].text.value,
"reference_documents": thread_message.content[0].text.annotations,
"language": "en",
Expand All @@ -159,3 +163,8 @@ def get_chats(self, obj):
chatmessages = ChatMessage.objects.filter(chat=obj).order_by("created_at")
context = {"request": self.context.get("request")}
return ChatMessageSerializer(chatmessages, many=True, context=context).data


class SpeechToTextSerializer(serializers.Serializer):
audio = serializers.FileField(required=True)
language = serializers.CharField(default="en")
6 changes: 6 additions & 0 deletions ayushma/utils/converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def converse_api(
audio = request.data.get("audio")
text = request.data.get("text")
language = request.data.get("language") or "en"

try:
service: Service = request.service
except AttributeError:
Expand Down Expand Up @@ -128,6 +129,11 @@ def converse_api(
translated_text = transcript

elif converse_type == "text":
if request.data.get("transcript_start_time") and request.data.get(
"transcript_end_time"
):
stats["transcript_start_time"] = request.data["transcript_start_time"]
stats["transcript_end_time"] = request.data["transcript_end_time"]
translated_text = text

if language != "en":
Expand Down
44 changes: 43 additions & 1 deletion ayushma/views/chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import time

from django.conf import settings
from drf_spectacular.utils import extend_schema
from rest_framework import filters, status
from rest_framework.decorators import action
from rest_framework.decorators import action, api_view, permission_classes
from rest_framework.exceptions import ValidationError
from rest_framework.mixins import (
CreateModelMixin,
Expand All @@ -20,8 +22,10 @@
ChatFeedbackSerializer,
ChatSerializer,
ConverseSerializer,
SpeechToTextSerializer,
)
from ayushma.utils.converse import converse_api
from ayushma.utils.speech_to_text import speech_to_text
from utils.views.base import BaseModelViewSet
from utils.views.mixins import PartialUpdateModelMixin

Expand All @@ -42,6 +46,7 @@ class ChatViewSet(
"retrieve": ChatDetailSerializer,
"list_all": ChatDetailSerializer,
"converse": ConverseSerializer,
"speech_to_text": SpeechToTextSerializer,
}
permission_classes = (IsTempTokenOrAuthenticated,)
lookup_field = "external_id"
Expand Down Expand Up @@ -100,6 +105,43 @@ def list_all(self, *args, **kwarg):
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)

@extend_schema(
tags=("chats",),
)
@action(detail=True, methods=["post"])
def speech_to_text(self, *args, **kwarg):
serializer = self.get_serializer(data=self.request.data)
serializer.is_valid()

project_id = kwarg["project_external_id"]
audio = serializer.validated_data["audio"]
language = serializer.validated_data.get("language", "en")

stats = {}
try:
stt_engine = Project.objects.get(external_id=project_id).stt_engine
except Project.DoesNotExist:
return Response(
{"error": "Project not found"}, status=status.HTTP_400_BAD_REQUEST
)
try:
stats["transcript_start_time"] = time.time()
transcript = speech_to_text(stt_engine, audio, language + "-IN")
stats["transcript_end_time"] = time.time()
translated_text = transcript
except Exception as e:
print(f"Failed to transcribe speech with {stt_engine} engine: {e}")
return Response(
{
"error": "Something went wrong in getting transcription, please try again later"
},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

return Response(
{"transcript": translated_text, "stats": stats}, status=status.HTTP_200_OK
)

@extend_schema(
tags=("chats",),
)
Expand Down

0 comments on commit fba55aa

Please sign in to comment.