import unittest from unittest.mock import patch, MagicMock import demo import os class TestDemo(unittest.TestCase): def setUp(self): # Mock environment variable self.patcher = patch.dict(os.environ, {'VOLCANO_DS__API_KEY': 'test_key'}) self.patcher.start() def tearDown(self): self.patcher.stop() def test_get_prompt_template(self): # Test standard template template = demo.get_prompt_template("标准模板") self.assertIn("一、产品信息:", template) self.assertIn("二、文案结构:", template) # Test story template template = demo.get_prompt_template("故事模板") self.assertIn("讲述用户故事", template) # Test invalid template template = demo.get_prompt_template("invalid") self.assertIn("一、产品信息:", template) # Should return default template def test_generate_prompt(self): # Test prompt generation prompt = demo.generate_prompt( product_name="Test Product", product_desc="Test Description", product_features="Test Features", writing_style="专业正式", template="标准模板" ) self.assertIn("Test Product", prompt) self.assertIn("专业正式", prompt) self.assertIn("400-600字", prompt) @patch('demo.llm') def test_query_deepseek(self, mock_llm): # Mock LLM response mock_response = MagicMock() mock_response.content = "Mocked response" mock_llm.invoke.return_value = mock_response # Test successful query prompt, response = demo.query_deepseek( product_name="Test Product", product_desc="Test Description", product_features="Test Features", writing_style="专业正式", template="标准模板" ) self.assertIn("Test Product", prompt) self.assertEqual(response, "Mocked response") # Test error handling mock_llm.invoke.side_effect = Exception("Test error") prompt, response = demo.query_deepseek( product_name="Test Product", product_desc="Test Description", product_features="Test Features", writing_style="专业正式", template="标准模板" ) self.assertIn("Test Product", prompt) self.assertIn("Error: Test error", response) @patch('demo.gr.Interface') def test_interface_creation(self, mock_interface): # Test Gradio interface creation demo.interface mock_interface.assert_called_once() if __name__ == '__main__': unittest.main()