chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,45 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import dashscope
|
||||
from dashscope.common.constants import (DEFAULT_DASHSCOPE_API_KEY_FILE_PATH,
|
||||
DEFAULT_DASHSCOPE_CACHE_PATH)
|
||||
from dashscope.common.error import AuthenticationError
|
||||
|
||||
|
||||
def get_default_api_key():
|
||||
if dashscope.api_key is not None:
|
||||
# user set environment variable DASHSCOPE_API_KEY
|
||||
return dashscope.api_key
|
||||
elif dashscope.api_key_file_path:
|
||||
# user set environment variable DASHSCOPE_API_KEY_FILE_PATH
|
||||
with open(dashscope.api_key_file_path, 'rt',
|
||||
encoding='utf-8') as f: # open with text mode.
|
||||
return f.read().strip()
|
||||
else: # Find the api key from default key file.
|
||||
if os.path.exists(DEFAULT_DASHSCOPE_API_KEY_FILE_PATH):
|
||||
with open(DEFAULT_DASHSCOPE_API_KEY_FILE_PATH,
|
||||
'rt',
|
||||
encoding='utf-8') as f:
|
||||
return f.read().strip()
|
||||
|
||||
raise AuthenticationError(
|
||||
'No api key provided. You can set by dashscope.api_key = your_api_key in code, ' # noqa: E501
|
||||
'or you can set it via environment variable DASHSCOPE_API_KEY= your_api_key. ' # noqa: E501
|
||||
'You can store your api key to a file, and use dashscope.api_key_file_path=api_key_file_path in code, ' # noqa: E501
|
||||
'or you can set api key file path via environment variable DASHSCOPE_API_KEY_FILE_PATH, ' # noqa: E501
|
||||
'You can call save_api_key to api_key_file_path or default path(~/.dashscope/api_key).' # noqa: E501
|
||||
)
|
||||
|
||||
|
||||
def save_api_key(api_key: str, api_key_file_path: Optional[str] = None):
|
||||
if api_key_file_path is None:
|
||||
os.makedirs(DEFAULT_DASHSCOPE_CACHE_PATH, exist_ok=True)
|
||||
with open(DEFAULT_DASHSCOPE_API_KEY_FILE_PATH, 'w+') as f:
|
||||
f.write(api_key)
|
||||
else:
|
||||
os.makedirs(os.path.dirname(api_key_file_path), exist_ok=True)
|
||||
with open(api_key_file_path, 'w+') as f:
|
||||
f.write(api_key)
|
||||
@@ -0,0 +1,135 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
|
||||
import dashscope
|
||||
|
||||
|
||||
def get_object_type(name: str):
|
||||
dashscope_objects = {
|
||||
'assistant': dashscope.Assistant,
|
||||
'assistant.deleted': dashscope.DeleteResponse,
|
||||
'thread.message': dashscope.ThreadMessage,
|
||||
'thread.run': dashscope.Run,
|
||||
'thread.run.step': dashscope.RunStep,
|
||||
'thread.message.file': dashscope.MessageFile,
|
||||
'assistant.file': dashscope.AssistantFile,
|
||||
'thread': dashscope.Thread,
|
||||
}
|
||||
return dashscope_objects.get(name, None)
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class BaseObjectMixin(object):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
field_type_map = self._get_fields_type()
|
||||
for k, v in kwargs.items():
|
||||
field = field_type_map.get(k, None)
|
||||
if field and v is not None:
|
||||
if dataclasses.is_dataclass(field.type): # process dataclasses
|
||||
self.__setattr__(k, field.type(**v))
|
||||
continue
|
||||
|
||||
if isinstance(v, dict):
|
||||
object_name = v.get('object', None)
|
||||
if object_name:
|
||||
object_type = get_object_type(object_name)
|
||||
if object_type:
|
||||
self.__setattr__(k, object_type(**v))
|
||||
else:
|
||||
self.__setattr__(k, v)
|
||||
else:
|
||||
self.__setattr__(k, v)
|
||||
elif isinstance(v, list):
|
||||
obj_list = self._init_list_element_recursive(field, v)
|
||||
self.__setattr__(k, obj_list)
|
||||
else:
|
||||
self.__setattr__(k, v)
|
||||
|
||||
def _init_list_element_recursive(self, field, items: list) -> List[Any]:
|
||||
obj_list = []
|
||||
for item in items:
|
||||
if field:
|
||||
# current only support List[cls_name],
|
||||
# not support List[cls_nam1, cls_name2]
|
||||
element_type = field.type.__args__[0]
|
||||
if dataclasses.is_dataclass(element_type):
|
||||
obj_list.append(element_type(**item))
|
||||
continue
|
||||
|
||||
if isinstance(item, dict):
|
||||
object_name = item.get('object', None)
|
||||
if object_name:
|
||||
object_type = get_object_type(object_name)
|
||||
if object_type:
|
||||
obj_list.append(object_type(**item))
|
||||
else:
|
||||
obj_list.append(item)
|
||||
else:
|
||||
obj_list.append(item)
|
||||
elif isinstance(item, list):
|
||||
obj_list.append(self._init_list_element_recursive(item))
|
||||
else:
|
||||
obj_list.append(item)
|
||||
return obj_list
|
||||
|
||||
def _get_fields_type(self):
|
||||
field_type_map = {}
|
||||
if dataclasses.is_dataclass(self):
|
||||
for field in dataclasses.fields(self):
|
||||
field_type_map[field.name] = field
|
||||
return field_type_map
|
||||
|
||||
def __setitem__(self, __key: Any, __value: Any) -> None:
|
||||
return self.__setattr__(__key, __value)
|
||||
|
||||
def __getitem__(self, __key: Any) -> Any:
|
||||
return self.__getattribute__(__key)
|
||||
|
||||
def __contains__(self, item):
|
||||
return hasattr(self, item)
|
||||
|
||||
def __delitem__(self, key):
|
||||
self.__delattr__(key)
|
||||
|
||||
def _recursive_to_str__(self, input_object) -> Any:
|
||||
if isinstance(input_object, list):
|
||||
output_object = []
|
||||
for item in input_object:
|
||||
output_object.append(self._recursive_to_str__(item))
|
||||
return output_object
|
||||
elif dataclasses.is_dataclass(input_object):
|
||||
output_object = {}
|
||||
for field in dataclasses.fields(input_object):
|
||||
if hasattr(input_object, field.name):
|
||||
output_object[field.name] = self._recursive_to_str__(
|
||||
getattr(input_object, field.name))
|
||||
return output_object
|
||||
else:
|
||||
return input_object
|
||||
|
||||
def __str__(self) -> str:
|
||||
real_dict = self.__dict__
|
||||
self_fields = dataclasses.fields(self)
|
||||
for field in self_fields:
|
||||
if hasattr(self, field.name):
|
||||
real_dict[field.name] = getattr(self, field.name)
|
||||
output_object = {}
|
||||
for key, value in real_dict.items():
|
||||
output_object[key] = self._recursive_to_str__(value)
|
||||
return str(output_object)
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class BaseList(BaseObjectMixin):
|
||||
status_code: int
|
||||
has_more: bool
|
||||
last_id: str
|
||||
first_id: str
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
|
||||
DASHSCOPE_API_KEY_ENV = 'DASHSCOPE_API_KEY'
|
||||
DASHSCOPE_API_KEY_FILE_PATH_ENV = 'DASHSCOPE_API_KEY_FILE_PATH'
|
||||
DASHSCOPE_API_REGION_ENV = 'DASHSCOPE_API_REGION'
|
||||
DASHSCOPE_API_VERSION_ENV = 'DASHSCOPE_API_VERSION'
|
||||
# to disable data inspection
|
||||
# export DASHSCOPE_DISABLE_DATA_INSPECTION=true
|
||||
DASHSCOPE_DISABLE_DATA_INSPECTION_ENV = 'DASHSCOPE_DISABLE_DATA_INSPECTION'
|
||||
DEFAULT_DASHSCOPE_CACHE_PATH = Path.home().joinpath('.dashscope')
|
||||
DEFAULT_DASHSCOPE_API_KEY_FILE_PATH = Path.joinpath(
|
||||
DEFAULT_DASHSCOPE_CACHE_PATH, 'api_key')
|
||||
|
||||
DEFAULT_REQUEST_TIMEOUT_SECONDS = 300
|
||||
REQUEST_TIMEOUT_KEYWORD = 'request_timeout'
|
||||
SERVICE_API_PATH = 'services'
|
||||
DASHSCOPE_LOGGING_LEVEL_ENV = 'DASHSCOPE_LOGGING_LEVEL'
|
||||
# task config keys.
|
||||
PROMPT = 'prompt'
|
||||
MESSAGES = 'messages'
|
||||
NEGATIVE_PROMPT = 'negative_prompt'
|
||||
HISTORY = 'history'
|
||||
CUSTOMIZED_MODEL_ID = 'customized_model_id'
|
||||
IMAGES = 'images'
|
||||
TEXT_EMBEDDING_INPUT_KEY = 'texts'
|
||||
SERVICE_503_MESSAGE = 'Service temporarily unavailable, possibly overloaded or not ready.' # noqa E501
|
||||
WEBSOCKET_ERROR_CODE = 44
|
||||
SSE_CONTENT_TYPE = 'text/event-stream'
|
||||
DEPRECATED_MESSAGE = 'history and auto_history are deprecated for qwen serial models and will be remove in future, use messages' # noqa E501
|
||||
SCENE = 'scene'
|
||||
MESSAGE = 'message'
|
||||
REQUEST_CONTENT_TEXT = 'text'
|
||||
REQUEST_CONTENT_IMAGE = 'image'
|
||||
REQUEST_CONTENT_AUDIO = 'audio'
|
||||
FILE_PATH_SCHEMA = 'file://'
|
||||
|
||||
ENCRYPTION_AES_SECRET_KEY_BYTES = 32
|
||||
ENCRYPTION_AES_IV_LENGTH = 12
|
||||
|
||||
REPEATABLE_STATUS = [
|
||||
HTTPStatus.SERVICE_UNAVAILABLE, HTTPStatus.GATEWAY_TIMEOUT
|
||||
]
|
||||
|
||||
|
||||
class FilePurpose:
|
||||
fine_tune = 'fine_tune'
|
||||
assistants = 'assistants'
|
||||
|
||||
|
||||
class DeploymentStatus:
|
||||
DEPLOYING = 'DEPLOYING'
|
||||
SERVING = 'RUNNING'
|
||||
DELETING = 'DELETING'
|
||||
FAILED = 'FAILED'
|
||||
PENDING = 'PENDING'
|
||||
|
||||
|
||||
class ApiProtocol:
|
||||
WEBSOCKET = 'websocket'
|
||||
HTTP = 'http'
|
||||
HTTPS = 'https'
|
||||
|
||||
|
||||
class HTTPMethod:
|
||||
GET = 'GET'
|
||||
HEAD = 'HEAD'
|
||||
POST = 'POST'
|
||||
PUT = 'PUT'
|
||||
DELETE = 'DELETE'
|
||||
CONNECT = 'CONNECT'
|
||||
OPTIONS = 'OPTIONS'
|
||||
TRACE = 'TRACE'
|
||||
PATCH = 'PATCH'
|
||||
|
||||
|
||||
class TaskStatus:
|
||||
PENDING = 'PENDING'
|
||||
SUSPENDED = 'SUSPENDED'
|
||||
SUCCEEDED = 'SUCCEEDED'
|
||||
CANCELED = 'CANCELED'
|
||||
RUNNING = 'RUNNING'
|
||||
FAILED = 'FAILED'
|
||||
UNKNOWN = 'UNKNOWN'
|
||||
|
||||
|
||||
class Tasks(object):
|
||||
TextGeneration = 'text-generation'
|
||||
AutoSpeechRecognition = 'asr'
|
||||
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
|
||||
from dashscope.common.constants import (DASHSCOPE_API_KEY_ENV,
|
||||
DASHSCOPE_API_KEY_FILE_PATH_ENV,
|
||||
DASHSCOPE_API_REGION_ENV,
|
||||
DASHSCOPE_API_VERSION_ENV)
|
||||
|
||||
api_region = os.environ.get(DASHSCOPE_API_REGION_ENV, 'cn-beijing')
|
||||
api_version = os.environ.get(DASHSCOPE_API_VERSION_ENV, 'v1')
|
||||
# read the api key from env
|
||||
api_key = os.environ.get(DASHSCOPE_API_KEY_ENV)
|
||||
api_key_file_path = os.environ.get(DASHSCOPE_API_KEY_FILE_PATH_ENV)
|
||||
|
||||
# define api base url, ensure end /
|
||||
base_http_api_url = os.environ.get(
|
||||
'DASHSCOPE_HTTP_BASE_URL',
|
||||
'https://dashscope.aliyuncs.com/api/%s' % (api_version))
|
||||
base_websocket_api_url = os.environ.get(
|
||||
'DASHSCOPE_WEBSOCKET_BASE_URL',
|
||||
'wss://dashscope.aliyuncs.com/api-ws/%s/inference' % (api_version))
|
||||
@@ -0,0 +1,136 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
|
||||
class DashScopeException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidParameter(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidTask(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedModel(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedTask(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class ModelRequired(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidModel(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidInput(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidFileFormat(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedApiProtocol(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class NotImplemented(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class MultiInputsWithBinaryNotSupported(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class UnexpectedMessageReceived(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedData(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class AssistantError(DashScopeException):
|
||||
def __init__(self, **kwargs):
|
||||
self.message = None
|
||||
self.code = None
|
||||
self.request_id = None
|
||||
if 'message' in kwargs:
|
||||
import json
|
||||
msg = json.loads(kwargs['message'])
|
||||
if 'request_id' in msg:
|
||||
self.request_id = msg['request_id']
|
||||
if 'code' in msg:
|
||||
self.code = msg['code']
|
||||
if 'message' in msg:
|
||||
self.message = msg['message']
|
||||
|
||||
def __str__(self):
|
||||
msg = 'Request failed, request_id: %s, code: %s, message: %s' % ( # noqa E501
|
||||
self.request_id, self.code, self.message)
|
||||
return msg
|
||||
|
||||
|
||||
# for server send generation or inference error.
|
||||
class RequestFailure(DashScopeException):
|
||||
def __init__(self,
|
||||
request_id=None,
|
||||
message=None,
|
||||
name=None,
|
||||
http_code=None):
|
||||
self.request_id = request_id
|
||||
self.message = message
|
||||
self.name = name
|
||||
self.http_code = http_code
|
||||
|
||||
def __str__(self):
|
||||
msg = 'Request failed, request_id: %s, http_code: %s error_name: %s, error_message: %s' % ( # noqa E501
|
||||
self.request_id, self.http_code, self.name, self.message)
|
||||
return msg
|
||||
|
||||
|
||||
class UnknownMessageReceived(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class InputDataRequired(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class InputRequired(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedDataType(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class ServiceUnavailableError(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedHTTPMethod(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncTaskCreateFailed(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class UploadFileException(DashScopeException):
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutException(DashScopeException):
|
||||
pass
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dashscope.common.constants import DASHSCOPE_LOGGING_LEVEL_ENV
|
||||
|
||||
logger = logging.getLogger('dashscope')
|
||||
|
||||
|
||||
def enable_logging():
|
||||
level = os.environ.get(DASHSCOPE_LOGGING_LEVEL_ENV, None)
|
||||
if level is not None: # set logging level.
|
||||
if level not in ['info', 'debug']:
|
||||
# set logging level env, but invalid value, use default.
|
||||
level = 'info'
|
||||
if level == 'info':
|
||||
logger.setLevel(logging.INFO)
|
||||
else:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
# set default logging handler
|
||||
console_handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(filename)s - %(funcName)s - %(lineno)d - %(levelname)s - %(message)s' # noqa E501
|
||||
)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
|
||||
# in release disable dashscope log
|
||||
# you can enable dashscope log for debugger.
|
||||
enable_logging()
|
||||
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from collections import deque
|
||||
from typing import List
|
||||
|
||||
from dashscope.api_entities.dashscope_response import (ConversationResponse,
|
||||
GenerationResponse,
|
||||
Message)
|
||||
|
||||
|
||||
class MessageManager(object):
|
||||
DEFAULT_MAXIMUM_MESSAGES = 100
|
||||
|
||||
def __init__(self, max_length: int = None):
|
||||
if max_length is None:
|
||||
self._dq = deque(maxlen=MessageManager.DEFAULT_MAXIMUM_MESSAGES)
|
||||
else:
|
||||
self._dq = deque(maxlen=max_length)
|
||||
|
||||
def add_generation_response(self, response: GenerationResponse):
|
||||
self._dq.append(Message.from_generation_response(response))
|
||||
|
||||
def add_conversation_response(self, response: ConversationResponse):
|
||||
self._dq.append(Message.from_conversation_response(response))
|
||||
|
||||
def add(self, message: Message):
|
||||
"""Add message to message manager
|
||||
|
||||
Args:
|
||||
message (Message): The message to add.
|
||||
"""
|
||||
self._dq.append(message)
|
||||
|
||||
def get(self) -> List[Message]:
|
||||
return list(self._dq)
|
||||
@@ -0,0 +1,438 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import queue
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Dict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from dashscope.api_entities.dashscope_response import DashScopeAPIResponse
|
||||
from dashscope.common.api_key import get_default_api_key
|
||||
from dashscope.common.constants import SSE_CONTENT_TYPE
|
||||
from dashscope.common.logging import logger
|
||||
from dashscope.version import __version__
|
||||
|
||||
|
||||
def is_validate_fine_tune_file(file_path):
|
||||
with open(file_path, encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
json.loads(line)
|
||||
except json.decoder.JSONDecodeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_task_group_and_task(module_name):
|
||||
"""Get task_group and task name.
|
||||
get task_group and task name based on api file __name__
|
||||
|
||||
Args:
|
||||
module_name (str): The api file __name__
|
||||
|
||||
Returns:
|
||||
(str, str): task_group and task
|
||||
"""
|
||||
pkg, task = module_name.rsplit('.', 1)
|
||||
task = task.replace('_', '-')
|
||||
_, task_group = pkg.rsplit('.', 1)
|
||||
return task_group, task
|
||||
|
||||
|
||||
def is_path(path: str):
|
||||
"""Check the input path is valid local path.
|
||||
|
||||
Args:
|
||||
path_or_url (str): The path.
|
||||
|
||||
Returns:
|
||||
bool: If path return True, otherwise False.
|
||||
"""
|
||||
url_parsed = urlparse(path)
|
||||
if url_parsed.scheme in ('file', ''):
|
||||
return os.path.exists(url_parsed.path)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_url(url: str):
|
||||
"""Check the input url is valid url.
|
||||
|
||||
Args:
|
||||
url (str): The url
|
||||
|
||||
Returns:
|
||||
bool: If is url return True, otherwise False.
|
||||
"""
|
||||
url_parsed = urlparse(url)
|
||||
if url_parsed.scheme in ('http', 'https', 'oss'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def iter_over_async(ait):
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
ait = ait.__aiter__()
|
||||
|
||||
async def get_next():
|
||||
try:
|
||||
obj = await ait.__anext__()
|
||||
return False, obj
|
||||
except StopAsyncIteration:
|
||||
return True, None
|
||||
|
||||
def iter_thread(loop, message_queue):
|
||||
while True:
|
||||
try:
|
||||
done, obj = loop.run_until_complete(get_next())
|
||||
if done:
|
||||
message_queue.put((True, None, None))
|
||||
break
|
||||
message_queue.put((False, None, obj))
|
||||
except BaseException as e: # noqa E722
|
||||
logger.exception(e)
|
||||
message_queue.put((True, e, None))
|
||||
break
|
||||
|
||||
message_queue = queue.Queue()
|
||||
x = threading.Thread(target=iter_thread,
|
||||
args=(loop, message_queue),
|
||||
name='iter_async_thread')
|
||||
x.start()
|
||||
while True:
|
||||
finished, error, obj = message_queue.get()
|
||||
if finished:
|
||||
if error is not None:
|
||||
yield DashScopeAPIResponse(
|
||||
-1,
|
||||
'',
|
||||
'Unknown',
|
||||
message='Error type: %s, message: %s' %
|
||||
(type(error), error))
|
||||
break
|
||||
else:
|
||||
yield obj
|
||||
|
||||
|
||||
def async_to_sync(async_generator):
|
||||
for message in iter_over_async(async_generator):
|
||||
yield message
|
||||
|
||||
|
||||
def get_user_agent():
|
||||
ua = 'dashscope/%s; python/%s; platform/%s; processor/%s' % (
|
||||
__version__,
|
||||
platform.python_version(),
|
||||
platform.platform(),
|
||||
platform.processor(),
|
||||
)
|
||||
return ua
|
||||
|
||||
|
||||
def default_headers(api_key: str = None) -> Dict[str, str]:
|
||||
ua = 'dashscope/%s; python/%s; platform/%s; processor/%s' % (
|
||||
__version__,
|
||||
platform.python_version(),
|
||||
platform.platform(),
|
||||
platform.processor(),
|
||||
)
|
||||
headers = {'user-agent': ua}
|
||||
if api_key is None:
|
||||
api_key = get_default_api_key()
|
||||
headers['Authorization'] = 'Bearer %s' % api_key
|
||||
headers['Accept'] = 'application/json'
|
||||
return headers
|
||||
|
||||
|
||||
def join_url(base_url, *args):
|
||||
if not base_url.endswith('/'):
|
||||
base_url = base_url + '/'
|
||||
url = base_url
|
||||
for arg in args:
|
||||
if arg is not None:
|
||||
url += arg + '/'
|
||||
return url[:-1]
|
||||
|
||||
|
||||
async def _handle_aiohttp_response(response: aiohttp.ClientResponse):
|
||||
request_id = ''
|
||||
if response.status == HTTPStatus.OK:
|
||||
json_content = await response.json()
|
||||
if 'request_id' in json_content:
|
||||
request_id = json_content['request_id']
|
||||
return DashScopeAPIResponse(request_id=request_id,
|
||||
status_code=HTTPStatus.OK,
|
||||
output=json_content)
|
||||
else:
|
||||
if 'application/json' in response.content_type:
|
||||
error = await response.json()
|
||||
msg = ''
|
||||
if 'message' in error:
|
||||
msg = error['message']
|
||||
if 'request_id' in error:
|
||||
request_id = error['request_id']
|
||||
return DashScopeAPIResponse(request_id=request_id,
|
||||
status_code=response.status,
|
||||
output=None,
|
||||
code=error['code'],
|
||||
message=msg)
|
||||
else:
|
||||
msg = await response.read()
|
||||
return DashScopeAPIResponse(request_id=request_id,
|
||||
status_code=response.status,
|
||||
output=None,
|
||||
code='Unknown',
|
||||
message=msg)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SSEEvent:
|
||||
id: str
|
||||
eventType: str
|
||||
data: str
|
||||
|
||||
def __init__(self, id: str, type: str, data: str):
|
||||
self.id = id
|
||||
self.eventType = type
|
||||
self.data = data
|
||||
|
||||
|
||||
def _handle_stream(response: requests.Response):
|
||||
# TODO define done message.
|
||||
is_error = False
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
event = SSEEvent(None, None, None)
|
||||
eventType = None
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode('utf8')
|
||||
line = line.rstrip('\n').rstrip('\r')
|
||||
if line.startswith('id:'):
|
||||
id = line[len('id:'):]
|
||||
event.id = id.strip()
|
||||
elif line.startswith('event:'):
|
||||
eventType = line[len('event:'):]
|
||||
event.eventType = eventType.strip()
|
||||
if eventType == 'error':
|
||||
is_error = True
|
||||
elif line.startswith('status:'):
|
||||
status_code = line[len('status:'):]
|
||||
status_code = int(status_code.strip())
|
||||
elif line.startswith('data:'):
|
||||
line = line[len('data:'):]
|
||||
event.data = line.strip()
|
||||
if eventType is not None and eventType == 'done':
|
||||
continue
|
||||
yield (is_error, status_code, event)
|
||||
if is_error:
|
||||
break
|
||||
else:
|
||||
continue # ignore heartbeat...
|
||||
|
||||
|
||||
def _handle_error_message(error, status_code, flattened_output):
|
||||
code = None
|
||||
msg = ''
|
||||
request_id = ''
|
||||
if flattened_output:
|
||||
error['status_code'] = status_code
|
||||
return error
|
||||
if 'message' in error:
|
||||
msg = error['message']
|
||||
if 'msg' in error:
|
||||
msg = error['msg']
|
||||
if 'code' in error:
|
||||
code = error['code']
|
||||
if 'request_id' in error:
|
||||
request_id = error['request_id']
|
||||
return DashScopeAPIResponse(request_id=request_id,
|
||||
status_code=status_code,
|
||||
code=code,
|
||||
message=msg)
|
||||
|
||||
|
||||
def _handle_http_failed_response(
|
||||
response: requests.Response,
|
||||
flattened_output: bool = False) -> DashScopeAPIResponse:
|
||||
request_id = ''
|
||||
if 'application/json' in response.headers.get('content-type', ''):
|
||||
error = response.json()
|
||||
return _handle_error_message(error, response.status_code,
|
||||
flattened_output)
|
||||
elif SSE_CONTENT_TYPE in response.headers.get('content-type', ''):
|
||||
msgs = response.content.decode('utf-8').split('\n')
|
||||
for msg in msgs:
|
||||
if msg.startswith('data:'):
|
||||
error = json.loads(msg.replace('data:', '').strip())
|
||||
return _handle_error_message(error, response.status_code,
|
||||
flattened_output)
|
||||
return DashScopeAPIResponse(request_id=request_id,
|
||||
status_code=response.status_code,
|
||||
code='Unknown',
|
||||
message=msgs)
|
||||
else:
|
||||
msg = response.content.decode('utf-8')
|
||||
if flattened_output:
|
||||
return {'status_code': response.status_code, 'message': msg}
|
||||
return DashScopeAPIResponse(request_id=request_id,
|
||||
status_code=response.status_code,
|
||||
code='Unknown',
|
||||
message=msg)
|
||||
|
||||
|
||||
async def _handle_aio_stream(response):
|
||||
# TODO define done message.
|
||||
is_error = False
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
async for line in response.content:
|
||||
if line:
|
||||
line = line.decode('utf8')
|
||||
line = line.rstrip('\n').rstrip('\r')
|
||||
if line.startswith('event:error'):
|
||||
is_error = True
|
||||
elif line.startswith('status:'):
|
||||
status_code = line[len('status:'):]
|
||||
status_code = int(status_code.strip())
|
||||
elif line.startswith('data:'):
|
||||
line = line[len('data:'):]
|
||||
yield (is_error, status_code, line)
|
||||
if is_error:
|
||||
break
|
||||
else:
|
||||
continue # ignore heartbeat...
|
||||
|
||||
|
||||
async def _handle_aiohttp_failed_response(
|
||||
response: requests.Response,
|
||||
flattened_output: bool = False) -> DashScopeAPIResponse:
|
||||
request_id = ''
|
||||
if 'application/json' in response.content_type:
|
||||
error = await response.json()
|
||||
return _handle_error_message(error, response.status, flattened_output)
|
||||
elif SSE_CONTENT_TYPE in response.content_type:
|
||||
async for _, _, data in _handle_aio_stream(response):
|
||||
error = json.loads(data)
|
||||
return _handle_error_message(error, response.status, flattened_output)
|
||||
else:
|
||||
msg = response.content.decode('utf-8')
|
||||
if flattened_output:
|
||||
return {'status_code': response.status, 'message': msg}
|
||||
return DashScopeAPIResponse(request_id=request_id,
|
||||
status_code=response.status,
|
||||
code='Unknown',
|
||||
message=msg)
|
||||
|
||||
|
||||
def _handle_http_response(response: requests.Response,
|
||||
flattened_output: bool = False):
|
||||
response = _handle_http_stream_response(response, flattened_output)
|
||||
_, output = next(response)
|
||||
try:
|
||||
next(response)
|
||||
except StopIteration:
|
||||
pass
|
||||
return output
|
||||
|
||||
|
||||
def _handle_http_stream_response(response: requests.Response,
|
||||
flattened_output: bool = False):
|
||||
request_id = ''
|
||||
if (response.status_code == HTTPStatus.OK
|
||||
and SSE_CONTENT_TYPE in response.headers.get('content-type', '')):
|
||||
for is_error, status_code, event in _handle_stream(response):
|
||||
if not is_error:
|
||||
try:
|
||||
output = None
|
||||
usage = None
|
||||
msg = json.loads(event.data)
|
||||
if flattened_output:
|
||||
msg['status_code'] = response.status_code
|
||||
yield event.eventType, msg
|
||||
else:
|
||||
logger.debug('Stream message: %s' % msg)
|
||||
if not is_error:
|
||||
if 'output' in msg:
|
||||
output = msg['output']
|
||||
if 'usage' in msg:
|
||||
usage = msg['usage']
|
||||
if 'request_id' in msg:
|
||||
request_id = msg['request_id']
|
||||
yield event.eventType, DashScopeAPIResponse(
|
||||
request_id=request_id,
|
||||
status_code=HTTPStatus.OK,
|
||||
output=output,
|
||||
usage=usage)
|
||||
except json.JSONDecodeError as e:
|
||||
if flattened_output:
|
||||
yield event.eventType, {
|
||||
'status_code': response.status_code,
|
||||
'message': e.message
|
||||
}
|
||||
else:
|
||||
yield event.eventType, DashScopeAPIResponse(
|
||||
request_id=request_id,
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
output=None,
|
||||
code='Unknown',
|
||||
message=event.data)
|
||||
continue
|
||||
else:
|
||||
if flattened_output:
|
||||
yield event.eventType, {
|
||||
'status_code': status_code,
|
||||
'message': event.data
|
||||
}
|
||||
else:
|
||||
msg = json.loads(event.eventType)
|
||||
yield event.eventType, DashScopeAPIResponse(
|
||||
request_id=request_id,
|
||||
status_code=status_code,
|
||||
output=None,
|
||||
code=msg['code']
|
||||
if 'code' in msg else None, # noqa E501
|
||||
message=msg['message']
|
||||
if 'message' in msg else None) # noqa E501
|
||||
elif response.status_code == HTTPStatus.OK or response.status_code == HTTPStatus.CREATED:
|
||||
json_content = response.json()
|
||||
if flattened_output:
|
||||
json_content['status_code'] = response.status_code
|
||||
yield None, json_content
|
||||
else:
|
||||
output = None
|
||||
usage = None
|
||||
code = None
|
||||
msg = ''
|
||||
if 'data' in json_content:
|
||||
output = json_content['data']
|
||||
if 'code' in json_content:
|
||||
code = json_content['code']
|
||||
if 'message' in json_content:
|
||||
msg = json_content['message']
|
||||
if 'output' in json_content:
|
||||
output = json_content['output']
|
||||
if 'usage' in json_content:
|
||||
usage = json_content['usage']
|
||||
if 'request_id' in json_content:
|
||||
request_id = json_content['request_id']
|
||||
json_content.pop('request_id', None)
|
||||
|
||||
if 'data' not in json_content and 'output' not in json_content:
|
||||
output = json_content
|
||||
|
||||
yield None, DashScopeAPIResponse(request_id=request_id,
|
||||
status_code=response.status_code,
|
||||
code=code,
|
||||
output=output,
|
||||
usage=usage,
|
||||
message=msg)
|
||||
else:
|
||||
yield None, _handle_http_failed_response(response, flattened_output)
|
||||
Reference in New Issue
Block a user