from vllm import LLM, SamplingParams
import torch
# AWQで量子化されたモデルを指定
model_name = "llm-jp/llm-jp-3.1-1.8b-instruct4"

llm = LLM(model=model_name,
            trust_remote_code=True,
            tensor_parallel_size=1,
            dtype=torch.bfloat16,
            enforce_eager=True,
            enable_prefix_caching=True
            ) # 必要に応じてGPU数を指定

# 推論用のサンプリングパラメータを設定
sampling_params = SamplingParams(temperature=0.7, top_p=0.95, max_tokens=100)

# プロンプトを定義
prompts = [
    "vLLMの量子化について教えてください。",
    "日本の首都はどこですか？",
]

# 推論を実行
outputs = llm.generate(prompts, sampling_params)

# 結果を出力
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}")
    print(f"Generated: {generated_text!r}\n")
