This commit is contained in:
parent
5c05617e79
commit
2b4836ba70
2
.env
2
.env
|
|
@ -11,3 +11,5 @@ EMBED_DIM=768
|
|||
|
||||
QDRANT_URL=http://localhost:6333
|
||||
QDRANT_COLLECTION=text_chunks
|
||||
GEMINI_API_KEY = "AIzaSyDWqNUBKhaZjbFI8CW52_hKr46JtWABkGU"
|
||||
GEMINI_MODEL = "models/gemini-2.0-flash-001"
|
||||
Binary file not shown.
|
|
@ -1,66 +1,24 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
LLM Client — tương tác với vLLM server
|
||||
-------------------------------------
|
||||
- Dùng HTTP endpoint (OpenAI-compatible) của vLLM.
|
||||
- Nhận prompt đã được build từ PromptBuilder.
|
||||
- Gọi model sinh câu trả lời.
|
||||
"""
|
||||
import os
|
||||
import google.generativeai as genai
|
||||
|
||||
import requests
|
||||
from typing import Optional, Dict, Any
|
||||
from src.core.config import VLLM_URL, VLLM_MODEL, LLM_TEMPERATURE, LLM_MAX_TOKENS
|
||||
from src.core.config import GEMINI_API_KEY, GEMINI_MODEL
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""
|
||||
Gọi model vLLM theo giao thức OpenAI-compatible.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = VLLM_URL,
|
||||
model: str = VLLM_MODEL,
|
||||
temperature: float = LLM_TEMPERATURE,
|
||||
max_tokens: int = LLM_MAX_TOKENS
|
||||
):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
def __init__(self, api_key: str = GEMINI_API_KEY, model: str = GEMINI_MODEL):
|
||||
if not api_key:
|
||||
raise ValueError("Thiếu GEMINI_API_KEY trong config hoặc .env")
|
||||
|
||||
def generate(self, prompt: str, system_prompt: Optional[str] = None) -> str:
|
||||
genai.configure(api_key=api_key)
|
||||
self.model = genai.GenerativeModel(model)
|
||||
|
||||
def generate(self, prompt: str) -> str:
|
||||
"""
|
||||
Gửi prompt tới vLLM API và trả về câu trả lời.
|
||||
Gửi prompt vào Gemini và trả về nội dung text.
|
||||
"""
|
||||
url = f"{self.base_url}/v1/chat/completions"
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=payload, timeout=60)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["choices"][0]["message"]["content"].strip()
|
||||
response = self.model.generate_content(prompt)
|
||||
return response.text.strip() if response.text else ""
|
||||
except Exception as e:
|
||||
print(f"[LLMClient] ❌ Error calling vLLM: {e}")
|
||||
return f"[Error] {e}"
|
||||
|
||||
|
||||
# ---- Test nhanh ----
|
||||
if __name__ == "__main__":
|
||||
client = LLMClient()
|
||||
prompt = "Viết đoạn mô tả ngắn về hành tinh Sao Hỏa."
|
||||
answer = client.generate(prompt)
|
||||
print("🪐 Kết quả từ LLM:")
|
||||
print(answer)
|
||||
print(f"[GeminiLLMClient] Lỗi khi gọi Gemini API: {e}")
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -70,14 +70,4 @@ class PromptBuilder:
|
|||
return prompt
|
||||
|
||||
|
||||
# 🔍 Kiểm tra nhanh
|
||||
if __name__ == "__main__":
|
||||
fake_docs = [
|
||||
{"file_name": "doc1.txt", "text": "Điện thoại Vivo V27 có camera 64MP.", "score": 0.87},
|
||||
{"file_name": "doc2.txt", "text": "Máy có pin 4600mAh, sạc nhanh 66W.", "score": 0.83},
|
||||
]
|
||||
|
||||
builder = PromptBuilder()
|
||||
query = "Điện thoại Vivo V27 pin bao nhiêu?"
|
||||
prompt = builder.build_prompt(query, fake_docs)
|
||||
print(prompt)
|
||||
|
|
|
|||
|
|
@ -11,59 +11,18 @@ Kết nối các module:
|
|||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from src.chatbot.llm_client import LLMClient
|
||||
from src.chatbot.retriever import Retriever
|
||||
from src.chatbot.prompt_builder import PromptBuilder
|
||||
from src.chatbot.llm_client import LLMClient
|
||||
|
||||
|
||||
class RAGPipeline:
|
||||
def __init__(self):
|
||||
# Khởi tạo các thành phần chính
|
||||
self.retriever = Retriever()
|
||||
self.prompt_builder = PromptBuilder()
|
||||
self.llm_client = LLMClient()
|
||||
self.llm = LLMClient()
|
||||
|
||||
def query(self, user_query: str, top_k: int = 5) -> Dict[str, Any]:
|
||||
"""
|
||||
Xử lý một truy vấn RAG hoàn chỉnh:
|
||||
- Lấy top_k văn bản liên quan
|
||||
- Xây prompt hoàn chỉnh
|
||||
- Gọi LLM sinh câu trả lời
|
||||
"""
|
||||
print(f"\n[🔍] Nhận truy vấn: {user_query}")
|
||||
|
||||
# 1️⃣ Retrieve context từ Qdrant
|
||||
context_docs = self.retriever.search(user_query, top_k=top_k)
|
||||
if not context_docs:
|
||||
return {"answer": "❌ Không tìm thấy thông tin phù hợp trong cơ sở tri thức."}
|
||||
|
||||
print(f"[📚] Lấy được {len(context_docs)} đoạn văn liên quan.")
|
||||
|
||||
# 2️⃣ Build prompt từ user_query + context
|
||||
prompt = self.prompt_builder.build_prompt(user_query, context_docs)
|
||||
|
||||
# 3️⃣ Gọi LLM
|
||||
answer = self.llm_client.generate(prompt)
|
||||
|
||||
# 4️⃣ Trả về kết quả tổng hợp
|
||||
return {
|
||||
"query": user_query,
|
||||
"context_used": context_docs,
|
||||
"prompt": prompt,
|
||||
"answer": answer,
|
||||
}
|
||||
|
||||
|
||||
# ---- Test thủ công ----
|
||||
if __name__ == "__main__":
|
||||
rag = RAGPipeline()
|
||||
query = "Hành tinh Sao Mộc có những đặc điểm nổi bật nào?"
|
||||
result = rag.query(query, top_k=3)
|
||||
|
||||
print("\n=== 💬 Kết quả Chatbot ===")
|
||||
print("Câu hỏi:", result["query"])
|
||||
print("\n🧠 Context dùng:")
|
||||
for i, ctx in enumerate(result["context_used"], 1):
|
||||
print(f"{i}. {ctx['text'][:200]}...")
|
||||
print("\n🤖 Trả lời:")
|
||||
print(result["answer"])
|
||||
def run(self, user_query: str) -> str:
|
||||
docs = self.retriever.retrieve(user_query)
|
||||
prompt = self.prompt_builder.build(user_query, docs)
|
||||
answer = self.llm.generate(prompt)
|
||||
return answer
|
||||
|
|
|
|||
|
|
@ -1,7 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Retriever: Tìm kiếm các đoạn văn bản liên quan trong Qdrant.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
|
@ -63,7 +59,7 @@ class Retriever:
|
|||
if __name__ == "__main__":
|
||||
# Test nhanh
|
||||
retriever = Retriever(top_k=3)
|
||||
query = "Mahola ca sĩ nam Phi là ai?"
|
||||
query = "Mahola la ai"
|
||||
results = retriever.search(query)
|
||||
for i, r in enumerate(results, 1):
|
||||
print(f"\n[{i}] Score={r['score']:.4f}")
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue