diff --git a/examples/assistant/chat_with_assistant.ipynb b/examples/assistant/chat_with_assistant.ipynb index eb0aaa6..13d9335 100644 --- a/examples/assistant/chat_with_assistant.ipynb +++ b/examples/assistant/chat_with_assistant.ipynb @@ -128,7 +128,7 @@ " \"No matter what the user's language is, you will use the {{langugae}} to explain.\"\n", " ],\n", " tools=[AssistantTool(\n", - " type=AssistantToolType.action,\n", + " type=AssistantToolType.ACTION,\n", " id=action.action_id,\n", " )],\n", " retrievals=[],\n", diff --git a/taskingai/assistant/assistant.py b/taskingai/assistant/assistant.py index 6f6fd4b..a53a283 100644 --- a/taskingai/assistant/assistant.py +++ b/taskingai/assistant/assistant.py @@ -58,7 +58,7 @@ def list_assistants( if after and before: raise ValueError("Only one of after and before can be specified.") - api_instance = get_api_instance(ModuleType.assistant) + api_instance = get_api_instance(ModuleType.ASSISTANT) # only add non-None parameters params = { "order": order, @@ -91,7 +91,7 @@ async def a_list_assistants( if after and before: raise ValueError("Only one of after and before can be specified.") - api_instance = get_api_instance(ModuleType.assistant, async_client=True) + api_instance = get_api_instance(ModuleType.ASSISTANT, async_client=True) # only add non-None parameters params = { "order": order, @@ -113,7 +113,7 @@ def get_assistant(assistant_id: str) -> Assistant: :param assistant_id: The ID of the assistant. """ - api_instance = get_api_instance(ModuleType.assistant) + api_instance = get_api_instance(ModuleType.ASSISTANT) response: AssistantGetResponse = api_instance.get_assistant(assistant_id=assistant_id) assistant: Assistant = Assistant(**response.data) return assistant @@ -126,7 +126,7 @@ async def a_get_assistant(assistant_id: str) -> Assistant: :param assistant_id: The ID of the assistant. """ - api_instance = get_api_instance(ModuleType.assistant, async_client=True) + api_instance = get_api_instance(ModuleType.ASSISTANT, async_client=True) response: AssistantGetResponse = await api_instance.get_assistant(assistant_id=assistant_id) assistant: Assistant = Assistant(**response.data) return assistant @@ -156,7 +156,7 @@ def create_assistant( :return: The created assistant object. """ - api_instance = get_api_instance(ModuleType.assistant) + api_instance = get_api_instance(ModuleType.ASSISTANT) memory_dict = memory.model_dump() body = AssistantCreateRequest( model_id=model_id, @@ -197,7 +197,7 @@ async def a_create_assistant( :return: The created assistant object. """ - api_instance = get_api_instance(ModuleType.assistant, async_client=True) + api_instance = get_api_instance(ModuleType.ASSISTANT, async_client=True) memory_dict = memory.model_dump() body = AssistantCreateRequest( model_id=model_id, @@ -240,7 +240,7 @@ def update_assistant( :return: The updated assistant object. """ - api_instance = get_api_instance(ModuleType.assistant) + api_instance = get_api_instance(ModuleType.ASSISTANT) body = AssistantUpdateRequest( model_id=model_id, name=name, @@ -282,7 +282,7 @@ async def a_update_assistant( :return: The updated assistant object. """ - api_instance = get_api_instance(ModuleType.assistant, async_client=True) + api_instance = get_api_instance(ModuleType.ASSISTANT, async_client=True) body = AssistantUpdateRequest( model_id=model_id, name=name, @@ -305,7 +305,7 @@ def delete_assistant(assistant_id: str) -> None: :param assistant_id: The ID of the assistant. """ - api_instance = get_api_instance(ModuleType.assistant) + api_instance = get_api_instance(ModuleType.ASSISTANT) api_instance.delete_assistant(assistant_id=assistant_id) @@ -316,6 +316,6 @@ async def a_delete_assistant(assistant_id: str) -> None: :param assistant_id: The ID of the assistant. """ - api_instance = get_api_instance(ModuleType.assistant, async_client=True) + api_instance = get_api_instance(ModuleType.ASSISTANT, async_client=True) await api_instance.delete_assistant(assistant_id=assistant_id) diff --git a/taskingai/assistant/chat.py b/taskingai/assistant/chat.py index beb1966..9a3b8f7 100644 --- a/taskingai/assistant/chat.py +++ b/taskingai/assistant/chat.py @@ -40,7 +40,7 @@ def list_chats( if after and before: raise ValueError("Only one of after and before can be specified.") - api_instance = get_api_instance(ModuleType.assistant) + api_instance = get_api_instance(ModuleType.ASSISTANT) # only add non-None parameters params = { "order": order, @@ -76,7 +76,7 @@ async def a_list_chats( if after and before: raise ValueError("Only one of after and before can be specified.") - api_instance = get_api_instance(ModuleType.assistant, async_client=True) + api_instance = get_api_instance(ModuleType.ASSISTANT, async_client=True) # only add non-None parameters params = { "order": order, @@ -101,7 +101,7 @@ def get_chat(assistant_id: str, chat_id: str) -> Chat: :param chat_id: The ID of the chat. """ - api_instance = get_api_instance(ModuleType.assistant) + api_instance = get_api_instance(ModuleType.ASSISTANT) response: ChatGetResponse = api_instance.get_chat( assistant_id=assistant_id, chat_id=chat_id, @@ -118,7 +118,7 @@ async def a_get_chat(assistant_id: str, chat_id: str) -> Chat: :param chat_id: The ID of the chat. """ - api_instance = get_api_instance(ModuleType.assistant, async_client=True) + api_instance = get_api_instance(ModuleType.ASSISTANT, async_client=True) response: ChatGetResponse = await api_instance.get_chat( assistant_id=assistant_id, chat_id=chat_id, @@ -139,7 +139,7 @@ def create_chat( :return: The created chat object. """ - api_instance = get_api_instance(ModuleType.assistant) + api_instance = get_api_instance(ModuleType.ASSISTANT) body = ChatCreateRequest( metadata=metadata, ) @@ -163,7 +163,7 @@ async def a_create_chat( :return: The created chat object. """ - api_instance = get_api_instance(ModuleType.assistant, async_client=True) + api_instance = get_api_instance(ModuleType.ASSISTANT, async_client=True) body = ChatCreateRequest( metadata=metadata, ) @@ -189,7 +189,7 @@ def update_chat( :return: The updated chat object. """ - api_instance = get_api_instance(ModuleType.assistant) + api_instance = get_api_instance(ModuleType.ASSISTANT) body = ChatUpdateRequest( metadata=metadata, ) @@ -216,7 +216,7 @@ async def a_update_chat( :return: The updated chat object. """ - api_instance = get_api_instance(ModuleType.assistant, async_client=True) + api_instance = get_api_instance(ModuleType.ASSISTANT, async_client=True) body = ChatUpdateRequest( metadata=metadata, ) @@ -240,7 +240,7 @@ def delete_chat( :param chat_id: The ID of the chat. """ - api_instance = get_api_instance(ModuleType.assistant) + api_instance = get_api_instance(ModuleType.ASSISTANT) api_instance.delete_chat(assistant_id=assistant_id, chat_id=chat_id) @@ -255,7 +255,7 @@ async def a_delete_chat( :param chat_id: The ID of the chat. """ - api_instance = get_api_instance(ModuleType.assistant, async_client=True) + api_instance = get_api_instance(ModuleType.ASSISTANT, async_client=True) await api_instance.delete_chat(assistant_id=assistant_id, chat_id=chat_id) diff --git a/taskingai/assistant/message.py b/taskingai/assistant/message.py index 163c2ce..0a99945 100644 --- a/taskingai/assistant/message.py +++ b/taskingai/assistant/message.py @@ -56,7 +56,7 @@ def list_messages( if after and before: raise ValueError("Only one of after and before can be specified.") - api_instance = get_api_instance(ModuleType.assistant) + api_instance = get_api_instance(ModuleType.ASSISTANT) # only add non-None parameters params = { "order": order, @@ -96,7 +96,7 @@ async def a_list_messages( if after and before: raise ValueError("Only one of after and before can be specified.") - api_instance = get_api_instance(ModuleType.assistant, async_client=True) + api_instance = get_api_instance(ModuleType.ASSISTANT, async_client=True) # only add non-None parameters params = { "order": order, @@ -127,7 +127,7 @@ def get_message( :param message_id: The ID of the message. """ - api_instance = get_api_instance(ModuleType.assistant) + api_instance = get_api_instance(ModuleType.ASSISTANT) response: MessageGetResponse = api_instance.get_message( assistant_id=assistant_id, chat_id=chat_id, @@ -150,7 +150,7 @@ async def a_get_message( :param message_id: The ID of the message. """ - api_instance = get_api_instance(ModuleType.assistant, async_client=True) + api_instance = get_api_instance(ModuleType.ASSISTANT, async_client=True) response: MessageGetResponse = await api_instance.get_message( assistant_id=assistant_id, chat_id=chat_id, @@ -177,9 +177,9 @@ def create_message( :return: The created message object. """ - api_instance = get_api_instance(ModuleType.assistant) + api_instance = get_api_instance(ModuleType.ASSISTANT) body = MessageCreateRequest( - role=MessageRole.user, + role=MessageRole.USER, content=MessageContent(text=text), metadata=metadata, ) @@ -208,9 +208,9 @@ async def a_create_message( :return: The created message object. """ - api_instance = get_api_instance(ModuleType.assistant, async_client=True) + api_instance = get_api_instance(ModuleType.ASSISTANT, async_client=True) body = MessageCreateRequest( - role=MessageRole.user, + role=MessageRole.USER, content=MessageContent(text=text), metadata=metadata, ) @@ -238,7 +238,7 @@ def update_message( :return: The updated message object. """ - api_instance = get_api_instance(ModuleType.assistant) + api_instance = get_api_instance(ModuleType.ASSISTANT) body = MessageUpdateRequest( metadata=metadata, ) @@ -267,7 +267,7 @@ async def a_update_message( :return: The updated message object. """ - api_instance = get_api_instance(ModuleType.assistant, async_client=True) + api_instance = get_api_instance(ModuleType.ASSISTANT, async_client=True) body = MessageUpdateRequest( metadata=metadata, ) @@ -298,7 +298,7 @@ def generate_message( :return: The generated message object. """ - api_instance = get_api_instance(ModuleType.assistant) + api_instance = get_api_instance(ModuleType.ASSISTANT) body = MessageGenerateRequest( system_prompt_variables=system_prompt_variables, stream=stream, @@ -339,7 +339,7 @@ async def a_generate_message( :return: The generated message object. """ - api_instance = get_api_instance(ModuleType.assistant, async_client=True) + api_instance = get_api_instance(ModuleType.ASSISTANT, async_client=True) body = MessageGenerateRequest( system_prompt_variables=system_prompt_variables, stream=stream, diff --git a/taskingai/client/constants.py b/taskingai/client/constants.py index 5a9284c..b06070e 100644 --- a/taskingai/client/constants.py +++ b/taskingai/client/constants.py @@ -4,10 +4,10 @@ DEFAULT_PARENT_LOGGER_LEVEL = 'ERROR' class ModuleType(str, Enum): - assistant = "assistant" - tool = "tool" - retrieval = "retrieval" - inference = "inference" + ASSISTANT = "assistant" + TOOL = "tool" + RETRIEVAL = "retrieval" + INFERENCE = "inference" diff --git a/taskingai/client/models/entity/assistant/assistant.py b/taskingai/client/models/entity/assistant/assistant.py index b88e724..1d7364e 100644 --- a/taskingai/client/models/entity/assistant/assistant.py +++ b/taskingai/client/models/entity/assistant/assistant.py @@ -19,12 +19,12 @@ class AssistantToolType(str, Enum): - action = "action" - function = "function" + ACTION = "action" + FUNCTION = "function" class AssistantRetrievalType(str, Enum): - collection = "collection" + COLLECTION = "collection" class AssistantRetrieval(TaskingaiBaseModel): diff --git a/taskingai/client/models/entity/assistant/message.py b/taskingai/client/models/entity/assistant/message.py index 6d6f4b9..4b72746 100644 --- a/taskingai/client/models/entity/assistant/message.py +++ b/taskingai/client/models/entity/assistant/message.py @@ -12,8 +12,8 @@ class MessageRole(str, Enum): - user = "user" - assistant = "assistant" + USER = "user" + ASSISTANT = "assistant" class MessageContent(TaskingaiBaseModel): diff --git a/taskingai/client/models/entity/inference/chat_completion.py b/taskingai/client/models/entity/inference/chat_completion.py index 2a32a2f..3cb2187 100644 --- a/taskingai/client/models/entity/inference/chat_completion.py +++ b/taskingai/client/models/entity/inference/chat_completion.py @@ -22,10 +22,10 @@ class ChatCompletionRole(str, Enum): - system = "system" - assistant = "assistant" - user = "user" - function = "function" + SYSTEM = "system" + ASSISTANT = "assistant" + USER = "user" + FUNCTION = "function" class ChatCompletionFunctionCall(TaskingaiBaseModel): @@ -45,30 +45,30 @@ class ChatCompletionMessage(TaskingaiBaseModel, metaclass=ABCMeta): class ChatCompletionSystemMessage(ChatCompletionMessage): - role: ChatCompletionRole = Field(Literal[ChatCompletionRole.system]) + role: ChatCompletionRole = Field(Literal[ChatCompletionRole.SYSTEM]) class ChatCompletionUserMessage(ChatCompletionMessage): - role: ChatCompletionRole = Field(Literal[ChatCompletionRole.user]) + role: ChatCompletionRole = Field(Literal[ChatCompletionRole.USER]) class ChatCompletionAssistantMessage(ChatCompletionMessage): - role: ChatCompletionRole = Field(Literal[ChatCompletionRole.assistant]) + role: ChatCompletionRole = Field(Literal[ChatCompletionRole.ASSISTANT]) function_calls: Optional[List[ChatCompletionFunctionCall]] class ChatCompletionFunctionMessage(ChatCompletionMessage): - role: ChatCompletionRole = Field(Literal[ChatCompletionRole.function]) + role: ChatCompletionRole = Field(Literal[ChatCompletionRole.FUNCTION]) id: str class ChatCompletionFinishReason(str, Enum): - stop = "stop" - length = "length" - function_calls = "function_calls" - recitation = "recitation" - error = "error" - unknown = "unknown" + STOP = "stop" + LENGTH = "length" + FUNCTION_CALLS = "function_calls" + RECITATION = "recitation" + ERROR = "error" + UNKNOWN = "unknown" class ChatCompletion(TaskingaiBaseModel): diff --git a/taskingai/client/utils.py b/taskingai/client/utils.py index cd5b22f..d885c06 100644 --- a/taskingai/client/utils.py +++ b/taskingai/client/utils.py @@ -25,17 +25,17 @@ def get_user_agent(): return user_agent sync_api_instance_dict = { - ModuleType.assistant: None, - ModuleType.tool: None, - ModuleType.retrieval: None, - ModuleType.inference: None + ModuleType.ASSISTANT: None, + ModuleType.TOOL: None, + ModuleType.RETRIEVAL: None, + ModuleType.INFERENCE: None } async_api_instance_dict = { - ModuleType.assistant: None, - ModuleType.tool: None, - ModuleType.retrieval: None, - ModuleType.inference: None + ModuleType.ASSISTANT: None, + ModuleType.TOOL: None, + ModuleType.RETRIEVAL: None, + ModuleType.INFERENCE: None } def get_api_instance(module: ModuleType, async_client=False): @@ -50,13 +50,13 @@ def get_api_instance(module: ModuleType, async_client=False): api_client.user_agent = get_user_agent() if async_api_instance_dict.get(module) is None: - if module == ModuleType.assistant: + if module == ModuleType.ASSISTANT: async_api_instance_dict[module] = AsyncAssistantApi(api_client) - elif module == ModuleType.tool: + elif module == ModuleType.TOOL: async_api_instance_dict[module] = AsyncToolApi(api_client) - elif module == ModuleType.retrieval: + elif module == ModuleType.RETRIEVAL: async_api_instance_dict[module] = AsyncRetrievalApi(api_client) - elif module == ModuleType.inference: + elif module == ModuleType.INFERENCE: async_api_instance_dict[module] = AsyncInferenceApi(api_client) api_instance = async_api_instance_dict[module] @@ -67,13 +67,13 @@ def get_api_instance(module: ModuleType, async_client=False): if sync_api_instance_dict.get(module) is None: - if module == ModuleType.assistant: + if module == ModuleType.ASSISTANT: sync_api_instance_dict[module] = AssistantApi(api_client) - elif module == ModuleType.tool: + elif module == ModuleType.TOOL: sync_api_instance_dict[module] = ToolApi(api_client) - elif module == ModuleType.retrieval: + elif module == ModuleType.RETRIEVAL: sync_api_instance_dict[module] = RetrievalApi(api_client) - elif module == ModuleType.inference: + elif module == ModuleType.INFERENCE: sync_api_instance_dict[module] = InferenceApi(api_client) api_instance = sync_api_instance_dict[module] diff --git a/taskingai/inference/chat_completion.py b/taskingai/inference/chat_completion.py index 244e4dc..d1f21f3 100644 --- a/taskingai/inference/chat_completion.py +++ b/taskingai/inference/chat_completion.py @@ -33,7 +33,7 @@ class SystemMessage(ChatCompletionSystemMessage): def __init__(self, content: str): super().__init__( - role=ChatCompletionRole.system, + role=ChatCompletionRole.SYSTEM, content=content ) @@ -41,7 +41,7 @@ def __init__(self, content: str): class UserMessage(ChatCompletionUserMessage): def __init__(self, content: str): super().__init__( - role=ChatCompletionRole.user, + role=ChatCompletionRole.USER, content=content ) @@ -49,7 +49,7 @@ def __init__(self, content: str): class AssistantMessage(ChatCompletionAssistantMessage): def __init__(self, content: str = None, function_calls: Optional[List[FunctionCall]] = None): super().__init__( - role=ChatCompletionRole.assistant, + role=ChatCompletionRole.ASSISTANT, content=content, function_calls=function_calls ) @@ -58,7 +58,7 @@ def __init__(self, content: str = None, function_calls: Optional[List[FunctionCa class FunctionMessage(ChatCompletionFunctionMessage): def __init__(self, id: str, content: str): super().__init__( - role=ChatCompletionRole.function, + role=ChatCompletionRole.FUNCTION, id=id, content=content ) @@ -87,7 +87,7 @@ def chat_completion( :param functions: The list of functions. :return: The list of assistants. """ - api_instance = get_api_instance(ModuleType.inference) + api_instance = get_api_instance(ModuleType.INFERENCE) # only add non-None parameters body = ChatCompletionRequest( model_id=model_id, @@ -129,7 +129,7 @@ async def a_chat_completion( :param functions: The list of functions. :return: The list of assistants. """ - api_instance = get_api_instance(ModuleType.inference, async_client=True) + api_instance = get_api_instance(ModuleType.INFERENCE, async_client=True) # only add non-None parameters body = ChatCompletionRequest( model_id=model_id, diff --git a/taskingai/inference/text_embedding.py b/taskingai/inference/text_embedding.py index 46cd0ef..8e62cac 100644 --- a/taskingai/inference/text_embedding.py +++ b/taskingai/inference/text_embedding.py @@ -25,7 +25,7 @@ def text_embedding( :param input: The input text or list of input texts. :return: The list of assistants. """ - api_instance = get_api_instance(ModuleType.inference) + api_instance = get_api_instance(ModuleType.INFERENCE) # only add non-None parameters body = TextEmbeddingRequest( model_id=model_id, @@ -54,7 +54,7 @@ async def a_text_embedding( :param input: The input text or list of input texts. :return: The list of assistants. """ - api_instance = get_api_instance(ModuleType.inference, async_client=True) + api_instance = get_api_instance(ModuleType.INFERENCE, async_client=True) # only add non-None parameters body = TextEmbeddingRequest( model_id=model_id, diff --git a/taskingai/retrieval/chunk.py b/taskingai/retrieval/chunk.py index 01a55fb..30a60b2 100644 --- a/taskingai/retrieval/chunk.py +++ b/taskingai/retrieval/chunk.py @@ -23,7 +23,7 @@ def query_chunks( :param top_k: The number of most relevant chunks to return. """ - api_instance = get_api_instance(ModuleType.retrieval) + api_instance = get_api_instance(ModuleType.RETRIEVAL) # only add non-None parameters body = ChunkQueryRequest( top_k=top_k, @@ -49,7 +49,7 @@ async def a_query_chunks( :param top_k: The number of most relevant chunks to return. """ - api_instance = get_api_instance(ModuleType.retrieval, async_client=True) + api_instance = get_api_instance(ModuleType.RETRIEVAL, async_client=True) # only add non-None parameters body = ChunkQueryRequest( top_k=top_k, diff --git a/taskingai/retrieval/collection.py b/taskingai/retrieval/collection.py index a1e0284..c8a78fe 100644 --- a/taskingai/retrieval/collection.py +++ b/taskingai/retrieval/collection.py @@ -41,7 +41,7 @@ def list_collections( if after and before: raise ValueError("Only one of after and before can be specified.") - api_instance = get_api_instance(ModuleType.retrieval) + api_instance = get_api_instance(ModuleType.RETRIEVAL) # only add non-None parameters params = { "order": order, @@ -74,7 +74,7 @@ async def a_list_collections( if after and before: raise ValueError("Only one of after and before can be specified.") - api_instance = get_api_instance(ModuleType.retrieval, async_client=True) + api_instance = get_api_instance(ModuleType.RETRIEVAL, async_client=True) # only add non-None parameters params = { "order": order, @@ -95,7 +95,7 @@ def get_collection(collection_id: str) -> Collection: :param collection_id: The ID of the collection. """ - api_instance = get_api_instance(ModuleType.retrieval) + api_instance = get_api_instance(ModuleType.RETRIEVAL) response: CollectionGetResponse = api_instance.get_collection(collection_id=collection_id) collection: Collection = Collection(**response.data) return collection @@ -108,7 +108,7 @@ async def a_get_collection(collection_id: str) -> Collection: :param collection_id: The ID of the collection. """ - api_instance = get_api_instance(ModuleType.retrieval, async_client=True) + api_instance = get_api_instance(ModuleType.RETRIEVAL, async_client=True) response: CollectionGetResponse = await api_instance.get_collection(collection_id=collection_id) collection: Collection = Collection(**response.data) return collection @@ -135,7 +135,7 @@ def create_collection( """ # todo verify parameters - api_instance = get_api_instance(ModuleType.retrieval) + api_instance = get_api_instance(ModuleType.RETRIEVAL) body = CollectionCreateRequest( embedding_model_id=embedding_model_id, capacity=capacity, @@ -170,7 +170,7 @@ async def a_create_collection( """ # todo verify parameters - api_instance = get_api_instance(ModuleType.retrieval, async_client=True) + api_instance = get_api_instance(ModuleType.RETRIEVAL, async_client=True) body = CollectionCreateRequest( embedding_model_id=embedding_model_id, capacity=capacity, @@ -199,7 +199,7 @@ def update_collection( :return: The updated collection object. """ #todo: verify at least one parameter is not None - api_instance = get_api_instance(ModuleType.retrieval) + api_instance = get_api_instance(ModuleType.RETRIEVAL) body = CollectionUpdateRequest( name=name, description=description, @@ -227,7 +227,7 @@ async def a_update_collection( :param authentication: The collection API authentication. :return: The updated collection object. """ - api_instance = get_api_instance(ModuleType.retrieval, async_client=True) + api_instance = get_api_instance(ModuleType.RETRIEVAL, async_client=True) body = CollectionUpdateRequest( name=name, description=description, @@ -248,7 +248,7 @@ def delete_collection(collection_id: str) -> None: :param collection_id: The ID of the collection. """ - api_instance = get_api_instance(ModuleType.retrieval) + api_instance = get_api_instance(ModuleType.RETRIEVAL) api_instance.delete_collection(collection_id=collection_id) @@ -259,6 +259,6 @@ async def a_delete_collection(collection_id: str) -> None: :param collection_id: The ID of the collection. """ - api_instance = get_api_instance(ModuleType.retrieval, async_client=True) + api_instance = get_api_instance(ModuleType.RETRIEVAL, async_client=True) await api_instance.delete_collection(collection_id=collection_id) diff --git a/taskingai/retrieval/record.py b/taskingai/retrieval/record.py index 3007980..0a395e9 100644 --- a/taskingai/retrieval/record.py +++ b/taskingai/retrieval/record.py @@ -41,7 +41,7 @@ def list_records( if after and before: raise ValueError("Only one of after and before can be specified.") - api_instance = get_api_instance(ModuleType.retrieval) + api_instance = get_api_instance(ModuleType.RETRIEVAL) # only add non-None parameters params = { "order": order, @@ -78,7 +78,7 @@ async def a_list_records( if after and before: raise ValueError("Only one of after and before can be specified.") - api_instance = get_api_instance(ModuleType.retrieval, async_client=True) + api_instance = get_api_instance(ModuleType.RETRIEVAL, async_client=True) # only add non-None parameters params = { "order": order, @@ -104,7 +104,7 @@ def get_record(collection_id: str, record_id: str) -> Record: :param record_id: The ID of the record. """ - api_instance = get_api_instance(ModuleType.retrieval) + api_instance = get_api_instance(ModuleType.RETRIEVAL) response: RecordGetResponse = api_instance.get_record( collection_id=collection_id, record_id=record_id, @@ -121,7 +121,7 @@ async def a_get_record(collection_id: str, record_id: str) -> Record: :param record_id: The ID of the record. """ - api_instance = get_api_instance(ModuleType.retrieval, async_client=True) + api_instance = get_api_instance(ModuleType.RETRIEVAL, async_client=True) response: RecordGetResponse = await api_instance.get_record( collection_id=collection_id, record_id=record_id, @@ -145,7 +145,7 @@ def create_text_record( :return: The created record object. """ - api_instance = get_api_instance(ModuleType.retrieval) + api_instance = get_api_instance(ModuleType.RETRIEVAL) body = RecordCreateRequest( type="text", text=text, @@ -174,7 +174,7 @@ async def a_create_text_record( :return: The created record object. """ - api_instance = get_api_instance(ModuleType.retrieval, async_client=True) + api_instance = get_api_instance(ModuleType.RETRIEVAL, async_client=True) body = RecordCreateRequest( type="text", text=text, @@ -202,7 +202,7 @@ def update_record( :return: The collection object. """ - api_instance = get_api_instance(ModuleType.retrieval) + api_instance = get_api_instance(ModuleType.RETRIEVAL) body = RecordUpdateRequest( metadata=metadata, ) @@ -229,7 +229,7 @@ async def a_update_record( :return: The collection object. """ - api_instance = get_api_instance(ModuleType.retrieval, async_client=True) + api_instance = get_api_instance(ModuleType.RETRIEVAL, async_client=True) body = RecordUpdateRequest( metadata=metadata, ) @@ -253,7 +253,7 @@ def delete_record( :param record_id: The ID of the record. """ - api_instance = get_api_instance(ModuleType.retrieval) + api_instance = get_api_instance(ModuleType.RETRIEVAL) api_instance.delete_record(collection_id=collection_id, record_id=record_id) @@ -268,7 +268,7 @@ async def a_delete_record( :param record_id: The ID of the record. """ - api_instance = get_api_instance(ModuleType.retrieval, async_client=True) + api_instance = get_api_instance(ModuleType.RETRIEVAL, async_client=True) await api_instance.delete_record(collection_id=collection_id, record_id=record_id) diff --git a/taskingai/tool/action.py b/taskingai/tool/action.py index bd92449..09c4635 100644 --- a/taskingai/tool/action.py +++ b/taskingai/tool/action.py @@ -50,7 +50,7 @@ def list_actions( if after and before: raise ValueError("Only one of after and before can be specified.") - api_instance = get_api_instance(ModuleType.tool) + api_instance = get_api_instance(ModuleType.TOOL) # only add non-None parameters params = { "order": order, @@ -82,7 +82,7 @@ async def a_list_actions( if after and before: raise ValueError("Only one of after and before can be specified.") - api_instance = get_api_instance(ModuleType.tool, async_client=True) + api_instance = get_api_instance(ModuleType.TOOL, async_client=True) # only add non-None parameters params = { "order": order, @@ -104,7 +104,7 @@ def get_action(action_id: str) -> Action: :param action_id: The ID of the action. """ - api_instance = get_api_instance(ModuleType.tool) + api_instance = get_api_instance(ModuleType.TOOL) response: ActionGetResponse = api_instance.get_action(action_id=action_id) action: Action = Action(**response.data) return action @@ -117,7 +117,7 @@ async def a_get_action(action_id: str) -> Action: :param action_id: The ID of the action. """ - api_instance = get_api_instance(ModuleType.tool, async_client=True) + api_instance = get_api_instance(ModuleType.TOOL, async_client=True) response: ActionGetResponse = await api_instance.get_action(action_id=action_id) action: Action = Action(**response.data) return action @@ -135,7 +135,7 @@ def bulk_create_actions( """ # todo verify schema - api_instance = get_api_instance(ModuleType.tool) + api_instance = get_api_instance(ModuleType.TOOL) if authentication is None: authentication = ActionAuthentication( type=ActionAuthenticationType.NONE, @@ -162,7 +162,7 @@ async def a_bulk_create_actions( """ # todo verify schema - api_instance = get_api_instance(ModuleType.tool, async_client=True) + api_instance = get_api_instance(ModuleType.TOOL, async_client=True) if authentication is None: authentication = ActionAuthentication( type=ActionAuthenticationType.NONE, @@ -190,7 +190,7 @@ def update_action( :return: The updated action object. """ #todo: verify schema - api_instance = get_api_instance(ModuleType.tool) + api_instance = get_api_instance(ModuleType.TOOL) body = ActionUpdateRequest( schema=schema, authentication=authentication, @@ -216,7 +216,7 @@ async def a_update_action( :param authentication: The action API authentication. :return: The updated action object. """ - api_instance = get_api_instance(ModuleType.tool, async_client=True) + api_instance = get_api_instance(ModuleType.TOOL, async_client=True) body = ActionUpdateRequest( schema=schema, authentication=authentication, @@ -236,7 +236,7 @@ def delete_action(action_id: str) -> None: :param action_id: The ID of the action. """ - api_instance = get_api_instance(ModuleType.tool) + api_instance = get_api_instance(ModuleType.TOOL) api_instance.delete_action(action_id=action_id) @@ -247,7 +247,7 @@ async def a_delete_action(action_id: str) -> None: :param action_id: The ID of the action. """ - api_instance = get_api_instance(ModuleType.tool, async_client=True) + api_instance = get_api_instance(ModuleType.TOOL, async_client=True) await api_instance.delete_action(action_id=action_id) @@ -263,7 +263,7 @@ def run_action( :return: The action response. """ - api_instance = get_api_instance(ModuleType.tool) + api_instance = get_api_instance(ModuleType.TOOL) body = ActionRunRequest( parameters=parameters, ) @@ -287,7 +287,7 @@ async def a_run_action( :return: The action response. """ - api_instance = get_api_instance(ModuleType.tool, async_client=True) + api_instance = get_api_instance(ModuleType.TOOL, async_client=True) body = ActionRunRequest( parameters=parameters, ) diff --git a/taskingai/tool/function.py b/taskingai/tool/function.py index 079dc14..9924b0b 100644 --- a/taskingai/tool/function.py +++ b/taskingai/tool/function.py @@ -39,7 +39,7 @@ def list_functions( if after and before: raise ValueError("Only one of after and before can be specified.") - api_instance = get_api_instance(ModuleType.tool) + api_instance = get_api_instance(ModuleType.TOOL) # only add non-None parameters params = { "order": order, @@ -71,7 +71,7 @@ async def a_list_functions( if after and before: raise ValueError("Only one of after and before can be specified.") - api_instance = get_api_instance(ModuleType.tool, async_client=True) + api_instance = get_api_instance(ModuleType.TOOL, async_client=True) # only add non-None parameters params = { "order": order, @@ -93,7 +93,7 @@ def get_function(function_id: str) -> Function: :param function_id: The ID of the function. """ - api_instance = get_api_instance(ModuleType.tool) + api_instance = get_api_instance(ModuleType.TOOL) response: FunctionGetResponse = api_instance.get_function(function_id=function_id) function: Function = Function(**response.data) return function @@ -106,7 +106,7 @@ async def a_get_function(function_id: str) -> Function: :param function_id: The ID of the function. """ - api_instance = get_api_instance(ModuleType.tool, async_client=True) + api_instance = get_api_instance(ModuleType.TOOL, async_client=True) response: FunctionGetResponse = await api_instance.get_function(function_id=function_id) function: Function = Function(**response.data) return function @@ -127,7 +127,7 @@ def create_function( """ # todo verify parameters - api_instance = get_api_instance(ModuleType.tool) + api_instance = get_api_instance(ModuleType.TOOL) body = FunctionCreateRequest( name=name, description=description, @@ -153,7 +153,7 @@ async def a_create_function( """ # todo verify parameters - api_instance = get_api_instance(ModuleType.tool, async_client=True) + api_instance = get_api_instance(ModuleType.TOOL, async_client=True) body = FunctionCreateRequest( name=name, description=description, @@ -179,7 +179,7 @@ def update_function( :return: The updated function object. """ #todo: verify schema - api_instance = get_api_instance(ModuleType.tool) + api_instance = get_api_instance(ModuleType.TOOL) body = FunctionUpdateRequest( name=name, description=description, @@ -207,7 +207,7 @@ async def a_update_function( :param authentication: The function API authentication. :return: The updated function object. """ - api_instance = get_api_instance(ModuleType.tool, async_client=True) + api_instance = get_api_instance(ModuleType.TOOL, async_client=True) body = FunctionUpdateRequest( name=name, description=description, @@ -228,7 +228,7 @@ def delete_function(function_id: str) -> None: :param function_id: The ID of the function. """ - api_instance = get_api_instance(ModuleType.tool) + api_instance = get_api_instance(ModuleType.TOOL) api_instance.delete_function(function_id=function_id) @@ -239,7 +239,7 @@ async def a_delete_function(function_id: str) -> None: :param function_id: The ID of the function. """ - api_instance = get_api_instance(ModuleType.tool, async_client=True) + api_instance = get_api_instance(ModuleType.TOOL, async_client=True) await api_instance.delete_function(function_id=function_id) diff --git a/test/common/logger.py b/test/common/logger.py index cd77471..88a5fef 100644 --- a/test/common/logger.py +++ b/test/common/logger.py @@ -33,26 +33,26 @@ def __init__(self): def logger_info_base(http_code: str, http_status: str, res): - logger.info("http_code ==>> except_res:{}, real_res:【 {} 】".format(http_code, res.status_code)) - logger.info("http_status ==>> except_res:{}, real_res:【 {} 】".format(http_status, res.json()["status"])) + logger.info("http_code ==>> except_res:{}, real_res: {}".format(http_code, res.status_code)) + logger.info("http_status ==>> except_res:{}, real_res: {}".format(http_status, res.json()["status"])) def logger_info(http_code: str, http_status: str, res): logger_info_base(http_code, http_status, res) - logger.info("total_count ==>> real_res:【 {} 】".format(res.json()["total_count"])) - logger.info("fetched_count ==>> real_res:【 {} 】".format(res.json()["fetched_count"])) + logger.info("total_count ==>> real_res: {}".format(res.json()["total_count"])) + logger.info("fetched_count ==>> real_res: {}".format(res.json()["fetched_count"])) def logger_ready_or_creating_info(http_code: str, http_status: str, res): logger_info_base(http_code, http_status, res) - logger.info("data_status ==>> real_res:【 {} 】".format(res.json()["data"]["status"])) + logger.info("data_status ==>> real_res: {}".format(res.json()["data"]["status"])) def logger_success_info(http_code: str, http_status: str, data_status, res): logger_info_base(http_code, http_status, res) - logger.info("data_status ==>> except_res:{}, real_res:【 {} 】".format(data_status, res.json()["data"]["status"])) + logger.info("data_status ==>> except_res:{}, real_res: {}".format(data_status, res.json()["data"]["status"])) def logger_error_info(http_code: str, http_status: str, data_status, res): logger_info_base(http_code, http_status, res) - logger.info("data_status ==>> except_res:{}, real_res:【 {} 】".format(data_status, res.json()["error"]["code"])) + logger.info("data_status ==>> except_res:{}, real_res: {}".format(data_status, res.json()["error"]["code"])) diff --git a/test/common/utils.py b/test/common/utils.py index 737f824..88be708 100644 --- a/test/common/utils.py +++ b/test/common/utils.py @@ -24,10 +24,6 @@ def list_to_dict(data: list): return d -def get_password(): - return ''.join(random.choices(string.ascii_letters, k=7))+str(random.randint(0, 9)) - - def assume(res, except_dict: Dict[str, Any]): pytest.assume(res.status_code == int(except_dict["except_http_code"])) pytest.assume(res.json()["status"] == except_dict["except_status"]) @@ -56,6 +52,8 @@ def assume_assistant(res, assistant_dict: Dict[str, Any]): pass elif key == 'system_prompt_template' and isinstance(value, str): pytest.assume(res[key] == [assistant_dict[key]]) + elif key in ["memory", "tool", "retrievals"]: + pass else: pytest.assume(res[key] == assistant_dict[key]) diff --git a/test/config.py b/test/config.py index 43d3a83..408d1a4 100644 --- a/test/config.py +++ b/test/config.py @@ -1,14 +1,13 @@ -# your chat_model_id +import os -chat_model_id = "KnnBZsjH" +chat_completion_model_id = os.environ.get("CHAT_COMPLETION_MODEL_ID") +if not chat_completion_model_id: + raise ValueError("chat_completion_model_id is not defined") -# your text_model_id - -text_model_id = "fTfk462c" - -# need sleep time +embedding_model_id = os.environ.get("EMBEDDING_MODEL_ID") +if not chat_completion_model_id: + raise ValueError("chat_completion_model_id is not defined") sleep_time = 1 -# TASKINGAI_API_KEY = "tag4rNUwc7sYBjTjXtGNe6WbGNhI056C" \ No newline at end of file diff --git a/test/pytest.ini b/test/pytest.ini index 51e4977..ca5138e 100644 --- a/test/pytest.ini +++ b/test/pytest.ini @@ -8,3 +8,4 @@ markers = log_cli = False +addopts = -W ignore::DeprecationWarning --clean-alluredir --alluredir=./allure-report \ No newline at end of file diff --git a/test/testcase/test_async/test_async_assistant.py b/test/testcase/test_async/test_async_assistant.py index 2cf3cc4..c7bf613 100644 --- a/test/testcase/test_async/test_async_assistant.py +++ b/test/testcase/test_async/test_async_assistant.py @@ -2,7 +2,8 @@ import pytest from taskingai.assistant import * -from test.config import chat_model_id, sleep_time, text_model_id +from taskingai.assistant.memory import AssistantNaiveMemory +from test.config import chat_completion_model_id, sleep_time, embedding_model_id from test.common.read_data import data from test.common.logger import logger from test.common.utils import list_to_dict @@ -26,14 +27,10 @@ class TestAssistant(Base): @pytest.mark.asyncio async def test_a_create_assistant(self, a_create_assistant_data): - # List assistants. - old_res = await a_list_assistants(limit=100) - old_nums = len(old_res) - # Create an assistant. assistant_dict = list_to_dict(a_create_assistant_data) - assistant_dict.update({"model_id": chat_model_id}) + assistant_dict.update({"model_id": chat_completion_model_id}) if ("retrievals" in assistant_dict.keys() and len(assistant_dict["retrievals"]) > 0 and assistant_dict["retrievals"][0]["type"] == "collection"): assistant_dict["retrievals"][0]["id"] = Base.collection_id @@ -41,32 +38,20 @@ async def test_a_create_assistant(self, a_create_assistant_data): == "action"): logger.info(f'a_create_assistant_action_id:{Base.action_id}') assistant_dict["tools"][0]["id"] = Base.action_id + assistant_dict.update({"memory": AssistantNaiveMemory()}) res = await a_create_assistant(**assistant_dict) res_dict = res.to_dict() logger.info(f'response_dict:{res_dict}, except_dict:{assistant_dict}') pytest.assume(res_dict.keys() == self.assistant_keys) assume_assistant(res_dict, assistant_dict) - # Get an assistant. - - get_res = await a_get_assistant(assistant_id=res_dict["assistant_id"]) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.assistant_keys) - - # List assistants. - - new_res = await a_list_assistants(limit=100) - new_nums = len(new_res) - logger.info(f'old_nums:{old_nums}, new_nums:{new_nums}') - pytest.assume(new_nums == old_nums + 1) - @pytest.mark.run(order=19) @pytest.mark.asyncio async def test_a_list_assistants(self): # List assistants. - nums_limit = 2 + nums_limit = 1 res = await a_list_assistants(limit=nums_limit) pytest.assume(len(res) == nums_limit) @@ -111,15 +96,7 @@ async def test_a_update_assistant(self): pytest.assume(res_dict["name"] == name) pytest.assume(res_dict["description"] == description) - # Get an assistant. - - get_res = await a_get_assistant(assistant_id=self.assistant_id) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.assistant_keys) - pytest.assume(get_res_dict["name"] == name) - pytest.assume(get_res_dict["description"] == description) - - @pytest.mark.run(order=32) + @pytest.mark.run(order=33) @pytest.mark.asyncio async def test_a_delete_assistant(self): @@ -142,6 +119,7 @@ async def test_a_delete_assistant(self): new_nums = len(new_assistants) pytest.assume(new_nums == old_nums - 1 - i) + @pytest.mark.test_async class TestChat(Base): @@ -152,12 +130,7 @@ class TestChat(Base): @pytest.mark.asyncio async def test_a_create_chat(self): - # List chats. - - old_res = await a_list_chats(assistant_id=self.assistant_id) - old_nums = len(old_res) - - for x in range(4): + for x in range(2): # Create a chat. @@ -165,25 +138,13 @@ async def test_a_create_chat(self): res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.chat_keys) - # Get a chat. - - get_res = await a_get_chat(assistant_id=self.assistant_id, chat_id=res_dict["chat_id"]) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.chat_keys) - - # List chats. - - new_res = await a_list_chats(assistant_id=self.assistant_id) - new_nums = len(new_res) - pytest.assume(new_nums == old_nums + 1 + x) - @pytest.mark.run(order=23) @pytest.mark.asyncio async def test_a_list_chats(self): # List chats. - nums_limit = 2 + nums_limit = 1 res = await a_list_chats(limit=nums_limit, assistant_id=self.assistant_id) pytest.assume(len(res) == nums_limit) @@ -226,14 +187,7 @@ async def test_a_update_chat(self): pytest.assume(res_dict.keys() == self.chat_keys) pytest.assume(res_dict["metadata"] == metadata) - # Get a chat. - - get_res = await a_get_chat(assistant_id=self.assistant_id, chat_id=self.chat_id) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.chat_keys) - pytest.assume(get_res_dict["metadata"] == metadata) - - @pytest.mark.run(order=31) + @pytest.mark.run(order=32) @pytest.mark.asyncio async def test_a_delete_chat(self): @@ -268,12 +222,7 @@ class TestMessage(Base): @pytest.mark.asyncio async def test_a_create_message(self): - # List messages. - - old_res = await a_list_messages(assistant_id=self.assistant_id, chat_id=self.chat_id) - old_nums = len(old_res) - - for x in range(4): + for x in range(2): # Create a user message. @@ -285,26 +234,13 @@ async def test_a_create_message(self): pytest.assume(res_dict["content"]["text"] == text) pytest.assume(res_dict["role"] == "user") - # Get a message. - - get_res = await a_get_message(assistant_id=self.assistant_id, chat_id=self.chat_id, - message_id=res_dict["message_id"]) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.message_keys) - - # List messages. - - new_res = await a_list_messages(assistant_id=self.assistant_id, chat_id=self.chat_id) - new_nums = len(new_res) - pytest.assume(new_nums == old_nums + 1 + x) - @pytest.mark.run(order=27) @pytest.mark.asyncio async def test_a_list_messages(self): # List messages. - nums_limit = 2 + nums_limit = 1 res = await a_list_messages(limit=nums_limit, assistant_id=self.assistant_id, chat_id=self.chat_id) pytest.assume(len(res) == nums_limit) @@ -351,22 +287,10 @@ async def test_a_update_message(self): pytest.assume(res_dict.keys() == self.message_keys) pytest.assume(res_dict["metadata"] == metadata) - # Get a message. - - get_res = await a_get_message(assistant_id=self.assistant_id, chat_id=self.chat_id, message_id=self.message_id) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.message_keys) - pytest.assume(get_res_dict["metadata"] == metadata) - @pytest.mark.run(order=30) @pytest.mark.asyncio async def test_a_generate_message(self): - # List messages. - - messages = await a_list_messages(assistant_id=self.assistant_id, chat_id=self.chat_id) - old_nums = len(messages) - # Generate an assistant message. res = await a_generate_message(assistant_id=self.assistant_id, chat_id=self.chat_id, @@ -375,18 +299,6 @@ async def test_a_generate_message(self): pytest.assume(res_dict.keys() == self.message_keys) pytest.assume(res_dict["role"] == "assistant") - # Get a message. - - get_res = await a_get_message(assistant_id=self.assistant_id, chat_id=self.chat_id, - message_id=res_dict["message_id"]) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.message_keys) - # List messages. - - new_res = await a_list_messages(assistant_id=self.assistant_id, chat_id=self.chat_id) - new_nums = len(new_res) - pytest.assume(new_nums == old_nums + 1) - @pytest.mark.run(order=30) @pytest.mark.asyncio async def test_a_generate_message_by_stream(self): @@ -425,265 +337,3 @@ async def test_a_generate_message_by_stream(self): logger.info(f"Message: {real_str}") logger.info(f"except_list: {except_list} real_list: {real_list}") pytest.assume(set(except_list) == set(real_list)) - - @pytest.mark.run(order=30) - @pytest.mark.asyncio - @pytest.mark.test_abnormal - async def test_a_generate_message_in_user_message_not_created(self): - - # create chat - - chat_res = await a_create_chat(assistant_id=self.assistant_id) - chat_id = chat_res.chat_id - logger.info(f'chat_id:{chat_id}') - - # Generate an assistant message. - - try: - res = await a_generate_message(assistant_id=self.assistant_id, chat_id=chat_id, - system_prompt_variables={}) - except Exception as e: - logger.info(f'test_a_generate_message_in_user_message_not_created{e}') - pytest.assume("There is no user message in the chat context." in str(e)) - - @pytest.mark.run(order=30) - @pytest.mark.asyncio - @pytest.mark.test_abnormal - async def test_a_create_message_in_generating_assistant_message(self): - - # create chat - - chat_res = await a_create_chat(assistant_id=self.assistant_id) - chat_id = chat_res.chat_id - logger.info(f'chat_id:{chat_id}') - - # create user message - - user_message = await a_create_message( - assistant_id=self.assistant_id, - chat_id=chat_id, - text="count from 1 to 100 and separate numbers by comma.", - ) - - # Generate an assistant message by stream. - - await a_generate_message(assistant_id=self.assistant_id, chat_id=chat_id, - system_prompt_variables={}, - stream=True) - - # create user message - - try: - user_message = await a_create_message( - assistant_id=self.assistant_id, - chat_id=chat_id, - text="count from 100 to 200 and separate numbers by comma.", - ) - except Exception as e: - logger.info(f'test_a_create_message_in_generating_assistant_message{user_message}') - pytest.assume("Chat is locked by another generation process. Please try again later." in str(e)) - - @pytest.mark.run(order=30) - @pytest.mark.asyncio - @pytest.mark.test_abnormal - async def test_a_generate_message_in_generating_assistant_message(self): - - # create chat - - chat_res = await a_create_chat(assistant_id=self.assistant_id) - chat_id = chat_res.chat_id - logger.info(f'chat_id:{chat_id}') - - # create user message - - await a_create_message( - assistant_id=self.assistant_id, - chat_id=chat_id, - text="count from 1 to 100 and separate numbers by comma.", - ) - - # Generate an assistant message by stream. - - stream_res = await a_generate_message(assistant_id=self.assistant_id, chat_id=chat_id, - system_prompt_variables={}, - stream=True) - - # Generate an assistant message by stream. - - try: - stream_res = await a_generate_message(assistant_id=self.assistant_id, chat_id=chat_id, - system_prompt_variables={}, - stream=True) - except Exception as e: - logger.info(f'test_a_generate_message_in_generating_assistant_message{stream_res}') - pytest.assume("Chat is locked by another generation process. Please try again later." in str(e)) - - @pytest.mark.run(order=30) - @pytest.mark.asyncio - @pytest.mark.test_abnormal - async def test_a_generate_message_in_generated_assistant_message(self): - - # create chat - - chat_res = await a_create_chat(assistant_id=self.assistant_id) - chat_id = chat_res.chat_id - logger.info(f'chat_id:{chat_id}') - - # create user message - - user_message = await a_create_message( - assistant_id=self.assistant_id, - chat_id=chat_id, - text="count from 1 to 100 and separate numbers by comma.", - ) - - # Generate an assistant message by stream. - - res = await a_generate_message(assistant_id=self.assistant_id, chat_id=chat_id, - system_prompt_variables={}) - - # Generate an assistant message by stream. - - try: - stream_res = await a_generate_message(assistant_id=self.assistant_id, chat_id=chat_id, - system_prompt_variables={}, - stream=True) - except Exception as e: - logger.info(f'test_a_generate_message_in_generated_assistant_message{e}') - pytest.assume("Cannot generate another assistant message after an assistant message." in str(e)) - - @pytest.mark.run(order=30) - @pytest.mark.asyncio - @pytest.mark.test_abnormal - async def test_a_generate_message_in_action_deleted_assistant(self): - - # create action - - schema = { - "openapi": "3.1.0", - "info": { - "title": "Get weather data", - "description": "Retrieves current weather data for a location.", - "version": "v1.0.0" - }, - "servers": [ - { - "url": "https://weather.example.com" - } - ], - "paths": { - "/location": { - "get": { - "description": "Get temperature for a specific location", - "operationId": "GetCurrentWeather", - "parameters": [ - { - "name": "location", - "in": "query", - "description": "The city and state to retrieve the weather for", - "required": True, - "schema": { - "type": "string" - } - } - ], - "deprecated": False - }, - "post": { - "description": "UPDATE temperature for a specific location", - "operationId": "UpdateCurrentWeather", - "requestBody": { - "required": True, - "content": { - "application/json": { - "schema": { - "$ref": "#/componeents/schemas/ActionCreateRequest" - } - } - } - }, - "deprecated": False - } - } - }, - "components": { - "schemas": {} - }, - "security": [] - } - action_res = await a_bulk_create_actions(schema=schema) - action_id = action_res[0].action_id - - # create an assistant - - assistant_res = await a_create_assistant(name="test", description="test", model_id=chat_model_id, tools=[{"type": "action", "id": action_id}]) - assistant_id = assistant_res.assistant_id - - # create a chat - - chat_res = await a_create_chat(assistant_id=assistant_id) - chat_id = chat_res.chat_id - - # create user message - - user_message = await a_create_message( - assistant_id=assistant_id, - chat_id=chat_id, - text="count from 1 to 100 and separate numbers by comma.", - ) - - # delete action - - await a_delete_action(action_id=action_id) - await asyncio.sleep(sleep_time) - - # Generate an assistant message by stream. - - try: - res = await a_generate_message(assistant_id=assistant_id, chat_id=chat_id, - system_prompt_variables={}) - except Exception as e: - logger.info(f'test_a_generate_message_in_action_deleted_assistant{e}') - pytest.assume("Some tools are not found" in str(e)) - - @pytest.mark.run(order=30) - @pytest.mark.asyncio - @pytest.mark.test_abnormal - async def test_a_generate_message_in_collection_deleted_assistant(self): - - # create collection - - collection_res = await a_create_collection(name="test", description="test", embedding_model_id=text_model_id, capacity=1000) - collection_id = collection_res.collection_id - - # create an assistant - - assistant_res = await a_create_assistant(name="test", description="test", model_id=chat_model_id, retrievals=[{"type": "collection", "id": collection_id}]) - assistant_id = assistant_res.assistant_id - - # create chat - - chat_res = await a_create_chat(assistant_id=assistant_id) - chat_id = chat_res.chat_id - - # create user message - - user_message = await a_create_message( - assistant_id=assistant_id, - chat_id=chat_id, - text="count from 1 to 1000 and separate numbers by comma.", - ) - - # delete collection - - await a_delete_collection(collection_id) - await asyncio.sleep(sleep_time) - - # Generate an assistant message by stream. - - try: - res = await a_generate_message(assistant_id=assistant_id, chat_id=chat_id, - system_prompt_variables={}) - except Exception as e: - logger.info(f'test_a_generate_message_in_collection_deleted_assistant{e}') - pytest.assume(f"Collections not found" in str(e)) diff --git a/test/testcase/test_async/test_async_inference.py b/test/testcase/test_async/test_async_inference.py index 56b5880..6f5cd66 100644 --- a/test/testcase/test_async/test_async_inference.py +++ b/test/testcase/test_async/test_async_inference.py @@ -1,7 +1,7 @@ import pytest from taskingai.inference import * -from test.config import text_model_id, chat_model_id +from test.config import embedding_model_id, chat_completion_model_id from test.common.logger import logger @@ -15,7 +15,7 @@ async def test_a_chat_completion(self): # normal chat completion. normal_res = await a_chat_completion( - model_id=chat_model_id, + model_id=chat_completion_model_id, messages=[ SystemMessage("You are a professional assistant."), UserMessage("Hi"), @@ -24,12 +24,12 @@ async def test_a_chat_completion(self): pytest.assume(normal_res.finish_reason == "stop") pytest.assume(normal_res.message.content) pytest.assume(normal_res.message.role == "assistant") - pytest.assume(normal_res.message.function_call is None) + pytest.assume(normal_res.message.function_calls is None) # multi round chat completion. multi_round_res = await a_chat_completion( - model_id=chat_model_id, + model_id=chat_completion_model_id, messages=[ SystemMessage("You are a professional assistant."), UserMessage("Hi"), @@ -44,12 +44,12 @@ async def test_a_chat_completion(self): pytest.assume(multi_round_res.finish_reason == "stop") pytest.assume(multi_round_res.message.content) pytest.assume(multi_round_res.message.role == "assistant") - pytest.assume(multi_round_res.message.function_call is None) + pytest.assume(multi_round_res.message.function_calls is None) # config max tokens chat completion. max_tokens_res = await a_chat_completion( - model_id=chat_model_id, + model_id=chat_completion_model_id, messages=[ SystemMessage("You are a professional assistant."), UserMessage("Hi"), @@ -66,11 +66,11 @@ async def test_a_chat_completion(self): pytest.assume(max_tokens_res.finish_reason == "length") pytest.assume(max_tokens_res.message.content) pytest.assume(max_tokens_res.message.role == "assistant") - pytest.assume(max_tokens_res.message.function_call is None) + pytest.assume(max_tokens_res.message.function_calls is None) # chat completion with stream. - stream_res = await a_chat_completion(model_id=chat_model_id, + stream_res = await a_chat_completion(model_id=chat_completion_model_id, messages=[ SystemMessage("You are a professional assistant."), UserMessage("count from 1 to 50 and separate numbers by comma."), @@ -100,7 +100,7 @@ async def test_a_text_embedding(self): # Text embedding with str. input_str = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." - str_res = await a_text_embedding(model_id=text_model_id, input=input_str) + str_res = await a_text_embedding(model_id=embedding_model_id, input=input_str) pytest.assume(len(str_res) > 0) for score in str_res: pytest.assume(float(-1) <= score <= float(1)) @@ -109,7 +109,7 @@ async def test_a_text_embedding(self): input_list = ["hello", "world"] input_list_length = len(input_list) - list_res = await a_text_embedding(model_id=text_model_id, input=input_list) + list_res = await a_text_embedding(model_id=embedding_model_id, input=input_list) pytest.assume(len(list_res) == input_list_length) for str_res in list_res: pytest.assume(len(str_res) > 0) diff --git a/test/testcase/test_async/test_async_retrieval.py b/test/testcase/test_async/test_async_retrieval.py index d152616..ab52f19 100644 --- a/test/testcase/test_async/test_async_retrieval.py +++ b/test/testcase/test_async/test_async_retrieval.py @@ -2,7 +2,7 @@ import pytest from taskingai.retrieval import a_list_collections, a_create_collection, a_get_collection, a_update_collection, a_delete_collection, a_list_records, a_create_text_record, a_get_record, a_update_record, a_delete_record, a_query_chunks -from test.config import text_model_id, sleep_time +from test.config import embedding_model_id, sleep_time from test.common.logger import logger from test.testcase.test_async.base import Base @@ -20,56 +20,30 @@ class TestCollection(Base): @pytest.mark.asyncio async def test_a_create_collection(self): - # List collections. - - old_res = await a_list_collections(limit=100) - old_nums = len(old_res) - - for x in range(4): + for x in range(2): # Create a collection. name = f"test{x}" description = "just for test" - res = await a_create_collection(name=name, description=description, embedding_model_id=text_model_id, capacity=1000) + res = await a_create_collection(name=name, description=description, embedding_model_id=embedding_model_id, capacity=1000) res_dict = res.to_dict() logger.info(res_dict) pytest.assume(res_dict.keys() == self.collection_keys) pytest.assume(res_dict["configs"].keys() == self.collection_configs_keys) pytest.assume(res_dict["name"] == name) pytest.assume(res_dict["description"] == description) - pytest.assume(res_dict["embedding_model_id"] == text_model_id) + pytest.assume(res_dict["embedding_model_id"] == embedding_model_id) pytest.assume(res_dict["capacity"] == 1000) pytest.assume(res_dict["status"] == "creating") - # Get a collection. - - await asyncio.sleep(sleep_time) - collection_id = res_dict["collection_id"] - get_res = await a_get_collection(collection_id=collection_id) - get_res_dict = get_res.to_dict() - logger.info(get_res_dict) - pytest.assume(get_res_dict.keys() == self.collection_keys) - pytest.assume(get_res_dict["configs"].keys() == self.collection_configs_keys) - pytest.assume(get_res_dict["name"] == name) - pytest.assume(get_res_dict["description"] == description) - pytest.assume(get_res_dict["embedding_model_id"] == text_model_id) - pytest.assume(get_res_dict["capacity"] == 1000) - pytest.assume(get_res_dict["status"] == "ready") - - # List collections. - - new_res = await a_list_collections(limit=100) - new_nums = len(new_res) - pytest.assume(new_nums == old_nums + 1 + x) - @pytest.mark.run(order=10) @pytest.mark.asyncio async def test_a_list_collections(self): # List collections. - nums_limit = 2 + nums_limit = 1 res = await a_list_collections(limit=nums_limit) pytest.assume(len(res) == nums_limit) after_id = res[-1].collection_id @@ -97,7 +71,7 @@ async def test_a_get_collection(self, a_collection_id): res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.collection_keys) pytest.assume(res_dict["configs"].keys() == self.collection_configs_keys) - pytest.assume(res_dict["status"] == "ready") + pytest.assume(res_dict["status"] == "ready" or "creating") @pytest.mark.run(order=12) @pytest.mark.asyncio @@ -115,37 +89,23 @@ async def test_a_update_collection(self): pytest.assume(res_dict["description"] == description) pytest.assume(res_dict["status"] == "ready") - # Get a collection. - - await asyncio.sleep(sleep_time) - get_res = await a_get_collection(collection_id=self.collection_id) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.collection_keys) - pytest.assume(get_res_dict["configs"].keys() == self.collection_configs_keys) - pytest.assume(get_res_dict["name"] == name) - pytest.assume(get_res_dict["description"] == description) - pytest.assume(get_res_dict["status"] == "ready") - - @pytest.mark.run(order=34) + @pytest.mark.run(order=35) @pytest.mark.asyncio async def test_a_delete_collection(self): # List collections. old_res = await a_list_collections(order="desc", limit=100, after=None, before=None) - old_nums = len(old_res) for index, collection in enumerate(old_res): collection_id = collection.collection_id # Delete a collection. await a_delete_collection(collection_id=collection_id) - # await asyncio.sleep(3) + new_collections = await a_list_collections(order="desc", limit=100, after=None, before=None) # List collections. collection_ids = [c.collection_id for c in new_collections] pytest.assume(collection_id not in collection_ids) - new_nums = len(new_collections) - # pytest.assume( new_nums == old_nums - 1 - index - - + + @pytest.mark.test_async class TestRecord(Base): @@ -159,12 +119,7 @@ class TestRecord(Base): @pytest.mark.asyncio async def test_a_create_text_record(self): - # List records. - - old_res = await a_list_records(collection_id=self.collection_id) - old_nums = len(old_res) - - for x in range(4): + for x in range(2): # Create a text record. @@ -176,24 +131,6 @@ async def test_a_create_text_record(self): pytest.assume(res_dict["content"]["text"] == text) pytest.assume(res_dict["status"] == "creating") - # Get a record. - - await asyncio.sleep(sleep_time*25) - record_id = res_dict["record_id"] - get_res = await a_get_record(collection_id=self.collection_id, record_id=record_id) - logger.info(f'a_create_record:get_res {get_res}') - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.record_keys) - pytest.assume(get_res_dict["content"].keys() == self.record_content_keys) - pytest.assume(get_res_dict["content"]["text"] == text) - pytest.assume(get_res_dict["status"] == "ready") - - # List records. - - new_res = await a_list_records(collection_id=self.collection_id) - new_nums = len(new_res) - pytest.assume(new_nums == old_nums + 1 + x) - @pytest.mark.run(order=14) @pytest.mark.asyncio async def test_a_list_records(self, a_record_id): @@ -202,7 +139,7 @@ async def test_a_list_records(self, a_record_id): if not Base.record_id: Base.record_id = await a_record_id - nums_limit = 2 + nums_limit = 1 res = await a_list_records(limit=nums_limit, collection_id=self.collection_id) pytest.assume(len(res) == nums_limit) @@ -226,13 +163,13 @@ async def test_a_list_records(self, a_record_id): async def test_a_get_record(self): # Get a record. - await asyncio.sleep(sleep_time*25) + res = await a_get_record(collection_id=self.collection_id, record_id=self.record_id) logger.info(f'a_get_record:{res}') res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.record_keys) pytest.assume(res_dict["content"].keys() == self.record_content_keys) - pytest.assume(res_dict["status"] == "ready") + pytest.assume(res_dict["status"] == "ready" or "creating") @pytest.mark.run(order=16) @pytest.mark.asyncio @@ -248,80 +185,8 @@ async def test_a_update_record(self): pytest.assume(res_dict["content"].keys() == self.record_content_keys) pytest.assume(res_dict["metadata"] == metadata) - # Get a record. - - await asyncio.sleep(sleep_time*25) - get_res = await a_get_record(collection_id=self.collection_id, record_id=self.record_id) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.record_keys) - pytest.assume(get_res_dict["content"].keys() == self.record_content_keys) - pytest.assume(get_res_dict["metadata"] == metadata) - pytest.assume(get_res_dict["status"] == "ready") - - @pytest.mark.run(order=17) - @pytest.mark.asyncio - @pytest.mark.test_abnormal - async def test_a_create_record_in_nonexistent_collection(self): - - # Create collection. - - collection_id = "nonexistent_collection_id" - - # Create a record. - - text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." - try: - res = await a_create_text_record(collection_id=collection_id, text=text) - except Exception as e: - logger.info(f'test_a_create_record_in_creating_collection:{e}') - pytest.assume(f"Collection not found: {collection_id}" in str(e)) - - @pytest.mark.run(order=17) - @pytest.mark.asyncio - @pytest.mark.test_abnormal - async def test_a_create_record_in_creating_collection(self): - - # Create collection. - - name = "test" - description = "just for test" - res = await a_create_collection(name=name, description=description, embedding_model_id=text_model_id, - capacity=1000) - collection_id = res.collection_id - - # Create a record. - - text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." - try: - res = await a_create_text_record(collection_id=collection_id, text=text) - except Exception as e: - logger.info(f'test_a_create_record_in_creating_collection:{e}') - pytest.assume(f"Collection {collection_id} is not ready." in str(e)) - - @pytest.mark.run(order=17) - @pytest.mark.asyncio - @pytest.mark.test_abnormal - async def test_a_create_record_in_deleting_collection(self): - - # Create collection. - name = "test" - description = "just for test" - res = await a_create_collection(name=name, description=description, embedding_model_id=text_model_id, - capacity=1000) - collection_id = res.collection_id - await a_delete_collection(collection_id=collection_id) - - # Create a record. - - text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." - try: - res = await a_create_text_record(collection_id=collection_id, text=text) - except Exception as e: - logger.info(f'test_a_create_record_in_creating_collection:{e}') - pytest.assume(f"Collection not found: {collection_id}" in str(e)) - - @pytest.mark.run(order=33) + @pytest.mark.run(order=34) @pytest.mark.asyncio async def test_a_delete_record(self): @@ -369,67 +234,3 @@ async def test_a_query_chunks(self): pytest.assume(query_text in chunk_dict["text"]) pytest.assume(chunk_dict["score"] >= 0) - @pytest.mark.run(order=17) - @pytest.mark.asyncio - @pytest.mark.test_abnormal - async def test_a_query_chunks_in_creating_collection(self): - - # Create collection. - - name = "test" - description = "just for test" - res = await a_create_collection(name=name, description=description, embedding_model_id=text_model_id, - capacity=1000) - collection_id = res.collection_id - - # Query chunks - - query_text = "Machine learning" - top_k = 1 - try: - res = await a_query_chunks(collection_id=collection_id, query_text=query_text, top_k=top_k) - except Exception as e: - logger.info(f'test_a_query_chunks_in_creating_collection:{e}') - pytest.assume(f"Collection {collection_id} is not ready." in str(e)) - - @pytest.mark.run(order=17) - @pytest.mark.asyncio - @pytest.mark.test_abnormal - async def test_a_query_chunks_in_deleting_collection(self): - - # Create collection. - - name = "test" - description = "just for test" - collection_res = await a_create_collection(name=name, description=description, embedding_model_id=text_model_id, - capacity=1000) - collection_id = collection_res.collection_id - - # delete collection - - await a_delete_collection(collection_id=collection_id) - - # Query chunks - - query_text = "Machine learning" - top_k = 1 - try: - res = await a_query_chunks(collection_id=collection_id, query_text=query_text, top_k=top_k) - except Exception as e: - logger.info(f'test_a_query_chunks_in_deleting_collection:{e}') - pytest.assume("Collections not found" in str(e)) - - @pytest.mark.run(order=17) - @pytest.mark.asyncio - @pytest.mark.test_abnormal - async def test_a_query_chunks_in_nonexistent_collection(self): - - # Query chunks - - query_text = "Machine learning" - top_k = 1 - try: - res = await a_query_chunks(collection_id="nonexistent_collection_id", query_text=query_text, top_k=top_k) - except Exception as e: - logger.info(f'test_a_query_chunks_in_nonexistent_collection:{e}') - pytest.assume('Collections not found' in str(e)) diff --git a/test/testcase/test_async/test_async_tool.py b/test/testcase/test_async/test_async_tool.py index 25a0eb6..497410c 100644 --- a/test/testcase/test_async/test_async_tool.py +++ b/test/testcase/test_async/test_async_tool.py @@ -71,57 +71,23 @@ class TestAction(Base): @pytest.mark.asyncio async def test_a_bulk_create_actions(self): - # List actions. - - old_res = await a_list_actions(limit=100) - old_nums = len(old_res) - - for x in range(4): - - # Create an action. - - res = await a_bulk_create_actions(schema=self.schema) - for action in res: - action_dict = action.to_dict() - logger.info(action_dict) - pytest.assume(action_dict.keys() == self.action_keys) - pytest.assume(action_dict["schema"].keys() == self.action_schema_keys) - - for key in action_dict["schema"].keys(): - if key == "paths": - if action_dict["schema"][key]["/location"] == "get": - pytest.assume(action_dict["schema"][key]["/location"]["get"] == self.schema["paths"]["/location"]["get"]) - elif action_dict["schema"][key]["/location"] == "post": - pytest.assume(action_dict["schema"][key]["/location"]["post"] == self.schema["paths"]["/location"]["post"]) - else: - pytest.assume(action_dict["schema"][key] == self.schema[key]) - - # Get an action. - - action_id = action_dict["action_id"] - get_res = await a_get_action(action_id=action_id) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.action_keys) - pytest.assume(get_res_dict["schema"].keys() == self.action_schema_keys) - - for key in action_dict["schema"].keys(): - if key == "paths": - if action_dict["schema"][key]["/location"] == "get": - pytest.assume( - action_dict["schema"][key]["/location"]["get"] == self.schema["paths"]["/location"]["get"]) - elif action_dict["schema"][key]["/location"] == "post": - pytest.assume( - action_dict["schema"][key]["/location"]["post"] == self.schema["paths"]["/location"][ - "post"]) - else: - pytest.assume(action_dict["schema"][key] == self.schema[key]) - - # List actions. - - new_res = await a_list_actions(limit=100) - new_nums = len(new_res) - res_num = len(res) - pytest.assume(new_nums == old_nums + res_num + x*2) + # Create an action. + + res = await a_bulk_create_actions(schema=self.schema) + for action in res: + action_dict = action.to_dict() + logger.info(action_dict) + pytest.assume(action_dict.keys() == self.action_keys) + pytest.assume(action_dict["schema"].keys() == self.action_schema_keys) + + for key in action_dict["schema"].keys(): + if key == "paths": + if action_dict["schema"][key]["/location"] == "get": + pytest.assume(action_dict["schema"][key]["/location"]["get"] == self.schema["paths"]["/location"]["get"]) + elif action_dict["schema"][key]["/location"] == "post": + pytest.assume(action_dict["schema"][key]["/location"]["post"] == self.schema["paths"]["/location"]["post"]) + else: + pytest.assume(action_dict["schema"][key] == self.schema[key]) @pytest.mark.run(order=5) @pytest.mark.asyncio @@ -143,7 +109,7 @@ async def test_a_list_actions(self): # List actions. - nums_limit = 2 + nums_limit = 1 res = await a_list_actions(limit=nums_limit) pytest.assume(len(res) == nums_limit) @@ -225,15 +191,6 @@ async def test_a_update_action(self): pytest.assume(res_dict["schema"].keys() == self.action_schema_keys) pytest.assume(res_dict["schema"] == update_schema) - # Get an action. - - get_res = await a_get_action(action_id=self.action_id) - get_res_dict = get_res.to_dict() - logger.info(get_res_dict) - pytest.assume(get_res_dict.keys() == self.action_keys) - pytest.assume(get_res_dict["schema"].keys() == self.action_schema_keys) - pytest.assume(res_dict["schema"] == update_schema) - @pytest.mark.run(order=40) @pytest.mark.asyncio async def test_a_delete_action(self): @@ -249,7 +206,7 @@ async def test_a_delete_action(self): # Delete an action. await a_delete_action(action_id=action_id) - await asyncio.sleep(sleep_time) + new_actions = await a_list_actions() action_ids = [action.action_id for action in new_actions] pytest.assume(action_id not in action_ids) diff --git a/test/testcase/test_sync/test_sync_assistant.py b/test/testcase/test_sync/test_sync_assistant.py index 9243a92..26de66d 100644 --- a/test/testcase/test_sync/test_sync_assistant.py +++ b/test/testcase/test_sync/test_sync_assistant.py @@ -4,7 +4,8 @@ from taskingai.assistant import * from taskingai.retrieval import * from taskingai.tool import * -from test.config import chat_model_id, text_model_id, sleep_time +from taskingai.assistant.memory import AssistantNaiveMemory +from test.config import chat_completion_model_id, embedding_model_id, sleep_time from test.common.read_data import data from test.common.logger import logger from test.common.utils import list_to_dict @@ -13,55 +14,38 @@ assistant_data = data.load_yaml("test_assistant_data.yml") + @pytest.mark.test_sync class TestAssistant: - assistant_list = ['assistant_id', 'created_timestamp', 'description', 'metadata', 'model_id', 'name', 'object', 'retrievals', 'system_prompt_template', 'tools'] + assistant_list = ['assistant_id', 'created_timestamp', 'description', 'metadata', 'model_id', 'name', 'object', 'retrievals', 'system_prompt_template', 'tools',"memory"] assistant_keys = set(assistant_list) @pytest.mark.parametrize("create_assistant_data", assistant_data["test_success_create_assistant"]) @pytest.mark.run(order=18) def test_create_assistant(self, collection_id, action_id, create_assistant_data): - # List assistants. - - old_res = list_assistants(limit=100) - old_nums = len(old_res) - # Create an assistant. assistant_dict = list_to_dict(create_assistant_data) - assistant_dict.update({"model_id": chat_model_id}) + assistant_dict.update({"model_id": chat_completion_model_id}) if "retrievals" in assistant_dict.keys() and len(assistant_dict["retrievals"]) > 0 and assistant_dict["retrievals"][0]["type"] == "collection": assistant_dict["retrievals"][0]["id"] = collection_id if "tools" in assistant_dict.keys() and len(assistant_dict["tools"]) > 0 and assistant_dict["tools"][0]["type"] == "action": assistant_dict["tools"][0]["id"] = action_id + assistant_dict.update({"memory": AssistantNaiveMemory()}) res = create_assistant(**assistant_dict) res_dict = res.to_dict() logger.info(f'response_dict:{res_dict}, except_dict:{assistant_dict}') pytest.assume(res_dict.keys() == self.assistant_keys) assume_assistant(res_dict, assistant_dict) - # Get an assistant. - - get_res = get_assistant(assistant_id=res_dict["assistant_id"]) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.assistant_keys) - assume_assistant(get_res_dict, assistant_dict) - - # List assistants. - - new_res = list_assistants(limit=100) - new_nums = len(new_res) - logger.info(f'old_nums:{old_nums}, new_nums:{new_nums}') - pytest.assume(new_nums == old_nums + 1) - @pytest.mark.run(order=19) def test_list_assistants(self): # List assistants. - nums_limit = 2 + nums_limit = 1 res = list_assistants(limit=nums_limit) pytest.assume(len(res) == nums_limit) @@ -101,14 +85,8 @@ def test_update_assistant(self, assistant_id): pytest.assume(res_dict.keys() == self.assistant_keys) pytest.assume(res_dict["name"] == name) pytest.assume(res_dict["description"] == description) - # Get an assistant. - get_res = get_assistant(assistant_id=assistant_id) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.assistant_keys) - pytest.assume(get_res_dict["name"] == name) - pytest.assume(get_res_dict["description"] == description) - @pytest.mark.run(order=32) + @pytest.mark.run(order=33) def test_delete_assistant(self): # List assistants. @@ -140,11 +118,7 @@ class TestChat: @pytest.mark.run(order=22) def test_create_chat(self, assistant_id): - # List chats. - - old_res = list_chats(assistant_id=assistant_id) - old_nums = len(old_res) - for x in range(4): + for x in range(2): # Create a chat. @@ -152,24 +126,12 @@ def test_create_chat(self, assistant_id): res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.chat_keys) - # Get a chat. - - get_res = get_chat(assistant_id=assistant_id, chat_id=res_dict["chat_id"]) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.chat_keys) - - # List chats. - - new_res = list_chats(assistant_id=assistant_id) - new_nums = len(new_res) - pytest.assume(new_nums == old_nums + 1 + x) - @pytest.mark.run(order=23) def test_list_chats(self, assistant_id): # List chats. - nums_limit = 2 + nums_limit = 1 res = list_chats(limit=nums_limit, assistant_id=assistant_id) pytest.assume(len(res) == nums_limit) @@ -208,14 +170,7 @@ def test_update_chat(self, assistant_id, chat_id): pytest.assume(res_dict.keys() == self.chat_keys) pytest.assume(res_dict["metadata"] == metadata) - # Get a chat. - - get_res = get_chat(assistant_id=assistant_id, chat_id=chat_id) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.chat_keys) - pytest.assume(get_res_dict["metadata"] == metadata) - - @pytest.mark.run(order=31) + @pytest.mark.run(order=32) def test_delete_chat(self, assistant_id): # List chats. @@ -247,11 +202,7 @@ class TestMessage: @pytest.mark.run(order=26) def test_create_message(self, assistant_id, chat_id): - # List messages. - - old_res = list_messages(assistant_id=assistant_id, chat_id=chat_id) - old_nums = len(old_res) - for x in range(4): + for x in range(2): # Create a user message. @@ -263,24 +214,12 @@ def test_create_message(self, assistant_id, chat_id): pytest.assume(res_dict["content"]["text"] == text) pytest.assume(res_dict["role"] == "user") - # Get a message. - - get_res = get_message(assistant_id=assistant_id, chat_id=chat_id, message_id=res_dict["message_id"]) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.message_keys) - - # List messages. - - new_res = list_messages(assistant_id=assistant_id, chat_id=chat_id) - new_nums = len(new_res) - pytest.assume(new_nums == old_nums + 1 + x) - @pytest.mark.run(order=27) def test_list_messages(self, assistant_id, chat_id): # List messages. - nums_limit = 2 + nums_limit = 1 res = list_messages(limit=nums_limit, assistant_id=assistant_id, chat_id=chat_id) pytest.assume(len(res) == nums_limit) after_id = res[-1].message_id @@ -319,21 +258,9 @@ def test_update_message(self, assistant_id, chat_id, message_id): pytest.assume(res_dict.keys() == self.message_keys) pytest.assume(res_dict["metadata"] == metadata) - # Get a message. - - get_res = get_message(assistant_id=assistant_id, chat_id=chat_id, message_id=message_id) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.message_keys) - pytest.assume(get_res_dict["metadata"] == metadata) - @pytest.mark.run(order=30) def test_generate_message(self, assistant_id, chat_id): - # List messages. - - messages = list_messages(assistant_id=assistant_id, chat_id=chat_id) - old_nums = len(messages) - # Generate an assistant message by no stream. res = generate_message(assistant_id=assistant_id, chat_id=chat_id, system_prompt_variables={}) @@ -341,18 +268,6 @@ def test_generate_message(self, assistant_id, chat_id): pytest.assume(res_dict.keys() == self.message_keys) pytest.assume(res_dict["role"] == "assistant") - # Get a message. - - get_res = get_message(assistant_id=assistant_id, chat_id=chat_id, message_id=res_dict["message_id"]) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.message_keys) - - # List messages. - - new_res = list_messages(assistant_id=assistant_id, chat_id=chat_id) - new_nums = len(new_res) - pytest.assume(new_nums == old_nums + 1) - @pytest.mark.run(order=30) def test_generate_message_by_stream(self): @@ -391,261 +306,3 @@ def test_generate_message_by_stream(self): logger.info(f"except_list: {except_list} real_list: {real_list}") pytest.assume(set(except_list) == set(real_list)) - @pytest.mark.run(order=30) - @pytest.mark.test_abnormal - def test_generate_message_in_user_message_not_created(self, assistant_id): - - # create chat - - chat_res = create_chat(assistant_id=assistant_id) - chat_id = chat_res.chat_id - logger.info(f'chat_id:{chat_id}') - - # Generate an assistant message. - - try: - res = generate_message(assistant_id=assistant_id, chat_id=chat_id, - system_prompt_variables={}) - except Exception as e: - logger.info(f'test_generate_message_in_user_message_not_created{e}') - pytest.assume("There is no user message in the chat context." in str(e)) - - @pytest.mark.run(order=30) - @pytest.mark.test_abnormal - def test_create_message_in_generating_assistant_message(self, assistant_id): - - # create chat - - chat_res = create_chat(assistant_id=assistant_id) - chat_id = chat_res.chat_id - logger.info(f'chat_id:{chat_id}') - - # create user message - - user_message = create_message( - assistant_id=assistant_id, - chat_id=chat_id, - text="count from 1 to 100 and separate numbers by comma.", - ) - - # Generate an assistant message by stream. - - stream_res = generate_message(assistant_id=assistant_id, chat_id=chat_id, - system_prompt_variables={}, stream=True) - - # create user message - - try: - user_message = create_message( - assistant_id=assistant_id, - chat_id=chat_id, - text="count from 100 to 200 and separate numbers by comma.", - ) - except Exception as e: - logger.info(f'test_create_message_in_generating_assistant_message{user_message}') - pytest.assume("Chat is locked by another generation process. Please try again later." in str(e)) - - @pytest.mark.run(order=30) - @pytest.mark.test_abnormal - def test_generate_message_in_generating_assistant_message(self, assistant_id): - - # create chat - - chat_res = create_chat(assistant_id=assistant_id) - chat_id = chat_res.chat_id - logger.info(f'chat_id:{chat_id}') - - # create user message - - user_message = create_message( - assistant_id=assistant_id, - chat_id=chat_id, - text="count from 1 to 100 and separate numbers by comma.", - ) - - # Generate an assistant message by stream. - - stream_res = generate_message(assistant_id=assistant_id, chat_id=chat_id, - system_prompt_variables={}, - stream=True) - - # Generate an assistant message by stream. - - try: - stream_res = generate_message(assistant_id=assistant_id, chat_id=chat_id, - system_prompt_variables={}, - stream=True) - except Exception as e: - logger.info(f'est_generate_message_in_generating_assistant_message{stream_res}') - pytest.assume("Chat is locked by another generation process. Please try again later." in str(e)) - - @pytest.mark.run(order=30) - @pytest.mark.test_abnormal - def test_generate_message_in_generated_assistant_message(self, assistant_id): - - # create chat - - chat_res = create_chat(assistant_id=assistant_id) - chat_id = chat_res.chat_id - logger.info(f'chat_id:{chat_id}') - - # create user message - - user_message = create_message( - assistant_id=assistant_id, - chat_id=chat_id, - text="count from 1 to 100 and separate numbers by comma.", - ) - - # Generate an assistant message by stream. - - res = generate_message(assistant_id=assistant_id, chat_id=chat_id, - system_prompt_variables={}) - - # Generate an assistant message by stream. - - try: - stream_res = generate_message(assistant_id=assistant_id, chat_id=chat_id, - system_prompt_variables={}, - stream=True) - except Exception as e: - logger.info(f'test_generate_message_in_generated_assistant_message{e}') - pytest.assume("Cannot generate another assistant message after an assistant message." in str(e)) - - @pytest.mark.run(order=30) - @pytest.mark.test_abnormal - def test_generate_message_in_action_deleted_assistant(self): - - # create action - - schema = { - "openapi": "3.1.0", - "info": { - "title": "Get weather data", - "description": "Retrieves current weather data for a location.", - "version": "v1.0.0" - }, - "servers": [ - { - "url": "https://weather.example.com" - } - ], - "paths": { - "/location": { - "get": { - "description": "Get temperature for a specific location", - "operationId": "GetCurrentWeather", - "parameters": [ - { - "name": "location", - "in": "query", - "description": "The city and state to retrieve the weather for", - "required": True, - "schema": { - "type": "string" - } - } - ], - "deprecated": False - }, - "post": { - "description": "UPDATE temperature for a specific location", - "operationId": "UpdateCurrentWeather", - "requestBody": { - "required": True, - "content": { - "application/json": { - "schema": { - "$ref": "#/componeents/schemas/ActionCreateRequest" - } - } - } - }, - "deprecated": False - } - } - }, - "components": { - "schemas": {} - }, - "security": [] - } - action_res = bulk_create_actions(schema=schema) - action_id = action_res[0].action_id - - # create an assistant - - assistant_res = create_assistant(name="test", description="test", model_id=chat_model_id, tools=[{"type": "action", "id": action_id}]) - assistant_id = assistant_res.assistant_id - - # create a chat - - chat_res = create_chat(assistant_id=assistant_id) - chat_id = chat_res.chat_id - - # create user message - - user_message = create_message( - assistant_id=assistant_id, - chat_id=chat_id, - text="count from 1 to 100 and separate numbers by comma.", - ) - - # delete action - - delete_action(action_id=action_id) - time.sleep(sleep_time) - - # Generate an assistant message by stream. - - try: - res = generate_message(assistant_id=assistant_id, chat_id=chat_id, - system_prompt_variables={}) - except Exception as e: - logger.info(f'test_generate_message_in_action_deleted_assistant{e}') - pytest.assume("Some tools are not found" in str(e)) - - @pytest.mark.run(order=30) - @pytest.mark.test_abnormal - def test_generate_message_in_collection_deleted_assistant(self): - - # create collection - - collection_res = create_collection(name="test", description="test", embedding_model_id=text_model_id, capacity=1000) - collection_id = collection_res.collection_id - - # create an assistant - - assistant_res = create_assistant(name="test", description="test", model_id=chat_model_id, retrievals=[{"type": "collection", "id": collection_id}]) - assistant_id = assistant_res.assistant_id - - # create chat - - chat_res = create_chat(assistant_id=assistant_id) - chat_id = chat_res.chat_id - - # create user message - - user_message = create_message( - assistant_id=assistant_id, - chat_id=chat_id, - text="count from 1 to 1000 and separate numbers by comma.", - ) - - # delete collection - - delete_collection(collection_id) - time.sleep(sleep_time) - - # Generate an assistant message by stream. - - try: - res = generate_message(assistant_id=assistant_id, chat_id=chat_id, - system_prompt_variables={}) - except Exception as e: - logger.info(f'test_generate_message_in_collection_deleted_assistant{e}') - pytest.assume(f"Collections not found" in str(e)) - - - - diff --git a/test/testcase/test_sync/test_sync_inference.py b/test/testcase/test_sync/test_sync_inference.py index 9b09eb9..9b07021 100644 --- a/test/testcase/test_sync/test_sync_inference.py +++ b/test/testcase/test_sync/test_sync_inference.py @@ -1,7 +1,7 @@ import pytest from taskingai.inference import * -from test.config import text_model_id, chat_model_id +from test.config import embedding_model_id, chat_completion_model_id from test.common.logger import logger @@ -14,7 +14,7 @@ def test_chat_completion(self): # normal chat completion. normal_res = chat_completion( - model_id=chat_model_id, + model_id=chat_completion_model_id, messages=[ SystemMessage("You are a professional assistant."), UserMessage("Hi"), @@ -23,12 +23,12 @@ def test_chat_completion(self): pytest.assume(normal_res.finish_reason == "stop") pytest.assume(normal_res.message.content) pytest.assume(normal_res.message.role == "assistant") - pytest.assume(normal_res.message.function_call is None) + pytest.assume(normal_res.message.function_calls is None) # multi round chat completion. multi_round_res = chat_completion( - model_id=chat_model_id, + model_id=chat_completion_model_id, messages=[ SystemMessage("You are a professional assistant."), UserMessage("Hi"), @@ -43,12 +43,12 @@ def test_chat_completion(self): pytest.assume(multi_round_res.finish_reason == "stop") pytest.assume(multi_round_res.message.content) pytest.assume(multi_round_res.message.role == "assistant") - pytest.assume(multi_round_res.message.function_call is None) + pytest.assume(multi_round_res.message.function_calls is None) # config max tokens chat completion. max_tokens_res = chat_completion( - model_id=chat_model_id, + model_id=chat_completion_model_id, messages=[ SystemMessage("You are a professional assistant."), UserMessage("Hi"), @@ -65,11 +65,11 @@ def test_chat_completion(self): pytest.assume(max_tokens_res.finish_reason == "length") pytest.assume(max_tokens_res.message.content) pytest.assume(max_tokens_res.message.role == "assistant") - pytest.assume(max_tokens_res.message.function_call is None) + pytest.assume(max_tokens_res.message.function_calls is None) # chat completion with stream. - stream_res = chat_completion(model_id=chat_model_id, + stream_res = chat_completion(model_id=chat_completion_model_id, messages=[ SystemMessage("You are a professional assistant."), UserMessage("count from 1 to 50 and separate numbers by comma."), @@ -99,7 +99,7 @@ def test_text_embedding(self): # Text embedding with str. input_str = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." - str_res = text_embedding(model_id=text_model_id, input=input_str) + str_res = text_embedding(model_id=embedding_model_id, input=input_str) pytest.assume(len(str_res) > 0) for score in str_res: pytest.assume(float(-1) <= score <= float(1)) @@ -108,7 +108,7 @@ def test_text_embedding(self): input_list = ["hello", "world"] input_list_length = len(input_list) - list_res = text_embedding(model_id=text_model_id, input=input_list) + list_res = text_embedding(model_id=embedding_model_id, input=input_list) pytest.assume(len(list_res) == input_list_length) for str_res in list_res: pytest.assume(len(str_res) > 0) diff --git a/test/testcase/test_sync/test_sync_retrieval.py b/test/testcase/test_sync/test_sync_retrieval.py index 759fcf8..cc6035a 100644 --- a/test/testcase/test_sync/test_sync_retrieval.py +++ b/test/testcase/test_sync/test_sync_retrieval.py @@ -2,7 +2,7 @@ import pytest from taskingai.retrieval import list_collections, create_collection, get_collection, update_collection, delete_collection, list_records, create_text_record, get_record, update_record, delete_record, query_chunks -from test.config import text_model_id, sleep_time +from test.config import embedding_model_id, sleep_time from test.common.logger import logger @@ -18,54 +18,28 @@ class TestCollection: @pytest.mark.run(order=9) def test_create_collection(self): - # List collections. - - old_res = list_collections(limit=100) - old_nums = len(old_res) - # Create a collection. - name = "test" description = "just for test" - for x in range(4): - res = create_collection(name=name, description=description, embedding_model_id=text_model_id, capacity=1000) + for x in range(2): + res = create_collection(name=name, description=description, embedding_model_id=embedding_model_id, capacity=1000) res_dict = res.to_dict() logger.info(res_dict) pytest.assume(res_dict.keys() == self.collection_keys) pytest.assume(res_dict["configs"].keys() == self.collection_configs_keys) pytest.assume(res_dict["name"] == name) pytest.assume(res_dict["description"] == description) - pytest.assume(res_dict["embedding_model_id"] == text_model_id) + pytest.assume(res_dict["embedding_model_id"] == embedding_model_id) pytest.assume(res_dict["capacity"] == 1000) pytest.assume(res_dict["status"] == "creating") - # Get a collection. - - time.sleep(sleep_time) - collection_id = res_dict["collection_id"] - get_res = get_collection(collection_id=collection_id) - get_res_dict = get_res.to_dict() - logger.info(get_res_dict) - pytest.assume(get_res_dict.keys() == self.collection_keys) - pytest.assume(get_res_dict["configs"].keys() == self.collection_configs_keys) - pytest.assume(get_res_dict["name"] == name) - pytest.assume(get_res_dict["description"] == description) - pytest.assume(get_res_dict["embedding_model_id"] == text_model_id) - pytest.assume(get_res_dict["capacity"] == 1000) - pytest.assume(get_res_dict["status"] == "ready") - - # List collections. - - new_res = list_collections(limit=100) - new_nums = len(new_res) - pytest.assume(new_nums == old_nums + 1 + x) @pytest.mark.run(order=10) def test_list_collections(self): # List collections. - nums_limit = 2 + nums_limit = 1 res = list_collections(limit=nums_limit) pytest.assume(len(res) == nums_limit) after_id = res[-1].collection_id @@ -107,18 +81,7 @@ def test_update_collection(self, collection_id): pytest.assume(res_dict["description"] == description) pytest.assume(res_dict["status"] == "ready") - # Get a collection. - - time.sleep(sleep_time) - get_res = get_collection(collection_id=collection_id) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.collection_keys) - pytest.assume(get_res_dict["configs"].keys() == self.collection_configs_keys) - pytest.assume(get_res_dict["name"] == name) - pytest.assume(get_res_dict["description"] == description) - pytest.assume(get_res_dict["status"] == "ready") - - @pytest.mark.run(order=34) + @pytest.mark.run(order=35) def test_delete_collection(self): # List collections. @@ -132,7 +95,7 @@ def test_delete_collection(self): # Delete a collection. delete_collection(collection_id=collection_id) - time.sleep(sleep_time) + new_collections = list_collections(order="desc", limit=100, after=None, before=None) # List collections. @@ -155,15 +118,10 @@ class TestRecord: @pytest.mark.run(order=13) def test_create_text_record(self, collection_id): - # List records. - - old_res = list_records(collection_id=collection_id) - old_nums = len(old_res) - # Create a text record. text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." - for x in range(4): + for x in range(2): res = create_text_record(collection_id=collection_id, text=text) res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.record_keys) @@ -171,30 +129,12 @@ def test_create_text_record(self, collection_id): pytest.assume(res_dict["content"]["text"] == text) pytest.assume(res_dict["status"] == "creating") - # Get a record. - - time.sleep(sleep_time*25) - record_id = res_dict["record_id"] - get_res = get_record(collection_id=collection_id, record_id=record_id) - logger.info(f'get record response: {get_res}') - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.record_keys) - pytest.assume(get_res_dict["content"].keys() == self.record_content_keys) - pytest.assume(get_res_dict["content"]["text"] == text) - pytest.assume(get_res_dict["status"] == "ready") - - # List records. - - new_res = list_records(collection_id=collection_id) - new_nums = len(new_res) - pytest.assume(new_nums == old_nums + 1 + x) - @pytest.mark.run(order=14) def test_list_records(self, collection_id): # List records. - nums_limit = 2 + nums_limit = 1 res = list_records(limit=nums_limit, collection_id=collection_id) pytest.assume(len(res) == nums_limit) @@ -218,7 +158,6 @@ def test_get_record(self, collection_id): # list records - time.sleep(sleep_time*25) records = list_records(collection_id=collection_id) for record in records: record_id = record.record_id @@ -227,7 +166,7 @@ def test_get_record(self, collection_id): res_dict = res.to_dict() pytest.assume(res_dict.keys() == self.record_keys) pytest.assume(res_dict["content"].keys() == self.record_content_keys) - pytest.assume(res_dict["status"] == "ready") + pytest.assume(res_dict["status"] == "creating" or "ready") @pytest.mark.run(order=16) def test_update_record(self, collection_id, record_id): @@ -241,78 +180,8 @@ def test_update_record(self, collection_id, record_id): pytest.assume(res_dict["content"].keys() == self.record_content_keys) pytest.assume(res_dict["metadata"] == metadata) - # Get a record. - - time.sleep(sleep_time*25) - get_res = get_record(collection_id=collection_id, record_id=record_id) - logger.info(f'get record response: {get_res}') - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.record_keys) - pytest.assume(get_res_dict["content"].keys() == self.record_content_keys) - pytest.assume(get_res_dict["metadata"] == metadata) - pytest.assume(get_res_dict["status"] == "ready") - - @pytest.mark.run(order=17) - @pytest.mark.test_abnormal - def test_create_record_in_nonexistent_collection(self): - - # Create collection. - - collection_id = "nonexistent_collection_id" - - # Create a record. - - text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." - try: - res = create_text_record(collection_id=collection_id, text=text) - except Exception as e: - logger.info(f'test_create_record_in_creating_collection:{e}') - pytest.assume(f"Collection not found: {collection_id}" in str(e)) - - @pytest.mark.run(order=17) - @pytest.mark.test_abnormal - def test_create_record_in_creating_collection(self): - - # Create collection. - - name = "test" - description = "just for test" - res = create_collection(name=name, description=description, embedding_model_id=text_model_id, - capacity=1000) - collection_id = res.collection_id - - # Create a record. - text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." - try: - res = create_text_record(collection_id=collection_id, text=text) - except Exception as e: - logger.info(f'test_create_record_in_creating_collection:{e}') - pytest.assume(f"Collection {collection_id} is not ready." in str(e)) - - @pytest.mark.run(order=17) - @pytest.mark.test_abnormal - def test_create_record_in_deleting_collection(self): - - # Create collection. - - name = "test" - description = "just for test" - res = create_collection(name=name, description=description, embedding_model_id=text_model_id, - capacity=1000) - collection_id = res.collection_id - delete_collection(collection_id=collection_id) - - # Create a record. - - text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data." - try: - res = create_text_record(collection_id=collection_id, text=text) - except Exception as e: - logger.info(f'test_create_record_in_creating_collection:{e}') - pytest.assume(f"Collection not found: {collection_id}" in str(e)) - - @pytest.mark.run(order=33) + @pytest.mark.run(order=34) def test_delete_record(self, collection_id): # List records. @@ -358,64 +227,3 @@ def test_query_chunks(self, collection_id): pytest.assume(query_text in chunk_dict["text"]) pytest.assume(chunk_dict["score"] >= 0) - @pytest.mark.run(order=17) - @pytest.mark.test_abnormal - def test_query_chunks_in_creating_collection(self): - - # Create collection. - - name = "test" - description = "just for test" - res = create_collection(name=name, description=description, embedding_model_id=text_model_id, - capacity=1000) - collection_id = res.collection_id - - # Query chunks - - query_text = "Machine learning" - top_k = 1 - try: - res = query_chunks(collection_id=collection_id, query_text=query_text, top_k=top_k) - except Exception as e: - logger.info(f'test_query_chunks_in_creating_collection:{e}') - pytest.assume(f"Collection {collection_id} is not ready." in str(e)) - - @pytest.mark.run(order=17) - @pytest.mark.test_abnormal - def test_query_chunks_in_deleting_collection(self): - - # Create collection. - - name = "test" - description = "just for test" - collection_res = create_collection(name=name, description=description, embedding_model_id=text_model_id, - capacity=1000) - collection_id = collection_res.collection_id - - # delete collection - - delete_collection(collection_id=collection_id) - - # Query chunks - - query_text = "Machine learning" - top_k = 1 - try: - res = query_chunks(collection_id=collection_id, query_text=query_text, top_k=top_k) - except Exception as e: - logger.info(f'test_query_chunks_in_deleting_collection:{e}') - pytest.assume("Collections not found" in str(e)) - - @pytest.mark.run(order=17) - @pytest.mark.test_abnormal - def test_query_chunks_in_nonexistent_collection(self): - - # Query chunks - - query_text = "Machine learning" - top_k = 1 - try: - res = query_chunks(collection_id="nonexistent_collection_id", query_text=query_text, top_k=top_k) - except Exception as e: - logger.info(f'test_query_chunks_in_nonexistent_collection:{e}') - pytest.assume('Collections not found' in str(e)) diff --git a/test/testcase/test_sync/test_sync_tool.py b/test/testcase/test_sync/test_sync_tool.py index eae9833..dce1db4 100644 --- a/test/testcase/test_sync/test_sync_tool.py +++ b/test/testcase/test_sync/test_sync_tool.py @@ -67,59 +67,26 @@ class TestAction: @pytest.mark.run(order=4) def test_bulk_create_actions(self): - # List actions. - - old_res = list_actions(limit=100) - old_nums = len(old_res) - for x in range(2): - - # Create an action. - - res = bulk_create_actions(schema=self.schema) - for action in res: - action_dict = action.to_dict() - logger.info(action_dict) - pytest.assume(action_dict.keys() == self.action_keys) - pytest.assume(action_dict["schema"].keys() == self.action_schema_keys) - - for key in action_dict["schema"].keys(): - if key == "paths": - if action_dict["schema"][key]["/location"] == "get": - pytest.assume( - action_dict["schema"][key]["/location"]["get"] == self.schema["paths"]["/location"]["get"]) - elif action_dict["schema"][key]["/location"] == "post": - pytest.assume( - action_dict["schema"][key]["/location"]["post"] == self.schema["paths"]["/location"][ - "post"]) - else: - pytest.assume(action_dict["schema"][key] == self.schema[key]) - - # Get an action. - - action_id = action_dict["action_id"] - get_res = get_action(action_id=action_id) - get_res_dict = get_res.to_dict() - pytest.assume(get_res_dict.keys() == self.action_keys) - pytest.assume(get_res_dict["schema"].keys() == self.action_schema_keys) - - for key in action_dict["schema"].keys(): - if key == "paths": - if action_dict["schema"][key]["/location"] == "get": - pytest.assume( - action_dict["schema"][key]["/location"]["get"] == self.schema["paths"]["/location"]["get"]) - elif action_dict["schema"][key]["/location"] == "post": - pytest.assume( - action_dict["schema"][key]["/location"]["post"] == self.schema["paths"]["/location"][ - "post"]) - else: - pytest.assume(action_dict["schema"][key] == self.schema[key]) - - # List actions. - - new_res = list_actions(limit=100) - new_nums = len(new_res) - res_num = len(res) - pytest.assume(new_nums == old_nums + res_num + 2*x) + # Create an action. + + res = bulk_create_actions(schema=self.schema) + for action in res: + action_dict = action.to_dict() + logger.info(action_dict) + pytest.assume(action_dict.keys() == self.action_keys) + pytest.assume(action_dict["schema"].keys() == self.action_schema_keys) + + for key in action_dict["schema"].keys(): + if key == "paths": + if action_dict["schema"][key]["/location"] == "get": + pytest.assume( + action_dict["schema"][key]["/location"]["get"] == self.schema["paths"]["/location"]["get"]) + elif action_dict["schema"][key]["/location"] == "post": + pytest.assume( + action_dict["schema"][key]["/location"]["post"] == self.schema["paths"]["/location"][ + "post"]) + else: + pytest.assume(action_dict["schema"][key] == self.schema[key]) @pytest.mark.run(order=5) def test_run_action(self, action_id): @@ -139,7 +106,7 @@ def test_list_actions(self): # List actions. - nums_limit = 2 + nums_limit = 1 res = list_actions(limit=nums_limit) logger.info(res) pytest.assume(len(res) == nums_limit) @@ -223,16 +190,6 @@ def test_update_action(self, action_id): pytest.assume(res_dict["schema"].keys() == self.action_schema_keys) pytest.assume(res_dict["schema"] == update_schema) - # Get an action. - - get_res = get_action(action_id=action_id) - get_res_dict = get_res.to_dict() - logger.info(get_res_dict) - pytest.assume(get_res_dict.keys() == self.action_keys) - - pytest.assume(get_res_dict["schema"].keys() == self.action_schema_keys) - pytest.assume(get_res_dict["schema"] == update_schema) - @pytest.mark.run(order=40) def test_delete_action(self):