2025-07-16 15:34:54 +08:00

314 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from flask import Flask, request, render_template, send_file, jsonify, redirect, url_for, session
from docx import Document
from docx.shared import Inches
import openai
import io
import json
from config import config
import time
from functools import wraps
try:
from volcengine.maas import MaasService, MaasException
except ImportError:
MaasService = None
MaasException = None
try:
from dashscope import Generation
except ImportError:
Generation = None
app = Flask(__name__)
app.secret_key = config.FLASK_CONFIG['secret_key']
def login_required(f):
# 登录态装饰器,未登录跳转到登录页
@wraps(f)
def decorated_function(*args, **kwargs):
if not session.get('logged_in'):
return redirect(url_for('login'))
return f(*args, **kwargs)
return decorated_function
def generate_with_openai(prompt, api_key, model, max_tokens=800, temperature=0.7):
"""使用OpenAI生成内容"""
client = openai.OpenAI(api_key=api_key)
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=temperature
)
return response.choices[0].message.content
def generate_with_volcengine(prompt, api_key, model, max_tokens=800, temperature=0.7):
"""使用火山引擎豆包生成内容"""
if MaasService is None:
raise ImportError("火山引擎SDK未安装请运行: pip install volcengine")
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', api_key)
messages = [{"role": "user", "content": prompt}]
response = maas.chat(messages, model=model, max_tokens=max_tokens)
return response['choices'][0]['message']['content']
def generate_with_dashscope(prompt, api_key, model, max_tokens=800, temperature=0.7):
"""使用阿里通义千问生成内容"""
if Generation is None:
return "阿里云DashScope SDK未安装请运行: pip install dashscope"
try:
response = Generation.call(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
api_key=api_key
)
print("阿里云原始返回:", response)
code = getattr(response, 'code', None)
msg = getattr(response, 'message', None)
if not response:
return "阿里云API无响应请检查网络或API Key。"
if code and code != 200:
return f"阿里云 API 错误: {msg or str(response)} (code={code})"
# 兼容 output.choices 和 output.text
if hasattr(response, 'output'):
if hasattr(response.output, 'choices') and response.output.choices:
return response.output.choices[0].message.content
elif hasattr(response.output, 'text') and response.output.text:
return response.output.text
else:
return f"阿里云API返回异常output字段无choices或text原始返回{str(response)}"
else:
return f"阿里云API返回异常无output字段原始返回{str(response)}"
except Exception as e:
import traceback
print("阿里云API调用异常:", e)
traceback.print_exc()
return f"阿里云API调用异常: {str(e)}"
def generate_news_content(description, style, model_type, topics=None, image=None):
# topics: list of str, 最多3个
topic_str = ''
if topics:
topics = [t for t in topics if t]
if topics:
topic_str = (
"主题报道部分请围绕下列每个话题分别成段深入分析,每个话题都要有具体内容:"
f"{''.join(topics)}\n"
)
image_str = ''
if image is not None and getattr(image, 'filename', None):
if image.filename:
image_str = "如有上传的新闻图片,请结合图片内容进行适当描述。\n"
prompt = (
f"你是一名专业新闻记者,请以{style}风格,围绕下述新闻事件,撰写一篇结构完整、语言流畅、符合新闻报道规范的新闻稿。\n"
f"结构要求:\n"
f"1. 导语:简要概括新闻事件。\n"
f"2. 品牌介绍:如涉及品牌,请介绍品牌的背景、定位、优势或发展愿景。\n"
f"3. 主题报道:{topic_str}每个话题单独成段。\n"
f"4. 结语:总结新闻意义或展望未来。\n"
f"其它要求:\n"
f"- 内容真实、客观,突出事件的背景、经过、影响。\n"
f"- 字数不少于500字。\n"
f"- {image_str}"
f"\n新闻事件描述:\n{description}\n"
)
model_config = config.get_model_config(model_type)
if not model_config:
return "不支持的模型类型", {}
api_key = model_config['api_key']
model = model_config['model']
max_tokens = model_config['max_tokens']
temperature = model_config['temperature']
api_info = {}
start_time = time.time()
try:
if model_type == 'openai':
client = openai.OpenAI(api_key=api_key)
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=temperature
)
content = response.choices[0].message.content
usage = getattr(response, 'usage', None)
if usage:
api_info['提示tokens'] = usage.prompt_tokens
api_info['生成tokens'] = usage.completion_tokens
api_info['总tokens'] = usage.total_tokens
# 费用估算OpenAI gpt-4 价格为例,实际可根据模型动态调整)
# gpt-4-1106-preview: $0.01/1K prompt, $0.03/1K completion
price_prompt = 0.01 / 1000
price_completion = 0.03 / 1000
fee = usage.prompt_tokens * price_prompt + usage.completion_tokens * price_completion
api_info['费用'] = f"${fee:.4f}"
else:
api_info['费用'] = '--'
api_info['模型'] = model
elif model_type == 'volcengine':
if MaasService is None:
return "火山引擎SDK未安装请运行: pip install volcengine", {}
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', api_key)
messages = [{"role": "user", "content": prompt}]
response = maas.chat(messages, model=model, max_tokens=max_tokens)
content = response['choices'][0]['message']['content']
if 'usage' in response:
api_info['提示tokens'] = response['usage'].get('prompt_tokens')
api_info['生成tokens'] = response['usage'].get('completion_tokens')
api_info['总tokens'] = response['usage'].get('total_tokens')
# 费用估算假设0.005元/1K token实际需查官方文档
price = 0.005 / 1000
total = response['usage'].get('total_tokens') or 0
fee = total * price
api_info['费用'] = f"¥{fee:.4f}"
else:
api_info['费用'] = '--'
api_info['模型'] = model
elif model_type == 'dashscope':
if Generation is None:
return "阿里云DashScope SDK未安装请运行: pip install dashscope", {}
response = Generation.call(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
api_key=api_key
)
# 兼容 output.choices 和 output.text
if hasattr(response, 'output'):
if hasattr(response.output, 'choices') and response.output.choices:
content = response.output.choices[0].message.content
elif hasattr(response.output, 'text') and response.output.text:
content = response.output.text
else:
content = f"阿里云API返回异常output字段无choices或text原始返回{str(response)}"
else:
content = f"阿里云API返回异常无output字段原始返回{str(response)}"
# tokens统计
if hasattr(response, 'usage'):
usage = response.usage
api_info['提示tokens'] = getattr(usage, 'input_tokens', None) or getattr(usage, 'prompt_tokens', None)
api_info['生成tokens'] = getattr(usage, 'output_tokens', None) or getattr(usage, 'completion_tokens', None)
api_info['总tokens'] = getattr(usage, 'total_tokens', None)
# 费用估算假设0.012元/1K token实际需查官方文档
price = 0.012 / 1000
total = getattr(usage, 'total_tokens', None) or 0
fee = total * price
api_info['费用'] = f"¥{fee:.4f}"
else:
api_info['费用'] = '--'
api_info['模型'] = model
else:
return f"不支持的模型类型: {model_type}", {}
except Exception as e:
content = f"生成内容时出错: {str(e)}"
api_info['响应时间'] = f"{(time.time() - start_time):.2f}s"
return content, api_info
def create_word_doc(news_content, image_stream, image_filename):
doc = Document()
doc.add_heading(config.DOCUMENT_CONFIG['title'], 0)
if image_stream:
doc.add_picture(image_stream, width=Inches(config.DOCUMENT_CONFIG['image_width']))
doc.add_paragraph(news_content)
output = io.BytesIO()
doc.save(output)
output.seek(0)
return output
def get_valid_models(models):
"""过滤掉结构不完整的模型配置"""
return {
k: v for k, v in models.items()
if v.get('name') and v.get('model')
}
@app.route('/login', methods=['GET', 'POST'])
def login():
# 登录页面支持GET和POST
if request.method == 'POST':
username = request.form.get('username')
password = request.form.get('password')
# 这里简单写死用户名密码,实际可接数据库
if username == 'admin' and password == '123456':
session['logged_in'] = True
return redirect(url_for('index'))
else:
return render_template('login.html', error='用户名或密码错误')
return render_template('login.html')
@app.route('/logout')
def logout():
# 注销登录
session.pop('logged_in', None)
return redirect(url_for('login'))
@app.route('/', methods=['GET', 'POST'])
@login_required
def index():
news_content = ''
selected_model = request.form.get('model', 'dashscope')
if request.method == 'POST':
action = request.form.get('action')
if action == 'generate':
description = request.form['description']
style = request.form['style']
model_type = request.form['model']
image = request.files.get('image')
topics = [request.form.get('topic1', '').strip(), request.form.get('topic2', '').strip(), request.form.get('topic3', '').strip()]
news_content, api_info = generate_news_content(description, style, model_type, topics, image)
selected_model = model_type
# 生成后渲染页面右侧显示news_content
return render_template(
'index.html',
styles=config.STYLES,
models=get_valid_models(config.MODELS),
api_key_status=config.validate_api_keys(),
news_content=news_content,
selected_model=selected_model,
api_info=api_info
)
elif action == 'download':
news_content = request.form.get('news_content', '')
image = request.files.get('image')
word_file = create_word_doc(news_content, image.stream if image else None, image.filename if image else None)
return send_file(word_file, as_attachment=True, download_name=config.DOCUMENT_CONFIG['default_filename'])
# GET 或首次访问
return render_template(
'index.html',
styles=config.STYLES,
models=get_valid_models(config.MODELS),
api_key_status=config.validate_api_keys(),
news_content=news_content,
selected_model=selected_model
)
@app.route('/api/config', methods=['GET'])
@login_required
def get_config():
"""获取配置信息API"""
return jsonify({
'models': config.MODELS,
'styles': config.STYLES,
'available_models': config.get_available_models(),
'api_key_status': config.validate_api_keys()
})
@app.route('/api/generate', methods=['POST'])
@login_required
def api_generate():
description = request.form['description']
style = request.form['style']
model_type = request.form['model']
topics = [request.form.get('topic1', '').strip(), request.form.get('topic2', '').strip(), request.form.get('topic3', '').strip()]
image = request.files.get('image')
news_content, api_info = generate_news_content(description, style, model_type, topics, image)
return jsonify({'news_content': news_content, 'api_info': api_info})
if __name__ == '__main__':
app.run(
debug=config.FLASK_CONFIG['debug'],
host=config.FLASK_CONFIG['host'],
port=config.FLASK_CONFIG['port']
)