from vllm import LLM, SamplingParams

model_name = "llm-jp/llm-jp-3.1-1.8b-instruct"

llm = LLM(model=model_name,
            quantization="bitsandbytes",
            trust_remote_code=True,
            tensor_parallel_size=1,
            dtype=torch.bfloat16,
            enforce_eager=True,
            load_format="bitsandbytes"
            )
# 推論用のサンプリングパラメータを設定
sampling_params = SamplingParams(temperature=0.7, top_p=0.95, max_tokens=100)

# プロンプトを定義
prompts = [
    "vLLMの量子化について教えてください。",
    "日本の首都はどこですか？",
]

# 推論を実行
outputs = llm.generate(prompts, sampling_params)

# 結果を出力
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}")
    print(f"Generated: {generated_text!r}\n")
