abo/test_demo.py

81 lines
2.7 KiB
Python
Raw Normal View History

2025-07-16 15:34:54 +08:00
import unittest
from unittest.mock import patch, MagicMock
import demo
import os
class TestDemo(unittest.TestCase):
def setUp(self):
# Mock environment variable
self.patcher = patch.dict(os.environ, {'VOLCANO_DS__API_KEY': 'test_key'})
self.patcher.start()
def tearDown(self):
self.patcher.stop()
def test_get_prompt_template(self):
# Test standard template
template = demo.get_prompt_template("标准模板")
self.assertIn("一、产品信息:", template)
self.assertIn("二、文案结构:", template)
# Test story template
template = demo.get_prompt_template("故事模板")
self.assertIn("讲述用户故事", template)
# Test invalid template
template = demo.get_prompt_template("invalid")
self.assertIn("一、产品信息:", template) # Should return default template
def test_generate_prompt(self):
# Test prompt generation
prompt = demo.generate_prompt(
product_name="Test Product",
product_desc="Test Description",
product_features="Test Features",
writing_style="专业正式",
template="标准模板"
)
self.assertIn("Test Product", prompt)
self.assertIn("专业正式", prompt)
self.assertIn("400-600字", prompt)
@patch('demo.llm')
def test_query_deepseek(self, mock_llm):
# Mock LLM response
mock_response = MagicMock()
mock_response.content = "Mocked response"
mock_llm.invoke.return_value = mock_response
# Test successful query
prompt, response = demo.query_deepseek(
product_name="Test Product",
product_desc="Test Description",
product_features="Test Features",
writing_style="专业正式",
template="标准模板"
)
self.assertIn("Test Product", prompt)
self.assertEqual(response, "Mocked response")
# Test error handling
mock_llm.invoke.side_effect = Exception("Test error")
prompt, response = demo.query_deepseek(
product_name="Test Product",
product_desc="Test Description",
product_features="Test Features",
writing_style="专业正式",
template="标准模板"
)
self.assertIn("Test Product", prompt)
self.assertIn("Error: Test error", response)
@patch('demo.gr.Interface')
def test_interface_creation(self, mock_interface):
# Test Gradio interface creation
demo.interface
mock_interface.assert_called_once()
if __name__ == '__main__':
unittest.main()