From fba55aa2216d658855b20a396f41b0e0ab0c1c02 Mon Sep 17 00:00:00 2001 From: Ishan Mishra Date: Tue, 20 Feb 2024 13:39:01 +0530 Subject: [PATCH] immediate display transcribed text added (#472) --- ayushma/models/chat.py | 1 + ayushma/serializers/chat.py | 15 ++++++++++--- ayushma/utils/converse.py | 6 +++++ ayushma/views/chat.py | 44 ++++++++++++++++++++++++++++++++++++- 4 files changed, 62 insertions(+), 4 deletions(-) diff --git a/ayushma/models/chat.py b/ayushma/models/chat.py index c83d005c..425633ca 100644 --- a/ayushma/models/chat.py +++ b/ayushma/models/chat.py @@ -32,6 +32,7 @@ class ChatMessage(BaseModel): original_message = models.TextField(blank=True, null=True) language = models.CharField(max_length=10, blank=False, default="en") reference_documents = models.ManyToManyField(Document, blank=True) + # generated ayushma voice audio via TTS audio = models.FileField(blank=True, null=True) meta = models.JSONField(blank=True, null=True) temperature = models.FloatField(blank=True, null=True) diff --git a/ayushma/serializers/chat.py b/ayushma/serializers/chat.py index ec713884..e133e3cd 100644 --- a/ayushma/serializers/chat.py +++ b/ayushma/serializers/chat.py @@ -110,6 +110,8 @@ class ConverseSerializer(serializers.Serializer): stream = serializers.BooleanField(default=True) generate_audio = serializers.BooleanField(default=True) noonce = serializers.CharField(required=False) + transcript_start_time = serializers.FloatField(required=False) + transcript_end_time = serializers.FloatField(required=False) class ChatDetailSerializer(serializers.ModelSerializer): @@ -146,9 +148,11 @@ def get_chats(self, obj): ) return [ { - "messageType": ChatMessageType.USER - if thread_message.role == "user" - else ChatMessageType.AYUSHMA, + "messageType": ( + ChatMessageType.USER + if thread_message.role == "user" + else ChatMessageType.AYUSHMA + ), "message": thread_message.content[0].text.value, "reference_documents": thread_message.content[0].text.annotations, "language": "en", @@ -159,3 +163,8 @@ def get_chats(self, obj): chatmessages = ChatMessage.objects.filter(chat=obj).order_by("created_at") context = {"request": self.context.get("request")} return ChatMessageSerializer(chatmessages, many=True, context=context).data + + +class SpeechToTextSerializer(serializers.Serializer): + audio = serializers.FileField(required=True) + language = serializers.CharField(default="en") diff --git a/ayushma/utils/converse.py b/ayushma/utils/converse.py index 2dacb0b1..19ea4aca 100644 --- a/ayushma/utils/converse.py +++ b/ayushma/utils/converse.py @@ -33,6 +33,7 @@ def converse_api( audio = request.data.get("audio") text = request.data.get("text") language = request.data.get("language") or "en" + try: service: Service = request.service except AttributeError: @@ -128,6 +129,11 @@ def converse_api( translated_text = transcript elif converse_type == "text": + if request.data.get("transcript_start_time") and request.data.get( + "transcript_end_time" + ): + stats["transcript_start_time"] = request.data["transcript_start_time"] + stats["transcript_end_time"] = request.data["transcript_end_time"] translated_text = text if language != "en": diff --git a/ayushma/views/chat.py b/ayushma/views/chat.py index 7a1cceaf..c6f72163 100644 --- a/ayushma/views/chat.py +++ b/ayushma/views/chat.py @@ -1,7 +1,9 @@ +import time + from django.conf import settings from drf_spectacular.utils import extend_schema from rest_framework import filters, status -from rest_framework.decorators import action +from rest_framework.decorators import action, api_view, permission_classes from rest_framework.exceptions import ValidationError from rest_framework.mixins import ( CreateModelMixin, @@ -20,8 +22,10 @@ ChatFeedbackSerializer, ChatSerializer, ConverseSerializer, + SpeechToTextSerializer, ) from ayushma.utils.converse import converse_api +from ayushma.utils.speech_to_text import speech_to_text from utils.views.base import BaseModelViewSet from utils.views.mixins import PartialUpdateModelMixin @@ -42,6 +46,7 @@ class ChatViewSet( "retrieve": ChatDetailSerializer, "list_all": ChatDetailSerializer, "converse": ConverseSerializer, + "speech_to_text": SpeechToTextSerializer, } permission_classes = (IsTempTokenOrAuthenticated,) lookup_field = "external_id" @@ -100,6 +105,43 @@ def list_all(self, *args, **kwarg): serializer = self.get_serializer(queryset, many=True) return Response(serializer.data) + @extend_schema( + tags=("chats",), + ) + @action(detail=True, methods=["post"]) + def speech_to_text(self, *args, **kwarg): + serializer = self.get_serializer(data=self.request.data) + serializer.is_valid() + + project_id = kwarg["project_external_id"] + audio = serializer.validated_data["audio"] + language = serializer.validated_data.get("language", "en") + + stats = {} + try: + stt_engine = Project.objects.get(external_id=project_id).stt_engine + except Project.DoesNotExist: + return Response( + {"error": "Project not found"}, status=status.HTTP_400_BAD_REQUEST + ) + try: + stats["transcript_start_time"] = time.time() + transcript = speech_to_text(stt_engine, audio, language + "-IN") + stats["transcript_end_time"] = time.time() + translated_text = transcript + except Exception as e: + print(f"Failed to transcribe speech with {stt_engine} engine: {e}") + return Response( + { + "error": "Something went wrong in getting transcription, please try again later" + }, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + return Response( + {"transcript": translated_text, "stats": stats}, status=status.HTTP_200_OK + ) + @extend_schema( tags=("chats",), )