abo/news_agent/app.py

314 lines
13 KiB
Python
Raw Normal View History

2025-07-16 15:34:54 +08:00
import os
from flask import Flask, request, render_template, send_file, jsonify, redirect, url_for, session
from docx import Document
from docx.shared import Inches
import openai
import io
import json
from config import config
import time
from functools import wraps
try:
from volcengine.maas import MaasService, MaasException
except ImportError:
MaasService = None
MaasException = None
try:
from dashscope import Generation
except ImportError:
Generation = None
app = Flask(__name__)
app.secret_key = config.FLASK_CONFIG['secret_key']
def login_required(f):
# 登录态装饰器,未登录跳转到登录页
@wraps(f)
def decorated_function(*args, **kwargs):
if not session.get('logged_in'):
return redirect(url_for('login'))
return f(*args, **kwargs)
return decorated_function
def generate_with_openai(prompt, api_key, model, max_tokens=800, temperature=0.7):
"""使用OpenAI生成内容"""
client = openai.OpenAI(api_key=api_key)
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=temperature
)
return response.choices[0].message.content
def generate_with_volcengine(prompt, api_key, model, max_tokens=800, temperature=0.7):
"""使用火山引擎豆包生成内容"""
if MaasService is None:
raise ImportError("火山引擎SDK未安装请运行: pip install volcengine")
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', api_key)
messages = [{"role": "user", "content": prompt}]
response = maas.chat(messages, model=model, max_tokens=max_tokens)
return response['choices'][0]['message']['content']
def generate_with_dashscope(prompt, api_key, model, max_tokens=800, temperature=0.7):
"""使用阿里通义千问生成内容"""
if Generation is None:
return "阿里云DashScope SDK未安装请运行: pip install dashscope"
try:
response = Generation.call(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
api_key=api_key
)
print("阿里云原始返回:", response)
code = getattr(response, 'code', None)
msg = getattr(response, 'message', None)
if not response:
return "阿里云API无响应请检查网络或API Key。"
if code and code != 200:
return f"阿里云 API 错误: {msg or str(response)} (code={code})"
# 兼容 output.choices 和 output.text
if hasattr(response, 'output'):
if hasattr(response.output, 'choices') and response.output.choices:
return response.output.choices[0].message.content
elif hasattr(response.output, 'text') and response.output.text:
return response.output.text
else:
return f"阿里云API返回异常output字段无choices或text原始返回{str(response)}"
else:
return f"阿里云API返回异常无output字段原始返回{str(response)}"
except Exception as e:
import traceback
print("阿里云API调用异常:", e)
traceback.print_exc()
return f"阿里云API调用异常: {str(e)}"
def generate_news_content(description, style, model_type, topics=None, image=None):
# topics: list of str, 最多3个
topic_str = ''
if topics:
topics = [t for t in topics if t]
if topics:
topic_str = (
"主题报道部分请围绕下列每个话题分别成段深入分析,每个话题都要有具体内容:"
f"{''.join(topics)}\n"
)
image_str = ''
if image is not None and getattr(image, 'filename', None):
if image.filename:
image_str = "如有上传的新闻图片,请结合图片内容进行适当描述。\n"
prompt = (
f"你是一名专业新闻记者,请以{style}风格,围绕下述新闻事件,撰写一篇结构完整、语言流畅、符合新闻报道规范的新闻稿。\n"
f"结构要求:\n"
f"1. 导语:简要概括新闻事件。\n"
f"2. 品牌介绍:如涉及品牌,请介绍品牌的背景、定位、优势或发展愿景。\n"
f"3. 主题报道:{topic_str}每个话题单独成段。\n"
f"4. 结语:总结新闻意义或展望未来。\n"
f"其它要求:\n"
f"- 内容真实、客观,突出事件的背景、经过、影响。\n"
f"- 字数不少于500字。\n"
f"- {image_str}"
f"\n新闻事件描述:\n{description}\n"
)
model_config = config.get_model_config(model_type)
if not model_config:
return "不支持的模型类型", {}
api_key = model_config['api_key']
model = model_config['model']
max_tokens = model_config['max_tokens']
temperature = model_config['temperature']
api_info = {}
start_time = time.time()
try:
if model_type == 'openai':
client = openai.OpenAI(api_key=api_key)
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=temperature
)
content = response.choices[0].message.content
usage = getattr(response, 'usage', None)
if usage:
api_info['提示tokens'] = usage.prompt_tokens
api_info['生成tokens'] = usage.completion_tokens
api_info['总tokens'] = usage.total_tokens
# 费用估算OpenAI gpt-4 价格为例,实际可根据模型动态调整)
# gpt-4-1106-preview: $0.01/1K prompt, $0.03/1K completion
price_prompt = 0.01 / 1000
price_completion = 0.03 / 1000
fee = usage.prompt_tokens * price_prompt + usage.completion_tokens * price_completion
api_info['费用'] = f"${fee:.4f}"
else:
api_info['费用'] = '--'
api_info['模型'] = model
elif model_type == 'volcengine':
if MaasService is None:
return "火山引擎SDK未安装请运行: pip install volcengine", {}
maas = MaasService('maas-api.ml-platform-cn-beijing.volces.com', api_key)
messages = [{"role": "user", "content": prompt}]
response = maas.chat(messages, model=model, max_tokens=max_tokens)
content = response['choices'][0]['message']['content']
if 'usage' in response:
api_info['提示tokens'] = response['usage'].get('prompt_tokens')
api_info['生成tokens'] = response['usage'].get('completion_tokens')
api_info['总tokens'] = response['usage'].get('total_tokens')
# 费用估算假设0.005元/1K token实际需查官方文档
price = 0.005 / 1000
total = response['usage'].get('total_tokens') or 0
fee = total * price
api_info['费用'] = f"¥{fee:.4f}"
else:
api_info['费用'] = '--'
api_info['模型'] = model
elif model_type == 'dashscope':
if Generation is None:
return "阿里云DashScope SDK未安装请运行: pip install dashscope", {}
response = Generation.call(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
api_key=api_key
)
# 兼容 output.choices 和 output.text
if hasattr(response, 'output'):
if hasattr(response.output, 'choices') and response.output.choices:
content = response.output.choices[0].message.content
elif hasattr(response.output, 'text') and response.output.text:
content = response.output.text
else:
content = f"阿里云API返回异常output字段无choices或text原始返回{str(response)}"
else:
content = f"阿里云API返回异常无output字段原始返回{str(response)}"
# tokens统计
if hasattr(response, 'usage'):
usage = response.usage
api_info['提示tokens'] = getattr(usage, 'input_tokens', None) or getattr(usage, 'prompt_tokens', None)
api_info['生成tokens'] = getattr(usage, 'output_tokens', None) or getattr(usage, 'completion_tokens', None)
api_info['总tokens'] = getattr(usage, 'total_tokens', None)
# 费用估算假设0.012元/1K token实际需查官方文档
price = 0.012 / 1000
total = getattr(usage, 'total_tokens', None) or 0
fee = total * price
api_info['费用'] = f"¥{fee:.4f}"
else:
api_info['费用'] = '--'
api_info['模型'] = model
else:
return f"不支持的模型类型: {model_type}", {}
except Exception as e:
content = f"生成内容时出错: {str(e)}"
api_info['响应时间'] = f"{(time.time() - start_time):.2f}s"
return content, api_info
def create_word_doc(news_content, image_stream, image_filename):
doc = Document()
doc.add_heading(config.DOCUMENT_CONFIG['title'], 0)
if image_stream:
doc.add_picture(image_stream, width=Inches(config.DOCUMENT_CONFIG['image_width']))
doc.add_paragraph(news_content)
output = io.BytesIO()
doc.save(output)
output.seek(0)
return output
def get_valid_models(models):
"""过滤掉结构不完整的模型配置"""
return {
k: v for k, v in models.items()
if v.get('name') and v.get('model')
}
@app.route('/login', methods=['GET', 'POST'])
def login():
# 登录页面支持GET和POST
if request.method == 'POST':
username = request.form.get('username')
password = request.form.get('password')
# 这里简单写死用户名密码,实际可接数据库
if username == 'admin' and password == '123456':
session['logged_in'] = True
return redirect(url_for('index'))
else:
return render_template('login.html', error='用户名或密码错误')
return render_template('login.html')
@app.route('/logout')
def logout():
# 注销登录
session.pop('logged_in', None)
return redirect(url_for('login'))
@app.route('/', methods=['GET', 'POST'])
@login_required
def index():
news_content = ''
selected_model = request.form.get('model', 'dashscope')
if request.method == 'POST':
action = request.form.get('action')
if action == 'generate':
description = request.form['description']
style = request.form['style']
model_type = request.form['model']
image = request.files.get('image')
topics = [request.form.get('topic1', '').strip(), request.form.get('topic2', '').strip(), request.form.get('topic3', '').strip()]
news_content, api_info = generate_news_content(description, style, model_type, topics, image)
selected_model = model_type
# 生成后渲染页面右侧显示news_content
return render_template(
'index.html',
styles=config.STYLES,
models=get_valid_models(config.MODELS),
api_key_status=config.validate_api_keys(),
news_content=news_content,
selected_model=selected_model,
api_info=api_info
)
elif action == 'download':
news_content = request.form.get('news_content', '')
image = request.files.get('image')
word_file = create_word_doc(news_content, image.stream if image else None, image.filename if image else None)
return send_file(word_file, as_attachment=True, download_name=config.DOCUMENT_CONFIG['default_filename'])
# GET 或首次访问
return render_template(
'index.html',
styles=config.STYLES,
models=get_valid_models(config.MODELS),
api_key_status=config.validate_api_keys(),
news_content=news_content,
selected_model=selected_model
)
@app.route('/api/config', methods=['GET'])
@login_required
def get_config():
"""获取配置信息API"""
return jsonify({
'models': config.MODELS,
'styles': config.STYLES,
'available_models': config.get_available_models(),
'api_key_status': config.validate_api_keys()
})
@app.route('/api/generate', methods=['POST'])
@login_required
def api_generate():
description = request.form['description']
style = request.form['style']
model_type = request.form['model']
topics = [request.form.get('topic1', '').strip(), request.form.get('topic2', '').strip(), request.form.get('topic3', '').strip()]
image = request.files.get('image')
news_content, api_info = generate_news_content(description, style, model_type, topics, image)
return jsonify({'news_content': news_content, 'api_info': api_info})
if __name__ == '__main__':
app.run(
debug=config.FLASK_CONFIG['debug'],
host=config.FLASK_CONFIG['host'],
port=config.FLASK_CONFIG['port']
)