314 lines
13 KiB
Python
314 lines
13 KiB
Python
|
|
import os
|
|||
|
|
from flask import Flask, request, render_template, send_file, jsonify, redirect, url_for, session
|
|||
|
|
from docx import Document
|
|||
|
|
from docx.shared import Inches
|
|||
|
|
import openai
|
|||
|
|
import io
|
|||
|
|
import json
|
|||
|
|
from config import config
|
|||
|
|
import time
|
|||
|
|
from functools import wraps
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from volcengine.maas import MaasService, MaasException
|
|||
|
|
except ImportError:
|
|||
|
|
MaasService = None
|
|||
|
|
MaasException = None
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from dashscope import Generation
|
|||
|
|
except ImportError:
|
|||
|
|
Generation = None
|
|||
|
|
|
|||
|
|
app = Flask(__name__)
|
|||
|
|
app.secret_key = config.FLASK_CONFIG['secret_key']
|
|||
|
|
|
|||
|
|
def login_required(f):
|
|||
|
|
# 登录态装饰器,未登录跳转到登录页
|
|||
|
|
@wraps(f)
|
|||
|
|
def decorated_function(*args, **kwargs):
|
|||
|
|
if not session.get('logged_in'):
|
|||
|
|
return redirect(url_for('login'))
|
|||
|
|
return f(*args, **kwargs)
|
|||
|
|
return decorated_function
|
|||
|
|
|
|||
|
|
def generate_with_openai(prompt, api_key, model, max_tokens=800, temperature=0.7):
|
|||
|
|
"""使用OpenAI生成内容"""
|
|||
|
|
client = openai.OpenAI(api_key=api_key)
|
|||
|
|
response = client.chat.completions.create(
|
|||
|
|
model=model,
|
|||
|
|
messages=[{"role": "user", "content": prompt}],
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
temperature=temperature
|
|||
|
|
)
|
|||
|
|
return response.choices[0].message.content
|
|||
|
|
|
|||
|
|
def generate_with_volcengine(prompt, api_key, model, max_tokens=800, temperature=0.7):
|
|||
|
|
"""使用火山引擎豆包生成内容"""
|
|||
|
|
if MaasService is None:
|
|||
|
|
raise ImportError("火山引擎SDK未安装,请运行: pip install volcengine")
|
|||
|
|
|
|||
|
|
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', api_key)
|
|||
|
|
messages = [{"role": "user", "content": prompt}]
|
|||
|
|
response = maas.chat(messages, model=model, max_tokens=max_tokens)
|
|||
|
|
return response['choices'][0]['message']['content']
|
|||
|
|
|
|||
|
|
def generate_with_dashscope(prompt, api_key, model, max_tokens=800, temperature=0.7):
|
|||
|
|
"""使用阿里通义千问生成内容"""
|
|||
|
|
if Generation is None:
|
|||
|
|
return "阿里云DashScope SDK未安装,请运行: pip install dashscope"
|
|||
|
|
try:
|
|||
|
|
response = Generation.call(
|
|||
|
|
model=model,
|
|||
|
|
messages=[{"role": "user", "content": prompt}],
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
api_key=api_key
|
|||
|
|
)
|
|||
|
|
print("阿里云原始返回:", response)
|
|||
|
|
code = getattr(response, 'code', None)
|
|||
|
|
msg = getattr(response, 'message', None)
|
|||
|
|
if not response:
|
|||
|
|
return "阿里云API无响应,请检查网络或API Key。"
|
|||
|
|
if code and code != 200:
|
|||
|
|
return f"阿里云 API 错误: {msg or str(response)} (code={code})"
|
|||
|
|
# 兼容 output.choices 和 output.text
|
|||
|
|
if hasattr(response, 'output'):
|
|||
|
|
if hasattr(response.output, 'choices') and response.output.choices:
|
|||
|
|
return response.output.choices[0].message.content
|
|||
|
|
elif hasattr(response.output, 'text') and response.output.text:
|
|||
|
|
return response.output.text
|
|||
|
|
else:
|
|||
|
|
return f"阿里云API返回异常:output字段无choices或text,原始返回:{str(response)}"
|
|||
|
|
else:
|
|||
|
|
return f"阿里云API返回异常:无output字段,原始返回:{str(response)}"
|
|||
|
|
except Exception as e:
|
|||
|
|
import traceback
|
|||
|
|
print("阿里云API调用异常:", e)
|
|||
|
|
traceback.print_exc()
|
|||
|
|
return f"阿里云API调用异常: {str(e)}"
|
|||
|
|
|
|||
|
|
def generate_news_content(description, style, model_type, topics=None, image=None):
|
|||
|
|
# topics: list of str, 最多3个
|
|||
|
|
topic_str = ''
|
|||
|
|
if topics:
|
|||
|
|
topics = [t for t in topics if t]
|
|||
|
|
if topics:
|
|||
|
|
topic_str = (
|
|||
|
|
"主题报道部分请围绕下列每个话题分别成段深入分析,每个话题都要有具体内容:"
|
|||
|
|
f"{'、'.join(topics)}。\n"
|
|||
|
|
)
|
|||
|
|
image_str = ''
|
|||
|
|
if image is not None and getattr(image, 'filename', None):
|
|||
|
|
if image.filename:
|
|||
|
|
image_str = "如有上传的新闻图片,请结合图片内容进行适当描述。\n"
|
|||
|
|
prompt = (
|
|||
|
|
f"你是一名专业新闻记者,请以{style}风格,围绕下述新闻事件,撰写一篇结构完整、语言流畅、符合新闻报道规范的新闻稿。\n"
|
|||
|
|
f"结构要求:\n"
|
|||
|
|
f"1. 导语:简要概括新闻事件。\n"
|
|||
|
|
f"2. 品牌介绍:如涉及品牌,请介绍品牌的背景、定位、优势或发展愿景。\n"
|
|||
|
|
f"3. 主题报道:{topic_str}每个话题单独成段。\n"
|
|||
|
|
f"4. 结语:总结新闻意义或展望未来。\n"
|
|||
|
|
f"其它要求:\n"
|
|||
|
|
f"- 内容真实、客观,突出事件的背景、经过、影响。\n"
|
|||
|
|
f"- 字数不少于500字。\n"
|
|||
|
|
f"- {image_str}"
|
|||
|
|
f"\n新闻事件描述:\n{description}\n"
|
|||
|
|
)
|
|||
|
|
model_config = config.get_model_config(model_type)
|
|||
|
|
if not model_config:
|
|||
|
|
return "不支持的模型类型", {}
|
|||
|
|
api_key = model_config['api_key']
|
|||
|
|
model = model_config['model']
|
|||
|
|
max_tokens = model_config['max_tokens']
|
|||
|
|
temperature = model_config['temperature']
|
|||
|
|
api_info = {}
|
|||
|
|
start_time = time.time()
|
|||
|
|
try:
|
|||
|
|
if model_type == 'openai':
|
|||
|
|
client = openai.OpenAI(api_key=api_key)
|
|||
|
|
response = client.chat.completions.create(
|
|||
|
|
model=model,
|
|||
|
|
messages=[{"role": "user", "content": prompt}],
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
temperature=temperature
|
|||
|
|
)
|
|||
|
|
content = response.choices[0].message.content
|
|||
|
|
usage = getattr(response, 'usage', None)
|
|||
|
|
if usage:
|
|||
|
|
api_info['提示tokens'] = usage.prompt_tokens
|
|||
|
|
api_info['生成tokens'] = usage.completion_tokens
|
|||
|
|
api_info['总tokens'] = usage.total_tokens
|
|||
|
|
# 费用估算(OpenAI gpt-4 价格为例,实际可根据模型动态调整)
|
|||
|
|
# gpt-4-1106-preview: $0.01/1K prompt, $0.03/1K completion
|
|||
|
|
price_prompt = 0.01 / 1000
|
|||
|
|
price_completion = 0.03 / 1000
|
|||
|
|
fee = usage.prompt_tokens * price_prompt + usage.completion_tokens * price_completion
|
|||
|
|
api_info['费用'] = f"${fee:.4f}"
|
|||
|
|
else:
|
|||
|
|
api_info['费用'] = '--'
|
|||
|
|
api_info['模型'] = model
|
|||
|
|
elif model_type == 'volcengine':
|
|||
|
|
if MaasService is None:
|
|||
|
|
return "火山引擎SDK未安装,请运行: pip install volcengine", {}
|
|||
|
|
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', api_key)
|
|||
|
|
messages = [{"role": "user", "content": prompt}]
|
|||
|
|
response = maas.chat(messages, model=model, max_tokens=max_tokens)
|
|||
|
|
content = response['choices'][0]['message']['content']
|
|||
|
|
if 'usage' in response:
|
|||
|
|
api_info['提示tokens'] = response['usage'].get('prompt_tokens')
|
|||
|
|
api_info['生成tokens'] = response['usage'].get('completion_tokens')
|
|||
|
|
api_info['总tokens'] = response['usage'].get('total_tokens')
|
|||
|
|
# 费用估算(假设0.005元/1K token,实际需查官方文档)
|
|||
|
|
price = 0.005 / 1000
|
|||
|
|
total = response['usage'].get('total_tokens') or 0
|
|||
|
|
fee = total * price
|
|||
|
|
api_info['费用'] = f"¥{fee:.4f}"
|
|||
|
|
else:
|
|||
|
|
api_info['费用'] = '--'
|
|||
|
|
api_info['模型'] = model
|
|||
|
|
elif model_type == 'dashscope':
|
|||
|
|
if Generation is None:
|
|||
|
|
return "阿里云DashScope SDK未安装,请运行: pip install dashscope", {}
|
|||
|
|
response = Generation.call(
|
|||
|
|
model=model,
|
|||
|
|
messages=[{"role": "user", "content": prompt}],
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
api_key=api_key
|
|||
|
|
)
|
|||
|
|
# 兼容 output.choices 和 output.text
|
|||
|
|
if hasattr(response, 'output'):
|
|||
|
|
if hasattr(response.output, 'choices') and response.output.choices:
|
|||
|
|
content = response.output.choices[0].message.content
|
|||
|
|
elif hasattr(response.output, 'text') and response.output.text:
|
|||
|
|
content = response.output.text
|
|||
|
|
else:
|
|||
|
|
content = f"阿里云API返回异常:output字段无choices或text,原始返回:{str(response)}"
|
|||
|
|
else:
|
|||
|
|
content = f"阿里云API返回异常:无output字段,原始返回:{str(response)}"
|
|||
|
|
# tokens统计
|
|||
|
|
if hasattr(response, 'usage'):
|
|||
|
|
usage = response.usage
|
|||
|
|
api_info['提示tokens'] = getattr(usage, 'input_tokens', None) or getattr(usage, 'prompt_tokens', None)
|
|||
|
|
api_info['生成tokens'] = getattr(usage, 'output_tokens', None) or getattr(usage, 'completion_tokens', None)
|
|||
|
|
api_info['总tokens'] = getattr(usage, 'total_tokens', None)
|
|||
|
|
# 费用估算(假设0.012元/1K token,实际需查官方文档)
|
|||
|
|
price = 0.012 / 1000
|
|||
|
|
total = getattr(usage, 'total_tokens', None) or 0
|
|||
|
|
fee = total * price
|
|||
|
|
api_info['费用'] = f"¥{fee:.4f}"
|
|||
|
|
else:
|
|||
|
|
api_info['费用'] = '--'
|
|||
|
|
api_info['模型'] = model
|
|||
|
|
else:
|
|||
|
|
return f"不支持的模型类型: {model_type}", {}
|
|||
|
|
except Exception as e:
|
|||
|
|
content = f"生成内容时出错: {str(e)}"
|
|||
|
|
api_info['响应时间'] = f"{(time.time() - start_time):.2f}s"
|
|||
|
|
return content, api_info
|
|||
|
|
|
|||
|
|
def create_word_doc(news_content, image_stream, image_filename):
|
|||
|
|
doc = Document()
|
|||
|
|
doc.add_heading(config.DOCUMENT_CONFIG['title'], 0)
|
|||
|
|
if image_stream:
|
|||
|
|
doc.add_picture(image_stream, width=Inches(config.DOCUMENT_CONFIG['image_width']))
|
|||
|
|
doc.add_paragraph(news_content)
|
|||
|
|
output = io.BytesIO()
|
|||
|
|
doc.save(output)
|
|||
|
|
output.seek(0)
|
|||
|
|
return output
|
|||
|
|
|
|||
|
|
def get_valid_models(models):
|
|||
|
|
"""过滤掉结构不完整的模型配置"""
|
|||
|
|
return {
|
|||
|
|
k: v for k, v in models.items()
|
|||
|
|
if v.get('name') and v.get('model')
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
@app.route('/login', methods=['GET', 'POST'])
|
|||
|
|
def login():
|
|||
|
|
# 登录页面,支持GET和POST
|
|||
|
|
if request.method == 'POST':
|
|||
|
|
username = request.form.get('username')
|
|||
|
|
password = request.form.get('password')
|
|||
|
|
# 这里简单写死用户名密码,实际可接数据库
|
|||
|
|
if username == 'admin' and password == '123456':
|
|||
|
|
session['logged_in'] = True
|
|||
|
|
return redirect(url_for('index'))
|
|||
|
|
else:
|
|||
|
|
return render_template('login.html', error='用户名或密码错误')
|
|||
|
|
return render_template('login.html')
|
|||
|
|
|
|||
|
|
@app.route('/logout')
|
|||
|
|
def logout():
|
|||
|
|
# 注销登录
|
|||
|
|
session.pop('logged_in', None)
|
|||
|
|
return redirect(url_for('login'))
|
|||
|
|
|
|||
|
|
@app.route('/', methods=['GET', 'POST'])
|
|||
|
|
@login_required
|
|||
|
|
def index():
|
|||
|
|
news_content = ''
|
|||
|
|
selected_model = request.form.get('model', 'dashscope')
|
|||
|
|
if request.method == 'POST':
|
|||
|
|
action = request.form.get('action')
|
|||
|
|
if action == 'generate':
|
|||
|
|
description = request.form['description']
|
|||
|
|
style = request.form['style']
|
|||
|
|
model_type = request.form['model']
|
|||
|
|
image = request.files.get('image')
|
|||
|
|
topics = [request.form.get('topic1', '').strip(), request.form.get('topic2', '').strip(), request.form.get('topic3', '').strip()]
|
|||
|
|
news_content, api_info = generate_news_content(description, style, model_type, topics, image)
|
|||
|
|
selected_model = model_type
|
|||
|
|
# 生成后渲染页面,右侧显示news_content
|
|||
|
|
return render_template(
|
|||
|
|
'index.html',
|
|||
|
|
styles=config.STYLES,
|
|||
|
|
models=get_valid_models(config.MODELS),
|
|||
|
|
api_key_status=config.validate_api_keys(),
|
|||
|
|
news_content=news_content,
|
|||
|
|
selected_model=selected_model,
|
|||
|
|
api_info=api_info
|
|||
|
|
)
|
|||
|
|
elif action == 'download':
|
|||
|
|
news_content = request.form.get('news_content', '')
|
|||
|
|
image = request.files.get('image')
|
|||
|
|
word_file = create_word_doc(news_content, image.stream if image else None, image.filename if image else None)
|
|||
|
|
return send_file(word_file, as_attachment=True, download_name=config.DOCUMENT_CONFIG['default_filename'])
|
|||
|
|
# GET 或首次访问
|
|||
|
|
return render_template(
|
|||
|
|
'index.html',
|
|||
|
|
styles=config.STYLES,
|
|||
|
|
models=get_valid_models(config.MODELS),
|
|||
|
|
api_key_status=config.validate_api_keys(),
|
|||
|
|
news_content=news_content,
|
|||
|
|
selected_model=selected_model
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@app.route('/api/config', methods=['GET'])
|
|||
|
|
@login_required
|
|||
|
|
def get_config():
|
|||
|
|
"""获取配置信息API"""
|
|||
|
|
return jsonify({
|
|||
|
|
'models': config.MODELS,
|
|||
|
|
'styles': config.STYLES,
|
|||
|
|
'available_models': config.get_available_models(),
|
|||
|
|
'api_key_status': config.validate_api_keys()
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
@app.route('/api/generate', methods=['POST'])
|
|||
|
|
@login_required
|
|||
|
|
def api_generate():
|
|||
|
|
description = request.form['description']
|
|||
|
|
style = request.form['style']
|
|||
|
|
model_type = request.form['model']
|
|||
|
|
topics = [request.form.get('topic1', '').strip(), request.form.get('topic2', '').strip(), request.form.get('topic3', '').strip()]
|
|||
|
|
image = request.files.get('image')
|
|||
|
|
news_content, api_info = generate_news_content(description, style, model_type, topics, image)
|
|||
|
|
return jsonify({'news_content': news_content, 'api_info': api_info})
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
app.run(
|
|||
|
|
debug=config.FLASK_CONFIG['debug'],
|
|||
|
|
host=config.FLASK_CONFIG['host'],
|
|||
|
|
port=config.FLASK_CONFIG['port']
|
|||
|
|
)
|