Skip to content

Commit

Permalink
Set a valid voice for a persona (xi_labs voices)
Browse files Browse the repository at this point in the history
  • Loading branch information
AnniePacheco committed Sep 13, 2024
1 parent a11cd7f commit f6504f2
Show file tree
Hide file tree
Showing 12 changed files with 329 additions and 20 deletions.
63 changes: 62 additions & 1 deletion apis/paios/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,29 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/ConversationCreate'
/voices:
get:
summary: Retrieve all the xi_labs voices
tags:
- Voice Management
description: Retrieve the information of all the xi_labs voices
operationId: backend.api.VoicesFacesView.search
parameters:
- $ref: '#/components/parameters/sort'
- $ref: '#/components/parameters/range'
- $ref: '#/components/parameters/filter'
responses:
'200':
description: OK
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/Voice'
headers:
X-Total-Count:
$ref: '#/components/headers/X-Total-Count'
tags:
- name: Abilities Management
description: Installation and configuration of abilities
Expand All @@ -829,6 +852,8 @@ tags:
description: Management of Messages
- name: Conversation Management
description: Management of conversations
- name: Voice Management
description: Management of voices
components:
headers:
X-Total-Count:
Expand Down Expand Up @@ -995,6 +1020,14 @@ components:
format: date-time
example: "2023-07-29T12:34:56Z"
pattern: ^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}Z$
voice_image_url:
type: string
example: '/voice-icon.png'
pattern: ^\/.*\.png$
sample_mp3_url:
type: string
example: '/Laura.mp3'
pattern: ^\/.*\.mp3$
Ability:
type: object
title: Ability
Expand Down Expand Up @@ -1329,12 +1362,40 @@ components:
title: Voice
properties:
id:
$ref: '#/components/schemas/uuid4ReadOnly'
xi_id:
$ref: '#/components/schemas/textShort'
name:
$ref: '#/components/schemas/textShort'
text_to_speak:
type: string
nullable: true
image_url:
$ref: '#/components/schemas/voice_image_url'
nullable: true
sample_url:
$ref: '#/components/schemas/sample_mp3_url'
nullable: true
VoiceCreate:
type: object
title: VoiceCreate
description: Voice without id which is server-generated.
properties:
xi_id:
$ref: '#/components/schemas/textShort'
name:
$ref: '#/components/schemas/textShort'
text_to_speak:
type: string
nullable: true
image_url:
$ref: '#/components/schemas/voice_image_url'
nullable: true
sample_url:
$ref: '#/components/schemas/sample_mp3_url'
nullable: true
required:
- id
- name
Face:
type: object
title: Face
Expand Down
9 changes: 6 additions & 3 deletions backend/api/PersonasView.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ async def get(self, id: str):
return JSONResponse(persona.dict(), status_code=200)

async def post(self, body: PersonaCreateSchema):
id = await self.pm.create_persona(body)
persona = await self.pm.retrieve_persona(id)
return JSONResponse(persona.dict(), status_code=201, headers={'Location': f'{api_base_url}/personas/{id}'})
valid_msg = await self.pm.validate_persona_data(body)
if valid_msg == None:
voice_id = await self.pm.create_persona(body)
persona = await self.pm.retrieve_persona(voice_id)
return JSONResponse(persona.dict(), status_code=201, headers={'Location': f'{api_base_url}/personas/{persona.id}'})
return JSONResponse({"error": " Invalid persona: " + valid_msg}, status_code=400)

async def put(self, id: str, body: PersonaCreateSchema):
await self.pm.update_persona(id, body)
Expand Down
39 changes: 39 additions & 0 deletions backend/api/VoicesFacesView.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from starlette.responses import JSONResponse
from backend.managers.VoicesFacesManager import VoicesFacesManager
from common.paths import api_base_url
from backend.pagination import parse_pagination_params
from backend.schemas import VoiceCreateSchema



