81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
import unittest
|
|
from unittest.mock import patch, MagicMock
|
|
import demo
|
|
import os
|
|
|
|
class TestDemo(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
# Mock environment variable
|
|
self.patcher = patch.dict(os.environ, {'VOLCANO_DS__API_KEY': 'test_key'})
|
|
self.patcher.start()
|
|
|
|
def tearDown(self):
|
|
self.patcher.stop()
|
|
|
|
def test_get_prompt_template(self):
|
|
# Test standard template
|
|
template = demo.get_prompt_template("标准模板")
|
|
self.assertIn("一、产品信息:", template)
|
|
self.assertIn("二、文案结构:", template)
|
|
|
|
# Test story template
|
|
template = demo.get_prompt_template("故事模板")
|
|
self.assertIn("讲述用户故事", template)
|
|
|
|
# Test invalid template
|
|
template = demo.get_prompt_template("invalid")
|
|
self.assertIn("一、产品信息:", template) # Should return default template
|
|
|
|
def test_generate_prompt(self):
|
|
# Test prompt generation
|
|
prompt = demo.generate_prompt(
|
|
product_name="Test Product",
|
|
product_desc="Test Description",
|
|
product_features="Test Features",
|
|
writing_style="专业正式",
|
|
template="标准模板"
|
|
)
|
|
self.assertIn("Test Product", prompt)
|
|
self.assertIn("专业正式", prompt)
|
|
self.assertIn("400-600字", prompt)
|
|
|
|
@patch('demo.llm')
|
|
def test_query_deepseek(self, mock_llm):
|
|
# Mock LLM response
|
|
mock_response = MagicMock()
|
|
mock_response.content = "Mocked response"
|
|
mock_llm.invoke.return_value = mock_response
|
|
|
|
# Test successful query
|
|
prompt, response = demo.query_deepseek(
|
|
product_name="Test Product",
|
|
product_desc="Test Description",
|
|
product_features="Test Features",
|
|
writing_style="专业正式",
|
|
template="标准模板"
|
|
)
|
|
self.assertIn("Test Product", prompt)
|
|
self.assertEqual(response, "Mocked response")
|
|
|
|
# Test error handling
|
|
mock_llm.invoke.side_effect = Exception("Test error")
|
|
prompt, response = demo.query_deepseek(
|
|
product_name="Test Product",
|
|
product_desc="Test Description",
|
|
product_features="Test Features",
|
|
writing_style="专业正式",
|
|
template="标准模板"
|
|
)
|
|
self.assertIn("Test Product", prompt)
|
|
self.assertIn("Error: Test error", response)
|
|
|
|
@patch('demo.gr.Interface')
|
|
def test_interface_creation(self, mock_interface):
|
|
# Test Gradio interface creation
|
|
demo.interface
|
|
mock_interface.assert_called_once()
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|