chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .tingwu import tingwu
|
||||
from .tingwu.tingwu import TingWu
|
||||
from .tingwu.tingwu_realtime import TingWuRealtime, TingWuRealtimeCallback
|
||||
|
||||
from .multimodal_dialog import MultiModalDialog, MultiModalCallback
|
||||
from .dialog_state import DialogState
|
||||
from .multimodal_constants import *
|
||||
from .multimodal_request_params import *
|
||||
|
||||
__all__ = [
|
||||
'tingwu',
|
||||
'TingWu',
|
||||
'TingWuRealtime',
|
||||
'TingWuRealtimeCallback',
|
||||
'MultiModalDialog',
|
||||
'MultiModalCallback',
|
||||
'DialogState'
|
||||
]
|
||||
@@ -0,0 +1,56 @@
|
||||
# dialog_state.py
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DialogState(Enum):
|
||||
"""
|
||||
对话状态枚举类,定义了对话机器人可能处于的不同状态。
|
||||
|
||||
Attributes:
|
||||
IDLE (str): 表示机器人处于空闲状态。
|
||||
LISTENING (str): 表示机器人正在监听用户输入。
|
||||
THINKING (str): 表示机器人正在思考。
|
||||
RESPONDING (str): 表示机器人正在生成或回复中。
|
||||
"""
|
||||
IDLE = 'Idle'
|
||||
LISTENING = 'Listening'
|
||||
THINKING = 'Thinking'
|
||||
RESPONDING = 'Responding'
|
||||
|
||||
|
||||
class StateMachine:
|
||||
"""
|
||||
状态机类,用于管理机器人的状态转换。
|
||||
|
||||
Attributes:
|
||||
current_state (DialogState): 当前状态。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 初始化状态机时设置初始状态为IDLE
|
||||
self.current_state = DialogState.IDLE
|
||||
|
||||
def change_state(self, new_state: str) -> None:
|
||||
"""
|
||||
更改当前状态到指定的新状态。
|
||||
|
||||
Args:
|
||||
new_state (str): 要切换到的新状态。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果尝试切换到一个无效的状态,则抛出此异常。
|
||||
"""
|
||||
if new_state in [state.value for state in DialogState]:
|
||||
self.current_state = DialogState(new_state)
|
||||
else:
|
||||
raise ValueError("无效的状态类型")
|
||||
|
||||
def get_current_state(self) -> DialogState:
|
||||
"""
|
||||
获取当前状态。
|
||||
|
||||
Returns:
|
||||
DialogState: 当前状态。
|
||||
"""
|
||||
return self.current_state
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# multimodal conversation request directive
|
||||
|
||||
class RequestToRespondType:
|
||||
TRANSCRIPT = 'transcript'
|
||||
PROMPT = 'prompt'
|
||||
|
||||
|
||||
# multimodal conversation response directive
|
||||
RESPONSE_NAME_TASK_STARTED = "task-started"
|
||||
RESPONSE_NAME_RESULT_GENERATED = "result-generated"
|
||||
RESPONSE_NAME_TASK_FINISHED = "task-finished"
|
||||
|
||||
RESPONSE_NAME_TASK_FAILED = "TaskFailed"
|
||||
RESPONSE_NAME_STARTED = "Started"
|
||||
RESPONSE_NAME_STOPPED = "Stopped"
|
||||
RESPONSE_NAME_STATE_CHANGED = "DialogStateChanged"
|
||||
RESPONSE_NAME_REQUEST_ACCEPTED = "RequestAccepted"
|
||||
RESPONSE_NAME_SPEECH_STARTED = "SpeechStarted"
|
||||
RESPONSE_NAME_SPEECH_ENDED = "SpeechEnded" # 服务端检测到asr语音尾点时下发此事件,可选事件
|
||||
RESPONSE_NAME_RESPONDING_STARTED = "RespondingStarted" # AI语音应答开始,sdk要准备接收服务端下发的语音数据
|
||||
RESPONSE_NAME_RESPONDING_ENDED = "RespondingEnded" # AI语音应答结束
|
||||
RESPONSE_NAME_SPEECH_CONTENT = "SpeechContent" # 用户语音识别出的文本,流式全量输出
|
||||
RESPONSE_NAME_RESPONDING_CONTENT = "RespondingContent" # 统对外输出的文本,流式全量输出
|
||||
RESPONSE_NAME_ERROR = "Error" # 服务端对话中报错
|
||||
RESPONSE_NAME_HEART_BEAT = "HeartBeat" # 心跳消息
|
||||
@@ -0,0 +1,643 @@
|
||||
import json
|
||||
import platform
|
||||
import time
|
||||
import threading
|
||||
from abc import abstractmethod
|
||||
|
||||
import websocket
|
||||
|
||||
import dashscope
|
||||
from dashscope.common.logging import logger
|
||||
from dashscope.common.error import InputRequired
|
||||
from dashscope.multimodal import dialog_state
|
||||
from dashscope.multimodal.multimodal_constants import *
|
||||
from dashscope.multimodal.multimodal_request_params import RequestParameters, get_random_uuid, DashHeader, \
|
||||
RequestBodyInput, DashPayload, RequestToRespondParameters, RequestToRespondBodyInput
|
||||
from dashscope.protocol.websocket import ActionType
|
||||
|
||||
|
||||
class MultiModalCallback:
|
||||
"""
|
||||
语音聊天回调类,用于处理语音聊天过程中的各种事件。
|
||||
"""
|
||||
|
||||
def on_started(self, dialog_id: str) -> None:
|
||||
"""
|
||||
通知对话开始
|
||||
|
||||
:param dialog_id: 回调对话ID
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_stopped(self) -> None:
|
||||
"""
|
||||
通知对话停止
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_state_changed(self, state: 'dialog_state.DialogState') -> None:
|
||||
"""
|
||||
对话状态改变
|
||||
|
||||
:param state: 新的对话状态
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_speech_audio_data(self, data: bytes) -> None:
|
||||
"""
|
||||
合成音频数据回调
|
||||
|
||||
:param data: 音频数据
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_error(self, error) -> None:
|
||||
"""
|
||||
发生错误时调用此方法。
|
||||
|
||||
:param error: 错误信息
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_connected(self) -> None:
|
||||
"""
|
||||
成功连接到服务器后调用此方法。
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_responding_started(self):
|
||||
"""
|
||||
回复开始回调
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_responding_ended(self, payload):
|
||||
"""
|
||||
回复结束
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_speech_started(self):
|
||||
"""
|
||||
检测到语音输入结束
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_speech_ended(self):
|
||||
"""
|
||||
检测到语音输入结束
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_speech_content(self, payload):
|
||||
"""
|
||||
语音识别文本
|
||||
|
||||
:param payload: text
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_responding_content(self, payload):
|
||||
"""
|
||||
大模型回复文本。
|
||||
|
||||
:param payload: text
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_request_accepted(self):
|
||||
"""
|
||||
打断请求被接受。
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_close(self, close_status_code, close_msg):
|
||||
"""
|
||||
连接关闭时调用此方法。
|
||||
|
||||
:param close_status_code: 关闭状态码
|
||||
:param close_msg: 关闭消息
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MultiModalDialog:
|
||||
"""
|
||||
用于管理WebSocket连接以进行语音聊天的服务类。
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
app_id: str,
|
||||
request_params: RequestParameters,
|
||||
multimodal_callback: MultiModalCallback,
|
||||
workspace_id: str = None,
|
||||
url: str = None,
|
||||
api_key: str = None,
|
||||
dialog_id: str = None,
|
||||
model: str = None
|
||||
):
|
||||
"""
|
||||
创建一个语音对话会话。
|
||||
|
||||
此方法用于初始化一个新的voice_chat会话,设置必要的参数以准备开始与模型的交互。
|
||||
:param workspace_id: 客户的workspace_id 主工作空间id,非必填字段
|
||||
:param app_id: 客户在管控台创建的应用id,可以根据值规律确定使用哪个对话系统
|
||||
:param request_params: 请求参数集合
|
||||
:param url: (str) API的URL地址。
|
||||
:param multimodal_callback: (MultimodalCallback) 回调对象,用于处理来自服务器的消息。
|
||||
:param api_key: (str) 应用程序接入的唯一key
|
||||
:param dialog_id:对话id,如果传入表示承接上下文继续聊
|
||||
:param model: 模型
|
||||
"""
|
||||
if request_params is None:
|
||||
raise InputRequired('request_params is required!')
|
||||
if url is None:
|
||||
url = dashscope.base_websocket_api_url
|
||||
if api_key is None:
|
||||
api_key = dashscope.api_key
|
||||
|
||||
self.request_params = request_params
|
||||
self.model = model
|
||||
self._voice_detection = None
|
||||
self.thread = None
|
||||
self.ws = None
|
||||
self.request = _Request()
|
||||
self._callback = multimodal_callback
|
||||
self.url = url
|
||||
self.api_key = api_key
|
||||
self.workspace_id = workspace_id
|
||||
self.app_id = app_id
|
||||
self.dialog_id = dialog_id
|
||||
self.dialog_state = dialog_state.StateMachine()
|
||||
self.response = _Response(self.dialog_state, self._callback, self.close) # 传递 self.close 作为回调
|
||||
|
||||
def _on_message(self, ws, message):
|
||||
logger.debug(f"<<<<<<< Received message: {message}")
|
||||
if isinstance(message, str):
|
||||
self.response.handle_text_response(message)
|
||||
elif isinstance(message, (bytes, bytearray)):
|
||||
self.response.handle_binary_response(message)
|
||||
|
||||
def _on_error(self, ws, error):
|
||||
logger.error(f"Error: {error}")
|
||||
if self._callback:
|
||||
self._callback.on_error(error)
|
||||
|
||||
def _on_close(self, ws, close_status_code, close_msg):
|
||||
try:
|
||||
logger.debug(
|
||||
"WebSocket connection closed with status {} and message {}".format(close_status_code, close_msg))
|
||||
if close_status_code is None:
|
||||
close_status_code = 1000
|
||||
if close_msg is None:
|
||||
close_msg = "websocket is closed"
|
||||
self._callback.on_close(close_status_code, close_msg)
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
|
||||
def _on_open(self, ws):
|
||||
self._callback.on_connected()
|
||||
|
||||
# def _on_pong(self, _):
|
||||
# _log.debug("on pong")
|
||||
|
||||
def start(self, dialog_id, enable_voice_detection=False):
|
||||
"""
|
||||
初始化WebSocket连接并发送启动请求
|
||||
:param dialog_id: 上下位继承标志位。新对话无需设置。
|
||||
如果继承之前的对话历史,则需要记录之前的dialog_id并传入
|
||||
:param enable_voice_detection: 是否开启语音检测,可选参数 默认False
|
||||
"""
|
||||
self._voice_detection = enable_voice_detection
|
||||
self._connect(self.api_key)
|
||||
logger.debug("connected with server.")
|
||||
self._send_start_request(dialog_id, self.request_params)
|
||||
|
||||
def start_speech(self):
|
||||
"""开始上传语音数据"""
|
||||
_send_speech_json = self.request.generate_common_direction_request("SendSpeech", self.dialog_id)
|
||||
self._send_text_frame(_send_speech_json)
|
||||
|
||||
def send_audio_data(self, speech_data: bytes):
|
||||
"""发送语音数据"""
|
||||
self.__send_binary_frame(speech_data)
|
||||
|
||||
def stop_speech(self):
|
||||
"""停止上传语音数据"""
|
||||
_send_speech_json = self.request.generate_common_direction_request("StopSpeech", self.dialog_id)
|
||||
self._send_text_frame(_send_speech_json)
|
||||
|
||||
def interrupt(self):
|
||||
"""请求服务端开始说话"""
|
||||
_send_speech_json = self.request.generate_common_direction_request("RequestToSpeak", self.dialog_id)
|
||||
self._send_text_frame(_send_speech_json)
|
||||
|
||||
def request_to_respond(self,
|
||||
request_type: str,
|
||||
text: str,
|
||||
parameters: RequestToRespondParameters = None):
|
||||
"""请求服务端直接文本合成语音"""
|
||||
_send_speech_json = self.request.generate_request_to_response_json(direction_name="RequestToRespond",
|
||||
dialog_id=self.dialog_id,
|
||||
request_type=request_type, text=text,
|
||||
parameters=parameters)
|
||||
self._send_text_frame(_send_speech_json)
|
||||
|
||||
@abstractmethod
|
||||
def request_to_respond_prompt(self, text):
|
||||
"""请求服务端通过文本请求回复文本答复"""
|
||||
return
|
||||
|
||||
def local_responding_started(self):
|
||||
"""本地tts播放开始"""
|
||||
_send_speech_json = self.request.generate_common_direction_request("LocalRespondingStarted", self.dialog_id)
|
||||
self._send_text_frame(_send_speech_json)
|
||||
|
||||
def local_responding_ended(self):
|
||||
"""本地tts播放结束"""
|
||||
_send_speech_json = self.request.generate_common_direction_request("LocalRespondingEnded", self.dialog_id)
|
||||
self._send_text_frame(_send_speech_json)
|
||||
|
||||
def send_heart_beat(self):
|
||||
"""发送心跳"""
|
||||
_send_speech_json = self.request.generate_common_direction_request("HeartBeat", self.dialog_id)
|
||||
self._send_text_frame(_send_speech_json)
|
||||
|
||||
def update_info(self, parameters: RequestToRespondParameters = None):
|
||||
"""更新信息"""
|
||||
_send_speech_json = self.request.generate_update_info_json(direction_name="UpdateInfo", dialog_id=self.dialog_id, parameters=parameters)
|
||||
self._send_text_frame(_send_speech_json)
|
||||
|
||||
def stop(self):
|
||||
if self.ws is None or not self.ws.sock or not self.ws.sock.connected:
|
||||
self._callback.on_close(1001, "websocket is not connected")
|
||||
return
|
||||
_send_speech_json = self.request.generate_stop_request("Stop", self.dialog_id)
|
||||
self._send_text_frame(_send_speech_json)
|
||||
|
||||
def get_dialog_state(self) -> dialog_state.DialogState:
|
||||
return self.dialog_state.get_current_state()
|
||||
|
||||
def get_conversation_mode(self) -> str:
|
||||
"""get mode of conversation: support tap2talk/push2talk/duplex"""
|
||||
return self.request_params.upstream.mode
|
||||
|
||||
"""内部方法"""
|
||||
|
||||
def _send_start_request(self, dialog_id: str, request_params: RequestParameters):
|
||||
"""发送'Start'请求"""
|
||||
_start_json = self.request.generate_start_request(
|
||||
workspace_id=self.workspace_id,
|
||||
direction_name="Start",
|
||||
dialog_id=dialog_id,
|
||||
app_id=self.app_id,
|
||||
request_params=request_params,
|
||||
model=self.model
|
||||
)
|
||||
# send start request
|
||||
self._send_text_frame(_start_json)
|
||||
|
||||
def _run_forever(self):
|
||||
self.ws.run_forever(ping_interval=5, ping_timeout=4)
|
||||
|
||||
def _connect(self, api_key: str):
|
||||
"""初始化WebSocket连接并发送启动请求。"""
|
||||
self.ws = websocket.WebSocketApp(self.url, header=self.request.get_websocket_header(api_key),
|
||||
on_open=self._on_open,
|
||||
on_message=self._on_message,
|
||||
on_error=self._on_error,
|
||||
on_close=self._on_close)
|
||||
self.thread = threading.Thread(target=self._run_forever)
|
||||
self.ws.ping_interval = 3
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
self._wait_for_connection()
|
||||
|
||||
def close(self):
|
||||
if self.ws is None or not self.ws.sock or not self.ws.sock.connected:
|
||||
return
|
||||
self.ws.close()
|
||||
|
||||
def _wait_for_connection(self):
|
||||
"""等待WebSocket连接建立"""
|
||||
timeout = 5
|
||||
start_time = time.time()
|
||||
while not (self.ws.sock and self.ws.sock.connected) and (time.time() - start_time) < timeout:
|
||||
time.sleep(0.1) # 短暂休眠,避免密集轮询
|
||||
|
||||
def _send_text_frame(self, text: str):
|
||||
logger.info('>>>>>> send text frame : %s' % text)
|
||||
self.ws.send(text, websocket.ABNF.OPCODE_TEXT)
|
||||
|
||||
def __send_binary_frame(self, binary: bytes):
|
||||
# _log.info('send binary frame length: %d' % len(binary))
|
||||
self.ws.send(binary, websocket.ABNF.OPCODE_BINARY)
|
||||
|
||||
def __del__(self):
|
||||
self.cleanup()
|
||||
|
||||
def cleanup(self):
|
||||
"""清理所有资源"""
|
||||
try:
|
||||
if self.ws:
|
||||
self.ws.close()
|
||||
if self.thread and self.thread.is_alive():
|
||||
# 设置标志位通知线程退出
|
||||
self.thread.join(timeout=2)
|
||||
# 清除引用
|
||||
self.ws = None
|
||||
self.thread = None
|
||||
self._callback = None
|
||||
self.response = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup: {e}")
|
||||
|
||||
|
||||
class _Request:
|
||||
def __init__(self):
|
||||
# websocket header
|
||||
self.ws_headers = None
|
||||
# request body for voice chat
|
||||
self.header = None
|
||||
self.payload = None
|
||||
# params
|
||||
self.task_id = None
|
||||
self.app_id = None
|
||||
self.workspace_id = None
|
||||
|
||||
def get_websocket_header(self, api_key):
|
||||
ua = 'dashscope/%s; python/%s; platform/%s; processor/%s' % (
|
||||
'1.18.0', # dashscope version
|
||||
platform.python_version(),
|
||||
platform.platform(),
|
||||
platform.processor(),
|
||||
)
|
||||
self.ws_headers = {
|
||||
"User-Agent": ua,
|
||||
"Authorization": f"bearer {api_key}",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
logger.info('websocket header: {}'.format(self.ws_headers))
|
||||
return self.ws_headers
|
||||
|
||||
def generate_start_request(self, direction_name: str,
|
||||
dialog_id: str,
|
||||
app_id: str,
|
||||
request_params: RequestParameters,
|
||||
model: str = None,
|
||||
workspace_id: str = None
|
||||
) -> str:
|
||||
"""
|
||||
构建语音聊天服务的启动请求数据.
|
||||
:param app_id: 管控台应用id
|
||||
:param request_params: start请求body中的parameters
|
||||
:param direction_name:
|
||||
:param dialog_id: 对话ID.
|
||||
:param workspace_id: 管控台工作空间id, 非必填字段。
|
||||
:param model: 模型
|
||||
:return: 启动请求字典.
|
||||
"""
|
||||
self._get_dash_request_header(ActionType.START)
|
||||
self._get_dash_request_payload(direction_name, dialog_id, app_id, workspace_id=workspace_id,
|
||||
request_params=request_params, model=model)
|
||||
|
||||
cmd = {
|
||||
"header": self.header,
|
||||
"payload": self.payload
|
||||
}
|
||||
return json.dumps(cmd)
|
||||
|
||||
def generate_common_direction_request(self, direction_name: str, dialog_id: str) -> str:
|
||||
"""
|
||||
构建语音聊天服务的命令请求数据.
|
||||
:param direction_name: 命令.
|
||||
:param dialog_id: 对话ID.
|
||||
:return: 命令请求json.
|
||||
"""
|
||||
self._get_dash_request_header(ActionType.CONTINUE)
|
||||
self._get_dash_request_payload(direction_name, dialog_id, self.app_id)
|
||||
cmd = {
|
||||
"header": self.header,
|
||||
"payload": self.payload
|
||||
}
|
||||
return json.dumps(cmd)
|
||||
|
||||
def generate_stop_request(self, direction_name: str, dialog_id: str) -> str:
|
||||
"""
|
||||
构建语音聊天服务的启动请求数据.
|
||||
:param direction_name:指令名称
|
||||
:param dialog_id: 对话ID.
|
||||
:return: 启动请求json.
|
||||
"""
|
||||
self._get_dash_request_header(ActionType.FINISHED)
|
||||
self._get_dash_request_payload(direction_name, dialog_id, self.app_id)
|
||||
|
||||
cmd = {
|
||||
"header": self.header,
|
||||
"payload": self.payload
|
||||
}
|
||||
return json.dumps(cmd)
|
||||
|
||||
def generate_request_to_response_json(self, direction_name: str, dialog_id: str, request_type: str, text: str,
|
||||
parameters: RequestToRespondParameters = None) -> str:
|
||||
"""
|
||||
构建语音聊天服务的命令请求数据.
|
||||
:param direction_name: 命令.
|
||||
:param dialog_id: 对话ID.
|
||||
:param request_type: 服务应该采取的交互类型,transcript 表示直接把文本转语音,prompt 表示把文本送大模型回答
|
||||
:param text: 文本.
|
||||
:param parameters: 命令请求body中的parameters
|
||||
:return: 命令请求字典.
|
||||
"""
|
||||
self._get_dash_request_header(ActionType.CONTINUE)
|
||||
|
||||
custom_input = RequestToRespondBodyInput(
|
||||
app_id=self.app_id,
|
||||
directive=direction_name,
|
||||
dialog_id=dialog_id,
|
||||
type_=request_type,
|
||||
text=text
|
||||
)
|
||||
|
||||
self._get_dash_request_payload(direction_name, dialog_id, self.app_id, request_params=parameters,
|
||||
custom_input=custom_input)
|
||||
cmd = {
|
||||
"header": self.header,
|
||||
"payload": self.payload
|
||||
}
|
||||
return json.dumps(cmd)
|
||||
|
||||
def generate_update_info_json(self, direction_name: str, dialog_id: str,parameters: RequestToRespondParameters = None) -> str:
|
||||
"""
|
||||
构建语音聊天服务的命令请求数据.
|
||||
:param direction_name: 命令.
|
||||
:param parameters: 命令请求body中的parameters
|
||||
:return: 命令请求字典.
|
||||
"""
|
||||
self._get_dash_request_header(ActionType.CONTINUE)
|
||||
|
||||
custom_input = RequestToRespondBodyInput(
|
||||
app_id=self.app_id,
|
||||
directive=direction_name,
|
||||
dialog_id=dialog_id,
|
||||
)
|
||||
|
||||
self._get_dash_request_payload(direction_name, dialog_id, self.app_id, request_params=parameters,
|
||||
custom_input=custom_input)
|
||||
cmd = {
|
||||
"header": self.header,
|
||||
"payload": self.payload
|
||||
}
|
||||
return json.dumps(cmd)
|
||||
|
||||
def _get_dash_request_header(self, action: str):
|
||||
"""
|
||||
构建多模对话请求的请求协议Header
|
||||
:param action: ActionType 百炼协议action 支持:run-task, continue-task, finish-task
|
||||
"""
|
||||
if self.task_id is None:
|
||||
self.task_id = get_random_uuid()
|
||||
self.header = DashHeader(action=action, task_id=self.task_id).to_dict()
|
||||
|
||||
def _get_dash_request_payload(self, direction_name: str,
|
||||
dialog_id: str, app_id: str, workspace_id: str = None,
|
||||
request_params: RequestParameters = None, custom_input=None, model: str = None):
|
||||
"""
|
||||
构建多模对话请求的请求协议payload
|
||||
:param direction_name: 对话协议内部的指令名称
|
||||
:param dialog_id: 对话ID.
|
||||
:param app_id: 管控台应用id
|
||||
:param request_params: start请求body中的parameters
|
||||
:param custom_input: 自定义输入
|
||||
:param model: 模型
|
||||
"""
|
||||
if custom_input is not None:
|
||||
input = custom_input
|
||||
else:
|
||||
input = RequestBodyInput(
|
||||
workspace_id=workspace_id,
|
||||
app_id=app_id,
|
||||
directive=direction_name,
|
||||
dialog_id=dialog_id
|
||||
)
|
||||
|
||||
self.payload = DashPayload(
|
||||
model=model,
|
||||
input=input,
|
||||
parameters=request_params
|
||||
).to_dict()
|
||||
|
||||
|
||||
class _Response:
|
||||
def __init__(self, state: dialog_state.StateMachine, callback: MultiModalCallback, close_callback=None):
|
||||
super().__init__()
|
||||
self.dialog_id = None # 对话ID.
|
||||
self.dialog_state = state
|
||||
self._callback = callback
|
||||
self._close_callback = close_callback # 保存关闭回调函数
|
||||
|
||||
def handle_text_response(self, response_json: str):
|
||||
"""
|
||||
处理语音聊天服务的响应数据.
|
||||
:param response_json: 从服务接收到的原始JSON字符串响应。
|
||||
"""
|
||||
logger.info("<<<<<< server response: %s" % response_json)
|
||||
try:
|
||||
# 尝试将消息解析为JSON
|
||||
json_data = json.loads(response_json)
|
||||
if "status_code" in json_data["header"] and json_data["header"]["status_code"] != 200:
|
||||
logger.error("Server returned invalid message: %s" % response_json)
|
||||
if self._callback:
|
||||
self._callback.on_error(response_json)
|
||||
return
|
||||
if "event" in json_data["header"] and json_data["header"]["event"] == "task-failed":
|
||||
logger.error("Server returned invalid message: %s" % response_json)
|
||||
if self._callback:
|
||||
self._callback.on_error(response_json)
|
||||
return None
|
||||
|
||||
payload = json_data["payload"]
|
||||
if "output" in payload and payload["output"] is not None:
|
||||
response_event = payload["output"]["event"]
|
||||
logger.info("Server response event: %s" % response_event)
|
||||
self._handle_text_response_in_conversation(response_event=response_event, response_json=json_data)
|
||||
del json_data
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to parse message as JSON.")
|
||||
|
||||
def _handle_text_response_in_conversation(self, response_event: str, response_json: dict):
|
||||
payload = response_json["payload"]
|
||||
try:
|
||||
if response_event == RESPONSE_NAME_STARTED:
|
||||
self._handle_started(payload["output"])
|
||||
elif response_event == RESPONSE_NAME_STOPPED:
|
||||
self._handle_stopped()
|
||||
elif response_event == RESPONSE_NAME_STATE_CHANGED:
|
||||
self._handle_state_changed(payload["output"]["state"])
|
||||
logger.debug("service response change state: %s" % payload["output"]["state"])
|
||||
elif response_event == RESPONSE_NAME_REQUEST_ACCEPTED:
|
||||
self._handle_request_accepted()
|
||||
elif response_event == RESPONSE_NAME_SPEECH_STARTED:
|
||||
self._handle_speech_started()
|
||||
elif response_event == RESPONSE_NAME_SPEECH_ENDED:
|
||||
self._handle_speech_ended()
|
||||
elif response_event == RESPONSE_NAME_RESPONDING_STARTED:
|
||||
self._handle_responding_started()
|
||||
elif response_event == RESPONSE_NAME_RESPONDING_ENDED:
|
||||
self._handle_responding_ended(payload)
|
||||
elif response_event == RESPONSE_NAME_SPEECH_CONTENT:
|
||||
self._handle_speech_content(payload)
|
||||
elif response_event == RESPONSE_NAME_RESPONDING_CONTENT:
|
||||
self._handle_responding_content(payload)
|
||||
elif response_event == RESPONSE_NAME_ERROR:
|
||||
self._callback.on_error(json.dumps(response_json))
|
||||
elif response_event == RESPONSE_NAME_HEART_BEAT:
|
||||
logger.debug("Server response heart beat")
|
||||
else:
|
||||
logger.error("Unknown response name: {}", response_event)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to parse message as JSON.")
|
||||
|
||||
def handle_binary_response(self, message: bytes):
|
||||
# logger.debug('<<<recv binary {}'.format(len(message)))
|
||||
self._callback.on_speech_audio_data(message)
|
||||
|
||||
def _handle_request_accepted(self):
|
||||
self._callback.on_request_accepted()
|
||||
|
||||
def _handle_started(self, payload: dict):
|
||||
self.dialog_id = payload["dialog_id"]
|
||||
self._callback.on_started(self.dialog_id)
|
||||
|
||||
def _handle_stopped(self):
|
||||
self._callback.on_stopped()
|
||||
if self._close_callback is not None:
|
||||
self._close_callback()
|
||||
|
||||
def _handle_state_changed(self, state: str):
|
||||
"""
|
||||
处理语音聊天状态流转.
|
||||
:param state: 状态.
|
||||
"""
|
||||
self.dialog_state.change_state(state)
|
||||
self._callback.on_state_changed(self.dialog_state.get_current_state())
|
||||
|
||||
def _handle_speech_started(self):
|
||||
self._callback.on_speech_started()
|
||||
|
||||
def _handle_speech_ended(self):
|
||||
self._callback.on_speech_ended()
|
||||
|
||||
def _handle_responding_started(self):
|
||||
self._callback.on_responding_started()
|
||||
|
||||
def _handle_responding_ended(self, payload: dict):
|
||||
self._callback.on_responding_ended(payload)
|
||||
|
||||
def _handle_speech_content(self, payload: dict):
|
||||
self._callback.on_speech_content(payload)
|
||||
|
||||
def _handle_responding_content(self, payload: dict):
|
||||
self._callback.on_responding_content(payload)
|
||||
@@ -0,0 +1,313 @@
|
||||
from dataclasses import dataclass, field, asdict
|
||||
import uuid
|
||||
|
||||
|
||||
def get_random_uuid() -> str:
|
||||
"""生成并返回32位UUID字符串"""
|
||||
return uuid.uuid4().hex
|
||||
|
||||
|
||||
@dataclass
|
||||
class DashHeader:
|
||||
action: str
|
||||
task_id: str = field(default=get_random_uuid())
|
||||
streaming: str = field(default="duplex") # 默认为 duplex
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"action": self.action,
|
||||
"task_id": self.task_id,
|
||||
"request_id": self.task_id,
|
||||
"streaming": self.streaming
|
||||
}
|
||||
|
||||
|
||||
class DashPayloadParameters:
|
||||
def to_dict(self):
|
||||
pass
|
||||
|
||||
|
||||
class DashPayloadInput:
|
||||
def to_dict(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DashPayload:
|
||||
task_group: str = field(default="aigc")
|
||||
function: str = field(default="generation")
|
||||
model: str = field(default="")
|
||||
task: str = field(default="multimodal-generation")
|
||||
parameters: DashPayloadParameters = field(default=None)
|
||||
input: DashPayloadInput = field(default=None)
|
||||
|
||||
def to_dict(self):
|
||||
payload = {
|
||||
"task_group": self.task_group,
|
||||
"function": self.function,
|
||||
"model": self.model,
|
||||
"task": self.task,
|
||||
}
|
||||
|
||||
if self.parameters is not None:
|
||||
payload["parameters"] = self.parameters.to_dict()
|
||||
|
||||
if self.input is not None:
|
||||
payload["input"] = self.input.to_dict()
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestBodyInput(DashPayloadInput):
|
||||
workspace_id: str
|
||||
app_id: str
|
||||
directive: str
|
||||
dialog_id: str
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"workspace_id": self.workspace_id,
|
||||
"app_id": self.app_id,
|
||||
"directive": self.directive,
|
||||
"dialog_id": self.dialog_id
|
||||
}
|
||||
@dataclass
|
||||
class AsrPostProcessing:
|
||||
replace_words: list = field(default=None)
|
||||
|
||||
def to_dict(self):
|
||||
if self.replace_words is None:
|
||||
return None
|
||||
if len(self.replace_words) == 0:
|
||||
return None
|
||||
return {
|
||||
"replace_words": [word.to_dict() for word in self.replace_words]
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class ReplaceWord:
|
||||
source: str = field(default=None)
|
||||
target: str = field(default=None)
|
||||
match_mode: str = field(default=None)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"source": self.source,
|
||||
"target": self.target,
|
||||
"match_mode": self.match_mode
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class Upstream:
|
||||
"""struct for upstream"""
|
||||
audio_format: str = field(default="pcm") # 上行语音格式,默认pcm.支持pcm/opus
|
||||
type: str = field(default="AudioOnly") # 上行类型:AudioOnly 仅语音通话; AudioAndVideo 上传视频
|
||||
mode: str = field(default="tap2talk") # 客户端交互模式 push2talk/tap2talk/duplex
|
||||
sample_rate: int = field(default=16000) # 音频采样率
|
||||
vocabulary_id: str = field(default=None)
|
||||
asr_post_processing: AsrPostProcessing = field(default=None)
|
||||
pass_through_params: dict = field(default=None)
|
||||
|
||||
def to_dict(self):
|
||||
upstream: dict = {
|
||||
"type": self.type,
|
||||
"mode": self.mode,
|
||||
"audio_format": self.audio_format,
|
||||
"sample_rate": self.sample_rate,
|
||||
"vocabulary_id": self.vocabulary_id,
|
||||
}
|
||||
if self.asr_post_processing is not None:
|
||||
upstream["asr_post_processing"] = self.asr_post_processing.to_dict()
|
||||
|
||||
if self.pass_through_params is not None:
|
||||
upstream.update(self.pass_through_params)
|
||||
return upstream
|
||||
|
||||
|
||||
@dataclass
|
||||
class Downstream:
|
||||
# transcript 返回用户语音识别结果
|
||||
# dialog 返回对话系统回答中间结果
|
||||
# 可以设置多种,以逗号分割,默认为transcript
|
||||
voice: str = field(default="") # 语音音色
|
||||
sample_rate: int = field(default=0) # 语音音色 # 合成音频采样率
|
||||
intermediate_text: str = field(default="transcript") # 控制返回给用户那些中间文本:
|
||||
debug: bool = field(default=False) # 控制是否返回debug信息
|
||||
# type_: str = field(default="Audio", metadata={"alias": "type"}) # 下行类型:Text:不需要下发语音;Audio:输出语音,默认值
|
||||
audio_format: str = field(default="pcm") # 下行语音格式,默认pcm 。支持pcm/mp3
|
||||
volume: int = field(default=50) # 语音音量 0-100
|
||||
pitch_rate: int = field(default=100) # 语音语调 50-200
|
||||
speech_rate: int = field(default=100) # 语音语速 50-200
|
||||
pass_through_params: dict = field(default=None)
|
||||
|
||||
def to_dict(self):
|
||||
stream: dict = {
|
||||
"intermediate_text": self.intermediate_text,
|
||||
"debug": self.debug,
|
||||
# "type": self.type_,
|
||||
"audio_format": self.audio_format,
|
||||
"volume": self.volume,
|
||||
"pitch_rate": self.pitch_rate,
|
||||
"speech_rate": self.speech_rate
|
||||
}
|
||||
if self.voice != "":
|
||||
stream["voice"] = self.voice
|
||||
if self.sample_rate != 0:
|
||||
stream["sample_rate"] = self.sample_rate
|
||||
if self.pass_through_params is not None:
|
||||
stream.update(self.pass_through_params)
|
||||
return stream
|
||||
|
||||
|
||||
@dataclass
|
||||
class DialogAttributes:
|
||||
agent_id: str = field(default=None)
|
||||
prompt: str = field(default=None)
|
||||
vocabulary_id: str = field(default=None)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"agent_id": self.agent_id,
|
||||
"prompt": self.prompt,
|
||||
"vocabulary_id": self.vocabulary_id
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Locations:
|
||||
city_name: str = field(default=None)
|
||||
latitude: str = field(default=None)
|
||||
longitude: str = field(default=None)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"city_name": self.city_name,
|
||||
"latitude": self.latitude,
|
||||
"longitude": self.longitude
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Network:
|
||||
ip: str = field(default=None)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"ip": self.ip
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Device:
|
||||
uuid: str = field(default=None)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"uuid": self.uuid
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientInfo:
|
||||
user_id: str
|
||||
device: Device = field(default=None)
|
||||
network: Network = field(default=None)
|
||||
location: Locations = field(default=None)
|
||||
|
||||
def to_dict(self):
|
||||
info = {
|
||||
"user_id": self.user_id,
|
||||
"sdk": "python"
|
||||
}
|
||||
if self.device is not None:
|
||||
info["device"] = self.device.to_dict()
|
||||
if self.network is not None:
|
||||
info["network"] = self.network.to_dict()
|
||||
if self.location is not None:
|
||||
info["location"] = self.location.to_dict()
|
||||
return info
|
||||
|
||||
|
||||
@dataclass
|
||||
class BizParams:
|
||||
user_defined_params: dict = field(default=None)
|
||||
user_defined_tokens: dict = field(default=None)
|
||||
tool_prompts: dict = field(default=None)
|
||||
user_prompt_params: dict = field(default=None)
|
||||
user_query_params: dict = field(default=None)
|
||||
videos: list = field(default=None)
|
||||
pass_through_params: dict = field(default=None)
|
||||
|
||||
def to_dict(self):
|
||||
params = {}
|
||||
if self.user_defined_params is not None:
|
||||
params["user_defined_params"] = self.user_defined_params
|
||||
if self.user_defined_tokens is not None:
|
||||
params["user_defined_tokens"] = self.user_defined_tokens
|
||||
if self.tool_prompts is not None:
|
||||
params["tool_prompts"] = self.tool_prompts
|
||||
if self.user_prompt_params is not None:
|
||||
params["user_prompt_params"] = self.user_prompt_params
|
||||
if self.user_query_params is not None:
|
||||
params["user_query_params"] = self.user_query_params
|
||||
if self.videos is not None:
|
||||
params["videos"] = self.videos
|
||||
if self.pass_through_params is not None:
|
||||
params.update(self.pass_through_params)
|
||||
return params
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestParameters(DashPayloadParameters):
|
||||
upstream: Upstream
|
||||
downstream: Downstream
|
||||
client_info: ClientInfo
|
||||
dialog_attributes: DialogAttributes = field(default=None)
|
||||
biz_params: BizParams = field(default=None)
|
||||
|
||||
def to_dict(self):
|
||||
params = {
|
||||
"upstream": self.upstream.to_dict(),
|
||||
"downstream": self.downstream.to_dict(),
|
||||
"client_info": self.client_info.to_dict(),
|
||||
}
|
||||
|
||||
if self.dialog_attributes is not None:
|
||||
params["dialog_attributes"] = self.dialog_attributes.to_dict()
|
||||
if self.biz_params is not None:
|
||||
params["biz_params"] = self.biz_params.to_dict()
|
||||
return params
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestToRespondParameters(DashPayloadParameters):
|
||||
images: list = field(default=None)
|
||||
biz_params: BizParams = field(default=None)
|
||||
|
||||
def to_dict(self):
|
||||
params = {
|
||||
}
|
||||
if self.images is not None:
|
||||
params["images"] = self.images
|
||||
if self.biz_params is not None:
|
||||
params["biz_params"] = self.biz_params.to_dict()
|
||||
return params
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestToRespondBodyInput(DashPayloadInput):
|
||||
app_id: str
|
||||
directive: str
|
||||
dialog_id: str
|
||||
type_: str = field(metadata={"alias": "type"}, default= None)
|
||||
text: str = field(default="")
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"app_id": self.app_id,
|
||||
"directive": self.directive,
|
||||
"dialog_id": self.dialog_id,
|
||||
"type": self.type_,
|
||||
"text": self.text
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .tingwu import TingWu
|
||||
from .tingwu_realtime import TingWuRealtime, TingWuRealtimeCallback
|
||||
|
||||
__all__ = [
|
||||
'TingWu',
|
||||
'TingWuRealtime',
|
||||
'TingWuRealtimeCallback'
|
||||
]
|
||||
@@ -0,0 +1,80 @@
|
||||
from typing import Dict, Any
|
||||
|
||||
from dashscope.api_entities.api_request_factory import _build_api_request
|
||||
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
|
||||
from dashscope.client.base_api import BaseApi
|
||||
from dashscope.common.error import ModelRequired
|
||||
|
||||
|
||||
class TingWu(BaseApi):
|
||||
"""API for TingWu APP.
|
||||
|
||||
"""
|
||||
|
||||
task = None
|
||||
task_group = None
|
||||
function = None
|
||||
|
||||
@classmethod
|
||||
def call(
|
||||
cls,
|
||||
model: str,
|
||||
user_defined_input: Dict[str, Any],
|
||||
parameters: Dict[str, Any] = None,
|
||||
api_key: str = None,
|
||||
**kwargs
|
||||
) -> DashScopeAPIResponse:
|
||||
"""Call generation model service.
|
||||
|
||||
Args:
|
||||
model (str): The requested model, such as qwen-turbo
|
||||
api_key (str, optional): The api api_key, can be None,
|
||||
if None, will get by default rule(TODO: api key doc).
|
||||
user_defined_input: custom input
|
||||
parameters: custom parameters
|
||||
**kwargs:
|
||||
base_address: base address
|
||||
additional parameters for request
|
||||
|
||||
Raises:
|
||||
InvalidInput: The history and auto_history are mutually exclusive.
|
||||
|
||||
Returns:
|
||||
Union[GenerationResponse,
|
||||
Generator[GenerationResponse, None, None]]: If
|
||||
stream is True, return Generator, otherwise GenerationResponse.
|
||||
"""
|
||||
if model is None or not model:
|
||||
raise ModelRequired('Model is required!')
|
||||
input_config, parameters = cls._build_input_parameters(input_config=user_defined_input,
|
||||
params=parameters,
|
||||
**kwargs)
|
||||
|
||||
request = _build_api_request(
|
||||
model=model,
|
||||
input=input_config,
|
||||
api_key=api_key,
|
||||
task_group=TingWu.task_group,
|
||||
task=TingWu.task,
|
||||
function=TingWu.function,
|
||||
is_service=False,
|
||||
**parameters)
|
||||
response = request.call()
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def _build_input_parameters(cls,
|
||||
input_config,
|
||||
params: Dict[str, Any] = None,
|
||||
**kwargs):
|
||||
parameters = {}
|
||||
if params is not None:
|
||||
parameters = params
|
||||
|
||||
input_param = input_config
|
||||
|
||||
if kwargs.keys() is not None:
|
||||
for key in kwargs.keys():
|
||||
parameters[key] = kwargs[key]
|
||||
return input_param, {**parameters, **kwargs}
|
||||
@@ -0,0 +1,579 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import json
|
||||
import platform
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from queue import Queue
|
||||
import dashscope
|
||||
from dashscope.client.base_api import BaseApi
|
||||
from dashscope.common.error import (InvalidParameter, ModelRequired)
|
||||
import websocket
|
||||
|
||||
from dashscope.common.logging import logger
|
||||
from dashscope.protocol.websocket import ActionType
|
||||
|
||||
|
||||
class TingWuRealtimeCallback:
|
||||
"""An interface that defines callback methods for getting TingWu results.
|
||||
Derive from this class and implement its function to provide your own data.
|
||||
"""
|
||||
|
||||
def on_open(self) -> None:
|
||||
pass
|
||||
|
||||
def on_started(self, task_id: str) -> None:
|
||||
pass
|
||||
|
||||
def on_speech_listen(self, result: dict):
|
||||
pass
|
||||
|
||||
def on_recognize_result(self, result: dict):
|
||||
pass
|
||||
|
||||
def on_ai_result(self, result: dict):
|
||||
pass
|
||||
|
||||
def on_stopped(self) -> None:
|
||||
pass
|
||||
|
||||
def on_error(self, error_code: str, error_msg: str) -> None:
|
||||
pass
|
||||
|
||||
def on_close(self, close_status_code, close_msg):
|
||||
"""
|
||||
callback when websocket connection is closed
|
||||
|
||||
:param close_status_code
|
||||
:param close_msg
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TingWuRealtime(BaseApi):
|
||||
"""TingWuRealtime interface.
|
||||
|
||||
Args:
|
||||
model (str): The requested model_id.
|
||||
callback (TingWuRealtimeCallback): A callback that returns
|
||||
speech recognition results.
|
||||
app_id (str): The dashscope tingwu app id.
|
||||
format (str): The input audio format for TingWu request.
|
||||
sample_rate (int): The input audio sample rate.
|
||||
terminology (str): The correct instruction set id.
|
||||
workspace (str): The dashscope workspace id.
|
||||
|
||||
**kwargs:
|
||||
max_end_silence (int): The maximum end silence time.
|
||||
other_params (dict, `optional`): Other parameters.
|
||||
|
||||
Raises:
|
||||
InputRequired: Input is required.
|
||||
"""
|
||||
|
||||
SILENCE_TIMEOUT_S = 60
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
callback: TingWuRealtimeCallback,
|
||||
audio_format: str = "pcm",
|
||||
sample_rate: int = 16000,
|
||||
max_end_silence: int = None,
|
||||
app_id: str = None,
|
||||
terminology: str = None,
|
||||
workspace: str = None,
|
||||
api_key: str = None,
|
||||
base_address: str = None,
|
||||
data_id: str = None,
|
||||
**kwargs):
|
||||
if api_key is None:
|
||||
self.api_key = dashscope.api_key
|
||||
else:
|
||||
self.api_key = api_key
|
||||
if base_address is None:
|
||||
self.base_address = dashscope.base_websocket_api_url
|
||||
else:
|
||||
self.base_address = base_address
|
||||
|
||||
if model is None:
|
||||
raise ModelRequired('Model is required!')
|
||||
|
||||
self.data_id = data_id
|
||||
self.max_end_silence = max_end_silence
|
||||
self.model = model
|
||||
self.audio_format = audio_format
|
||||
self.app_id = app_id
|
||||
self.terminology = terminology
|
||||
self.sample_rate = sample_rate
|
||||
# continuous recognition with start() or once recognition with call()
|
||||
self._recognition_once = False
|
||||
self._callback = callback
|
||||
self._running = False
|
||||
self._stream_data = Queue()
|
||||
self._worker = None
|
||||
self._silence_timer = None
|
||||
self._kwargs = kwargs
|
||||
self._workspace = workspace
|
||||
self._start_stream_timestamp = -1
|
||||
self._first_package_timestamp = -1
|
||||
self._stop_stream_timestamp = -1
|
||||
self._on_complete_timestamp = -1
|
||||
self.request_id_confirmed = False
|
||||
self.last_request_id = uuid.uuid4().hex
|
||||
self.request = _Request()
|
||||
self.response = _TingWuResponse(self._callback, self.close) # 传递 self.close 作为回调
|
||||
|
||||
def _on_message(self, ws, message):
|
||||
logger.debug(f"<<<<<<< Received message: {message}")
|
||||
if isinstance(message, str):
|
||||
self.response.handle_text_response(message)
|
||||
elif isinstance(message, (bytes, bytearray)):
|
||||
self.response.handle_binary_response(message)
|
||||
|
||||
def _on_error(self, ws, error):
|
||||
logger.error(f"Error: {error}")
|
||||
if self._callback:
|
||||
error_code = "" # 默认错误码
|
||||
if "connection" in str(error).lower():
|
||||
error_code = "1001" # 连接错误
|
||||
elif "timeout" in str(error).lower():
|
||||
error_code = "1002" # 超时错误
|
||||
elif "authentication" in str(error).lower():
|
||||
error_code = "1003" # 认证错误
|
||||
self._callback.on_error(error_code=error_code, error_msg=str(error))
|
||||
|
||||
def _on_close(self, ws, close_status_code, close_msg):
|
||||
try:
|
||||
logger.debug(
|
||||
"WebSocket connection closed with status {} and message {}".format(close_status_code, close_msg))
|
||||
if close_status_code is None:
|
||||
close_status_code = 1000
|
||||
if close_msg is None:
|
||||
close_msg = "websocket is closed"
|
||||
self._callback.on_close(close_status_code, close_msg)
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
|
||||
def _on_open(self, ws):
|
||||
self._callback.on_open()
|
||||
self._running = True
|
||||
|
||||
# def _on_pong(self):
|
||||
# logger.debug("on pong")
|
||||
|
||||
def start(self, **kwargs):
|
||||
"""
|
||||
interface for starting TingWu connection
|
||||
"""
|
||||
assert self._callback is not None, 'Please set the callback to get the TingWu result.' # noqa E501
|
||||
|
||||
if self._running:
|
||||
raise InvalidParameter('TingWu client has started.')
|
||||
|
||||
# self._start_stream_timestamp = -1
|
||||
# self._first_package_timestamp = -1
|
||||
# self._stop_stream_timestamp = -1
|
||||
# self._on_complete_timestamp = -1
|
||||
if self._kwargs is not None and len(self._kwargs) != 0:
|
||||
self._kwargs.update(**kwargs)
|
||||
|
||||
self._connect(self.api_key)
|
||||
logger.debug("connected with server.")
|
||||
self._send_start_request()
|
||||
|
||||
def send_audio_data(self, speech_data: bytes):
|
||||
"""send audio data to server"""
|
||||
if self._running:
|
||||
self.__send_binary_frame(speech_data)
|
||||
|
||||
def stop(self):
|
||||
if self.ws is None or not self.ws.sock or not self.ws.sock.connected:
|
||||
self._callback.on_close(1001, "websocket is not connected")
|
||||
return
|
||||
_send_speech_json = self.request.generate_stop_request("stop")
|
||||
self._send_text_frame(_send_speech_json)
|
||||
|
||||
"""inner class"""
|
||||
|
||||
def _send_start_request(self):
|
||||
"""send start request"""
|
||||
_start_json = self.request.generate_start_request(
|
||||
workspace_id=self._workspace,
|
||||
direction_name="start",
|
||||
app_id=self.app_id,
|
||||
model=self.model,
|
||||
audio_format=self.audio_format,
|
||||
sample_rate=self.sample_rate,
|
||||
terminology=self.terminology,
|
||||
max_end_silence=self.max_end_silence,
|
||||
data_id=self.data_id,
|
||||
**self._kwargs
|
||||
)
|
||||
# send start request
|
||||
self._send_text_frame(_start_json)
|
||||
|
||||
def _run_forever(self):
|
||||
self.ws.run_forever(ping_interval=5, ping_timeout=4)
|
||||
|
||||
def _connect(self, api_key: str):
|
||||
"""init websocket connection"""
|
||||
self.ws = websocket.WebSocketApp(self.base_address, header=self.request.get_websocket_header(api_key),
|
||||
on_open=self._on_open,
|
||||
on_message=self._on_message,
|
||||
on_error=self._on_error,
|
||||
on_close=self._on_close)
|
||||
self.thread = threading.Thread(target=self._run_forever)
|
||||
# 统一心跳机制配置
|
||||
self.ws.ping_interval = 5
|
||||
self.ws.ping_timeout = 4
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
self._wait_for_connection()
|
||||
|
||||
def close(self):
|
||||
if self.ws is None or not self.ws.sock or not self.ws.sock.connected:
|
||||
return
|
||||
self.ws.close()
|
||||
|
||||
def _wait_for_connection(self):
|
||||
"""wait for connection using event instead of busy waiting"""
|
||||
timeout = 5
|
||||
start_time = time.time()
|
||||
while not (self.ws.sock and self.ws.sock.connected) and (time.time() - start_time) < timeout:
|
||||
time.sleep(0.1) # 短暂休眠,避免密集轮询
|
||||
|
||||
def _send_text_frame(self, text: str):
|
||||
# 避免在日志中记录敏感信息,如API密钥等
|
||||
# 只记录非敏感信息
|
||||
if '"Authorization"' not in text:
|
||||
logger.info('>>>>>> send text frame : %s' % text)
|
||||
else:
|
||||
logger.info('>>>>>> send text frame with authorization header')
|
||||
self.ws.send(text, websocket.ABNF.OPCODE_TEXT)
|
||||
|
||||
def __send_binary_frame(self, binary: bytes):
|
||||
# _log.info('send binary frame length: %d' % len(binary))
|
||||
self.ws.send(binary, websocket.ABNF.OPCODE_BINARY)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.cleanup()
|
||||
return False
|
||||
|
||||
def cleanup(self):
|
||||
"""cleanup resources"""
|
||||
try:
|
||||
if self.ws:
|
||||
self.ws.close()
|
||||
if self.thread and self.thread.is_alive():
|
||||
# 设置标志位通知线程退出
|
||||
self.thread.join(timeout=2)
|
||||
# 清除引用
|
||||
self.ws = None
|
||||
self.thread = None
|
||||
self._callback = None
|
||||
self.response = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup: {e}")
|
||||
|
||||
def send_audio_frame(self, buffer: bytes):
|
||||
"""Push audio to server
|
||||
|
||||
Raises:
|
||||
InvalidParameter: Cannot send data to an uninitiated recognition.
|
||||
"""
|
||||
if self._running is False:
|
||||
raise InvalidParameter('TingWu client has stopped.')
|
||||
|
||||
if self._start_stream_timestamp < 0:
|
||||
self._start_stream_timestamp = time.time() * 1000
|
||||
logger.debug('send_audio_frame: {}'.format(len(buffer)))
|
||||
self.__send_binary_frame(buffer)
|
||||
|
||||
|
||||
class _Request:
|
||||
def __init__(self):
|
||||
# websocket header
|
||||
self.ws_headers = None
|
||||
# request body for voice chat
|
||||
self.header = None
|
||||
self.payload = None
|
||||
# params
|
||||
self.task_id = None
|
||||
self.app_id = None
|
||||
self.workspace_id = None
|
||||
|
||||
def get_websocket_header(self, api_key):
|
||||
ua = 'dashscope/%s; python/%s; platform/%s; processor/%s' % (
|
||||
'1.18.0', # dashscope version
|
||||
platform.python_version(),
|
||||
platform.platform(),
|
||||
platform.processor(),
|
||||
)
|
||||
self.ws_headers = {
|
||||
"User-Agent": ua,
|
||||
"Authorization": f"bearer {api_key}",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
logger.info('websocket header: {}'.format(self.ws_headers))
|
||||
return self.ws_headers
|
||||
|
||||
def generate_start_request(self, direction_name: str,
|
||||
app_id: str,
|
||||
model: str = None,
|
||||
workspace_id: str = None,
|
||||
audio_format: str = None,
|
||||
sample_rate: int = None,
|
||||
terminology: str = None,
|
||||
max_end_silence: int = None,
|
||||
data_id: str = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
build start request.
|
||||
:param app_id: web console app id
|
||||
:param direction_name:
|
||||
:param workspace_id: web console workspace id
|
||||
:param model: model name
|
||||
:param audio_format: audio format
|
||||
:param sample_rate: sample rate
|
||||
:param terminology:
|
||||
:param max_end_silence:
|
||||
:param data_id:
|
||||
:return:
|
||||
Args:
|
||||
:
|
||||
"""
|
||||
self._get_dash_request_header(ActionType.START)
|
||||
parameters = self._get_start_parameters(audio_format=audio_format, sample_rate=sample_rate,
|
||||
max_end_silence=max_end_silence,
|
||||
terminology=terminology,
|
||||
**kwargs)
|
||||
self._get_dash_request_payload(direction_name=direction_name, app_id=app_id, workspace_id=workspace_id,
|
||||
model=model,
|
||||
data_id=data_id,
|
||||
request_params=parameters)
|
||||
|
||||
cmd = {
|
||||
"header": self.header,
|
||||
"payload": self.payload
|
||||
}
|
||||
return json.dumps(cmd)
|
||||
|
||||
@staticmethod
|
||||
def _get_start_parameters(audio_format: str = None,
|
||||
sample_rate: int = None,
|
||||
terminology: str = None,
|
||||
max_end_silence: int = None,
|
||||
**kwargs):
|
||||
"""
|
||||
build start request parameters inner.
|
||||
:param kwargs: parameters
|
||||
:return
|
||||
"""
|
||||
parameters = {}
|
||||
if audio_format is not None:
|
||||
parameters['format'] = audio_format
|
||||
if sample_rate is not None:
|
||||
parameters['sampleRate'] = sample_rate
|
||||
if terminology is not None:
|
||||
parameters['terminology'] = terminology
|
||||
if max_end_silence is not None:
|
||||
parameters['maxEndSilence'] = max_end_silence
|
||||
if kwargs is not None and len(kwargs) != 0:
|
||||
parameters.update(kwargs)
|
||||
return parameters
|
||||
|
||||
def generate_stop_request(self, direction_name: str) -> str:
|
||||
"""
|
||||
build stop request.
|
||||
:param direction_name
|
||||
:return
|
||||
"""
|
||||
self._get_dash_request_header(ActionType.FINISHED)
|
||||
self._get_dash_request_payload(direction_name, self.app_id)
|
||||
|
||||
cmd = {
|
||||
"header": self.header,
|
||||
"payload": self.payload
|
||||
}
|
||||
return json.dumps(cmd)
|
||||
|
||||
def _get_dash_request_header(self, action: str):
|
||||
"""
|
||||
:param action: ActionType :run-task, continue-task, finish-task
|
||||
"""
|
||||
if self.task_id is None:
|
||||
self.task_id = get_random_uuid()
|
||||
self.header = DashHeader(action=action, task_id=self.task_id).to_dict()
|
||||
|
||||
def _get_dash_request_payload(self, direction_name: str,
|
||||
app_id: str,
|
||||
workspace_id: str = None,
|
||||
custom_input=None,
|
||||
model: str = None,
|
||||
data_id: str = None,
|
||||
request_params=None,
|
||||
):
|
||||
"""
|
||||
build start request payload inner.
|
||||
:param direction_name: inner direction name
|
||||
:param app_id: web console app id
|
||||
:param request_params: start direction body parameters
|
||||
:param custom_input: user custom input
|
||||
:param data_id: data id
|
||||
:param model: model name
|
||||
"""
|
||||
if custom_input is not None:
|
||||
input = custom_input
|
||||
else:
|
||||
input = RequestBodyInput(
|
||||
workspace_id=workspace_id,
|
||||
app_id=app_id,
|
||||
directive=direction_name,
|
||||
data_id=data_id
|
||||
)
|
||||
|
||||
self.payload = DashPayload(
|
||||
model=model,
|
||||
input=input.to_dict(),
|
||||
parameters=request_params
|
||||
).to_dict()
|
||||
|
||||
|
||||
class _TingWuResponse:
|
||||
def __init__(self, callback: TingWuRealtimeCallback, close_callback=None):
|
||||
super().__init__()
|
||||
self.task_id = None # 对话ID.
|
||||
self._callback = callback
|
||||
self._close_callback = close_callback # 保存关闭回调函数
|
||||
|
||||
def handle_text_response(self, response_json: str):
|
||||
"""
|
||||
handle text response.
|
||||
:param response_json: json format response from server
|
||||
"""
|
||||
logger.info("<<<<<< server response: %s" % response_json)
|
||||
try:
|
||||
# try to parse response as json
|
||||
json_data = json.loads(response_json)
|
||||
header = json_data.get('header', {})
|
||||
if header.get('event') == 'task-failed':
|
||||
logger.error('Server returned invalid message: %s' % response_json)
|
||||
if self._callback:
|
||||
self._callback.on_error(error_code=header.get('error_code'),
|
||||
error_msg=header.get('error_message'))
|
||||
return
|
||||
if header.get('event') == "task-started":
|
||||
self._handle_started(header.get('task_id'))
|
||||
return
|
||||
|
||||
payload = json_data.get('payload', {})
|
||||
output = payload.get('output', {})
|
||||
if output is not None:
|
||||
action = output.get('action')
|
||||
logger.info("Server response action: %s" % action)
|
||||
self._handle_tingwu_agent_text_response(action=action, response_json=json_data)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to parse message as JSON.")
|
||||
|
||||
def handle_binary_response(self, response_binary: bytes):
|
||||
"""
|
||||
handle binary response.
|
||||
:param response_binary: server response binary。
|
||||
"""
|
||||
logger.info("<<<<<< server response binary length: %d" % len(response_binary))
|
||||
|
||||
def _handle_tingwu_agent_text_response(self, action: str, response_json: dict):
|
||||
payload = response_json.get('payload', {})
|
||||
output = payload.get('output', {})
|
||||
if action == "task-failed":
|
||||
self._callback.on_error(error_code=output.get('errorCode'),
|
||||
error_msg=output.get('errorMessage'))
|
||||
elif action == "speech-listen":
|
||||
self._callback.on_speech_listen(response_json)
|
||||
elif action == "recognize-result":
|
||||
self._callback.on_recognize_result(response_json)
|
||||
elif action == "ai-result":
|
||||
self._callback.on_ai_result(response_json)
|
||||
elif action == "speech-end": # ai-result事件永远会先于speech-end事件
|
||||
self._callback.on_stopped()
|
||||
if self._close_callback is not None:
|
||||
self._close_callback()
|
||||
else:
|
||||
logger.info("Unknown response name:" + action)
|
||||
|
||||
def _handle_started(self, task_id: str):
|
||||
self.task_id = task_id
|
||||
self._callback.on_started(self.task_id)
|
||||
|
||||
|
||||
def get_random_uuid() -> str:
|
||||
"""generate random uuid."""
|
||||
return uuid.uuid4().hex
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestBodyInput():
|
||||
app_id: str
|
||||
directive: str
|
||||
data_id: str = field(default=None)
|
||||
workspace_id: str = field(default=None)
|
||||
|
||||
def to_dict(self):
|
||||
body_input = {
|
||||
"appId": self.app_id,
|
||||
"directive": self.directive,
|
||||
}
|
||||
if self.workspace_id is not None:
|
||||
body_input["workspace_id"] = self.workspace_id
|
||||
if self.data_id is not None:
|
||||
body_input["dataId"] = self.data_id
|
||||
return body_input
|
||||
|
||||
|
||||
@dataclass
|
||||
class DashHeader:
|
||||
action: str
|
||||
task_id: str = field(default=get_random_uuid())
|
||||
streaming: str = field(default="duplex") # 默认为 duplex
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"action": self.action,
|
||||
"task_id": self.task_id,
|
||||
"request_id": self.task_id,
|
||||
"streaming": self.streaming
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DashPayload:
|
||||
task_group: str = field(default="aigc")
|
||||
function: str = field(default="generation")
|
||||
model: str = field(default="")
|
||||
task: str = field(default="multimodal-generation")
|
||||
parameters: dict = field(default=None)
|
||||
input: dict = field(default=None)
|
||||
|
||||
def to_dict(self):
|
||||
payload = {
|
||||
"task_group": self.task_group,
|
||||
"function": self.function,
|
||||
"model": self.model,
|
||||
"task": self.task,
|
||||
}
|
||||
|
||||
if self.parameters is not None:
|
||||
payload["parameters"] = self.parameters
|
||||
|
||||
if self.input is not None:
|
||||
payload["input"] = self.input
|
||||
|
||||
return payload
|
||||
Reference in New Issue
Block a user