class VoicesFacesView:
def __init__(self):
self.vfm = VoicesFacesManager()
# TODO: Finish text to speech
async def post(self, id: str, body: VoiceCreateSchema):
response, error_message = await self.vfm.text_to_speech(id, body)
if error_message:
return JSONResponse({"error": error_message}, status_code=404)
else:
return JSONResponse(response, status_code=200)


async def search(self, filter: str = None, range: str = None, sort: str = None):
result = parse_pagination_params(filter, range, sort)
if isinstance(result, JSONResponse):
return result

offset, limit, sort_by, sort_order, filters = result

voices, total_count = await self.vfm.retrieve_voices(
limit=limit,
offset=offset,
sort_by=sort_by,
sort_order=sort_order,
filters=filters
)
headers = {
'X-Total-Count': str(total_count),
'Content-Range': f'voices {offset}-{offset + len(voices) - 1}/{total_count}'
}
return JSONResponse([voice.dict() for voice in voices], status_code=200, headers=headers)
3 changes: 2 additions & 1 deletion backend/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
from .PersonasView import PersonasView
from .RagIndexingView import RagIndexingView
from .MessagesView import MessagesView
from .ConversationsView import ConversationsView
from .ConversationsView import ConversationsView
from .VoicesFacesView import VoicesFacesView
5 changes: 1 addition & 4 deletions backend/managers/MessagesManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
from backend.models import Message, Conversation, Resource
from backend.db import db_session_context
from backend.schemas import MessageSchema, MessageCreateSchema
from typing import List, Tuple, Optional, Dict, Any
from typing import Tuple, Optional
from backend.utils import get_current_timestamp
from backend.managers.RagManager import RagManager

from langchain_community.llms import Ollama
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
import os

class MessagesManager:
Expand Down
11 changes: 10 additions & 1 deletion backend/managers/PersonasManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from backend.db import db_session_context
from backend.schemas import PersonaSchema, PersonaCreateSchema
from typing import List, Tuple, Optional, Dict, Any
from backend.managers.VoicesFacesManager import VoicesFacesManager

