314 lines
13 KiB
Python
314 lines
13 KiB
Python
import os
|
||
from flask import Flask, request, render_template, send_file, jsonify, redirect, url_for, session
|
||
from docx import Document
|
||
from docx.shared import Inches
|
||
import openai
|
||
import io
|
||
import json
|
||
from config import config
|
||
import time
|
||
from functools import wraps
|
||
|
||
try:
|
||
from volcengine.maas import MaasService, MaasException
|
||
except ImportError:
|
||
MaasService = None
|
||
MaasException = None
|
||
|
||
try:
|
||
from dashscope import Generation
|
||
except ImportError:
|
||
Generation = None
|
||
|
||
app = Flask(__name__)
|
||
app.secret_key = config.FLASK_CONFIG['secret_key']
|
||
|
||
def login_required(f):
|
||
# 登录态装饰器,未登录跳转到登录页
|
||
@wraps(f)
|
||
def decorated_function(*args, **kwargs):
|
||
if not session.get('logged_in'):
|
||
return redirect(url_for('login'))
|
||
return f(*args, **kwargs)
|
||
return decorated_function
|
||
|
||
def generate_with_openai(prompt, api_key, model, max_tokens=800, temperature=0.7):
|
||
"""使用OpenAI生成内容"""
|
||
client = openai.OpenAI(api_key=api_key)
|
||
response = client.chat.completions.create(
|
||
model=model,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
max_tokens=max_tokens,
|
||
temperature=temperature
|
||
)
|
||
return response.choices[0].message.content
|
||
|
||
def generate_with_volcengine(prompt, api_key, model, max_tokens=800, temperature=0.7):
|
||
"""使用火山引擎豆包生成内容"""
|
||
if MaasService is None:
|
||
raise ImportError("火山引擎SDK未安装,请运行: pip install volcengine")
|
||
|
||
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', api_key)
|
||
messages = [{"role": "user", "content": prompt}]
|
||
response = maas.chat(messages, model=model, max_tokens=max_tokens)
|
||
return response['choices'][0]['message']['content']
|
||
|
||
def generate_with_dashscope(prompt, api_key, model, max_tokens=800, temperature=0.7):
|
||
"""使用阿里通义千问生成内容"""
|
||
if Generation is None:
|
||
return "阿里云DashScope SDK未安装,请运行: pip install dashscope"
|
||
try:
|
||
response = Generation.call(
|
||
model=model,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
max_tokens=max_tokens,
|
||
api_key=api_key
|
||
)
|
||
print("阿里云原始返回:", response)
|
||
code = getattr(response, 'code', None)
|
||
msg = getattr(response, 'message', None)
|
||
if not response:
|
||
return "阿里云API无响应,请检查网络或API Key。"
|
||
if code and code != 200:
|
||
return f"阿里云 API 错误: {msg or str(response)} (code={code})"
|
||
# 兼容 output.choices 和 output.text
|
||
if hasattr(response, 'output'):
|
||
if hasattr(response.output, 'choices') and response.output.choices:
|
||
return response.output.choices[0].message.content
|
||
elif hasattr(response.output, 'text') and response.output.text:
|
||
return response.output.text
|
||
else:
|
||
return f"阿里云API返回异常:output字段无choices或text,原始返回:{str(response)}"
|
||
else:
|
||
return f"阿里云API返回异常:无output字段,原始返回:{str(response)}"
|
||
except Exception as e:
|
||
import traceback
|
||
print("阿里云API调用异常:", e)
|
||
traceback.print_exc()
|
||
return f"阿里云API调用异常: {str(e)}"
|
||
|
||
def generate_news_content(description, style, model_type, topics=None, image=None):
|
||
# topics: list of str, 最多3个
|
||
topic_str = ''
|
||
if topics:
|
||
topics = [t for t in topics if t]
|
||
if topics:
|
||
topic_str = (
|
||
"主题报道部分请围绕下列每个话题分别成段深入分析,每个话题都要有具体内容:"
|
||
f"{'、'.join(topics)}。\n"
|
||
)
|
||
image_str = ''
|
||
if image is not None and getattr(image, 'filename', None):
|
||
if image.filename:
|
||
image_str = "如有上传的新闻图片,请结合图片内容进行适当描述。\n"
|
||
prompt = (
|
||
f"你是一名专业新闻记者,请以{style}风格,围绕下述新闻事件,撰写一篇结构完整、语言流畅、符合新闻报道规范的新闻稿。\n"
|
||
f"结构要求:\n"
|
||
f"1. 导语:简要概括新闻事件。\n"
|
||
f"2. 品牌介绍:如涉及品牌,请介绍品牌的背景、定位、优势或发展愿景。\n"
|
||
f"3. 主题报道:{topic_str}每个话题单独成段。\n"
|
||
f"4. 结语:总结新闻意义或展望未来。\n"
|
||
f"其它要求:\n"
|
||
f"- 内容真实、客观,突出事件的背景、经过、影响。\n"
|
||
f"- 字数不少于500字。\n"
|
||
f"- {image_str}"
|
||
f"\n新闻事件描述:\n{description}\n"
|
||
)
|
||
model_config = config.get_model_config(model_type)
|
||
if not model_config:
|
||
return "不支持的模型类型", {}
|
||
api_key = model_config['api_key']
|
||
model = model_config['model']
|
||
max_tokens = model_config['max_tokens']
|
||
temperature = model_config['temperature']
|
||
api_info = {}
|
||
start_time = time.time()
|
||
try:
|
||
if model_type == 'openai':
|
||
client = openai.OpenAI(api_key=api_key)
|
||
response = client.chat.completions.create(
|
||
model=model,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
max_tokens=max_tokens,
|
||
temperature=temperature
|
||
)
|
||
content = response.choices[0].message.content
|
||
usage = getattr(response, 'usage', None)
|
||
if usage:
|
||
api_info['提示tokens'] = usage.prompt_tokens
|
||
api_info['生成tokens'] = usage.completion_tokens
|
||
api_info['总tokens'] = usage.total_tokens
|
||
# 费用估算(OpenAI gpt-4 价格为例,实际可根据模型动态调整)
|
||
# gpt-4-1106-preview: $0.01/1K prompt, $0.03/1K completion
|
||
price_prompt = 0.01 / 1000
|
||
price_completion = 0.03 / 1000
|
||
fee = usage.prompt_tokens * price_prompt + usage.completion_tokens * price_completion
|
||
api_info['费用'] = f"${fee:.4f}"
|
||
else:
|
||
api_info['费用'] = '--'
|
||
api_info['模型'] = model
|
||
elif model_type == 'volcengine':
|
||
if MaasService is None:
|
||
return "火山引擎SDK未安装,请运行: pip install volcengine", {}
|
||
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', api_key)
|
||
messages = [{"role": "user", "content": prompt}]
|
||
response = maas.chat(messages, model=model, max_tokens=max_tokens)
|
||
content = response['choices'][0]['message']['content']
|
||
if 'usage' in response:
|
||
api_info['提示tokens'] = response['usage'].get('prompt_tokens')
|
||
api_info['生成tokens'] = response['usage'].get('completion_tokens')
|
||
api_info['总tokens'] = response['usage'].get('total_tokens')
|
||
# 费用估算(假设0.005元/1K token,实际需查官方文档)
|
||
price = 0.005 / 1000
|
||
total = response['usage'].get('total_tokens') or 0
|
||
fee = total * price
|
||
api_info['费用'] = f"¥{fee:.4f}"
|
||
else:
|
||
api_info['费用'] = '--'
|
||
api_info['模型'] = model
|
||
elif model_type == 'dashscope':
|
||
if Generation is None:
|
||
return "阿里云DashScope SDK未安装,请运行: pip install dashscope", {}
|
||
response = Generation.call(
|
||
model=model,
|
||
messages=[{"role": "user", "content": prompt}],
|
||
max_tokens=max_tokens,
|
||
api_key=api_key
|
||
)
|
||
# 兼容 output.choices 和 output.text
|
||
if hasattr(response, 'output'):
|
||
if hasattr(response.output, 'choices') and response.output.choices:
|
||
content = response.output.choices[0].message.content
|
||
elif hasattr(response.output, 'text') and response.output.text:
|
||
content = response.output.text
|
||
else:
|
||
content = f"阿里云API返回异常:output字段无choices或text,原始返回:{str(response)}"
|
||
else:
|
||
content = f"阿里云API返回异常:无output字段,原始返回:{str(response)}"
|
||
# tokens统计
|
||
if hasattr(response, 'usage'):
|
||
usage = response.usage
|
||
api_info['提示tokens'] = getattr(usage, 'input_tokens', None) or getattr(usage, 'prompt_tokens', None)
|
||
api_info['生成tokens'] = getattr(usage, 'output_tokens', None) or getattr(usage, 'completion_tokens', None)
|
||
api_info['总tokens'] = getattr(usage, 'total_tokens', None)
|
||
# 费用估算(假设0.012元/1K token,实际需查官方文档)
|
||
price = 0.012 / 1000
|
||
total = getattr(usage, 'total_tokens', None) or 0
|
||
fee = total * price
|
||
api_info['费用'] = f"¥{fee:.4f}"
|
||
else:
|
||
api_info['费用'] = '--'
|
||
api_info['模型'] = model
|
||
else:
|
||
return f"不支持的模型类型: {model_type}", {}
|
||
except Exception as e:
|
||
content = f"生成内容时出错: {str(e)}"
|
||
api_info['响应时间'] = f"{(time.time() - start_time):.2f}s"
|
||
return content, api_info
|
||
|
||
def create_word_doc(news_content, image_stream, image_filename):
|
||
doc = Document()
|
||
doc.add_heading(config.DOCUMENT_CONFIG['title'], 0)
|
||
if image_stream:
|
||
doc.add_picture(image_stream, width=Inches(config.DOCUMENT_CONFIG['image_width']))
|
||
doc.add_paragraph(news_content)
|
||
output = io.BytesIO()
|
||
doc.save(output)
|
||
output.seek(0)
|
||
return output
|
||
|
||
def get_valid_models(models):
|
||
"""过滤掉结构不完整的模型配置"""
|
||
return {
|
||
k: v for k, v in models.items()
|
||
if v.get('name') and v.get('model')
|
||
}
|
||
|
||
@app.route('/login', methods=['GET', 'POST'])
|
||
def login():
|
||
# 登录页面,支持GET和POST
|
||
if request.method == 'POST':
|
||
username = request.form.get('username')
|
||
password = request.form.get('password')
|
||
# 这里简单写死用户名密码,实际可接数据库
|
||
if username == 'admin' and password == '123456':
|
||
session['logged_in'] = True
|
||
return redirect(url_for('index'))
|
||
else:
|
||
return render_template('login.html', error='用户名或密码错误')
|
||
return render_template('login.html')
|
||
|
||
@app.route('/logout')
|
||
def logout():
|
||
# 注销登录
|
||
session.pop('logged_in', None)
|
||
return redirect(url_for('login'))
|
||
|
||
@app.route('/', methods=['GET', 'POST'])
|
||
@login_required
|
||
def index():
|
||
news_content = ''
|
||
selected_model = request.form.get('model', 'dashscope')
|
||
if request.method == 'POST':
|
||
action = request.form.get('action')
|
||
if action == 'generate':
|
||
description = request.form['description']
|
||
style = request.form['style']
|
||
model_type = request.form['model']
|
||
image = request.files.get('image')
|
||
topics = [request.form.get('topic1', '').strip(), request.form.get('topic2', '').strip(), request.form.get('topic3', '').strip()]
|
||
news_content, api_info = generate_news_content(description, style, model_type, topics, image)
|
||
selected_model = model_type
|
||
# 生成后渲染页面,右侧显示news_content
|
||
return render_template(
|
||
'index.html',
|
||
styles=config.STYLES,
|
||
models=get_valid_models(config.MODELS),
|
||
api_key_status=config.validate_api_keys(),
|
||
news_content=news_content,
|
||
selected_model=selected_model,
|
||
api_info=api_info
|
||
)
|
||
elif action == 'download':
|
||
news_content = request.form.get('news_content', '')
|
||
image = request.files.get('image')
|
||
word_file = create_word_doc(news_content, image.stream if image else None, image.filename if image else None)
|
||
return send_file(word_file, as_attachment=True, download_name=config.DOCUMENT_CONFIG['default_filename'])
|
||
# GET 或首次访问
|
||
return render_template(
|
||
'index.html',
|
||
styles=config.STYLES,
|
||
models=get_valid_models(config.MODELS),
|
||
api_key_status=config.validate_api_keys(),
|
||
news_content=news_content,
|
||
selected_model=selected_model
|
||
)
|
||
|
||
@app.route('/api/config', methods=['GET'])
|
||
@login_required
|
||
def get_config():
|
||
"""获取配置信息API"""
|
||
return jsonify({
|
||
'models': config.MODELS,
|
||
'styles': config.STYLES,
|
||
'available_models': config.get_available_models(),
|
||
'api_key_status': config.validate_api_keys()
|
||
})
|
||
|
||
@app.route('/api/generate', methods=['POST'])
|
||
@login_required
|
||
def api_generate():
|
||
description = request.form['description']
|
||
style = request.form['style']
|
||
model_type = request.form['model']
|
||
topics = [request.form.get('topic1', '').strip(), request.form.get('topic2', '').strip(), request.form.get('topic3', '').strip()]
|
||
image = request.files.get('image')
|
||
news_content, api_info = generate_news_content(description, style, model_type, topics, image)
|
||
return jsonify({'news_content': news_content, 'api_info': api_info})
|
||
|
||
if __name__ == '__main__':
|
||
app.run(
|
||
debug=config.FLASK_CONFIG['debug'],
|
||
host=config.FLASK_CONFIG['host'],
|
||
port=config.FLASK_CONFIG['port']
|
||
) |