chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from . import asr, tts, tts_v2, qwen_tts, qwen_tts_realtime, qwen_omni
|
||||
|
||||
__all__ = [asr, tts, tts_v2, qwen_tts, qwen_tts_realtime, qwen_omni]
|
||||
Binary file not shown.
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .asr_phrase_manager import AsrPhraseManager
|
||||
from .recognition import Recognition, RecognitionCallback, RecognitionResult
|
||||
from .transcription import Transcription
|
||||
from .translation_recognizer import (TranscriptionResult, Translation,
|
||||
TranslationRecognizerCallback,
|
||||
TranslationRecognizerChat,
|
||||
TranslationRecognizerRealtime,
|
||||
TranslationRecognizerResultPack,
|
||||
TranslationResult)
|
||||
from .vocabulary import VocabularyService, VocabularyServiceException
|
||||
|
||||
__all__ = [
|
||||
'Transcription', 'Recognition', 'RecognitionCallback', 'RecognitionResult',
|
||||
'AsrPhraseManager', 'VocabularyServiceException', 'VocabularyService',
|
||||
'TranslationRecognizerRealtime', 'TranslationRecognizerChat',
|
||||
'TranslationRecognizerCallback', 'Translation', 'TranslationResult',
|
||||
'TranscriptionResult', 'TranslationRecognizerResultPack'
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,203 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict
|
||||
|
||||
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
|
||||
from dashscope.client.base_api import BaseAsyncApi
|
||||
from dashscope.common.error import InvalidParameter
|
||||
from dashscope.common.logging import logger
|
||||
from dashscope.customize.finetunes import FineTunes
|
||||
|
||||
|
||||
class AsrPhraseManager(BaseAsyncApi):
|
||||
"""Hot word management for speech recognition.
|
||||
"""
|
||||
@classmethod
|
||||
def create_phrases(cls,
|
||||
model: str,
|
||||
phrases: Dict[str, Any],
|
||||
training_type: str = 'compile_asr_phrase',
|
||||
workspace: str = None,
|
||||
**kwargs) -> DashScopeAPIResponse:
|
||||
"""Create hot words.
|
||||
|
||||
Args:
|
||||
model (str): The requested model.
|
||||
phrases (Dict[str, Any]): A dictionary that contains phrases,
|
||||
such as {'下一首':90,'上一首':90}.
|
||||
training_type (str, `optional`): The training type,
|
||||
'compile_asr_phrase' is default.
|
||||
workspace (str): The dashscope workspace id.
|
||||
|
||||
Raises:
|
||||
InvalidParameter: Parameter input is None or empty!
|
||||
|
||||
Returns:
|
||||
DashScopeAPIResponse: The results of creating hot words.
|
||||
"""
|
||||
if phrases is None or len(phrases) == 0:
|
||||
raise InvalidParameter('phrases is empty!')
|
||||
if training_type is None or len(training_type) == 0:
|
||||
raise InvalidParameter('training_type is empty!')
|
||||
|
||||
original_ft_sub_path = FineTunes.SUB_PATH
|
||||
FineTunes.SUB_PATH = 'fine-tunes'
|
||||
response = FineTunes.call(model=model,
|
||||
training_file_ids=[],
|
||||
validation_file_ids=[],
|
||||
mode=training_type,
|
||||
hyper_parameters={'phrase_list': phrases},
|
||||
workspace=workspace,
|
||||
**kwargs)
|
||||
FineTunes.SUB_PATH = original_ft_sub_path
|
||||
|
||||
if response.status_code != HTTPStatus.OK:
|
||||
logger.error('Create phrase failed, ' + str(response))
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def update_phrases(cls,
|
||||
model: str,
|
||||
phrase_id: str,
|
||||
phrases: Dict[str, Any],
|
||||
training_type: str = 'compile_asr_phrase',
|
||||
workspace: str = None,
|
||||
**kwargs) -> DashScopeAPIResponse:
|
||||
"""Update the hot words marked phrase_id.
|
||||
|
||||
Args:
|
||||
model (str): The requested model.
|
||||
phrase_id (str): The ID of phrases,
|
||||
which created by create_phrases().
|
||||
phrases (Dict[str, Any]): A dictionary that contains phrases,
|
||||
such as {'暂停':90}.
|
||||
training_type (str, `optional`):
|
||||
The training type, 'compile_asr_phrase' is default.
|
||||
workspace (str): The dashscope workspace id.
|
||||
|
||||
Raises:
|
||||
InvalidParameter: Parameter input is None or empty!
|
||||
|
||||
Returns:
|
||||
DashScopeAPIResponse: The results of updating hot words.
|
||||
"""
|
||||
if phrase_id is None or len(phrase_id) == 0:
|
||||
raise InvalidParameter('phrase_id is empty!')
|
||||
if phrases is None or len(phrases) == 0:
|
||||
raise InvalidParameter('phrases is empty!')
|
||||
if training_type is None or len(training_type) == 0:
|
||||
raise InvalidParameter('training_type is empty!')
|
||||
|
||||
original_ft_sub_path = FineTunes.SUB_PATH
|
||||
FineTunes.SUB_PATH = 'fine-tunes'
|
||||
response = FineTunes.call(model=model,
|
||||
training_file_ids=[],
|
||||
validation_file_ids=[],
|
||||
mode=training_type,
|
||||
hyper_parameters={'phrase_list': phrases},
|
||||
finetuned_output=phrase_id,
|
||||
workspace=workspace,
|
||||
**kwargs)
|
||||
FineTunes.SUB_PATH = original_ft_sub_path
|
||||
|
||||
if response.status_code != HTTPStatus.OK:
|
||||
logger.error('Update phrase failed, ' + str(response))
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def query_phrases(cls,
|
||||
phrase_id: str,
|
||||
workspace: str = None,
|
||||
**kwargs) -> DashScopeAPIResponse:
|
||||
"""Query the hot words by phrase_id.
|
||||
|
||||
Args:
|
||||
phrase_id (str): The ID of phrases,
|
||||
which created by create_phrases().
|
||||
workspace (str): The dashscope workspace id.
|
||||
|
||||
Raises:
|
||||
InvalidParameter: phrase_id input is None or empty!
|
||||
|
||||
Returns:
|
||||
AsrPhraseManagerResult: The results of querying hot words.
|
||||
"""
|
||||
if phrase_id is None or len(phrase_id) == 0:
|
||||
raise InvalidParameter('phrase_id is empty!')
|
||||
|
||||
original_ft_sub_path = FineTunes.SUB_PATH
|
||||
FineTunes.SUB_PATH = 'fine-tunes/outputs'
|
||||
response = FineTunes.get(job_id=phrase_id,
|
||||
workspace=workspace,
|
||||
**kwargs)
|
||||
FineTunes.SUB_PATH = original_ft_sub_path
|
||||
|
||||
if response.status_code != HTTPStatus.OK:
|
||||
logger.error('Query phrase failed, ' + str(response))
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def list_phrases(cls,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
workspace: str = None,
|
||||
**kwargs) -> DashScopeAPIResponse:
|
||||
"""List all information of phrases.
|
||||
|
||||
Args:
|
||||
page (int): Page number, greater than 0, default value 1.
|
||||
page_size (int): The paging size, greater than 0
|
||||
and less than or equal to 100, default value 10.
|
||||
workspace (str): The dashscope workspace id.
|
||||
|
||||
Returns:
|
||||
DashScopeAPIResponse: The results of listing hot words.
|
||||
"""
|
||||
original_ft_sub_path = FineTunes.SUB_PATH
|
||||
FineTunes.SUB_PATH = 'fine-tunes/outputs'
|
||||
response = FineTunes.list(page=page,
|
||||
page_size=page_size,
|
||||
workspace=workspace,
|
||||
**kwargs)
|
||||
FineTunes.SUB_PATH = original_ft_sub_path
|
||||
|
||||
if response.status_code != HTTPStatus.OK:
|
||||
logger.error('List phrase failed, ' + str(response))
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def delete_phrases(cls,
|
||||
phrase_id: str,
|
||||
workspace: str = None,
|
||||
**kwargs) -> DashScopeAPIResponse:
|
||||
"""Delete the hot words by phrase_id.
|
||||
|
||||
Args:
|
||||
phrase_id (str): The ID of phrases,
|
||||
which created by create_phrases().
|
||||
|
||||
Raises:
|
||||
InvalidParameter: phrase_id input is None or empty!
|
||||
|
||||
Returns:
|
||||
DashScopeAPIResponse: The results of deleting hot words.
|
||||
"""
|
||||
if phrase_id is None or len(phrase_id) == 0:
|
||||
raise InvalidParameter('phrase_id is empty!')
|
||||
|
||||
original_ft_sub_path = FineTunes.SUB_PATH
|
||||
FineTunes.SUB_PATH = 'fine-tunes/outputs'
|
||||
response = FineTunes.delete(job_id=phrase_id,
|
||||
workspace=workspace,
|
||||
**kwargs)
|
||||
FineTunes.SUB_PATH = original_ft_sub_path
|
||||
|
||||
if response.status_code != HTTPStatus.OK:
|
||||
logger.error('Delete phrase failed, ' + str(response))
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,527 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from queue import Queue
|
||||
from threading import Timer
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from dashscope.api_entities.dashscope_response import RecognitionResponse
|
||||
from dashscope.client.base_api import BaseApi
|
||||
from dashscope.common.constants import ApiProtocol
|
||||
from dashscope.common.error import (InputDataRequired, InputRequired,
|
||||
InvalidParameter, InvalidTask,
|
||||
ModelRequired)
|
||||
from dashscope.common.logging import logger
|
||||
from dashscope.common.utils import _get_task_group_and_task
|
||||
from dashscope.protocol.websocket import WebsocketStreamingMode
|
||||
|
||||
|
||||
class RecognitionResult(RecognitionResponse):
|
||||
"""The result set of speech recognition, including the single-sentence
|
||||
recognition result returned by the callback mode, and all recognition
|
||||
results in a synchronized manner.
|
||||
"""
|
||||
def __init__(self,
|
||||
response: RecognitionResponse,
|
||||
sentences: List[Any] = None,
|
||||
usages: List[Any] = None):
|
||||
self.status_code = response.status_code
|
||||
self.request_id = response.request_id
|
||||
self.code = response.code
|
||||
self.message = response.message
|
||||
self.usages = usages
|
||||
if sentences is not None and len(sentences) > 0:
|
||||
self.output = {'sentence': sentences}
|
||||
else:
|
||||
self.output = response.output
|
||||
if self.usages is not None and len(
|
||||
self.usages) > 0 and 'usage' in self.usages[-1]:
|
||||
self.usage = self.usages[-1]['usage']
|
||||
else:
|
||||
self.usage = None
|
||||
|
||||
def __str__(self):
|
||||
return json.dumps(RecognitionResponse.from_api_response(self),
|
||||
ensure_ascii=False)
|
||||
|
||||
def get_sentence(self) -> Union[Dict[str, Any], List[Any]]:
|
||||
"""The result of speech recognition.
|
||||
"""
|
||||
if self.output and 'sentence' in self.output:
|
||||
return self.output['sentence']
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_request_id(self) -> str:
|
||||
"""The request_id of speech recognition.
|
||||
"""
|
||||
return self.request_id
|
||||
|
||||
def get_usage(self, sentence: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get billing for the input sentence.
|
||||
"""
|
||||
if self.usages is not None:
|
||||
if sentence is not None and 'end_time' in sentence and sentence[
|
||||
'end_time'] is not None:
|
||||
for usage in self.usages:
|
||||
if usage['end_time'] == sentence['end_time']:
|
||||
return usage['usage']
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_sentence_end(sentence: Dict[str, Any]) -> bool:
|
||||
"""Determine whether the speech recognition result is the end of a sentence.
|
||||
This is a static method.
|
||||
"""
|
||||
if sentence is not None and 'end_time' in sentence and sentence[
|
||||
'end_time'] is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class RecognitionCallback():
|
||||
"""An interface that defines callback methods for getting speech recognition results. # noqa E501
|
||||
Derive from this class and implement its function to provide your own data.
|
||||
"""
|
||||
def on_open(self) -> None:
|
||||
pass
|
||||
|
||||
def on_complete(self) -> None:
|
||||
pass
|
||||
|
||||
def on_error(self, result: RecognitionResult) -> None:
|
||||
pass
|
||||
|
||||
def on_close(self) -> None:
|
||||
pass
|
||||
|
||||
def on_event(self, result: RecognitionResult) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class Recognition(BaseApi):
|
||||
"""Speech recognition interface.
|
||||
|
||||
Args:
|
||||
model (str): The requested model_id.
|
||||
callback (RecognitionCallback): A callback that returns
|
||||
speech recognition results.
|
||||
format (str): The input audio format for speech recognition.
|
||||
sample_rate (int): The input audio sample rate for speech recognition.
|
||||
workspace (str): The dashscope workspace id.
|
||||
|
||||
**kwargs:
|
||||
phrase_id (list, `optional`): The ID of phrase.
|
||||
disfluency_removal_enabled(bool, `optional`): Filter mood words,
|
||||
turned off by default.
|
||||
diarization_enabled (bool, `optional`): Speech auto diarization,
|
||||
turned off by default.
|
||||
speaker_count (int, `optional`): The number of speakers.
|
||||
timestamp_alignment_enabled (bool, `optional`): Timestamp-alignment
|
||||
calibration, turned off by default.
|
||||
special_word_filter(str, `optional`): Sensitive word filter.
|
||||
audio_event_detection_enabled(bool, `optional`):
|
||||
Audio event detection, turned off by default.
|
||||
|
||||
Raises:
|
||||
InputRequired: Input is required.
|
||||
"""
|
||||
|
||||
SILENCE_TIMEOUT_S = 23
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
callback: RecognitionCallback,
|
||||
format: str,
|
||||
sample_rate: int,
|
||||
workspace: str = None,
|
||||
**kwargs):
|
||||
if model is None:
|
||||
raise ModelRequired('Model is required!')
|
||||
if format is None:
|
||||
raise InputRequired('format is required!')
|
||||
if sample_rate is None:
|
||||
raise InputRequired('sample_rate is required!')
|
||||
|
||||
self.model = model
|
||||
self.format = format
|
||||
self.sample_rate = sample_rate
|
||||
# continuous recognition with start() or once recognition with call()
|
||||
self._recognition_once = False
|
||||
self._callback = callback
|
||||
self._running = False
|
||||
self._stream_data = Queue()
|
||||
self._worker = None
|
||||
self._silence_timer = None
|
||||
self._kwargs = kwargs
|
||||
self._workspace = workspace
|
||||
self._start_stream_timestamp = -1
|
||||
self._first_package_timestamp = -1
|
||||
self._stop_stream_timestamp = -1
|
||||
self._on_complete_timestamp = -1
|
||||
self.request_id_confirmed = False
|
||||
self.last_request_id = uuid.uuid4().hex
|
||||
|
||||
def __del__(self):
|
||||
if self._running:
|
||||
self._running = False
|
||||
self._stream_data = Queue()
|
||||
if self._worker is not None and self._worker.is_alive():
|
||||
self._worker.join()
|
||||
if self._silence_timer is not None and self._silence_timer.is_alive( # noqa E501
|
||||
):
|
||||
self._silence_timer.cancel()
|
||||
self._silence_timer = None
|
||||
if self._callback:
|
||||
self._callback.on_close()
|
||||
|
||||
def __receive_worker(self):
|
||||
"""Asynchronously, initiate a real-time speech recognition request and
|
||||
obtain the result for parsing.
|
||||
"""
|
||||
responses = self.__launch_request()
|
||||
for part in responses:
|
||||
if part.status_code == HTTPStatus.OK:
|
||||
if len(part.output) == 0 or ('finished' in part.output and part.output['finished'] == True):
|
||||
self._on_complete_timestamp = time.time() * 1000
|
||||
logger.debug('last package delay {}'.format(
|
||||
self.get_last_package_delay()))
|
||||
self._callback.on_complete()
|
||||
else:
|
||||
usage: Dict[str, Any] = None
|
||||
usages: List[Any] = None
|
||||
if 'sentence' in part.output:
|
||||
if (self._first_package_timestamp < 0):
|
||||
self._first_package_timestamp = time.time() * 1000
|
||||
logger.debug('first package delay {}'.format(
|
||||
self.get_first_package_delay()))
|
||||
sentence = part.output['sentence']
|
||||
if 'heartbeat' in sentence and sentence['heartbeat'] == True:
|
||||
logger.debug('recv heartbeat')
|
||||
continue
|
||||
logger.debug(
|
||||
'Recv Result [rid:{}]:{}, isEnd: {}'.format(
|
||||
part.request_id, sentence,
|
||||
RecognitionResult.is_sentence_end(sentence)))
|
||||
if part.usage is not None:
|
||||
usage = {
|
||||
'end_time':
|
||||
part.output['sentence']['end_time'],
|
||||
'usage': part.usage
|
||||
}
|
||||
usages = [usage]
|
||||
if self.request_id_confirmed is False and part.request_id is not None:
|
||||
self.last_request_id = part.request_id
|
||||
self.request_id_confirmed = True
|
||||
|
||||
self._callback.on_event(
|
||||
RecognitionResult(
|
||||
RecognitionResponse.from_api_response(part),
|
||||
usages=usages))
|
||||
else:
|
||||
self._running = False
|
||||
self._stream_data = Queue()
|
||||
self._callback.on_error(
|
||||
RecognitionResult(
|
||||
RecognitionResponse.from_api_response(part)))
|
||||
self._callback.on_close()
|
||||
break
|
||||
|
||||
def __launch_request(self):
|
||||
"""Initiate real-time speech recognition requests.
|
||||
"""
|
||||
resources_list: list = []
|
||||
if self._phrase is not None and len(self._phrase) > 0:
|
||||
item = {'resource_id': self._phrase, 'resource_type': 'asr_phrase'}
|
||||
resources_list.append(item)
|
||||
|
||||
if len(resources_list) > 0:
|
||||
self._kwargs['resources'] = resources_list
|
||||
|
||||
self._tidy_kwargs()
|
||||
task_name, _ = _get_task_group_and_task(__name__)
|
||||
responses = super().call(model=self.model,
|
||||
task_group='audio',
|
||||
task=task_name,
|
||||
function='recognition',
|
||||
input=self._input_stream_cycle(),
|
||||
api_protocol=ApiProtocol.WEBSOCKET,
|
||||
ws_stream_mode=WebsocketStreamingMode.DUPLEX,
|
||||
is_binary_input=True,
|
||||
sample_rate=self.sample_rate,
|
||||
format=self.format,
|
||||
stream=True,
|
||||
workspace=self._workspace,
|
||||
pre_task_id=self.last_request_id,
|
||||
**self._kwargs)
|
||||
return responses
|
||||
|
||||
def start(self, phrase_id: str = None, **kwargs):
|
||||
"""Real-time speech recognition in asynchronous mode.
|
||||
Please call 'stop()' after you have completed recognition.
|
||||
|
||||
Args:
|
||||
phrase_id (str, `optional`): The ID of phrase.
|
||||
|
||||
**kwargs:
|
||||
disfluency_removal_enabled(bool, `optional`):
|
||||
Filter mood words, turned off by default.
|
||||
diarization_enabled (bool, `optional`):
|
||||
Speech auto diarization, turned off by default.
|
||||
speaker_count (int, `optional`): The number of speakers.
|
||||
timestamp_alignment_enabled (bool, `optional`):
|
||||
Timestamp-alignment calibration, turned off by default.
|
||||
special_word_filter(str, `optional`): Sensitive word filter.
|
||||
audio_event_detection_enabled(bool, `optional`):
|
||||
Audio event detection, turned off by default.
|
||||
|
||||
Raises:
|
||||
InvalidParameter: This interface cannot be called again
|
||||
if it has already been started.
|
||||
InvalidTask: Task create failed.
|
||||
"""
|
||||
assert self._callback is not None, 'Please set the callback to get the speech recognition result.' # noqa E501
|
||||
|
||||
if self._running:
|
||||
raise InvalidParameter('Speech recognition has started.')
|
||||
|
||||
self._start_stream_timestamp = -1
|
||||
self._first_package_timestamp = -1
|
||||
self._stop_stream_timestamp = -1
|
||||
self._on_complete_timestamp = -1
|
||||
self._phrase = phrase_id
|
||||
self._kwargs.update(**kwargs)
|
||||
self._recognition_once = False
|
||||
self._worker = threading.Thread(target=self.__receive_worker)
|
||||
self._worker.start()
|
||||
if self._worker.is_alive():
|
||||
self._running = True
|
||||
self._callback.on_open()
|
||||
|
||||
# If audio data is not received for 23 seconds, the timeout exits
|
||||
self._silence_timer = Timer(Recognition.SILENCE_TIMEOUT_S,
|
||||
self._silence_stop_timer)
|
||||
self._silence_timer.start()
|
||||
else:
|
||||
self._running = False
|
||||
raise InvalidTask('Invalid task, task create failed.')
|
||||
|
||||
def call(self,
|
||||
file: str,
|
||||
phrase_id: str = None,
|
||||
**kwargs) -> RecognitionResult:
|
||||
"""Real-time speech recognition in synchronous mode.
|
||||
|
||||
Args:
|
||||
file (str): The path to the local audio file.
|
||||
phrase_id (str, `optional`): The ID of phrase.
|
||||
|
||||
**kwargs:
|
||||
disfluency_removal_enabled(bool, `optional`):
|
||||
Filter mood words, turned off by default.
|
||||
diarization_enabled (bool, `optional`):
|
||||
Speech auto diarization, turned off by default.
|
||||
speaker_count (int, `optional`): The number of speakers.
|
||||
timestamp_alignment_enabled (bool, `optional`):
|
||||
Timestamp-alignment calibration, turned off by default.
|
||||
special_word_filter(str, `optional`): Sensitive word filter.
|
||||
audio_event_detection_enabled(bool, `optional`):
|
||||
Audio event detection, turned off by default.
|
||||
|
||||
Raises:
|
||||
InvalidParameter: This interface cannot be called again
|
||||
if it has already been started.
|
||||
InputDataRequired: The supplied file was empty.
|
||||
|
||||
Returns:
|
||||
RecognitionResult: The result of speech recognition.
|
||||
"""
|
||||
self._start_stream_timestamp = time.time() * 1000
|
||||
if self._running:
|
||||
raise InvalidParameter('Speech recognition has been called.')
|
||||
|
||||
if os.path.exists(file):
|
||||
if os.path.isdir(file):
|
||||
raise IsADirectoryError('Is a directory: ' + file)
|
||||
else:
|
||||
raise FileNotFoundError('No such file or directory: ' + file)
|
||||
|
||||
self._recognition_once = True
|
||||
self._stream_data = Queue()
|
||||
self._phrase = phrase_id
|
||||
self._kwargs.update(**kwargs)
|
||||
error_flag: bool = False
|
||||
sentences: List[Any] = []
|
||||
usages: List[Any] = []
|
||||
response: RecognitionResponse = None
|
||||
result: RecognitionResult = None
|
||||
|
||||
try:
|
||||
audio_data: bytes = None
|
||||
f = open(file, 'rb')
|
||||
if os.path.getsize(file):
|
||||
while True:
|
||||
audio_data = f.read(12800)
|
||||
if not audio_data:
|
||||
break
|
||||
else:
|
||||
self._stream_data.put(audio_data)
|
||||
else:
|
||||
raise InputDataRequired(
|
||||
'The supplied file was empty (zero bytes long)')
|
||||
f.close()
|
||||
self._stop_stream_timestamp = time.time() * 1000
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
||||
if not self._stream_data.empty():
|
||||
self._running = True
|
||||
responses = self.__launch_request()
|
||||
for part in responses:
|
||||
if part.status_code == HTTPStatus.OK:
|
||||
if 'sentence' in part.output:
|
||||
if (self._first_package_timestamp < 0):
|
||||
self._first_package_timestamp = time.time() * 1000
|
||||
logger.debug('first package delay {}'.format(
|
||||
self._first_package_timestamp -
|
||||
self._start_stream_timestamp))
|
||||
sentence = part.output['sentence']
|
||||
logger.debug(
|
||||
'Recv Result [rid:{}]:{}, isEnd: {}'.format(
|
||||
part.request_id, sentence,
|
||||
RecognitionResult.is_sentence_end(sentence)))
|
||||
if RecognitionResult.is_sentence_end(sentence):
|
||||
sentences.append(sentence)
|
||||
|
||||
if part.usage is not None:
|
||||
usage = {
|
||||
'end_time':
|
||||
part.output['sentence']['end_time'],
|
||||
'usage': part.usage
|
||||
}
|
||||
usages.append(usage)
|
||||
|
||||
response = RecognitionResponse.from_api_response(part)
|
||||
else:
|
||||
response = RecognitionResponse.from_api_response(part)
|
||||
logger.error(response)
|
||||
error_flag = True
|
||||
break
|
||||
|
||||
self._on_complete_timestamp = time.time() * 1000
|
||||
logger.debug('last package delay {}'.format(
|
||||
self.get_last_package_delay()))
|
||||
|
||||
if error_flag:
|
||||
result = RecognitionResult(response)
|
||||
else:
|
||||
result = RecognitionResult(response, sentences, usages)
|
||||
|
||||
self._stream_data = Queue()
|
||||
self._recognition_once = False
|
||||
self._running = False
|
||||
|
||||
return result
|
||||
|
||||
def stop(self):
|
||||
"""End asynchronous speech recognition.
|
||||
|
||||
Raises:
|
||||
InvalidParameter: Cannot stop an uninitiated recognition.
|
||||
"""
|
||||
if self._running is False:
|
||||
raise InvalidParameter('Speech recognition has stopped.')
|
||||
|
||||
self._stop_stream_timestamp = time.time() * 1000
|
||||
|
||||
self._running = False
|
||||
if self._worker is not None and self._worker.is_alive():
|
||||
self._worker.join()
|
||||
self._stream_data = Queue()
|
||||
if self._silence_timer is not None and self._silence_timer.is_alive():
|
||||
self._silence_timer.cancel()
|
||||
self._silence_timer = None
|
||||
if self._callback:
|
||||
self._callback.on_close()
|
||||
|
||||
def send_audio_frame(self, buffer: bytes):
|
||||
"""Push speech recognition.
|
||||
|
||||
Raises:
|
||||
InvalidParameter: Cannot send data to an uninitiated recognition.
|
||||
"""
|
||||
if self._running is False:
|
||||
raise InvalidParameter('Speech recognition has stopped.')
|
||||
|
||||
if (self._start_stream_timestamp < 0):
|
||||
self._start_stream_timestamp = time.time() * 1000
|
||||
logger.debug('send_audio_frame: {}'.format(len(buffer)))
|
||||
self._stream_data.put(buffer)
|
||||
|
||||
def _tidy_kwargs(self):
|
||||
for k in self._kwargs.copy():
|
||||
if self._kwargs[k] is None:
|
||||
self._kwargs.pop(k, None)
|
||||
|
||||
def _input_stream_cycle(self):
|
||||
while self._running:
|
||||
while self._stream_data.empty():
|
||||
if self._running:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
# Reset silence_timer when getting stream.
|
||||
if self._silence_timer is not None and self._silence_timer.is_alive( # noqa E501
|
||||
):
|
||||
self._silence_timer.cancel()
|
||||
self._silence_timer = Timer(Recognition.SILENCE_TIMEOUT_S,
|
||||
self._silence_stop_timer)
|
||||
self._silence_timer.start()
|
||||
|
||||
while not self._stream_data.empty():
|
||||
frame = self._stream_data.get()
|
||||
yield bytes(frame)
|
||||
|
||||
if self._recognition_once:
|
||||
self._running = False
|
||||
|
||||
# drain all audio data when invoking stop().
|
||||
if self._recognition_once is False:
|
||||
while not self._stream_data.empty():
|
||||
frame = self._stream_data.get()
|
||||
yield bytes(frame)
|
||||
|
||||
def _silence_stop_timer(self):
|
||||
"""If audio data is not received for a long time, exit worker.
|
||||
"""
|
||||
self._running = False
|
||||
if self._silence_timer is not None and self._silence_timer.is_alive():
|
||||
self._silence_timer.cancel()
|
||||
self._silence_timer = None
|
||||
if self._worker is not None and self._worker.is_alive():
|
||||
self._worker.join()
|
||||
self._stream_data = Queue()
|
||||
|
||||
def get_first_package_delay(self):
|
||||
"""First Package Delay is the time between start sending audio and receive first words package
|
||||
"""
|
||||
return self._first_package_timestamp - self._start_stream_timestamp
|
||||
|
||||
def get_last_package_delay(self):
|
||||
"""Last Package Delay is the time between stop sending audio and receive last words package
|
||||
"""
|
||||
return self._on_complete_timestamp - self._stop_stream_timestamp
|
||||
|
||||
# 获取上一个任务的taskId
|
||||
def get_last_request_id(self):
|
||||
return self.last_request_id
|
||||
@@ -0,0 +1,231 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Union
|
||||
|
||||
import aiohttp
|
||||
|
||||
from dashscope.api_entities.dashscope_response import (DashScopeAPIResponse,
|
||||
TranscriptionResponse)
|
||||
from dashscope.client.base_api import BaseAsyncApi
|
||||
from dashscope.common.constants import ApiProtocol, HTTPMethod
|
||||
from dashscope.common.logging import logger
|
||||
from dashscope.common.utils import _get_task_group_and_task
|
||||
|
||||
|
||||
class Transcription(BaseAsyncApi):
|
||||
"""API for File Transcription models.
|
||||
"""
|
||||
|
||||
MAX_QUERY_TRY_COUNT = 3
|
||||
|
||||
class Models:
|
||||
paraformer_v1 = 'paraformer-v1'
|
||||
paraformer_8k_v1 = 'paraformer-8k-v1'
|
||||
paraformer_mtl_v1 = 'paraformer-mtl-v1'
|
||||
|
||||
@classmethod
|
||||
def call(cls,
|
||||
model: str,
|
||||
file_urls: List[str],
|
||||
phrase_id: str = None,
|
||||
api_key: str = None,
|
||||
workspace: str = None,
|
||||
**kwargs) -> TranscriptionResponse:
|
||||
"""Transcribe the given files synchronously.
|
||||
|
||||
Args:
|
||||
model (str): The requested model_id.
|
||||
file_urls (List[str]): List of stored URLs.
|
||||
phrase_id (str, `optional`): The ID of phrase.
|
||||
workspace (str): The dashscope workspace id.
|
||||
|
||||
**kwargs:
|
||||
channel_id (List[int], optional):
|
||||
The selected channel_id of audio file.
|
||||
disfluency_removal_enabled(bool, `optional`):
|
||||
Filter mood words, turned off by default.
|
||||
diarization_enabled (bool, `optional`):
|
||||
Speech auto diarization, turned off by default.
|
||||
speaker_count (int, `optional`): The number of speakers.
|
||||
timestamp_alignment_enabled (bool, `optional`):
|
||||
Timestamp-alignment calibration, turned off by default.
|
||||
special_word_filter(str, `optional`): Sensitive word filter.
|
||||
audio_event_detection_enabled(bool, `optional`):
|
||||
Audio event detection, turned off by default.
|
||||
|
||||
Returns:
|
||||
TranscriptionResponse: The result of batch transcription.
|
||||
"""
|
||||
kwargs.update(cls._fill_resource_id(phrase_id, **kwargs))
|
||||
kwargs = cls._tidy_kwargs(**kwargs)
|
||||
response = super().call(model,
|
||||
file_urls,
|
||||
api_key=api_key,
|
||||
workspace=workspace,
|
||||
**kwargs)
|
||||
return TranscriptionResponse.from_api_response(response)
|
||||
|
||||
@classmethod
|
||||
def async_call(cls,
|
||||
model: str,
|
||||
file_urls: List[str],
|
||||
phrase_id: str = None,
|
||||
api_key: str = None,
|
||||
workspace: str = None,
|
||||
**kwargs) -> TranscriptionResponse:
|
||||
"""Transcribe the given files asynchronously,
|
||||
return the status of task submission for querying results subsequently.
|
||||
|
||||
Args:
|
||||
model (str): The requested model, such as paraformer-16k-1
|
||||
file_urls (List[str]): List of stored URLs.
|
||||
phrase_id (str, `optional`): The ID of phrase.
|
||||
workspace (str): The dashscope workspace id.
|
||||
|
||||
**kwargs:
|
||||
channel_id (List[int], optional):
|
||||
The selected channel_id of audio file.
|
||||
disfluency_removal_enabled(bool, `optional`):
|
||||
Filter mood words, turned off by default.
|
||||
diarization_enabled (bool, `optional`):
|
||||
Speech auto diarization, turned off by default.
|
||||
speaker_count (int, `optional`): The number of speakers.
|
||||
timestamp_alignment_enabled (bool, `optional`):
|
||||
Timestamp-alignment calibration, turned off by default.
|
||||
special_word_filter(str, `optional`): Sensitive word filter.
|
||||
audio_event_detection_enabled(bool, `optional`):
|
||||
Audio event detection, turned off by default.
|
||||
|
||||
Returns:
|
||||
TranscriptionResponse: The response including task_id.
|
||||
"""
|
||||
kwargs.update(cls._fill_resource_id(phrase_id, **kwargs))
|
||||
kwargs = cls._tidy_kwargs(**kwargs)
|
||||
response = cls._launch_request(model,
|
||||
file_urls,
|
||||
api_key=api_key,
|
||||
workspace=workspace,
|
||||
**kwargs)
|
||||
return TranscriptionResponse.from_api_response(response)
|
||||
|
||||
@classmethod
|
||||
def fetch(cls,
|
||||
task: Union[str, TranscriptionResponse],
|
||||
api_key: str = None,
|
||||
workspace: str = None,
|
||||
**kwargs) -> TranscriptionResponse:
|
||||
"""Fetch the status of task, including results of batch transcription when task_status is SUCCEEDED. # noqa: E501
|
||||
|
||||
Args:
|
||||
task (Union[str, TranscriptionResponse]): The task_id or
|
||||
response including task_id returned from async_call().
|
||||
workspace (str): The dashscope workspace id.
|
||||
|
||||
Returns:
|
||||
TranscriptionResponse: The status of task_id,
|
||||
including results of batch transcription when task_status is SUCCEEDED.
|
||||
"""
|
||||
try_count: int = 0
|
||||
while True:
|
||||
try:
|
||||
response = super().fetch(task,
|
||||
api_key=api_key,
|
||||
workspace=workspace,
|
||||
**kwargs)
|
||||
except (asyncio.TimeoutError, aiohttp.ClientConnectorError) as e:
|
||||
logger.error(e)
|
||||
try_count += 1
|
||||
if try_count <= Transcription.MAX_QUERY_TRY_COUNT:
|
||||
time.sleep(2)
|
||||
continue
|
||||
|
||||
try_count = 0
|
||||
break
|
||||
|
||||
return TranscriptionResponse.from_api_response(response)
|
||||
|
||||
@classmethod
|
||||
def wait(cls,
|
||||
task: Union[str, TranscriptionResponse],
|
||||
api_key: str = None,
|
||||
workspace: str = None,
|
||||
**kwargs) -> TranscriptionResponse:
|
||||
"""Poll task until the final results of transcription is obtained.
|
||||
|
||||
Args:
|
||||
task (Union[str, TranscriptionResponse]): The task_id or
|
||||
response including task_id returned from async_call().
|
||||
workspace (str): The dashscope workspace id.
|
||||
|
||||
Returns:
|
||||
TranscriptionResponse: The result of batch transcription.
|
||||
"""
|
||||
response = super().wait(task,
|
||||
api_key=api_key,
|
||||
workspace=workspace,
|
||||
**kwargs)
|
||||
return TranscriptionResponse.from_api_response(response)
|
||||
|
||||
@classmethod
|
||||
def _launch_request(cls,
|
||||
model: str,
|
||||
files: List[str],
|
||||
api_key: str = None,
|
||||
workspace: str = None,
|
||||
**kwargs) -> DashScopeAPIResponse:
|
||||
"""Submit transcribe request.
|
||||
|
||||
Args:
|
||||
model (str): The requested model, such as paraformer-16k-1
|
||||
files (List[str]): List of stored URLs.
|
||||
workspace (str): The dashscope workspace id.
|
||||
|
||||
Returns:
|
||||
DashScopeAPIResponse: The result of task submission.
|
||||
"""
|
||||
task_name, function = _get_task_group_and_task(__name__)
|
||||
|
||||
try_count: int = 0
|
||||
while True:
|
||||
try:
|
||||
response = super().async_call(model=model,
|
||||
task_group='audio',
|
||||
task=task_name,
|
||||
function=function,
|
||||
input={'file_urls': files},
|
||||
api_protocol=ApiProtocol.HTTP,
|
||||
http_method=HTTPMethod.POST,
|
||||
api_key=api_key,
|
||||
workspace=workspace,
|
||||
**kwargs)
|
||||
except (asyncio.TimeoutError, aiohttp.ClientConnectorError) as e:
|
||||
logger.error(e)
|
||||
try_count += 1
|
||||
if try_count <= Transcription.MAX_QUERY_TRY_COUNT:
|
||||
time.sleep(2)
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def _fill_resource_id(cls, phrase_id: str, **kwargs):
|
||||
resources_list: list = []
|
||||
if phrase_id is not None and len(phrase_id) > 0:
|
||||
item = {'resource_id': phrase_id, 'resource_type': 'asr_phrase'}
|
||||
resources_list.append(item)
|
||||
|
||||
if len(resources_list) > 0:
|
||||
kwargs['resources'] = resources_list
|
||||
|
||||
return kwargs
|
||||
|
||||
@classmethod
|
||||
def _tidy_kwargs(cls, **kwargs):
|
||||
for k in kwargs.copy():
|
||||
if kwargs[k] is None:
|
||||
kwargs.pop(k, None)
|
||||
return kwargs
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,177 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
|
||||
from dashscope.client.base_api import BaseApi
|
||||
from dashscope.common.constants import ApiProtocol, HTTPMethod
|
||||
from dashscope.common.logging import logger
|
||||
|
||||
|
||||
class VocabularyServiceException(Exception):
|
||||
def __init__(self, request_id: str, status_code: int, code: str,
|
||||
error_message: str) -> None:
|
||||
self._request_id = request_id
|
||||
self._status_code = status_code
|
||||
self._code = code
|
||||
self._error_message = error_message
|
||||
|
||||
def __str__(self):
|
||||
return f'Request: {self._request_id}, Status Code: {self._status_code}, Code: {self._code}, Error Message: {self._error_message}'
|
||||
|
||||
|
||||
class VocabularyService(BaseApi):
|
||||
'''
|
||||
API for asr vocabulary service
|
||||
'''
|
||||
MAX_QUERY_TRY_COUNT = 3
|
||||
|
||||
def __init__(self,
|
||||
api_key=None,
|
||||
workspace=None,
|
||||
model=None,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self._api_key = api_key
|
||||
self._workspace = workspace
|
||||
self._kwargs = kwargs
|
||||
self._last_request_id = None
|
||||
self.model = model
|
||||
if self.model is None:
|
||||
self.model = 'speech-biasing'
|
||||
|
||||
def __call_with_input(self, input):
|
||||
try_count = 0
|
||||
while True:
|
||||
try:
|
||||
response = super().call(model=self.model,
|
||||
task_group='audio',
|
||||
task='asr',
|
||||
function='customization',
|
||||
input=input,
|
||||
api_protocol=ApiProtocol.HTTP,
|
||||
http_method=HTTPMethod.POST,
|
||||
api_key=self._api_key,
|
||||
workspace=self._workspace,
|
||||
**self._kwargs)
|
||||
except (asyncio.TimeoutError, aiohttp.ClientConnectorError) as e:
|
||||
logger.error(e)
|
||||
try_count += 1
|
||||
if try_count <= VocabularyService.MAX_QUERY_TRY_COUNT:
|
||||
time.sleep(2)
|
||||
continue
|
||||
|
||||
break
|
||||
logger.debug('>>>>recv', response)
|
||||
return response
|
||||
|
||||
def create_vocabulary(self, target_model: str, prefix: str,
|
||||
vocabulary: List[dict]) -> str:
|
||||
'''
|
||||
创建热词表
|
||||
param: target_model 热词表对应的语音识别模型版本
|
||||
param: prefix 热词表自定义前缀,仅允许数字和小写字母,小于十个字符。
|
||||
param: vocabulary 热词表字典
|
||||
return: 热词表标识符 vocabulary_id
|
||||
'''
|
||||
response = self.__call_with_input(input={
|
||||
'action': 'create_vocabulary',
|
||||
'target_model': target_model,
|
||||
'prefix': prefix,
|
||||
'vocabulary': vocabulary,
|
||||
}, )
|
||||
if response.status_code == 200:
|
||||
self._last_request_id = response.request_id
|
||||
return response.output['vocabulary_id']
|
||||
else:
|
||||
raise VocabularyServiceException(response.request_id, response.status_code,
|
||||
response.code, response.message)
|
||||
|
||||
def list_vocabularies(self,
|
||||
prefix=None,
|
||||
page_index: int = 0,
|
||||
page_size: int = 10) -> List[dict]:
|
||||
'''
|
||||
查询已创建的所有热词表
|
||||
param: prefix 自定义前缀,如果设定则只返回指定前缀的热词表标识符列表。
|
||||
param: page_index 查询的页索引
|
||||
param: page_size 查询页大小
|
||||
return: 热词表标识符列表
|
||||
'''
|
||||
if prefix:
|
||||
response = self.__call_with_input(input={
|
||||
'action': 'list_vocabulary',
|
||||
'prefix': prefix,
|
||||
'page_index': page_index,
|
||||
'page_size': page_size,
|
||||
}, )
|
||||
else:
|
||||
response = self.__call_with_input(input={
|
||||
'action': 'list_vocabulary',
|
||||
'page_index': page_index,
|
||||
'page_size': page_size,
|
||||
}, )
|
||||
if response.status_code == 200:
|
||||
self._last_request_id = response.request_id
|
||||
return response.output['vocabulary_list']
|
||||
else:
|
||||
raise VocabularyServiceException(response.request_id, response.status_code,
|
||||
response.code, response.message)
|
||||
|
||||
def query_vocabulary(self, vocabulary_id: str) -> List[dict]:
|
||||
'''
|
||||
获取热词表内容
|
||||
param: vocabulary_id 热词表标识符
|
||||
return: 热词表
|
||||
'''
|
||||
response = self.__call_with_input(input={
|
||||
'action': 'query_vocabulary',
|
||||
'vocabulary_id': vocabulary_id,
|
||||
}, )
|
||||
if response.status_code == 200:
|
||||
self._last_request_id = response.request_id
|
||||
return response.output
|
||||
else:
|
||||
raise VocabularyServiceException(response.request_id, response.status_code,
|
||||
response.code, response.message)
|
||||
|
||||
def update_vocabulary(self, vocabulary_id: str,
|
||||
vocabulary: List[dict]) -> None:
|
||||
'''
|
||||
用新的热词表替换已有热词表
|
||||
param: vocabulary_id 需要替换的热词表标识符
|
||||
param: vocabulary 热词表
|
||||
'''
|
||||
response = self.__call_with_input(input={
|
||||
'action': 'update_vocabulary',
|
||||
'vocabulary_id': vocabulary_id,
|
||||
'vocabulary': vocabulary,
|
||||
}, )
|
||||
if response.status_code == 200:
|
||||
self._last_request_id = response.request_id
|
||||
return
|
||||
else:
|
||||
raise VocabularyServiceException(response.request_id, response.status_code,
|
||||
response.code, response.message)
|
||||
|
||||
def delete_vocabulary(self, vocabulary_id: str) -> None:
|
||||
'''
|
||||
删除热词表
|
||||
param: vocabulary_id 需要删除的热词表标识符
|
||||
'''
|
||||
response = self.__call_with_input(input={
|
||||
'action': 'delete_vocabulary',
|
||||
'vocabulary_id': vocabulary_id,
|
||||
}, )
|
||||
if response.status_code == 200:
|
||||
self._last_request_id = response.request_id
|
||||
return
|
||||
else:
|
||||
raise VocabularyServiceException(response.request_id, response.status_code,
|
||||
response.code, response.message)
|
||||
|
||||
def get_last_request_id(self):
|
||||
return self._last_request_id
|
||||
@@ -0,0 +1,11 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .omni_realtime import (AudioFormat, MultiModality, OmniRealtimeCallback,
|
||||
OmniRealtimeConversation)
|
||||
|
||||
__all__ = [
|
||||
'OmniRealtimeCallback',
|
||||
'AudioFormat',
|
||||
'MultiModality',
|
||||
'OmniRealtimeConversation',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,459 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import json
|
||||
import platform
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import field, dataclass
|
||||
from typing import List
|
||||
import uuid
|
||||
from enum import Enum, unique
|
||||
|
||||
import dashscope
|
||||
import websocket
|
||||
from dashscope.common.error import InputRequired, ModelRequired
|
||||
from dashscope.common.logging import logger
|
||||
|
||||
|
||||
class OmniRealtimeCallback:
|
||||
"""
|
||||
An interface that defines callback methods for getting omni-realtime results. # noqa E501
|
||||
Derive from this class and implement its function to provide your own data.
|
||||
"""
|
||||
def on_open(self) -> None:
|
||||
pass
|
||||
|
||||
def on_close(self, close_status_code, close_msg) -> None:
|
||||
pass
|
||||
|
||||
def on_event(self, message: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranslationParams:
|
||||
"""
|
||||
TranslationParams
|
||||
"""
|
||||
language: str = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionParams:
|
||||
"""
|
||||
TranscriptionParams
|
||||
"""
|
||||
language: str = field(default=None)
|
||||
sample_rate: int = field(default=16000)
|
||||
input_audio_format: str = field(default="pcm")
|
||||
corpus: dict = field(default=None)
|
||||
corpus_text: str = field(default=None)
|
||||
|
||||
|
||||
@unique
|
||||
class AudioFormat(Enum):
|
||||
# format, sample_rate, channels, bit_rate, name
|
||||
PCM_16000HZ_MONO_16BIT = ('pcm', 16000, 'mono', '16bit', 'pcm16')
|
||||
PCM_24000HZ_MONO_16BIT = ('pcm', 24000, 'mono', '16bit', 'pcm16')
|
||||
|
||||
def __init__(self, format, sample_rate, channels, bit_rate, format_str):
|
||||
self.format = format
|
||||
self.sample_rate = sample_rate
|
||||
self.channels = channels
|
||||
self.bit_rate = bit_rate
|
||||
self.format_str = format_str
|
||||
|
||||
def __repr__(self):
|
||||
return self.format_str
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.format.upper()} with {self.sample_rate}Hz sample rate, {self.channels} channel, {self.bit_rate} bit rate: {self.format_str}'
|
||||
|
||||
|
||||
class MultiModality(Enum):
|
||||
"""
|
||||
MultiModality
|
||||
"""
|
||||
TEXT = 'text'
|
||||
AUDIO = 'audio'
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class OmniRealtimeConversation:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
callback: OmniRealtimeCallback,
|
||||
headers=None,
|
||||
workspace=None,
|
||||
url=None,
|
||||
additional_params=None,
|
||||
):
|
||||
"""
|
||||
Qwen Omni Realtime SDK
|
||||
Parameters:
|
||||
-----------
|
||||
model: str
|
||||
Model name.
|
||||
headers: Dict
|
||||
User-defined headers.
|
||||
callback: OmniRealtimeCallback
|
||||
Callback to receive real-time omni results.
|
||||
workspace: str
|
||||
Dashscope workspace ID.
|
||||
url: str
|
||||
Dashscope WebSocket URL.
|
||||
additional_params: Dict
|
||||
Additional parameters for the Dashscope API.
|
||||
"""
|
||||
|
||||
if model is None:
|
||||
raise ModelRequired('Model is required!')
|
||||
if callback is None:
|
||||
raise ModelRequired('Callback is required!')
|
||||
if url is None:
|
||||
url = f'wss://dashscope.aliyuncs.com/api-ws/v1/realtime?model={model}'
|
||||
else:
|
||||
url = f'{url}?model={model}'
|
||||
self.url = url
|
||||
self.apikey = dashscope.api_key
|
||||
self.user_headers = headers
|
||||
self.user_workspace = workspace
|
||||
self.model = model
|
||||
self.config = {}
|
||||
self.callback = callback
|
||||
self.ws = None
|
||||
self.session_id = None
|
||||
self.last_message = None
|
||||
self.last_response_id = None
|
||||
self.last_response_create_time = None
|
||||
self.last_first_text_delay = None
|
||||
self.last_first_audio_delay = None
|
||||
self.metrics = []
|
||||
|
||||
def _generate_event_id(self):
|
||||
'''
|
||||
generate random event id: event_xxxx
|
||||
'''
|
||||
return 'event_' + uuid.uuid4().hex
|
||||
|
||||
def _get_websocket_header(self, ):
|
||||
ua = 'dashscope/%s; python/%s; platform/%s; processor/%s' % (
|
||||
'1.18.0', # dashscope version
|
||||
platform.python_version(),
|
||||
platform.platform(),
|
||||
platform.processor(),
|
||||
)
|
||||
headers = {
|
||||
'user-agent': ua,
|
||||
'Authorization': 'bearer ' + self.apikey,
|
||||
}
|
||||
if self.user_headers:
|
||||
headers = {**self.user_headers, **headers}
|
||||
if self.user_workspace:
|
||||
headers = {
|
||||
**headers,
|
||||
'X-DashScope-WorkSpace': self.user_workspace,
|
||||
}
|
||||
return headers
|
||||
|
||||
def connect(self) -> None:
|
||||
'''
|
||||
connect to server, create session and return default session configuration
|
||||
'''
|
||||
self.ws = websocket.WebSocketApp(
|
||||
self.url,
|
||||
header=self._get_websocket_header(),
|
||||
on_message=self.on_message,
|
||||
on_error=self.on_error,
|
||||
on_close=self.on_close,
|
||||
)
|
||||
self.thread = threading.Thread(target=self.ws.run_forever)
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
timeout = 5 # 最长等待时间(秒)
|
||||
start_time = time.time()
|
||||
while (not (self.ws.sock and self.ws.sock.connected)
|
||||
and (time.time() - start_time) < timeout):
|
||||
time.sleep(0.1) # 短暂休眠,避免密集轮询
|
||||
if not (self.ws.sock and self.ws.sock.connected):
|
||||
raise TimeoutError(
|
||||
'websocket connection could not established within 5s. '
|
||||
'Please check your network connection, firewall settings, or server status.'
|
||||
)
|
||||
self.callback.on_open()
|
||||
|
||||
def __send_str(self, data: str, enable_log: bool = True):
|
||||
if enable_log:
|
||||
logger.debug('[omni realtime] send string: {}'.format(data))
|
||||
self.ws.send(data)
|
||||
|
||||
def update_session(self,
|
||||
output_modalities: List[MultiModality],
|
||||
voice: str = None,
|
||||
input_audio_format: AudioFormat = AudioFormat.
|
||||
PCM_16000HZ_MONO_16BIT,
|
||||
output_audio_format: AudioFormat = AudioFormat.
|
||||
PCM_24000HZ_MONO_16BIT,
|
||||
enable_input_audio_transcription: bool = True,
|
||||
input_audio_transcription_model: str = None,
|
||||
enable_turn_detection: bool = True,
|
||||
turn_detection_type: str = 'server_vad',
|
||||
prefix_padding_ms: int = 300,
|
||||
turn_detection_threshold: float = 0.2,
|
||||
turn_detection_silence_duration_ms: int = 800,
|
||||
turn_detection_param: dict = None,
|
||||
translation_params: TranslationParams = None,
|
||||
transcription_params: TranscriptionParams = None,
|
||||
**kwargs) -> None:
|
||||
'''
|
||||
update session configuration, should be used before create response
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output_modalities: list[MultiModality]
|
||||
omni output modalities to be used in session
|
||||
voice: str
|
||||
voice to be used in session
|
||||
input_audio_format: AudioFormat
|
||||
input audio format
|
||||
output_audio_format: AudioFormat
|
||||
output audio format
|
||||
enable_turn_detection: bool
|
||||
enable turn detection
|
||||
turn_detection_threshold: float
|
||||
turn detection threshold, range [-1, 1]
|
||||
In a noisy environment, it may be necessary to increase the threshold to reduce false detections
|
||||
In a quiet environment, it may be necessary to decrease the threshold to improve sensitivity
|
||||
turn_detection_silence_duration_ms: int
|
||||
duration of silence in milliseconds to detect turn, range [200, 6000]
|
||||
translation_params: TranslationParams
|
||||
translation params, include language. Only effective with qwen3-livetranslate-flash-realtime model or
|
||||
further models. Do not set this parameter for other models.
|
||||
transcription_params: TranscriptionParams
|
||||
transcription params, include language, sample_rate, input_audio_format, corpus.
|
||||
Only effective with qwen3-asr-flash-realtime model or
|
||||
further models. Do not set this parameter for other models.
|
||||
'''
|
||||
self.config = {
|
||||
'modalities': [m.value for m in output_modalities],
|
||||
'voice': voice,
|
||||
'input_audio_format': input_audio_format.format_str,
|
||||
'output_audio_format': output_audio_format.format_str,
|
||||
}
|
||||
if enable_input_audio_transcription:
|
||||
self.config['input_audio_transcription'] = {
|
||||
'model': input_audio_transcription_model,
|
||||
}
|
||||
else:
|
||||
self.config['input_audio_transcription'] = None
|
||||
if enable_turn_detection:
|
||||
self.config['turn_detection'] = {
|
||||
'type': turn_detection_type,
|
||||
'threshold': turn_detection_threshold,
|
||||
'prefix_padding_ms': prefix_padding_ms,
|
||||
'silence_duration_ms': turn_detection_silence_duration_ms,
|
||||
}
|
||||
if turn_detection_param is not None:
|
||||
self.config['turn_detection'].update(turn_detection_param)
|
||||
else:
|
||||
self.config['turn_detection'] = None
|
||||
if translation_params is not None:
|
||||
self.config['translation'] = {
|
||||
'language': translation_params.language
|
||||
}
|
||||
if transcription_params is not None:
|
||||
self.config['language'] = transcription_params.language
|
||||
if transcription_params.corpus is not None:
|
||||
self.config['corpus'] = transcription_params.corpus
|
||||
if transcription_params.corpus_text is not None:
|
||||
self.config['corpus'] = {
|
||||
"text": transcription_params.corpus_text
|
||||
}
|
||||
self.config['input_audio_format'] = transcription_params.input_audio_format
|
||||
self.config['sample_rate']= transcription_params.sample_rate
|
||||
self.config.update(kwargs)
|
||||
self.__send_str(
|
||||
json.dumps({
|
||||
'event_id': self._generate_event_id(),
|
||||
'type': 'session.update',
|
||||
'session': self.config
|
||||
}))
|
||||
|
||||
def append_audio(self, audio_b64: str) -> None:
|
||||
'''
|
||||
send audio in base64 format
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_b64: str
|
||||
base64 audio string
|
||||
'''
|
||||
logger.debug('[omni realtime] append audio: {}'.format(len(audio_b64)))
|
||||
self.__send_str(
|
||||
json.dumps({
|
||||
'event_id': self._generate_event_id(),
|
||||
'type': 'input_audio_buffer.append',
|
||||
'audio': audio_b64
|
||||
}), False)
|
||||
|
||||
def append_video(self, video_b64: str) -> None:
|
||||
'''
|
||||
send one image frame in video in base64 format
|
||||
|
||||
Parameters
|
||||
----------
|
||||
video_b64: str
|
||||
base64 image string
|
||||
'''
|
||||
logger.debug('[omni realtime] append video: {}'.format(len(video_b64)))
|
||||
self.__send_str(
|
||||
json.dumps({
|
||||
'event_id': self._generate_event_id(),
|
||||
'type': 'input_image_buffer.append',
|
||||
'image': video_b64
|
||||
}), False)
|
||||
|
||||
def commit(self, ) -> None:
|
||||
'''
|
||||
Commit the audio and video sent before.
|
||||
When in Server VAD mode, the client does not need to use this method,
|
||||
the server will commit the audio automatically after detecting vad end.
|
||||
'''
|
||||
self.__send_str(
|
||||
json.dumps({
|
||||
'event_id': self._generate_event_id(),
|
||||
'type': 'input_audio_buffer.commit'
|
||||
}))
|
||||
|
||||
def clear_appended_audio(self, ) -> None:
|
||||
'''
|
||||
clear the audio sent to server before.
|
||||
'''
|
||||
self.__send_str(
|
||||
json.dumps({
|
||||
'event_id': self._generate_event_id(),
|
||||
'type': 'input_audio_buffer.clear'
|
||||
}))
|
||||
|
||||
def create_response(self,
|
||||
instructions: str = None,
|
||||
output_modalities: List[MultiModality] = None) -> None:
|
||||
'''
|
||||
create response, use audio and video commited before to request llm.
|
||||
When in Server VAD mode, the client does not need to use this method,
|
||||
the server will create response automatically after detecting vad
|
||||
and sending commit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
instructions: str
|
||||
instructions to llm
|
||||
output_modalities: list[MultiModality]
|
||||
omni output modalities to be used in session
|
||||
'''
|
||||
request = {
|
||||
'event_id': self._generate_event_id(),
|
||||
'type': 'response.create',
|
||||
'response': {}
|
||||
}
|
||||
request['response']['instructions'] = instructions
|
||||
if output_modalities:
|
||||
request['response']['modalities'] = [
|
||||
m.value for m in output_modalities
|
||||
]
|
||||
self.__send_str(json.dumps(request))
|
||||
|
||||
def cancel_response(self, ) -> None:
|
||||
'''
|
||||
cancel the current response
|
||||
'''
|
||||
self.__send_str(
|
||||
json.dumps({
|
||||
'event_id': self._generate_event_id(),
|
||||
'type': 'response.cancel'
|
||||
}))
|
||||
|
||||
def send_raw(self, raw_data: str) -> None:
|
||||
'''
|
||||
send raw data to server
|
||||
'''
|
||||
self.__send_str(raw_data)
|
||||
|
||||
def close(self, ) -> None:
|
||||
'''
|
||||
close the connection to server
|
||||
'''
|
||||
self.ws.close()
|
||||
|
||||
# 监听消息的回调函数
|
||||
def on_message(self, ws, message):
|
||||
if isinstance(message, str):
|
||||
logger.debug('[omni realtime] receive string {}'.format(
|
||||
message[:1024]))
|
||||
try:
|
||||
# 尝试将消息解析为JSON
|
||||
json_data = json.loads(message)
|
||||
self.last_message = json_data
|
||||
self.callback.on_event(json_data)
|
||||
if 'type' in message:
|
||||
if 'session.created' == json_data['type']:
|
||||
self.session_id = json_data['session']['id']
|
||||
if 'response.created' == json_data['type']:
|
||||
self.last_response_id = json_data['response']['id']
|
||||
self.last_response_create_time = time.time() * 1000
|
||||
self.last_first_audio_delay = None
|
||||
self.last_first_text_delay = None
|
||||
elif 'response.audio_transcript.delta' == json_data[
|
||||
'type']:
|
||||
if self.last_response_create_time and self.last_first_text_delay is None:
|
||||
self.last_first_text_delay = time.time(
|
||||
) * 1000 - self.last_response_create_time
|
||||
elif 'response.audio.delta' == json_data['type']:
|
||||
if self.last_response_create_time and self.last_first_audio_delay is None:
|
||||
self.last_first_audio_delay = time.time(
|
||||
) * 1000 - self.last_response_create_time
|
||||
elif 'response.done' == json_data['type']:
|
||||
logger.info(
|
||||
'[Metric] response: {}, first text delay: {}, first audio delay: {}'
|
||||
.format(self.last_response_id,
|
||||
self.last_first_text_delay,
|
||||
self.last_first_audio_delay))
|
||||
except json.JSONDecodeError:
|
||||
logger.error('Failed to parse message as JSON.')
|
||||
raise Exception('Failed to parse message as JSON.')
|
||||
elif isinstance(message, (bytes, bytearray)):
|
||||
# 如果失败,认为是二进制消息
|
||||
logger.error(
|
||||
'should not receive binary message in omni realtime api')
|
||||
logger.debug('[omni realtime] receive binary {} bytes'.format(
|
||||
len(message)))
|
||||
|
||||
def on_close(self, ws, close_status_code, close_msg):
|
||||
self.callback.on_close(close_status_code, close_msg)
|
||||
|
||||
# WebSocket发生错误的回调函数
|
||||
def on_error(self, ws, error):
|
||||
print(f'websocket closed due to {error}')
|
||||
raise Exception(f'websocket closed due to {error}')
|
||||
|
||||
# 获取上一个任务的taskId
|
||||
def get_session_id(self) -> str:
|
||||
return self.session_id
|
||||
|
||||
def get_last_message(self) -> str:
|
||||
return self.last_message
|
||||
|
||||
def get_last_message(self) -> str:
|
||||
return self.last_message
|
||||
|
||||
def get_last_response_id(self) -> str:
|
||||
return self.last_response_id
|
||||
|
||||
def get_last_first_text_delay(self):
|
||||
return self.last_first_text_delay
|
||||
|
||||
def get_last_first_audio_delay(self):
|
||||
return self.last_first_audio_delay
|
||||
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .speech_synthesizer import SpeechSynthesizer
|
||||
|
||||
__all__ = [SpeechSynthesizer]
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Generator, Union
|
||||
|
||||
from dashscope.api_entities.dashscope_response import \
|
||||
TextToSpeechResponse
|
||||
from dashscope.client.base_api import BaseApi
|
||||
from dashscope.common.error import InputRequired, ModelRequired
|
||||
|
||||
|
||||
class SpeechSynthesizer(BaseApi):
|
||||
"""Text-to-speech interface.
|
||||
"""
|
||||
|
||||
task_group = 'aigc'
|
||||
task = 'multimodal-generation'
|
||||
function = 'generation'
|
||||
|
||||
class Models:
|
||||
qwen_tts = 'qwen-tts'
|
||||
|
||||
@classmethod
|
||||
def call(
|
||||
cls,
|
||||
model: str,
|
||||
text: str,
|
||||
api_key: str = None,
|
||||
workspace: str = None,
|
||||
**kwargs
|
||||
) -> Union[TextToSpeechResponse, Generator[
|
||||
TextToSpeechResponse, None, None]]:
|
||||
"""Call the conversation model service.
|
||||
|
||||
Args:
|
||||
model (str): The requested model, such as 'qwen-tts'
|
||||
text (str): Text content used for speech synthesis.
|
||||
api_key (str, optional): The api api_key, can be None,
|
||||
if None, will retrieve by rule [1].
|
||||
[1]: https://help.aliyun.com/zh/dashscope/developer-reference/api-key-settings. # noqa E501
|
||||
workspace (str): The dashscope workspace id.
|
||||
**kwargs:
|
||||
stream(bool, `optional`): Enable server-sent events
|
||||
(ref: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) # noqa E501
|
||||
the result will back partially[qwen-turbo,bailian-v1].
|
||||
voice: str
|
||||
Voice name.
|
||||
|
||||
Raises:
|
||||
InputRequired: The input must include the text parameter.
|
||||
ModelRequired: The input must include the model parameter.
|
||||
|
||||
Returns:
|
||||
Union[TextToSpeechResponse,
|
||||
Generator[TextToSpeechResponse, None, None]]: If
|
||||
stream is True, return Generator, otherwise TextToSpeechResponse.
|
||||
"""
|
||||
if not text:
|
||||
raise InputRequired('text is required!')
|
||||
if model is None or not model:
|
||||
raise ModelRequired('Model is required!')
|
||||
input = {'text': text}
|
||||
if 'voice' in kwargs:
|
||||
input['voice'] = kwargs.pop('voice')
|
||||
response = super().call(model=model,
|
||||
task_group=SpeechSynthesizer.task_group,
|
||||
task=SpeechSynthesizer.task,
|
||||
function=SpeechSynthesizer.function,
|
||||
api_key=api_key,
|
||||
input=input,
|
||||
workspace=workspace,
|
||||
**kwargs)
|
||||
is_stream = kwargs.get('stream', False)
|
||||
if is_stream:
|
||||
return (TextToSpeechResponse.from_api_response(rsp)
|
||||
for rsp in response)
|
||||
else:
|
||||
return TextToSpeechResponse.from_api_response(response)
|
||||
@@ -0,0 +1,10 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .qwen_tts_realtime import (AudioFormat, QwenTtsRealtimeCallback,
|
||||
QwenTtsRealtime)
|
||||
|
||||
__all__ = [
|
||||
'AudioFormat',
|
||||
'QwenTtsRealtimeCallback',
|
||||
'QwenTtsRealtime',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,350 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import json
|
||||
import platform
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from enum import Enum, unique
|
||||
|
||||
import dashscope
|
||||
import websocket
|
||||
from dashscope.common.error import InputRequired, ModelRequired
|
||||
from dashscope.common.logging import logger
|
||||
|
||||
|
||||
class QwenTtsRealtimeCallback:
|
||||
"""
|
||||
An interface that defines callback methods for getting omni-realtime results. # noqa E501
|
||||
Derive from this class and implement its function to provide your own data.
|
||||
"""
|
||||
def on_open(self) -> None:
|
||||
pass
|
||||
|
||||
def on_close(self, close_status_code, close_msg) -> None:
|
||||
pass
|
||||
|
||||
def on_event(self, message: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@unique
|
||||
class AudioFormat(Enum):
|
||||
# format, sample_rate, channels, bit_rate, name
|
||||
PCM_24000HZ_MONO_16BIT = ('pcm', 24000, 'mono', '16bit', 'pcm16')
|
||||
|
||||
def __init__(self, format, sample_rate, channels, bit_rate, format_str):
|
||||
self.format = format
|
||||
self.sample_rate = sample_rate
|
||||
self.channels = channels
|
||||
self.bit_rate = bit_rate
|
||||
self.format_str = format_str
|
||||
|
||||
def __repr__(self):
|
||||
return self.format_str
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.format.upper()} with {self.sample_rate}Hz sample rate, {self.channels} channel, {self.bit_rate} bit rate: {self.format_str}'
|
||||
|
||||
|
||||
class QwenTtsRealtime:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
headers=None,
|
||||
callback: QwenTtsRealtimeCallback = None,
|
||||
workspace=None,
|
||||
url=None,
|
||||
additional_params=None,
|
||||
):
|
||||
"""
|
||||
Qwen Tts Realtime SDK
|
||||
Parameters:
|
||||
-----------
|
||||
model: str
|
||||
Model name.
|
||||
headers: Dict
|
||||
User-defined headers.
|
||||
callback: OmniRealtimeCallback
|
||||
Callback to receive real-time omni results.
|
||||
workspace: str
|
||||
Dashscope workspace ID.
|
||||
url: str
|
||||
Dashscope WebSocket URL.
|
||||
additional_params: Dict
|
||||
Additional parameters for the Dashscope API.
|
||||
"""
|
||||
|
||||
if model is None:
|
||||
raise ModelRequired('Model is required!')
|
||||
if url is None:
|
||||
url = f'wss://dashscope.aliyuncs.com/api-ws/v1/realtime?model={model}'
|
||||
else:
|
||||
url = f'{url}?model={model}'
|
||||
self.url = url
|
||||
self.apikey = dashscope.api_key
|
||||
self.user_headers = headers
|
||||
self.user_workspace = workspace
|
||||
self.model = model
|
||||
self.config = {}
|
||||
self.callback = callback
|
||||
self.ws = None
|
||||
self.session_id = None
|
||||
self.last_message = None
|
||||
self.last_response_id = None
|
||||
self.last_first_text_time = None
|
||||
self.last_first_audio_delay = None
|
||||
self.metrics = []
|
||||
|
||||
def _generate_event_id(self):
|
||||
'''
|
||||
generate random event id: event_xxxx
|
||||
'''
|
||||
return 'event_' + uuid.uuid4().hex
|
||||
|
||||
def _get_websocket_header(self, ):
|
||||
ua = 'dashscope/%s; python/%s; platform/%s; processor/%s' % (
|
||||
'1.18.0', # dashscope version
|
||||
platform.python_version(),
|
||||
platform.platform(),
|
||||
platform.processor(),
|
||||
)
|
||||
headers = {
|
||||
'user-agent': ua,
|
||||
'Authorization': 'bearer ' + self.apikey,
|
||||
}
|
||||
if self.user_headers:
|
||||
headers = {**self.user_headers, **headers}
|
||||
if self.user_workspace:
|
||||
headers = {
|
||||
**headers,
|
||||
'X-DashScope-WorkSpace': self.user_workspace,
|
||||
}
|
||||
return headers
|
||||
|
||||
def connect(self) -> None:
|
||||
'''
|
||||
connect to server, create session and return default session configuration
|
||||
'''
|
||||
self.ws = websocket.WebSocketApp(
|
||||
self.url,
|
||||
header=self._get_websocket_header(),
|
||||
on_message=self.on_message,
|
||||
on_error=self.on_error,
|
||||
on_close=self.on_close,
|
||||
)
|
||||
self.thread = threading.Thread(target=self.ws.run_forever)
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
timeout = 5 # 最长等待时间(秒)
|
||||
start_time = time.time()
|
||||
while (not (self.ws.sock and self.ws.sock.connected)
|
||||
and (time.time() - start_time) < timeout):
|
||||
time.sleep(0.1) # 短暂休眠,避免密集轮询
|
||||
if not (self.ws.sock and self.ws.sock.connected):
|
||||
raise TimeoutError(
|
||||
'websocket connection could not established within 5s. '
|
||||
'Please check your network connection, firewall settings, or server status.'
|
||||
)
|
||||
self.callback.on_open()
|
||||
|
||||
def __send_str(self, data: str, enable_log: bool = True):
|
||||
if enable_log:
|
||||
logger.debug('[qwen tts realtime] send string: {}'.format(data))
|
||||
self.ws.send(data)
|
||||
|
||||
def update_session(self,
|
||||
voice: str,
|
||||
response_format: AudioFormat = AudioFormat.
|
||||
PCM_24000HZ_MONO_16BIT,
|
||||
mode: str = 'server_commit',
|
||||
sample_rate: int = None,
|
||||
volume: int = None,
|
||||
speech_rate: float = None,
|
||||
audio_format: str = None,
|
||||
pitch_rate: float = None,
|
||||
bit_rate: int = None,
|
||||
language_type: str = None,
|
||||
**kwargs) -> None:
|
||||
'''
|
||||
update session configuration, should be used before create response
|
||||
|
||||
Parameters
|
||||
----------
|
||||
voice: str
|
||||
voice to be used in session
|
||||
response_format: AudioFormat
|
||||
output audio format
|
||||
mode: str
|
||||
response mode, server_commit or commit
|
||||
language_type: str
|
||||
language type for synthesized audio, default is 'auto'
|
||||
sample_rate: int
|
||||
sampleRate for tts, range [8000,16000,22050,24000,44100,48000] default is 24000
|
||||
volume: int
|
||||
volume for tts, range [0,100] default is 50
|
||||
speech_rate: float
|
||||
speech_rate for tts, range [0.5~2.0] default is 1.0
|
||||
audio_format: str
|
||||
format for tts, support mp3,wav,pcm,opus, default is 'pcm'
|
||||
pitch_rate: float
|
||||
pitch_rate for tts, range [0.5~2.0] default is 1.0
|
||||
bit_rate: int
|
||||
bit_rate for tts, support 6~510,default is 128kbps. only work on format: opus/mp3
|
||||
'''
|
||||
self.config = {
|
||||
'voice': voice,
|
||||
'mode': mode,
|
||||
'response_format': response_format.format,
|
||||
'sample_rate': response_format.sample_rate,
|
||||
}
|
||||
if sample_rate is not None: # 如果配置,则更新
|
||||
self.config['sample_rate'] = sample_rate
|
||||
if volume is not None:
|
||||
self.config['volume'] = volume
|
||||
if speech_rate is not None:
|
||||
self.config['speech_rate'] = speech_rate
|
||||
if audio_format is not None:
|
||||
self.config['response_format'] = audio_format # 如果配置,则更新
|
||||
if pitch_rate is not None:
|
||||
self.config['pitch_rate'] = pitch_rate
|
||||
if bit_rate is not None:
|
||||
self.config['bit_rate'] = bit_rate
|
||||
|
||||
if language_type is not None:
|
||||
self.config['language_type'] = language_type
|
||||
self.config.update(kwargs)
|
||||
self.__send_str(
|
||||
json.dumps({
|
||||
'event_id': self._generate_event_id(),
|
||||
'type': 'session.update',
|
||||
'session': self.config
|
||||
}))
|
||||
|
||||
def append_text(self, text: str) -> None:
|
||||
'''
|
||||
send text
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text: str
|
||||
text to send
|
||||
'''
|
||||
self.__send_str(
|
||||
json.dumps({
|
||||
'event_id': self._generate_event_id(),
|
||||
'type': 'input_text_buffer.append',
|
||||
'text': text
|
||||
}))
|
||||
if self.last_first_text_time is None:
|
||||
self.last_first_text_time = time.time() * 1000
|
||||
|
||||
def commit(self, ) -> None:
|
||||
'''
|
||||
commit the text sent before, create response and start synthesis audio.
|
||||
'''
|
||||
self.__send_str(
|
||||
json.dumps({
|
||||
'event_id': self._generate_event_id(),
|
||||
'type': 'input_text_buffer.commit'
|
||||
}))
|
||||
|
||||
def clear_appended_text(self, ) -> None:
|
||||
'''
|
||||
clear the text sent to server before.
|
||||
'''
|
||||
self.__send_str(
|
||||
json.dumps({
|
||||
'event_id': self._generate_event_id(),
|
||||
'type': 'input_text_buffer.clear'
|
||||
}))
|
||||
|
||||
def cancel_response(self, ) -> None:
|
||||
'''
|
||||
cancel the current response
|
||||
'''
|
||||
self.__send_str(
|
||||
json.dumps({
|
||||
'event_id': self._generate_event_id(),
|
||||
'type': 'response.cancel'
|
||||
}))
|
||||
|
||||
def send_raw(self, raw_data: str) -> None:
|
||||
'''
|
||||
send raw data to server
|
||||
'''
|
||||
self.__send_str(raw_data)
|
||||
|
||||
def finish(self, ) -> None:
|
||||
'''
|
||||
finish input text stream, server will synthesis all text in buffer and close the connection
|
||||
'''
|
||||
self.__send_str(
|
||||
json.dumps({
|
||||
'event_id': self._generate_event_id(),
|
||||
'type': 'session.finish'
|
||||
}))
|
||||
|
||||
def close(self, ) -> None:
|
||||
'''
|
||||
close the connection to server
|
||||
'''
|
||||
self.ws.close()
|
||||
|
||||
# 监听消息的回调函数
|
||||
def on_message(self, ws, message):
|
||||
if isinstance(message, str):
|
||||
logger.debug('[omni realtime] receive string {}'.format(
|
||||
message[:1024]))
|
||||
try:
|
||||
# 尝试将消息解析为JSON
|
||||
json_data = json.loads(message)
|
||||
self.last_message = json_data
|
||||
self.callback.on_event(json_data)
|
||||
if 'type' in message:
|
||||
if 'session.created' == json_data['type']:
|
||||
self.session_id = json_data['session']['id']
|
||||
if 'response.created' == json_data['type']:
|
||||
self.last_response_id = json_data['response']['id']
|
||||
elif 'response.audio.delta' == json_data['type']:
|
||||
if self.last_first_text_time and self.last_first_audio_delay is None:
|
||||
self.last_first_audio_delay = time.time(
|
||||
) * 1000 - self.last_first_text_time
|
||||
elif 'response.done' == json_data['type']:
|
||||
logger.debug(
|
||||
'[Metric] response: {}, first audio delay: {}'
|
||||
.format(self.last_response_id,
|
||||
self.last_first_audio_delay))
|
||||
except json.JSONDecodeError:
|
||||
logger.error('Failed to parse message as JSON.')
|
||||
raise Exception('Failed to parse message as JSON.')
|
||||
elif isinstance(message, (bytes, bytearray)):
|
||||
# 如果失败,认为是二进制消息
|
||||
logger.error(
|
||||
'should not receive binary message in omni realtime api')
|
||||
logger.debug('[omni realtime] receive binary {} bytes'.format(
|
||||
len(message)))
|
||||
|
||||
def on_close(self, ws, close_status_code, close_msg):
|
||||
logger.debug(
|
||||
'[omni realtime] connection closed with code {} and message {}'.format(
|
||||
close_status_code, close_msg))
|
||||
self.callback.on_close(close_status_code, close_msg)
|
||||
|
||||
# WebSocket发生错误的回调函数
|
||||
def on_error(self, ws, error):
|
||||
print(f'websocket closed due to {error}')
|
||||
raise Exception(f'websocket closed due to {error}')
|
||||
|
||||
# 获取上一个任务的taskId
|
||||
def get_session_id(self):
|
||||
return self.session_id
|
||||
|
||||
def get_last_message(self):
|
||||
return self.last_message
|
||||
|
||||
def get_last_response_id(self):
|
||||
return self.last_response_id
|
||||
|
||||
def get_first_audio_delay(self):
|
||||
return self.last_first_audio_delay
|
||||
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .speech_synthesizer import (ResultCallback, SpeechSynthesisResult,
|
||||
SpeechSynthesizer)
|
||||
|
||||
__all__ = [SpeechSynthesizer, ResultCallback, SpeechSynthesisResult]
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List
|
||||
|
||||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||||
from dashscope.client.base_api import BaseApi
|
||||
from dashscope.common.constants import ApiProtocol
|
||||
from dashscope.common.utils import _get_task_group_and_task
|
||||
|
||||
|
||||
class SpeechSynthesisResult():
|
||||
"""The result set of speech synthesis, including audio data,
|
||||
timestamp information, and final result information.
|
||||
"""
|
||||
|
||||
_audio_frame: bytes = None
|
||||
_audio_data: bytes = None
|
||||
_sentence: Dict[str, str] = None
|
||||
_sentences: List[Dict[str, str]] = None
|
||||
_response: SpeechSynthesisResponse = None
|
||||
|
||||
def __init__(self, frame: bytes, data: bytes, sentence: Dict[str, str],
|
||||
sentences: List[Dict[str, str]],
|
||||
response: SpeechSynthesisResponse):
|
||||
if frame is not None:
|
||||
self._audio_frame = bytes(frame)
|
||||
if data is not None:
|
||||
self._audio_data = bytes(data)
|
||||
if sentence is not None:
|
||||
self._sentence = sentence
|
||||
if sentences is not None:
|
||||
self._sentences = sentences
|
||||
if response is not None:
|
||||
self._response = response
|
||||
|
||||
def get_audio_frame(self) -> bytes:
|
||||
"""Obtain the audio frame data of speech synthesis through callbacks.
|
||||
"""
|
||||
return self._audio_frame
|
||||
|
||||
def get_audio_data(self) -> bytes:
|
||||
"""Get complete audio data for speech synthesis.
|
||||
"""
|
||||
return self._audio_data
|
||||
|
||||
def get_timestamp(self) -> Dict[str, str]:
|
||||
"""Obtain the timestamp information of the current speech synthesis
|
||||
sentence through the callback.
|
||||
"""
|
||||
return self._sentence
|
||||
|
||||
def get_timestamps(self) -> List[Dict[str, str]]:
|
||||
"""Get complete timestamp information for all speech synthesis sentences.
|
||||
"""
|
||||
return self._sentences
|
||||
|
||||
def get_response(self) -> SpeechSynthesisResponse:
|
||||
"""Obtain the status information of the current speech synthesis task,
|
||||
including error information and billing information.
|
||||
"""
|
||||
return self._response
|
||||
|
||||
|
||||
class ResultCallback():
|
||||
"""
|
||||
An interface that defines callback methods for getting speech synthesis results. # noqa E501
|
||||
Derive from this class and implement its function to provide your own data.
|
||||
"""
|
||||
def on_open(self) -> None:
|
||||
pass
|
||||
|
||||
def on_complete(self) -> None:
|
||||
pass
|
||||
|
||||
def on_error(self, response: SpeechSynthesisResponse) -> None:
|
||||
pass
|
||||
|
||||
def on_close(self) -> None:
|
||||
pass
|
||||
|
||||
def on_event(self, result: SpeechSynthesisResult) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class SpeechSynthesizer(BaseApi):
|
||||
"""Text-to-speech interface.
|
||||
"""
|
||||
class AudioFormat:
|
||||
format_wav = 'wav'
|
||||
format_pcm = 'pcm'
|
||||
format_mp3 = 'mp3'
|
||||
|
||||
@classmethod
|
||||
def call(cls,
|
||||
model: str,
|
||||
text: str,
|
||||
callback: ResultCallback = None,
|
||||
workspace: str = None,
|
||||
**kwargs) -> SpeechSynthesisResult:
|
||||
"""Convert text to speech synchronously.
|
||||
|
||||
Args:
|
||||
model (str): The requested model_id.
|
||||
text (str): Text content used for speech synthesis.
|
||||
callback (ResultCallback): A callback that returns
|
||||
speech synthesis results.
|
||||
workspace (str): The dashscope workspace id.
|
||||
**kwargs:
|
||||
format(str, `optional`): Audio encoding format,
|
||||
such as pcm wav mp3, default is wav.
|
||||
sample_rate(int, `optional`): Audio sample rate,
|
||||
default is the sample rate of model.
|
||||
volume(int, `optional`): The volume of synthesized speech
|
||||
ranges from 0~100, default is 50.
|
||||
rate(float, `optional`): The speech rate of synthesized
|
||||
speech, the value range is 0.5~2.0, default is 1.0.
|
||||
pitch(float, `optional`): The intonation of synthesized
|
||||
speech,the value range is 0.5~2.0, default is 1.0.
|
||||
word_timestamp_enabled(bool, `optional`): Turn on word-level
|
||||
timestamping, default is False.
|
||||
phoneme_timestamp_enabled(bool, `optional`): Turn on phoneme
|
||||
level timestamping, default is False.
|
||||
|
||||
Returns:
|
||||
SpeechSynthesisResult: The result of systhesis.
|
||||
"""
|
||||
_callback = callback
|
||||
_audio_data: bytes = None
|
||||
_sentences: List[Dict[str, str]] = []
|
||||
_the_final_response = None
|
||||
_task_failed_flag: bool = False
|
||||
task_name, _ = _get_task_group_and_task(__name__)
|
||||
|
||||
response = super().call(model=model,
|
||||
task_group='audio',
|
||||
task=task_name,
|
||||
function='SpeechSynthesizer',
|
||||
input={'text': text},
|
||||
stream=True,
|
||||
api_protocol=ApiProtocol.WEBSOCKET,
|
||||
workspace=workspace,
|
||||
**kwargs)
|
||||
|
||||
if _callback is not None:
|
||||
_callback.on_open()
|
||||
|
||||
for part in response:
|
||||
if isinstance(part.output, bytes):
|
||||
if _callback is not None:
|
||||
audio_frame = SpeechSynthesisResult(
|
||||
bytes(part.output), None, None, None, None)
|
||||
_callback.on_event(audio_frame)
|
||||
|
||||
if _audio_data is None:
|
||||
_audio_data = bytes(part.output)
|
||||
else:
|
||||
_audio_data = _audio_data + bytes(part.output)
|
||||
|
||||
else:
|
||||
if part.status_code == HTTPStatus.OK:
|
||||
if part.output is None:
|
||||
_the_final_response = SpeechSynthesisResponse.from_api_response( # noqa E501
|
||||
part)
|
||||
else:
|
||||
if _callback is not None:
|
||||
sentence = SpeechSynthesisResult(
|
||||
None, None, part.output['sentence'], None,
|
||||
None)
|
||||
_callback.on_event(sentence)
|
||||
if len(_sentences) == 0:
|
||||
_sentences.append(part.output['sentence'])
|
||||
else:
|
||||
if _sentences[-1]['begin_time'] == part.output[
|
||||
'sentence']['begin_time']:
|
||||
if _sentences[-1]['end_time'] != part.output[
|
||||
'sentence']['end_time']:
|
||||
_sentences.pop()
|
||||
_sentences.append(part.output['sentence'])
|
||||
else:
|
||||
_sentences.append(part.output['sentence'])
|
||||
else:
|
||||
_task_failed_flag = True
|
||||
_the_final_response = SpeechSynthesisResponse.from_api_response( # noqa E501
|
||||
part)
|
||||
if _callback is not None:
|
||||
_callback.on_error(
|
||||
SpeechSynthesisResponse.from_api_response(part))
|
||||
|
||||
if _callback is not None:
|
||||
if _task_failed_flag is False:
|
||||
_callback.on_complete()
|
||||
_callback.on_close()
|
||||
|
||||
result = SpeechSynthesisResult(None, _audio_data, None, _sentences,
|
||||
_the_final_response)
|
||||
return result
|
||||
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .enrollment import VoiceEnrollmentException, VoiceEnrollmentService
|
||||
from .speech_synthesizer import AudioFormat, ResultCallback, SpeechSynthesizer
|
||||
|
||||
__all__ = [
|
||||
'SpeechSynthesizer', 'ResultCallback', 'AudioFormat',
|
||||
'VoiceEnrollmentException', 'VoiceEnrollmentService'
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,179 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
|
||||
from dashscope.client.base_api import BaseApi
|
||||
from dashscope.common.constants import ApiProtocol, HTTPMethod
|
||||
from dashscope.common.logging import logger
|
||||
|
||||
|
||||
class VoiceEnrollmentException(Exception):
|
||||
def __init__(self, request_id: str, status_code: int, code: str,
|
||||
error_message: str) -> None:
|
||||
self._request_id = request_id
|
||||
self._status_code = status_code
|
||||
self._code = code
|
||||
self._error_message = error_message
|
||||
|
||||
def __str__(self):
|
||||
return f'Request: {self._request_id}, Status Code: {self._status_code}, Code: {self._code}, Error Message: {self._error_message}'
|
||||
|
||||
|
||||
class VoiceEnrollmentService(BaseApi):
|
||||
'''
|
||||
API for voice clone service
|
||||
'''
|
||||
MAX_QUERY_TRY_COUNT = 3
|
||||
|
||||
def __init__(self,
|
||||
api_key=None,
|
||||
workspace=None,
|
||||
model=None,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self._api_key = api_key
|
||||
self._workspace = workspace
|
||||
self._kwargs = kwargs
|
||||
self._last_request_id = None
|
||||
self.model = model
|
||||
if self.model is None:
|
||||
self.model = 'voice-enrollment'
|
||||
|
||||
def __call_with_input(self, input):
|
||||
try_count = 0
|
||||
while True:
|
||||
try:
|
||||
response = super().call(model=self.model,
|
||||
task_group='audio',
|
||||
task='tts',
|
||||
function='customization',
|
||||
input=input,
|
||||
api_protocol=ApiProtocol.HTTP,
|
||||
http_method=HTTPMethod.POST,
|
||||
api_key=self._api_key,
|
||||
workspace=self._workspace,
|
||||
**self._kwargs)
|
||||
except (asyncio.TimeoutError, aiohttp.ClientConnectorError) as e:
|
||||
logger.error(e)
|
||||
try_count += 1
|
||||
if try_count <= VoiceEnrollmentService.MAX_QUERY_TRY_COUNT:
|
||||
time.sleep(2)
|
||||
continue
|
||||
|
||||
break
|
||||
logger.debug('>>>>recv', response)
|
||||
return response
|
||||
|
||||
def create_voice(self, target_model: str, prefix: str, url: str, language_hints: List[str] = None) -> str:
|
||||
'''
|
||||
创建新克隆音色
|
||||
param: target_model 克隆音色对应的语音合成模型版本
|
||||
param: prefix 音色自定义前缀,仅允许数字和小写字母,小于十个字符。
|
||||
param: url 用于克隆的音频文件url
|
||||
param: language_hints 克隆音色目标语言
|
||||
return: voice_id
|
||||
'''
|
||||
|
||||
input_params = {
|
||||
'action': 'create_voice',
|
||||
'target_model': target_model,
|
||||
'prefix': prefix,
|
||||
'url': url
|
||||
}
|
||||
if language_hints is not None:
|
||||
input_params['language_hints'] = language_hints
|
||||
response = self.__call_with_input(input_params)
|
||||
self._last_request_id = response.request_id
|
||||
if response.status_code == 200:
|
||||
return response.output['voice_id']
|
||||
else:
|
||||
raise VoiceEnrollmentException(response.request_id, response.status_code, response.code,
|
||||
response.message)
|
||||
|
||||
def list_voices(self,
|
||||
prefix=None,
|
||||
page_index: int = 0,
|
||||
page_size: int = 10) -> List[dict]:
|
||||
'''
|
||||
查询已创建的所有音色
|
||||
param: page_index 查询的页索引
|
||||
param: page_size 查询页大小
|
||||
return: List[dict] 音色列表,包含每个音色的id,创建时间,修改时间,状态。
|
||||
'''
|
||||
if prefix:
|
||||
response = self.__call_with_input(input={
|
||||
'action': 'list_voice',
|
||||
'prefix': prefix,
|
||||
'page_index': page_index,
|
||||
'page_size': page_size,
|
||||
}, )
|
||||
else:
|
||||
response = self.__call_with_input(input={
|
||||
'action': 'list_voice',
|
||||
'page_index': page_index,
|
||||
'page_size': page_size,
|
||||
}, )
|
||||
self._last_request_id = response.request_id
|
||||
if response.status_code == 200:
|
||||
return response.output['voice_list']
|
||||
else:
|
||||
raise VoiceEnrollmentException(response.request_id, response.status_code, response.code,
|
||||
response.message)
|
||||
|
||||
def query_voice(self, voice_id: str) -> List[str]:
|
||||
'''
|
||||
查询已创建的所有音色
|
||||
param: voice_id 需要查询的音色
|
||||
return: bytes 注册音色使用的音频
|
||||
'''
|
||||
response = self.__call_with_input(input={
|
||||
'action': 'query_voice',
|
||||
'voice_id': voice_id,
|
||||
}, )
|
||||
self._last_request_id = response.request_id
|
||||
if response.status_code == 200:
|
||||
return response.output
|
||||
else:
|
||||
raise VoiceEnrollmentException(response.request_id, response.status_code, response.code,
|
||||
response.message)
|
||||
|
||||
def update_voice(self, voice_id: str, url: str) -> None:
|
||||
'''
|
||||
更新音色
|
||||
param: voice_id 音色id
|
||||
param: url 用于克隆的音频文件url
|
||||
'''
|
||||
response = self.__call_with_input(input={
|
||||
'action': 'update_voice',
|
||||
'voice_id': voice_id,
|
||||
'url': url,
|
||||
}, )
|
||||
self._last_request_id = response.request_id
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
raise VoiceEnrollmentException(response.request_id, response.status_code, response.code,
|
||||
response.message)
|
||||
|
||||
def delete_voice(self, voice_id: str) -> None:
|
||||
'''
|
||||
删除音色
|
||||
param: voice_id 需要删除的音色
|
||||
'''
|
||||
response = self.__call_with_input(input={
|
||||
'action': 'delete_voice',
|
||||
'voice_id': voice_id,
|
||||
}, )
|
||||
self._last_request_id = response.request_id
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
raise VoiceEnrollmentException(response.request_id, response.status_code, response.code,
|
||||
response.message)
|
||||
|
||||
def get_last_request_id(self):
|
||||
return self._last_request_id
|
||||
@@ -0,0 +1,597 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import json
|
||||
import platform
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from enum import Enum, unique
|
||||
|
||||
import websocket
|
||||
|
||||
import dashscope
|
||||
from dashscope.common.error import InputRequired, InvalidTask, ModelRequired
|
||||
from dashscope.common.logging import logger
|
||||
from dashscope.protocol.websocket import (ACTION_KEY, EVENT_KEY, HEADER,
|
||||
TASK_ID, ActionType, EventType,
|
||||
WebsocketStreamingMode)
|
||||
|
||||
|
||||
class ResultCallback:
|
||||
"""
|
||||
An interface that defines callback methods for getting speech synthesis results. # noqa E501
|
||||
Derive from this class and implement its function to provide your own data.
|
||||
"""
|
||||
def on_open(self) -> None:
|
||||
pass
|
||||
|
||||
def on_complete(self) -> None:
|
||||
pass
|
||||
|
||||
def on_error(self, message) -> None:
|
||||
pass
|
||||
|
||||
def on_close(self) -> None:
|
||||
pass
|
||||
|
||||
def on_event(self, message: str) -> None:
|
||||
pass
|
||||
|
||||
def on_data(self, data: bytes) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@unique
|
||||
class AudioFormat(Enum):
|
||||
DEFAULT = ('Default', 0, '0', 0)
|
||||
WAV_8000HZ_MONO_16BIT = ('wav', 8000, 'mono', 0)
|
||||
WAV_16000HZ_MONO_16BIT = ('wav', 16000, 'mono', 16)
|
||||
WAV_22050HZ_MONO_16BIT = ('wav', 22050, 'mono', 16)
|
||||
WAV_24000HZ_MONO_16BIT = ('wav', 24000, 'mono', 16)
|
||||
WAV_44100HZ_MONO_16BIT = ('wav', 44100, 'mono', 16)
|
||||
WAV_48000HZ_MONO_16BIT = ('wav', 48000, 'mono', 16)
|
||||
|
||||
MP3_8000HZ_MONO_128KBPS = ('mp3', 8000, 'mono', 128)
|
||||
MP3_16000HZ_MONO_128KBPS = ('mp3', 16000, 'mono', 128)
|
||||
MP3_22050HZ_MONO_256KBPS = ('mp3', 22050, 'mono', 256)
|
||||
MP3_24000HZ_MONO_256KBPS = ('mp3', 24000, 'mono', 256)
|
||||
MP3_44100HZ_MONO_256KBPS = ('mp3', 44100, 'mono', 256)
|
||||
MP3_48000HZ_MONO_256KBPS = ('mp3', 48000, 'mono', 256)
|
||||
|
||||
PCM_8000HZ_MONO_16BIT = ('pcm', 8000, 'mono', 16)
|
||||
PCM_16000HZ_MONO_16BIT = ('pcm', 16000, 'mono', 16)
|
||||
PCM_22050HZ_MONO_16BIT = ('pcm', 22050, 'mono', 16)
|
||||
PCM_24000HZ_MONO_16BIT = ('pcm', 24000, 'mono', 16)
|
||||
PCM_44100HZ_MONO_16BIT = ('pcm', 44100, 'mono', 16)
|
||||
PCM_48000HZ_MONO_16BIT = ('pcm', 48000, 'mono', 16)
|
||||
|
||||
OGG_OPUS_8KHZ_MONO_32KBPS = ("opus", 8000, "mono", 32)
|
||||
OGG_OPUS_8KHZ_MONO_16KBPS = ("opus", 8000, "mono", 16)
|
||||
OGG_OPUS_16KHZ_MONO_16KBPS = ("opus", 16000, "mono", 16)
|
||||
OGG_OPUS_16KHZ_MONO_32KBPS = ("opus", 16000, "mono", 32)
|
||||
OGG_OPUS_16KHZ_MONO_64KBPS = ("opus", 16000, "mono", 64)
|
||||
OGG_OPUS_24KHZ_MONO_16KBPS = ("opus", 24000, "mono", 16)
|
||||
OGG_OPUS_24KHZ_MONO_32KBPS = ("opus", 24000, "mono", 32)
|
||||
OGG_OPUS_24KHZ_MONO_64KBPS = ("opus", 24000, "mono", 64)
|
||||
OGG_OPUS_48KHZ_MONO_16KBPS = ("opus", 48000, "mono", 16)
|
||||
OGG_OPUS_48KHZ_MONO_32KBPS = ("opus", 48000, "mono", 32)
|
||||
OGG_OPUS_48KHZ_MONO_64KBPS = ("opus", 48000, "mono", 64)
|
||||
def __init__(self, format, sample_rate, channels, bit_rate):
|
||||
self.format = format
|
||||
self.sample_rate = sample_rate
|
||||
self.channels = channels
|
||||
self.bit_rate = bit_rate
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.format.upper()} with {self.sample_rate}Hz sample rate, {self.channels} channel, {self.bit_rate}'
|
||||
|
||||
|
||||
class Request:
|
||||
def __init__(
|
||||
self,
|
||||
apikey,
|
||||
model,
|
||||
voice,
|
||||
format='wav',
|
||||
sample_rate=16000,
|
||||
bit_rate=64000,
|
||||
volume=50,
|
||||
speech_rate=1.0,
|
||||
pitch_rate=1.0,
|
||||
seed=0,
|
||||
synthesis_type=0,
|
||||
instruction=None,
|
||||
language_hints: list = None,
|
||||
):
|
||||
self.task_id = self.genUid()
|
||||
self.apikey = apikey
|
||||
self.voice = voice
|
||||
self.model = model
|
||||
self.format = format
|
||||
self.sample_rate = sample_rate
|
||||
self.bit_rate = bit_rate
|
||||
self.volume = volume
|
||||
self.speech_rate = speech_rate
|
||||
self.pitch_rate = pitch_rate
|
||||
self.seed = seed
|
||||
self.synthesis_type = synthesis_type
|
||||
self.instruction = instruction
|
||||
self.language_hints = language_hints
|
||||
|
||||
def genUid(self):
|
||||
# 生成随机UUID
|
||||
return uuid.uuid4().hex
|
||||
|
||||
def getWebsocketHeaders(self, headers, workspace):
|
||||
ua = 'dashscope/%s; python/%s; platform/%s; processor/%s' % (
|
||||
'1.18.0', # dashscope version
|
||||
platform.python_version(),
|
||||
platform.platform(),
|
||||
platform.processor(),
|
||||
)
|
||||
self.headers = {
|
||||
'user-agent': ua,
|
||||
'Authorization': 'bearer ' + self.apikey,
|
||||
}
|
||||
if headers:
|
||||
self.headers = {**self.headers, **headers}
|
||||
if workspace:
|
||||
self.headers = {
|
||||
**self.headers,
|
||||
'X-DashScope-WorkSpace': workspace,
|
||||
}
|
||||
return self.headers
|
||||
|
||||
def getStartRequest(self, additional_params=None):
|
||||
|
||||
cmd = {
|
||||
HEADER: {
|
||||
ACTION_KEY: ActionType.START,
|
||||
TASK_ID: self.task_id,
|
||||
'streaming': WebsocketStreamingMode.DUPLEX,
|
||||
},
|
||||
'payload': {
|
||||
'model': self.model,
|
||||
'task_group': 'audio',
|
||||
'task': 'tts',
|
||||
'function': 'SpeechSynthesizer',
|
||||
'input': {},
|
||||
'parameters': {
|
||||
'voice': self.voice,
|
||||
'volume': self.volume,
|
||||
'text_type': 'PlainText',
|
||||
'sample_rate': self.sample_rate,
|
||||
'rate': self.speech_rate,
|
||||
'format': self.format,
|
||||
'pitch': self.pitch_rate,
|
||||
'seed': self.seed,
|
||||
'type': self.synthesis_type
|
||||
},
|
||||
},
|
||||
}
|
||||
if self.format == 'opus':
|
||||
cmd['payload']['parameters']['bit_rate'] = self.bit_rate
|
||||
if additional_params:
|
||||
cmd['payload']['parameters'].update(additional_params)
|
||||
if self.instruction is not None:
|
||||
cmd['payload']['parameters']['instruction'] = self.instruction
|
||||
if self.language_hints is not None:
|
||||
cmd['payload']['parameters']['language_hints'] = self.language_hints
|
||||
return json.dumps(cmd)
|
||||
|
||||
def getContinueRequest(self, text):
|
||||
cmd = {
|
||||
HEADER: {
|
||||
ACTION_KEY: ActionType.CONTINUE,
|
||||
TASK_ID: self.task_id,
|
||||
'streaming': WebsocketStreamingMode.DUPLEX,
|
||||
},
|
||||
'payload': {
|
||||
'model': self.model,
|
||||
'task_group': 'audio',
|
||||
'task': 'tts',
|
||||
'function': 'SpeechSynthesizer',
|
||||
'input': {
|
||||
'text': text
|
||||
},
|
||||
},
|
||||
}
|
||||
return json.dumps(cmd)
|
||||
|
||||
def getFinishRequest(self):
|
||||
cmd = {
|
||||
HEADER: {
|
||||
ACTION_KEY: ActionType.FINISHED,
|
||||
TASK_ID: self.task_id,
|
||||
'streaming': WebsocketStreamingMode.DUPLEX,
|
||||
},
|
||||
'payload': {
|
||||
'input': {},
|
||||
},
|
||||
}
|
||||
return json.dumps(cmd)
|
||||
|
||||
|
||||
class SpeechSynthesizer:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
voice,
|
||||
format: AudioFormat = AudioFormat.DEFAULT,
|
||||
volume=50,
|
||||
speech_rate=1.0,
|
||||
pitch_rate=1.0,
|
||||
seed=0,
|
||||
synthesis_type=0,
|
||||
instruction=None,
|
||||
language_hints: list = None,
|
||||
headers=None,
|
||||
callback: ResultCallback = None,
|
||||
workspace=None,
|
||||
url=None,
|
||||
additional_params=None,
|
||||
):
|
||||
"""
|
||||
CosyVoice Speech Synthesis SDK
|
||||
Parameters:
|
||||
-----------
|
||||
model: str
|
||||
Model name.
|
||||
voice: str
|
||||
Voice name.
|
||||
format: AudioFormat
|
||||
Synthesis audio format.
|
||||
volume: int
|
||||
The volume of the synthesized audio, with a range from 0 to 100. Default is 50.
|
||||
rate: float
|
||||
The speech rate of the synthesized audio, with a range from 0.5 to 2. Default is 1.0.
|
||||
pitch: float
|
||||
The pitch of the synthesized audio, with a range from 0.5 to 2. Default is 1.0.
|
||||
headers: Dict
|
||||
User-defined headers.
|
||||
callback: ResultCallback
|
||||
Callback to receive real-time synthesis results.
|
||||
workspace: str
|
||||
Dashscope workspace ID.
|
||||
url: str
|
||||
Dashscope WebSocket URL.
|
||||
seed: int
|
||||
The seed of the synthesizer, with a range from 0 to 65535. Default is 0.
|
||||
synthesis_type: int
|
||||
The type of the synthesizer, Default is 0.
|
||||
instruction: str
|
||||
The instruction of the synthesizer, max length is 128.
|
||||
language_hints: list
|
||||
The language hints of the synthesizer. supported language: zh, en.
|
||||
additional_params: Dict
|
||||
Additional parameters for the Dashscope API.
|
||||
"""
|
||||
|
||||
if model is None:
|
||||
raise ModelRequired('Model is required!')
|
||||
if format is None:
|
||||
raise InputRequired('format is required!')
|
||||
if url is None:
|
||||
url = dashscope.base_websocket_api_url
|
||||
self.url = url
|
||||
self.apikey = dashscope.api_key
|
||||
self.headers = headers
|
||||
self.workspace = workspace
|
||||
self.additional_params = additional_params
|
||||
self.model = model
|
||||
self.voice = voice
|
||||
self.aformat = format.format
|
||||
if (self.aformat == 'DEFAULT'):
|
||||
self.aformat = 'mp3'
|
||||
self.sample_rate = format.sample_rate
|
||||
if (self.sample_rate == 0):
|
||||
self.sample_rate = 22050
|
||||
|
||||
self.request = Request(
|
||||
apikey=self.apikey,
|
||||
model=model,
|
||||
voice=voice,
|
||||
format=format.format,
|
||||
sample_rate=format.sample_rate,
|
||||
bit_rate = format.bit_rate,
|
||||
volume=volume,
|
||||
speech_rate=speech_rate,
|
||||
pitch_rate=pitch_rate,
|
||||
seed=seed,
|
||||
synthesis_type=synthesis_type,
|
||||
instruction=instruction,
|
||||
language_hints=language_hints
|
||||
)
|
||||
self.last_request_id = self.request.task_id
|
||||
self.start_event = threading.Event()
|
||||
self.complete_event = threading.Event()
|
||||
self._stopped = threading.Event()
|
||||
self._audio_data: bytes = None
|
||||
self._is_started = False
|
||||
self._cancel = False
|
||||
self._cancel_lock = threading.Lock()
|
||||
self.async_call = True
|
||||
self.callback = callback
|
||||
self._is_first = True
|
||||
self.async_call = True
|
||||
# since dashscope sdk will send first text in run-task
|
||||
if not self.callback:
|
||||
self.async_call = False
|
||||
self._start_stream_timestamp = -1
|
||||
self._first_package_timestamp = -1
|
||||
self._recv_audio_length = 0
|
||||
self.last_response = None
|
||||
|
||||
def __send_str(self, data: str):
|
||||
logger.debug('>>>send {}'.format(data))
|
||||
self.ws.send(data)
|
||||
|
||||
def __start_stream(self, ):
|
||||
self._start_stream_timestamp = time.time() * 1000
|
||||
self._first_package_timestamp = -1
|
||||
self._recv_audio_length = 0
|
||||
if self.callback is None:
|
||||
raise InputRequired('callback is required!')
|
||||
# reset inner params
|
||||
self._stopped.clear()
|
||||
self._stream_data = ['']
|
||||
self._worker = None
|
||||
self._audio_data: bytes = None
|
||||
|
||||
if self._is_started:
|
||||
raise InvalidTask('task has already started.')
|
||||
|
||||
self.ws = websocket.WebSocketApp(
|
||||
self.url,
|
||||
header=self.request.getWebsocketHeaders(headers=self.headers,
|
||||
workspace=self.workspace),
|
||||
on_message=self.on_message,
|
||||
on_error=self.on_error,
|
||||
on_close=self.on_close,
|
||||
)
|
||||
self.thread = threading.Thread(target=self.ws.run_forever)
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
request = self.request.getStartRequest(self.additional_params)
|
||||
# 等待连接建立
|
||||
timeout = 5 # 最长等待时间(秒)
|
||||
start_time = time.time()
|
||||
while (not (self.ws.sock and self.ws.sock.connected)
|
||||
and (time.time() - start_time) < timeout):
|
||||
time.sleep(0.1) # 短暂休眠,避免密集轮询
|
||||
if not (self.ws.sock and self.ws.sock.connected):
|
||||
raise TimeoutError(
|
||||
'websocket connection could not established within 5s. '
|
||||
'Please check your network connection, firewall settings, or server status.'
|
||||
)
|
||||
self.__send_str(request)
|
||||
if not self.start_event.wait(10):
|
||||
raise TimeoutError('start speech synthesizer failed within 5s.')
|
||||
self._is_started = True
|
||||
if self.callback:
|
||||
self.callback.on_open()
|
||||
|
||||
def __submit_text(self, text):
|
||||
if not self._is_started:
|
||||
raise InvalidTask('speech synthesizer has not been started.')
|
||||
|
||||
if self._stopped.is_set():
|
||||
raise InvalidTask('speech synthesizer task has stopped.')
|
||||
request = self.request.getContinueRequest(text)
|
||||
self.__send_str(request)
|
||||
|
||||
def streaming_call(self, text: str):
|
||||
"""
|
||||
Streaming input mode: You can call the stream_call function multiple times to send text.
|
||||
A session will be created on the first call.
|
||||
The session ends after calling streaming_complete.
|
||||
Parameters:
|
||||
-----------
|
||||
text: str
|
||||
utf-8 encoded text
|
||||
"""
|
||||
if self._is_first:
|
||||
self._is_first = False
|
||||
self.__start_stream()
|
||||
self.__submit_text(text)
|
||||
return None
|
||||
|
||||
def streaming_complete(self, complete_timeout_millis=600000):
|
||||
"""
|
||||
Synchronously stop the streaming input speech synthesis task.
|
||||
Wait for all remaining synthesized audio before returning
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
complete_timeout_millis: int
|
||||
Throws TimeoutError exception if it times out. If the timeout is not None
|
||||
and greater than zero, it will wait for the corresponding number of
|
||||
milliseconds; otherwise, it will wait indefinitely.
|
||||
"""
|
||||
if not self._is_started:
|
||||
raise InvalidTask('speech synthesizer has not been started.')
|
||||
if self._stopped.is_set():
|
||||
raise InvalidTask('speech synthesizer task has stopped.')
|
||||
request = self.request.getFinishRequest()
|
||||
self.__send_str(request)
|
||||
if complete_timeout_millis is not None and complete_timeout_millis > 0:
|
||||
if not self.complete_event.wait(timeout=complete_timeout_millis /
|
||||
1000):
|
||||
raise TimeoutError(
|
||||
'speech synthesizer wait for complete timeout {}ms'.format(
|
||||
complete_timeout_millis))
|
||||
else:
|
||||
self.complete_event.wait()
|
||||
self.close()
|
||||
self._stopped.set()
|
||||
self._is_started = False
|
||||
|
||||
def __waiting_for_complete(self, timeout):
|
||||
if timeout is not None and timeout > 0:
|
||||
if not self.complete_event.wait(timeout=timeout / 1000):
|
||||
raise TimeoutError(
|
||||
f'speech synthesizer wait for complete timeout {timeout}ms'
|
||||
)
|
||||
else:
|
||||
self.complete_event.wait()
|
||||
self.close()
|
||||
self._stopped.set()
|
||||
self._is_started = False
|
||||
|
||||
def async_streaming_complete(self, complete_timeout_millis=600000):
|
||||
"""
|
||||
Asynchronously stop the streaming input speech synthesis task, returns immediately.
|
||||
You need to listen and handle the STREAM_INPUT_TTS_EVENT_SYNTHESIS_COMPLETE event in the on_event callback.
|
||||
Do not destroy the object and callback before this event.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
complete_timeout_millis: int
|
||||
Throws TimeoutError exception if it times out. If the timeout is not None
|
||||
and greater than zero, it will wait for the corresponding number of
|
||||
milliseconds; otherwise, it will wait indefinitely.
|
||||
"""
|
||||
|
||||
if not self._is_started:
|
||||
raise InvalidTask('speech synthesizer has not been started.')
|
||||
if self._stopped.is_set():
|
||||
raise InvalidTask('speech synthesizer task has stopped.')
|
||||
request = self.request.getFinishRequest()
|
||||
self.__send_str(request)
|
||||
thread = threading.Thread(target=self.__waiting_for_complete,
|
||||
args=(complete_timeout_millis, ))
|
||||
thread.start()
|
||||
|
||||
def streaming_cancel(self):
|
||||
"""
|
||||
Immediately terminate the streaming input speech synthesis task
|
||||
and discard any remaining audio that is not yet delivered.
|
||||
"""
|
||||
|
||||
if not self._is_started:
|
||||
raise InvalidTask('speech synthesizer has not been started.')
|
||||
if self._stopped.is_set():
|
||||
return
|
||||
request = self.request.getFinishRequest()
|
||||
self.__send_str(request)
|
||||
self.close()
|
||||
self.start_event.set()
|
||||
self.complete_event.set()
|
||||
|
||||
# 监听消息的回调函数
|
||||
def on_message(self, ws, message):
|
||||
if isinstance(message, str):
|
||||
logger.debug('<<<recv {}'.format(message))
|
||||
try:
|
||||
# 尝试将消息解析为JSON
|
||||
json_data = json.loads(message)
|
||||
self.last_response = json_data
|
||||
event = json_data['header'][EVENT_KEY]
|
||||
# 调用JSON回调
|
||||
if EventType.STARTED == event:
|
||||
self.start_event.set()
|
||||
elif EventType.FINISHED == event:
|
||||
self.complete_event.set()
|
||||
if self.callback:
|
||||
self.callback.on_complete()
|
||||
self.callback.on_close()
|
||||
elif EventType.FAILED == event:
|
||||
self.start_event.set()
|
||||
self.complete_event.set()
|
||||
if self.async_call:
|
||||
self.callback.on_error(message)
|
||||
self.callback.on_close()
|
||||
else:
|
||||
logger.error(f'TaskFailed: {message}')
|
||||
raise Exception(f'TaskFailed: {message}')
|
||||
elif EventType.GENERATED == event:
|
||||
if self.callback:
|
||||
self.callback.on_event(message)
|
||||
else:
|
||||
pass
|
||||
except json.JSONDecodeError:
|
||||
logger.error('Failed to parse message as JSON.')
|
||||
raise Exception('Failed to parse message as JSON.')
|
||||
elif isinstance(message, (bytes, bytearray)):
|
||||
# 如果失败,认为是二进制消息
|
||||
logger.debug('<<<recv binary {}'.format(len(message)))
|
||||
if (self._recv_audio_length == 0):
|
||||
self._first_package_timestamp = time.time() * 1000
|
||||
logger.debug('first package delay {}'.format(
|
||||
self._first_package_timestamp -
|
||||
self._start_stream_timestamp))
|
||||
self._recv_audio_length += len(message) / (2 * self.sample_rate /
|
||||
1000)
|
||||
current = time.time() * 1000
|
||||
current_rtf = (current - self._start_stream_timestamp
|
||||
) / self._recv_audio_length
|
||||
logger.debug('total audio {} ms, current_rtf: {}'.format(
|
||||
self._recv_audio_length, current_rtf))
|
||||
# 只有在非异步调用的时候保存音频
|
||||
if not self.async_call:
|
||||
if self._audio_data is None:
|
||||
self._audio_data = bytes(message)
|
||||
else:
|
||||
self._audio_data = self._audio_data + bytes(message)
|
||||
if self.callback:
|
||||
self.callback.on_data(message)
|
||||
|
||||
def call(self, text: str, timeout_millis=None):
|
||||
"""
|
||||
Speech synthesis.
|
||||
If callback is set, the audio will be returned in real-time through the on_event interface.
|
||||
Otherwise, this function blocks until all audio is received and then returns the complete audio data.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
text: str
|
||||
utf-8 encoded text
|
||||
timeoutMillis:
|
||||
Integer or None
|
||||
return: bytes
|
||||
If a callback is not set during initialization, the complete audio is returned
|
||||
as the function's return value. Otherwise, the return value is null.
|
||||
If the timeout is set to a value greater than zero and not None,
|
||||
it will wait for the corresponding number of milliseconds;
|
||||
otherwise, it will wait indefinitely.
|
||||
"""
|
||||
# print('还不支持非流式语音合成sdk调用大模型,使用流式模拟')
|
||||
if self.additional_params is None:
|
||||
self.additional_params = {"enable_ssml":True}
|
||||
else:
|
||||
self.additional_params["enable_ssml"] = True
|
||||
if not self.callback:
|
||||
self.callback = ResultCallback()
|
||||
self.__start_stream()
|
||||
self.__submit_text(text)
|
||||
if self.async_call:
|
||||
self.async_streaming_complete(timeout_millis)
|
||||
return None
|
||||
else:
|
||||
self.streaming_complete(timeout_millis)
|
||||
return self._audio_data
|
||||
|
||||
# WebSocket关闭的回调函数
|
||||
def on_close(self, ws, close_status_code, close_msg):
|
||||
pass
|
||||
|
||||
# WebSocket发生错误的回调函数
|
||||
def on_error(self, ws, error):
|
||||
print(f'websocket closed due to {error}')
|
||||
raise Exception(f'websocket closed due to {error}')
|
||||
|
||||
# 关闭WebSocket连接
|
||||
def close(self):
|
||||
self.ws.close()
|
||||
|
||||
# 获取上一个任务的taskId
|
||||
def get_last_request_id(self):
|
||||
return self.last_request_id
|
||||
|
||||
def get_first_package_delay(self):
|
||||
"""First Package Delay is the time between start sending text and receive first audio package
|
||||
"""
|
||||
return self._first_package_timestamp - self._start_stream_timestamp
|
||||
|
||||
def get_response(self):
|
||||
return self.last_response
|
||||
Reference in New Issue
Block a user