49 lines
1.5 KiB
Python
49 lines
1.5 KiB
Python
import torch
|
||
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||
from PIL import Image
|
||
import os, sys
|
||
|
||
# 加载模型和处理器(替换为你自己的模型路径)
|
||
model_path = "./qwen2-vl-2b"
|
||
|
||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||
model_path,
|
||
device_map="auto", # 自动分配设备(CPU/GPU)
|
||
torch_dtype=torch.float16 # 使用半精度加快推理
|
||
)
|
||
processor = AutoProcessor.from_pretrained(model_path)
|
||
|
||
# 加载本地图片(替换成你自己的图片路径)
|
||
image_path = "./example.jpg"
|
||
if not os.path.exists(image_path):
|
||
raise FileNotFoundError(f"Image file {image_path} not found!")
|
||
|
||
image = Image.open(image_path).convert("RGB")
|
||
|
||
# 示例任务1:图像描述生成
|
||
prompt = "Describe the content of this image in detail."
|
||
|
||
# 构建输入
|
||
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
|
||
|
||
# 推理
|
||
with torch.no_grad():
|
||
generated_ids = model.generate(**inputs, max_new_tokens=100)
|
||
|
||
# 解码输出
|
||
description = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||
print("【图像描述】:", description)
|
||
|
||
# 示例任务2:视觉问答(VQA)
|
||
question = "What is the main object in the image?"
|
||
|
||
# 构建输入
|
||
inputs = processor(text=question, images=image, return_tensors="pt").to(model.device)
|
||
|
||
# 推理
|
||
with torch.no_grad():
|
||
generated_ids = model.generate(**inputs, max_new_tokens=50)
|
||
|
||
# 解码输出
|
||
answer = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||
print("【问题回答】:", answer) |