Skip to content

Commit

Permalink
Merge pull request #465 from coronasafe/openai_tts_engine
Browse files Browse the repository at this point in the history
  • Loading branch information
Ashesh3 authored Feb 11, 2024
2 parents 7172a54 + 6c18a39 commit 5df8789
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 21 deletions.
19 changes: 19 additions & 0 deletions ayushma/migrations/0051_project_tts_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Generated by Django 4.2.6 on 2024-02-11 15:23

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("ayushma", "0050_alter_chat_model_alter_project_model"),
]

operations = [
migrations.AddField(
model_name="project",
name="tts_engine",
field=models.SmallIntegerField(
choices=[(1, "openai"), (2, "google")], default=2
),
),
]
5 changes: 5 additions & 0 deletions ayushma/models/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ class STTEngine(IntegerChoices):
SELF_HOSTED = 3


class TTSEngine(IntegerChoices):
OPENAI = (1, "openai")
GOOGLE = (2, "google")


class FeedBackRating(IntegerChoices):
HALLUCINATING = 1
WRONG = 2
Expand Down
5 changes: 4 additions & 1 deletion ayushma/models/project.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from django.contrib.postgres.fields import ArrayField
from django.db import models

from ayushma.models.enums import ModelType, STTEngine
from ayushma.models.enums import ModelType, STTEngine, TTSEngine
from ayushma.models.users import User
from utils.models.base import BaseModel

Expand All @@ -16,6 +16,9 @@ class Project(BaseModel):
stt_engine = models.IntegerField(
choices=STTEngine.choices, default=STTEngine.WHISPER
)
tts_engine = models.SmallIntegerField(
choices=TTSEngine.choices, default=TTSEngine.GOOGLE
)
model = models.IntegerField(choices=ModelType.choices, default=ModelType.GPT_3_5)
preset_questions = ArrayField(models.TextField(), null=True, blank=True)
is_default = models.BooleanField(default=False)
Expand Down
1 change: 1 addition & 0 deletions ayushma/serializers/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Meta:
"modified_at",
"description",
"stt_engine",
"tts_engine",
"model",
"is_default",
"display_preset_questions",
Expand Down
54 changes: 35 additions & 19 deletions ayushma/utils/language_helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import re

from django.conf import settings
from google.cloud import texttospeech
from google.cloud import translate_v2 as translate
from openai import OpenAI
from rest_framework.exceptions import APIException

from ayushma.models.enums import TTSEngine


def translate_text(target, text):
try:
Expand Down Expand Up @@ -37,31 +41,43 @@ def sanitize_text(text):
return sanitized_text


def text_to_speech(text, language_code):
def text_to_speech(text, language_code, service):
try:
# in en-IN neural voice is not available
if language_code == "en-IN":
language_code = "en-US"

client = texttospeech.TextToSpeechClient()

text = sanitize_text(text)
synthesis_input = texttospeech.SynthesisInput(text=text)

voice = texttospeech.VoiceSelectionParams(
language_code=language_code, name=language_code_voice_map[language_code]
)
audio_config = texttospeech.AudioConfig(
audio_encoding=texttospeech.AudioEncoding.MP3
)

response = client.synthesize_speech(
input=synthesis_input,
voice=voice,
audio_config=audio_config,
)

return response.audio_content

if service == TTSEngine.GOOGLE:
client = texttospeech.TextToSpeechClient()

synthesis_input = texttospeech.SynthesisInput(text=text)

voice = texttospeech.VoiceSelectionParams(
language_code=language_code, name=language_code_voice_map[language_code]
)
audio_config = texttospeech.AudioConfig(
audio_encoding=texttospeech.AudioEncoding.MP3
)

response = client.synthesize_speech(
input=synthesis_input,
voice=voice,
audio_config=audio_config,
)

return response.audio_content
elif service == TTSEngine.OPENAI:
client = OpenAI(api_key=settings.OPENAI_API_KEY)
response = client.audio.speech.create(
model="tts-1-hd",
voice="nova",
input=text,
)
return response.read()
else:
raise APIException("Service not supported")
except Exception as e:
print(e)
return None
9 changes: 8 additions & 1 deletion ayushma/utils/openaiapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def handle_post_response(
temperature,
stats,
language,
tts_engine,
generate_audio=True,
):
chat_message: ChatMessage = ChatMessage.objects.create(
Expand All @@ -225,7 +226,9 @@ def handle_post_response(
ayushma_voice = None
if generate_audio:
stats["tts_start_time"] = time.time()
ayushma_voice = text_to_speech(translated_chat_response, user_language)
ayushma_voice = text_to_speech(
translated_chat_response, user_language, tts_engine
)
stats["tts_end_time"] = time.time()

url = None
Expand Down Expand Up @@ -324,6 +327,8 @@ def converse(
elif message.messageType == ChatMessageType.AYUSHMA:
chat_history.append(AIMessage(content=f"Ayushma: {message.message}"))

tts_engine = chat.project.tts_engine

if not stream:
lang_chain_helper = LangChainHelper(
stream=False,
Expand All @@ -347,6 +352,7 @@ def converse(
temperature,
stats,
language,
tts_engine,
generate_audio,
)

Expand Down Expand Up @@ -404,6 +410,7 @@ def converse(
temperature,
stats,
language,
tts_engine,
generate_audio,
)

Expand Down
1 change: 1 addition & 0 deletions utils/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ def get_paginated_response(self, data):
"has_previous": self.offset > 0,
"has_next": self.offset + self.limit < self.count,
"results": data,
"offset": self.offset,
}
)

0 comments on commit 5df8789

Please sign in to comment.