class PersonasManager:
_instance = None
Expand Down Expand Up @@ -104,4 +105,12 @@ async def retrieve_personas(self, offset: int = 0, limit: int = 100, sort_by: Op
total_count = await session.execute(count_query)
total_count = total_count.scalar()

return personas, total_count
return personas, total_count

async def validate_persona_data(self, persona_data: PersonaCreateSchema ) -> Optional[str]:
vfm = VoicesFacesManager()
if not persona_data.get("voice_id"):
return "It is mandatory to provide a voice_id for a persona"
if not await vfm.retrieve_voice(persona_data.get("voice_id")):
return "Not a valid voice_id"
return None
8 changes: 3 additions & 5 deletions backend/managers/RagManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langchain_ollama import OllamaEmbeddings
from common.paths import chroma_db_path
from pathlib import Path
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
import shutil
Expand All @@ -16,13 +16,11 @@
from backend.db import db_session_context
from sqlalchemy import delete, select, func
from pathlib import Path
from typing import List, Tuple, Optional, Dict, Any
from typing import List, Tuple, Optional, Dict, Any, Union
from backend.managers import ResourcesManager, PersonasManager
from distutils.util import strtobool
import os
import logging
from typing import Union
from langchain.prompts import PromptTemplate

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,7 +51,7 @@ async def create_index(self, resource_id: str, path_files: List[str]) -> List[di

for path in path_files:
loader = PyPDFLoader(path)
docs = loader.load() # Return a list of documents for each file
docs = loader.load()
all_docs.append(docs)
file_id = str(uuid4())
all_ids.append(file_id)
Expand Down
150 changes: 150 additions & 0 deletions backend/managers/VoicesFacesManager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from threading import Lock
import os
import requests
from backend.schemas import VoiceCreateSchema
from backend.db import db_session_context
from uuid import uuid4
from backend.models import Voice
from typing import List, Tuple, Optional, Dict, Any
from sqlalchemy import select, func
from backend.schemas import VoiceSchema


class VoicesFacesManager:
_instance = None
_lock = Lock()

def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super(VoicesFacesManager, cls).__new__(cls, *args, **kwargs)
return cls._instance

def __init__(self):
if not hasattr(self, '_initialized'):
with self._lock:
if not hasattr(self, '_initialized'):
self._initialized = True

async def map_xi_to_voice(self):
xi_voices = []
xi_api_key = os.environ.get('XI_API_KEY')
xi_url = "https://api.elevenlabs.io/v1/voices"
headers = {
"Accept": "application/json",
"xi-api-key": xi_api_key,
"Content-Type": "application/json"
}

response = requests.get(xi_url, headers=headers)
data = response.json()
for voice in data['voices']:
voice_data = {
'xi_id': voice['voice_id'],
'name': voice['name'],
'image_url': '/voice-icon.png',
'sample_url': '/' + voice['name'] + '.mp3'
}
xi_voices.append(voice_data)
return xi_voices


async def create_voice(self, voice_data: VoiceCreateSchema) -> str:
async with db_session_context() as session:
new_voice = Voice(id=str(uuid4()), xi_id=voice_data['xi_id'], name=voice_data['name'], image_url=voice_data['image_url'], sample_url=voice_data['sample_url'])
session.add(new_voice)
await session.commit()
await session.refresh(new_voice)
return new_voice.id

async def retrieve_voice(self, voice_id: str) -> Optional[VoiceSchema]:
async with db_session_context() as session:
query = select(Voice).filter(Voice.id == voice_id)
result = await session.execute(query)
voice = result.scalar_one_or_none()
if voice:
return VoiceSchema.from_orm(voice)
return None

async def retrieve_voices(self, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None,
sort_order:str = 'asc', filters: Optional[Dict[str, Any]] = None) -> Tuple[List[VoiceSchema], int]:
async with db_session_context() as session:
query = select(Voice)
result = await session.execute(query)
db_voice = result.scalars().all()
if db_voice == []:
xi_voices = await self.map_xi_to_voice()
for voice in xi_voices:
await self.create_voice(voice)

if filters:
for key, value in filters.items():
if key == 'name':
query = query.filter(Voice.name.ilike(f"%{value}%"))
elif isinstance(value, list):
query = query.filter(getattr(Voice, key).in_(value))
else:
query = query.filter(getattr(Voice, key) == value)

if sort_by and sort_by in ['id','xi_id', 'name','image_url','sample_url']:
order_column = getattr(Voice, sort_by)
query = query.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column)

query = query.offset(offset).limit(limit)

result = await session.execute(query)
voices = [VoiceSchema.from_orm(voice) for voice in result.scalars().all()]

# Get total count
count_query = select(func.count()).select_from(Voice)
if filters:
for key, value in filters.items():
if key == 'name':
count_query = count_query.filter(Voice.name.ilike(f"%{value}%"))
elif isinstance(value, list):
count_query = count_query.filter(getattr(Voice, key).in_(value))
else:
count_query = count_query.filter(getattr(Voice, key) == value)

total_count = await session.execute(count_query)
total_count = total_count.scalar()

return voices, total_count


async def text_to_speech(self, voice_id: str, body) -> str:
xi_api_key = os.environ.get('XI_API_KEY')
xi_id = body['xi_id']
OUTPUT_PATH = f"{xi_id}.mp3" # Path to save the output audio file

tts_url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}/stream"

headers = {
"Accept": "application/json",
"xi-api-key": xi_api_key
}

data = {
"text": "Hola" ,
"model_id": "eleven_multilingual_v2",
"voice_settings": {
"stability": 0.5,
"similarity_boost": 0.8,
"style": 0.0,
"use_speaker_boost": True
}
}

response = requests.post(tts_url, headers=headers, json=data, stream=True)

if response.ok:
with open(OUTPUT_PATH, "wb") as f:
for chunk in response.iter_content(chunk_size=os.environ.get('XI_CHUNK_SIZE')):
f.write(chunk)
print("Audio stream saved successfully.")
else:
print(response.text)
return response.text


Loading

0 comments on commit f6504f2

Please sign in to comment.