Sontranwakumo commited on
Commit
88cc76c
·
1 Parent(s): 77d75db

init: move from github

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.env.example ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ NEO4J_URI=
2
+ NEO4J_USER=neo4j
3
+ NEO4J_PASSWORD=
4
+ OPENAI_API_KEY=
5
+ GEMINI_API_KEY=
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.index filter=lfs diff=lfs merge=lfs -text
37
+ *.db filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ weights/
7
+
8
+ ## Big data
9
+ /Data
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+
32
+ # PyInstaller
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ .hypothesis/
50
+
51
+ # Translations
52
+ *.mo
53
+ *.pot
54
+
55
+ # Django stuff:
56
+ *.log
57
+ local_settings.py
58
+ db.sqlite3
59
+
60
+ # Flask stuff:
61
+ instance/
62
+ .webassets-cache
63
+
64
+ # Scrapy stuff:
65
+ .scrapy
66
+
67
+ # Sphinx documentation
68
+ docs/_build/
69
+
70
+ # PyBuilder
71
+ target/
72
+
73
+ # Jupyter Notebook
74
+ .ipynb_checkpoints
75
+
76
+ # pyenv
77
+ .python-version
78
+
79
+ # celery beat schedule file
80
+ celerybeat-schedule
81
+
82
+ # SageMath parsed files
83
+ *.sage.py
84
+
85
+ # Environments
86
+ .env
87
+ .venv
88
+ env/
89
+ venv/
90
+ ENV/
91
+ env.bak/
92
+ venv.bak/
93
+
94
+ # Spyder project settings
95
+ .spyderproject
96
+ .spyproject
97
+
98
+ # Rope project settings
99
+ .ropeproject
100
+
101
+ # mkdocs documentation
102
+ /site
103
+
104
+ # mypy
105
+ .mypy_cache/
106
+
107
+ # IDE specific files
108
+ .idea/
109
+ .vscode/
110
+ *.swp
111
+ *.swo
112
+
113
+ # FastAPI specific
114
+ .pytest_cache/
README.md CHANGED
@@ -8,4 +8,131 @@ pinned: false
8
  short_description: Crop diagnosis module
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  short_description: Crop diagnosis module
9
  ---
10
 
11
+
12
+ # Crop Diagnosis Knowledge Graph Module
13
+
14
+ A powerful tool for querying crop diagnosis knowledge graphs using LangChain and Neo4j. This module provides an API interface to interact with a knowledge graph containing agricultural and crop disease information.
15
+
16
+ ## Features
17
+
18
+ - Natural language querying of crop diagnosis knowledge graph
19
+ - Integration with LangChain for intelligent query processing
20
+ - Neo4j database backend for efficient graph operations
21
+ - RESTful API interface
22
+ - Environment-based configuration
23
+
24
+ ## Prerequisites
25
+
26
+ - Python 3.8+
27
+ - Neo4j Database (version 5.x)
28
+ - OpenAI API key (for LangChain integration)
29
+
30
+ ## Installation
31
+
32
+ 1. Clone the repository:
33
+ ```bash
34
+ git clone [repository-url]
35
+ cd crop-diag-module
36
+ ```
37
+
38
+ 2. Create and activate a virtual environment:
39
+ ```bash
40
+ python -m venv venv
41
+ source venv/bin/activate # On Windows: venv\Scripts\activate
42
+ ```
43
+
44
+ 3. Install dependencies:
45
+ ```bash
46
+ pip install -r requirements.txt
47
+ ```
48
+
49
+ ## Configuration
50
+
51
+ 1. Create a `.env` file in the project root:
52
+ ```bash
53
+ cp .env.example .env
54
+ ```
55
+
56
+ 2. Edit the `.env` file with your configuration:
57
+ ```env
58
+ # Neo4j Configuration
59
+ NEO4J_URI=bolt://localhost:7687
60
+ NEO4J_USER=neo4j
61
+ NEO4J_PASSWORD=your_password
62
+
63
+ # API Configuration
64
+ API_HOST=0.0.0.0
65
+ API_PORT=8000
66
+ DEBUG=True
67
+
68
+ # LangChain Configuration
69
+ OPENAI_API_KEY=your_openai_api_key
70
+ ```
71
+
72
+ Replace the following values:
73
+ - `NEO4J_URI`: Your Neo4j database URI
74
+ - `NEO4J_USER`: Neo4j username
75
+ - `NEO4J_PASSWORD`: Neo4j password
76
+ - `OPENAI_API_KEY`: Your OpenAI API key
77
+
78
+ ## Running the Application
79
+
80
+ 1. Start the FastAPI server:
81
+ ```bash
82
+ uvicorn app.main:app --reload
83
+ ```
84
+
85
+ 2. Access the API documentation:
86
+ - Swagger UI: http://localhost:8000/docs
87
+ - ReDoc: http://localhost:8000/redoc
88
+
89
+ ## API Usage
90
+
91
+ ### Query the Knowledge Graph
92
+
93
+ ```bash
94
+ curl -X POST "http://localhost:8000/api/query" \
95
+ -H "Content-Type: application/json" \
96
+ -d '{"question": "What are the symptoms of rice blast disease?"}'
97
+ ```
98
+
99
+ ## Project Structure
100
+
101
+ ```
102
+ crop-diag-module/
103
+ ├── app/
104
+ │ ├── api/ # API routes and endpoints
105
+ │ ├── core/ # Core functionality
106
+ │ ├── models/ # Data models
107
+ │ └── utils/ # Utility functions
108
+ ├── KG/ # Knowledge Graph data
109
+ ├── tests/ # Test cases
110
+ ├── requirements.txt # Project dependencies
111
+ └── .env # Environment configuration
112
+ ```
113
+
114
+ ## Development
115
+
116
+ ### Running Tests
117
+
118
+ ```bash
119
+ pytest tests/
120
+ ```
121
+
122
+ ### Code Style
123
+
124
+ This project follows PEP 8 style guidelines. Use the following command to check code style:
125
+
126
+ ```bash
127
+ uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
128
+
129
+ uvicorn app.main:app --reload
130
+ ```
131
+
132
+ ## License
133
+
134
+ [Add your license information here]
135
+
136
+ ## Contributing
137
+
138
+ [Add contribution guidelines here]
app/.DS_Store ADDED
Binary file (6.15 kB). View file
 
app/__init__.py ADDED
File without changes
app/api/__init__.py ADDED
File without changes
app/api/dto/kg_query.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from pydantic import BaseModel
3
+
4
+ from app.core.type import Node
5
+
6
+ class QueryContext(BaseModel):
7
+ crop_id: Optional[str] = None
8
+ nodes: Optional[List[Node]] = None
9
+ predicted_labels: Optional[List[str]] = None
10
+
11
+ class PredictedLabel(BaseModel):
12
+ crop_name: str
13
+ label: str
14
+ confidence: float
15
+
16
+ class KGQueryRequest(BaseModel):
17
+ context: Optional[QueryContext] = None
18
+ crop_id: Optional[str] = None
19
+ additional_info: Optional[str] = None
20
+
21
+ class KGQueryResponse(BaseModel):
22
+ answer: str
23
+ sources: List[str]
app/api/routes.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, File
2
+ from pydantic import BaseModel
3
+ from typing import List, Optional
4
+ from app.api.dto.kg_query import KGQueryRequest, KGQueryResponse, PredictedLabel, QueryContext
5
+ from app.core.dependencies import get_all_models, get_clip_model, get_data_mapper
6
+ from app.core.type import Node
7
+ from app.models.crop_clip import CLIPModule
8
+ from app.services.predict import PredictService, get_predict_service
9
+ from app.utils.extract_entity import extract_entities
10
+
11
+ router = APIRouter()
12
+
13
+ class QueryRequest(BaseModel):
14
+ question: str
15
+ context: Optional[List[str]] = None
16
+
17
+ class QueryResponse(BaseModel):
18
+ answer: str
19
+ sources: List[str]
20
+
21
+ @router.post("/analyze")
22
+ async def analyze(
23
+ image: UploadFile = File(None),
24
+ predict_service: PredictService = Depends(get_predict_service)
25
+ ):
26
+ predicted_label = predict_service.predict_image(image)
27
+
28
+ return {
29
+ "crop_id": predicted_label[0].crop_name,
30
+ "predicted_labels": predicted_label,
31
+ "nodes": []
32
+ }
33
+
34
+
35
+ @router.post("/kg-query")
36
+ async def query_kg(
37
+ request: KGQueryRequest,
38
+ predict_service: PredictService = Depends(get_predict_service),
39
+ ):
40
+ return predict_service.retrieve_kg(request)
app/core/__init__.py ADDED
File without changes
app/core/config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_settings import BaseSettings
2
+ from functools import lru_cache
3
+
4
+ class Settings(BaseSettings):
5
+ # Neo4j Configuration
6
+ neo4j_uri: str
7
+ neo4j_user: str
8
+ neo4j_password: str
9
+ neo4j_database: str = "neo4j"
10
+
11
+ # API Configuration
12
+ api_host: str = "0.0.0.0"
13
+ api_port: int = 8000
14
+ debug: bool = True
15
+
16
+ openai_api_key: str
17
+ gemini_api_key: str
18
+
19
+ load_clip_model: bool = True
20
+ load_gemini_model: bool = True
21
+ load_data_mapper: bool = True
22
+ load_knowledge_graph: bool = True
23
+
24
+ class Config:
25
+ env_file = ".env"
26
+
27
+ @lru_cache()
28
+ def get_settings() -> Settings:
29
+ return Settings()
app/core/dependencies.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Depends, Request
2
+ from app.models.crop_clip import CLIPModule
3
+ from app.utils.data_mapping import DataMapping
4
+ from app.models.knowledge_graph import KnowledgeGraphUtils, Neo4jConnection
5
+
6
+ def get_clip_model(request: Request) -> CLIPModule:
7
+ """Lấy CLIP model từ app.state"""
8
+ return request.app.state.model_loader.clip_model
9
+
10
+ def get_data_mapper(request: Request) -> DataMapping:
11
+ """Lấy DataMapper từ app.state"""
12
+ return request.app.state.model_loader.data_mapper
13
+
14
+ def get_knowledge_graph(request: Request) -> KnowledgeGraphUtils:
15
+ """Lấy KnowledgeGraph từ app.state"""
16
+ return request.app.state.model_loader.knowledge_graph
17
+
18
+ def get_all_models(
19
+ clip_model: CLIPModule = Depends(get_clip_model),
20
+ data_mapper: DataMapping = Depends(get_data_mapper),
21
+ knowledge_graph: KnowledgeGraphUtils = Depends(get_knowledge_graph)
22
+ ):
23
+ """Lấy tất cả các model từ app.state"""
24
+ return {
25
+ "clip_model": clip_model,
26
+ "data_mapper": data_mapper,
27
+ "knowledge_graph": knowledge_graph
28
+ }
app/core/type.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List, Optional
3
+ import json
4
+
5
+ class Node(BaseModel):
6
+ id: str
7
+ label: str
8
+ name: str
9
+ properties: dict
10
+ score: Optional[float] = None
11
+
12
+ @staticmethod
13
+ def map_json_to_node(json_data: dict, label: str = None) -> 'Node':
14
+ node_data = {
15
+ "name": json_data.pop("name") if "name" in json_data else json_data["id"],
16
+ "id": json_data.pop("id"),
17
+ "label": label if label else json_data.pop("label"),
18
+ "properties": json_data
19
+ }
20
+ return Node(**node_data)
21
+
22
+ @staticmethod
23
+ def data_row_to_node(data_row: list[str], score = None) -> 'Node':
24
+ return Node(
25
+ id=data_row[1],
26
+ name=data_row[2],
27
+ label=data_row[3],
28
+ properties=json.loads(data_row[4]),
29
+ score=score
30
+ )
31
+
32
+ class Relationship(BaseModel):
33
+ source: str
34
+ target: str
35
+ type: str
36
+ properties: dict
37
+
38
+ class KnowledgeGraph(BaseModel):
39
+ nodes: List[Node]
40
+ relationships: List[Relationship]
41
+
42
+ class GraphQuery(BaseModel):
43
+ key: str
44
+ cypher: str
45
+ parameters: Optional[dict] = None
46
+ description: Optional[str] = None
app/data/faiss_index.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1af6778f8fb10ee5a2d44f2815bc288c0ebac355c74a2b144d95720af5b8171
3
+ size 1188909
app/data/image_faiss_index.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1f42c9ecb4da4a64cf34d90dc887564accc467233a8bfde986e6fa02b788b10
3
+ size 41824301
app/data/vector_embeddings.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b9b2d7cab4f196a4a739830702944fedfe10f123e2c8ca3f4285857f56e7996
3
+ size 5394432
app/main.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from contextlib import asynccontextmanager
3
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
4
+ from app.core.config import get_settings
5
+ from app.api.routes import router as api_router
6
+ from app.models.crop_clip import CLIPModule
7
+ from app.models.gemini_caller import GeminiGenerator
8
+ from app.utils.data_mapping import DataMapping, SingletonModel
9
+ from app.models.knowledge_graph import KnowledgeGraphUtils, Neo4jConnection
10
+ import asyncio
11
+ from concurrent.futures import ThreadPoolExecutor
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+ settings = get_settings()
16
+
17
+ class ModelLoader:
18
+ def __init__(self):
19
+ self.clip_model = None
20
+ self.gemini_model = None
21
+ self.sentence_transformer = None
22
+ self.neo4j_connection = None
23
+
24
+ def load_models(self):
25
+ try:
26
+ if settings.load_clip_model:
27
+ logger.info("Loading CLIP model...")
28
+ self.clip_model = CLIPModule()
29
+ logger.info("CLIP model loaded successfully")
30
+
31
+ if settings.load_gemini_model:
32
+ logger.info("Loading Gemini model...")
33
+ self.gemini_model = GeminiGenerator()
34
+ logger.info("Gemini model loaded successfully")
35
+
36
+ if settings.load_data_mapper:
37
+ logger.info("Loading DataMapper model...")
38
+ self.data_mapper = DataMapping()
39
+ logger.info("DataMapper model loaded successfully")
40
+
41
+ if settings.load_knowledge_graph:
42
+ logger.info("Connecting to Knowledge Graph...")
43
+ self.knowledge_graph = KnowledgeGraphUtils()
44
+ logger.info("Knowledge Graph connection established")
45
+ except Exception as e:
46
+ logger.error(f"Failed to load models: {e}")
47
+ raise
48
+
49
+ def close(self):
50
+ if self.neo4j_connection:
51
+ logger.info("Closing Neo4j connection...")
52
+ self.neo4j_connection.close()
53
+ self.clip_model = None
54
+ self.gemini_model = None
55
+ self.sentence_transformer = None
56
+ logger.info("Models released")
57
+
58
+ # Lifespan event handler
59
+ @asynccontextmanager
60
+ async def lifespan(app: FastAPI):
61
+ loop = asyncio.get_event_loop()
62
+ with ThreadPoolExecutor() as pool:
63
+ await loop.run_in_executor(pool, app.state.model_loader.load_models)
64
+ logger.info("Application startup complete")
65
+ yield
66
+ app.state.model_loader.close()
67
+ logger.info("Application shutdown complete")
68
+
69
+ app = FastAPI(
70
+ title="Crop Diagnosis Knowledge Graph API",
71
+ description="API for querying crop diagnosis knowledge graph using LangChain",
72
+ version="1.0.0",
73
+ debug=settings.debug,
74
+ lifespan=lifespan
75
+ )
76
+
77
+ app.state.model_loader = ModelLoader()
78
+
79
+ app.include_router(api_router, prefix="/api")
80
+
81
+ @app.get("/")
82
+ async def root():
83
+ return {"message": "Welcome to Crop Diagnosis Knowledge Graph API"}
84
+
app/models/__init__.py ADDED
File without changes
app/models/crop_clip.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch.nn as nn
3
+ import torch
4
+ from torchvision import transforms
5
+ import clip
6
+ from PIL import Image
7
+ import os
8
+
9
+ from app.api.dto.kg_query import PredictedLabel
10
+
11
+ CLASS_NAMES = ['benhVerticilliumWiltCaChua', 'benhChayLaCaChua', 'benhXoanLaCaChua', 'benhDomLaCaChua',
12
+ 'benhNhenXanhSan', 'benhKhamLaSan', 'cassava healthy', 'benhDomNau',
13
+ 'boCanhCungHaiLaNgo', 'corn healthy', 'benhChayLaNgo', 'benhRiSatNgo', 'benhSocLaNgo',
14
+ 'benhDomLaNgo', 'benhBacLaLua', 'benhDaoOnLua', 'benhDomNauLuaNuoc']
15
+
16
+ CROP_NAMES = ['caChua', 'caChua', 'caChua', 'caChua', 'san', 'san', 'san', 'san',
17
+ 'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'luaNuoc', 'luaNuoc', 'luaNuoc']
18
+
19
+ WEIGHTS_PATH = os.path.join(os.path.dirname(__file__), 'weights', 'clip_finetuned.pth')
20
+
21
+ class CLIPFineTuner(nn.Module):
22
+ def __init__(self, model, num_classes):
23
+ super(CLIPFineTuner, self).__init__()
24
+ self.model = model
25
+ self.classifier = nn.Linear(model.visual.output_dim, num_classes)
26
+
27
+ def forward(self, x):
28
+ with torch.no_grad():
29
+ features = self.model.encode_image(x).float() # Convert to float32
30
+ return self.classifier(features)
31
+
32
+ class CLIPModule:
33
+ def __init__(self):
34
+ model, preprocess = clip.load("ViT-B/32", jit=False)
35
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ self.model = CLIPFineTuner(model, 17)
37
+ self.model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=self.device))
38
+ self.model.to(self.device)
39
+ self.model.eval()
40
+ self.classes = CLASS_NAMES
41
+ self.transform = preprocess
42
+
43
+ def predict_image(self, image: Image.Image):
44
+ output = self.__predict(image)
45
+ probabilities = torch.nn.functional.softmax(output, dim=1)[0]
46
+
47
+ predictions: List[PredictedLabel] = []
48
+ for idx, prob in enumerate(probabilities):
49
+ predictions.append(PredictedLabel(
50
+ crop_name=CROP_NAMES[idx],
51
+ label=self.classes[idx],
52
+ confidence=float(prob)
53
+ ))
54
+
55
+ # Sắp xếp giảm dần theo xác suất
56
+ predictions.sort(key=lambda x: x.confidence, reverse=True)
57
+
58
+ return predictions
59
+
60
+ def __predict(self, image_input):
61
+ """
62
+ Dự đoán nhãn cho một ảnh.
63
+
64
+ Args:
65
+ image_input: Đường dẫn file ảnh (str) hoặc đối tượng PIL.Image
66
+ device: Thiết bị chạy mô hình ('cuda' hoặc 'cpu').
67
+
68
+ Returns:
69
+ str: Nhãn dự đoán (e.g., "cassava_leaf beetle").
70
+ """
71
+ try:
72
+ image = self.__handle_image(image_input)
73
+ image_tensor = self.transform(image)
74
+ except ValueError as e:
75
+ raise e
76
+ except Exception as e:
77
+ raise ValueError(f"Không thể xử lý ảnh đầu vào: {str(e)}")
78
+
79
+ if image_tensor.dim() == 3:
80
+ image_tensor = image_tensor.unsqueeze(0)
81
+
82
+ print(image_tensor.shape)
83
+
84
+ image_tensor = image_tensor.to(self.device)
85
+
86
+ with torch.no_grad():
87
+ output = self.model(image_tensor)
88
+
89
+ return output ## an array of 17 values, no softmax
90
+
91
+ def __handle_image(self, image_input):
92
+ if isinstance(image_input, str):
93
+ image = Image.open(image_input).convert('RGB')
94
+ elif isinstance(image_input, Image.Image):
95
+ image = image_input
96
+ else:
97
+ raise ValueError("Invalid image input")
98
+ return image
app/models/gemini_caller.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import google.generativeai as genai
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+
7
+ class GeminiGenerator:
8
+ def __init__(self, model_name="gemini-2.0-flash", temperature=0):
9
+ self.key = os.environ.get("GEMINI_API_KEY")
10
+ genai.configure(api_key=self.key)
11
+
12
+ # Cấu hình generation config
13
+ self.generation_config = {
14
+ "temperature": temperature
15
+ }
16
+
17
+ # Hệ thống prompt mặc định
18
+ self.system_prompt = "Bạn là một trợ lý AI hữu ích cho các dự án IT về cây trồng và bệnh cây trồng. Bạn có khả năng trích xuất thông tin từ văn bản được cung cấp và trả dữ liệu bằng tiếng Việt theo yêu cầu."
19
+
20
+ # Khởi tạo model
21
+ self.model = genai.GenerativeModel(
22
+ model_name=model_name,
23
+ generation_config=self.generation_config,
24
+ system_instruction=self.system_prompt
25
+ )
26
+
27
+ def generate(self, prompt="Hello, world!", system_prompt=None):
28
+ # Sử dụng system prompt tùy chỉnh nếu được cung cấp
29
+ if system_prompt:
30
+ model = genai.GenerativeModel(
31
+ model_name=self.model.model_name,
32
+ generation_config=self.generation_config,
33
+ system_instruction=system_prompt
34
+ )
35
+ response = model.generate_content(prompt)
36
+ else:
37
+ response = self.model.generate_content(prompt)
38
+ return response
39
+
40
+ if __name__ == "__main__":
41
+ generator = GeminiGenerator()
app/models/knowledge_graph.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from fastapi import Depends
5
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
+ from app.core.config import Settings, get_settings
7
+ from utils.data_mapping import DataMapping
8
+ from utils.extract_entity import extract_entities
9
+ from core.type import Node
10
+ from neo4j import GraphDatabase
11
+ from utils.constant import NEO4J_LABELS, NEO4J_RELATIONS
12
+
13
+ NEO4J_URI = os.getenv("NEO4J_URI", "neo4j://localhost:7687")
14
+ NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
15
+ NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
16
+ NEO4J_DATABASE = os.getenv("NEO4J_DATABASE", "neo4j")
17
+
18
+ class Neo4jConnection:
19
+ def __init__(self):
20
+ """Khởi tạo kết nối tới Neo4j"""
21
+ self.uri = NEO4J_URI
22
+ self.user = NEO4J_USER
23
+ self.password = NEO4J_PASSWORD
24
+ self.database = NEO4J_DATABASE
25
+
26
+ self.driver = GraphDatabase.driver(
27
+ self.uri,
28
+ auth=(self.user, self.password),
29
+ database=self.database
30
+ )
31
+ self.entity_types = []
32
+ self.relations = []
33
+
34
+ with self.driver.session() as session:
35
+ result = session.run("CALL db.info()")
36
+ self.database_info = result.single().data()
37
+ self.entity_types = NEO4J_LABELS
38
+ self.relations = NEO4J_RELATIONS
39
+
40
+ def get_database_info(self):
41
+ """Trả về thông tin về database đang kết nối"""
42
+ return self.database_info
43
+
44
+ def close(self):
45
+ """Đóng kết nối tới Neo4j"""
46
+ if self.driver is not None:
47
+ self.driver.close()
48
+
49
+ def execute_query(self, query, parameters=None):
50
+ """Thực thi một truy vấn Cypher bất kỳ"""
51
+ with self.driver.session() as session:
52
+ result = session.run(query, parameters)
53
+ return [record for record in result]
54
+
55
+ class KnowledgeGraphUtils:
56
+ def get_disease_from_env_factors(self, crop_id: str, params: list[Node]):
57
+ envFactors = [param.id for param in params if param.label == "EnvironmentalFactor"]
58
+ query = f"""
59
+ MATCH (c:Crop {{id: "{crop_id}"}})
60
+ WITH c
61
+ MATCH (d:Disease)-[:AFFECTS]-(c)
62
+ OPTIONAL MATCH (ef:EnvironmentalFactor)-[:FAVORS]-(d)
63
+ WHERE ef.id IN {envFactors}
64
+ OPTIONAL MATCH (ef2:EnvironmentalFactor)-[:FAVORS]-(cause:Cause)-[:CAUSES|AFFECTS]-(d)
65
+ WHERE ef2.id IN {envFactors}
66
+ WITH d, COLLECT(DISTINCT ef.id) AS direct_env, COLLECT(DISTINCT ef2.id) AS indirect_env
67
+ WHERE SIZE(direct_env) > 0 OR SIZE(indirect_env) > 0
68
+ RETURN DISTINCT d, direct_env, indirect_env
69
+ """
70
+ kg = Neo4jConnection()
71
+ result = kg.execute_query(query)
72
+ print(result)
73
+ final_result = []
74
+ for record in result:
75
+ record_dict = dict(record)
76
+ disease = Node.map_json_to_node(dict(record_dict["d"]), "Disease")
77
+ env_ids = list(record_dict["direct_env"]) + list(record_dict["indirect_env"])
78
+ print(env_ids)
79
+ score = 0
80
+ for env_id in env_ids:
81
+ for param in params:
82
+ if param.id == env_id:
83
+ score = max(score, param.score)
84
+ disease.score = score
85
+ final_result.append({
86
+ "disease": disease,
87
+ "env_ids": env_ids
88
+ })
89
+ final_result.sort(key=lambda x: x["disease"].score, reverse=True)
90
+
91
+ return final_result
92
+
93
+ def get_disease_from_symptoms(self, crop_id: str, params: list[Node]) -> list:
94
+ symptoms = [param.id for param in params if param.label == "Symptom"]
95
+ query = f"""
96
+ MATCH (c:Crop {{id: "{crop_id}"}})
97
+ WITH c
98
+ MATCH (d:Disease)-[:AFFECTS]-(c)
99
+ OPTIONAL MATCH (sym1:Symptom)-[:HAS_SYMPTOM]-(d)
100
+ WHERE sym1.id IN {symptoms}
101
+ OPTIONAL MATCH (sym2:Symptom)-[:HAS_SYMPTOM|LOCATED_ON]-(p:PlantPart)-[:CONTAINS]-(d)
102
+ WHERE sym2.id IN {symptoms}
103
+ WITH d, p, c, sym1, sym2, COLLECT(DISTINCT sym1.id) AS direct_env, COLLECT(DISTINCT sym2.id) AS indirect_env
104
+ WHERE SIZE(direct_env) > 0 OR SIZE(indirect_env) > 0
105
+ RETURN d, c, p, sym1, sym2
106
+ """
107
+ kg = Neo4jConnection()
108
+ result = kg.execute_query(query)
109
+ final_result = []
110
+ for record in result:
111
+ record_dict = dict(record)
112
+ disease = Node.map_json_to_node(dict(record_dict["d"]), "Disease")
113
+ symptom_ids = list(record_dict["sym1"]) + list(record_dict["sym2"])
114
+ score = 0
115
+ for symptom_id in symptom_ids:
116
+ for param in params:
117
+ if param.id == symptom_id:
118
+ score = max(score, param.score)
119
+ disease.score = score
120
+ final_result.append({
121
+ "disease": disease,
122
+ "symptom_ids": symptom_ids
123
+ })
124
+ final_result.sort(key=lambda x: x["disease"].score, reverse=True)
125
+
126
+ return final_result
app/services/__init__.py ADDED
File without changes
app/services/predict.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Depends, UploadFile
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+
6
+ from app.api.dto.kg_query import KGQueryRequest, QueryContext
7
+ from app.core.dependencies import get_all_models
8
+ from app.core.type import Node
9
+ from app.models.crop_clip import CLIPModule
10
+ from app.models.knowledge_graph import KnowledgeGraphUtils
11
+ from app.utils.data_mapping import DataMapping
12
+ from app.utils.extract_entity import extract_entities
13
+
14
+ class PredictService:
15
+ def __init__(self, models):
16
+ self.models = models
17
+
18
+ def predict_image(self, image: UploadFile):
19
+ clip_model: CLIPModule = self.models["clip_model"]
20
+ image_content = image.file.read()
21
+ pil_image = Image.open(Image.io.BytesIO(image_content)).convert('RGB')
22
+ return clip_model.predict_image(pil_image)
23
+
24
+ def retrieve_kg(self, request: KGQueryRequest):
25
+ try:
26
+ kg: KnowledgeGraphUtils = self.models["knowledge_graph"]
27
+ if not request.context:
28
+ request.context = QueryContext()
29
+ if request.crop_id:
30
+ request.context.crop_id = request.crop_id
31
+ if request.additional_info:
32
+ request.context.nodes = self.__get_nodes_from_additional_info(request.additional_info, self.models["data_mapper"])
33
+ env_result = kg.get_disease_from_env_factors(request.context.crop_id, request.context.nodes)
34
+ symptom_result = kg.get_disease_from_symptoms(request.context.crop_id, request.context.nodes)
35
+ context = request.context
36
+ context.nodes.extend([env_result["disease"] for env_result in env_result])
37
+ context.nodes.extend([symptom_result["disease"] for symptom_result in symptom_result])
38
+ context.nodes.sort(key=lambda x: x.score, reverse=True)
39
+ return {
40
+ "context": context,
41
+ "env_result": env_result,
42
+ "symptom_result": symptom_result
43
+ }
44
+
45
+ except Exception as e:
46
+ print(e)
47
+ raise e
48
+
49
+ def __get_nodes_from_additional_info(self, additional_info: str, data_mapper: DataMapping):
50
+ entities = extract_entities(additional_info)
51
+ top_results: list[Node] = []
52
+ for entity in entities:
53
+ top_result = data_mapper.get_top_result_by_text(entity.name, 3)
54
+ print([result.name for result in top_result])
55
+ for result in top_result:
56
+ top_results.append(result)
57
+ return top_results
58
+
59
+ def get_predict_service(models = Depends(get_all_models)):
60
+ return PredictService(models)
app/utils/constant.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ NEO4J_LABELS =['Disease', 'Symptom', 'Treatment', 'Cause', 'Effect', 'Prevention', 'EnvironmentalFactor', 'Stage', 'Crop', 'CropType', 'PlantPart', 'SoilType', 'DiagnosisMethod']
2
+ NEO4J_RELATIONS = ['CAUSES', 'HAS_SYMPTOM', 'PRODUCES', 'FAVORS', 'IS_TREATED_BY', 'PREVENTS', 'OCCURS_AT', 'BELONGS_TO', 'CONTAINS', 'LOCATED_ON', 'AFFECTS', 'IS_APPLIED_TO']
app/utils/data_mapping.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from sentence_transformers import SentenceTransformer
3
+ import faiss
4
+ from pyvi.ViTokenizer import tokenize
5
+ import sqlite3
6
+ from app.core.type import Node
7
+
8
+ FAISS_INDEX_PATH = 'app/data/faiss_index.index'
9
+ VECTOR_EMBEDDINGS_DB_PATH = 'app/data/vector_embeddings.db'
10
+
11
+ class SingletonModel:
12
+ _instance = None
13
+
14
+ def __new__(cls):
15
+ if cls._instance is None:
16
+ cls._instance = super(SingletonModel, cls).__new__(cls)
17
+ cls._instance.model = SentenceTransformer('dangvantuan/vietnamese-embedding')
18
+ return cls._instance
19
+
20
+ class DataMapping:
21
+ def __init__(self):
22
+ try:
23
+ self.model: SentenceTransformer = SingletonModel().model
24
+ self.index: faiss.IndexFlatL2 = self.__load_faiss_index()
25
+ self.conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH, check_same_thread=False)
26
+ self.cursor = self.conn.cursor()
27
+ except Exception as e:
28
+ print(f"Error while initializing DataMapping: {e}")
29
+ raise
30
+
31
+ def __del__(self):
32
+ self.cursor.close()
33
+ self.conn.close()
34
+
35
+ def __load_faiss_index(self, index_file = FAISS_INDEX_PATH):
36
+ if os.path.exists(index_file):
37
+ index = faiss.read_index(index_file)
38
+ print(f"Đã nạp FAISS index từ {index_file}")
39
+ return index
40
+ return None
41
+
42
+ def get_top_index_by_text(self, text, top_k=1, distance_threshold=float(0.6)):
43
+ if not text or top_k < 1:
44
+ raise ValueError("Invalid input: text cannot be empty and top_k must be positive")
45
+
46
+ q_token = tokenize(text)
47
+ q_vec = self.model.encode([q_token])
48
+ faiss.normalize_L2(q_vec)
49
+ D, I = self.index.search(q_vec, top_k)
50
+ mask = D[0] >= distance_threshold
51
+ filtered_indices = I[0][mask].tolist()
52
+ distances = D[0][mask].tolist()
53
+ return filtered_indices, distances
54
+
55
+ def get_embedding_by_id(self, id):
56
+ self.cursor.execute("SELECT * FROM embeddings WHERE e_index = ?", (id,))
57
+ return self.cursor.fetchone()
58
+
59
+ def get_top_result_by_text(self, text, top_k = 1, type = None) -> list[Node]:
60
+ top_index, distances = self.get_top_index_by_text(text, top_k)
61
+ results = [self.get_embedding_by_id(int(index)) for index in top_index]
62
+ if type:
63
+ results = [result for result in results if result[3] == type]
64
+ return [Node.data_row_to_node(result, distance) for result, distance in zip(results, distances)]
65
+
app/utils/extract_entity.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from string import Template
3
+ from fastapi import Depends
4
+ from langchain_google_genai import ChatGoogleGenerativeAI
5
+ import dotenv
6
+ import os
7
+ from app.models.gemini_caller import GeminiGenerator
8
+ from app.core.type import Node
9
+ from app.utils.prompt import EXTRACT_ENTITIES_PROMPT
10
+ dotenv.load_dotenv()
11
+
12
+ def extract_entities(text: str) -> list[Node]:
13
+ try:
14
+ gemini = GeminiGenerator()
15
+ prompt = Template(EXTRACT_ENTITIES_PROMPT).substitute(ctext=text)
16
+ entities = gemini.generate(prompt)
17
+ entities = (json.loads(clean_text(entities.text)))["entities"]
18
+ return [Node.map_json_to_node(entity) for entity in entities]
19
+ except Exception as e:
20
+ print(f"Error while extract knowledge entities: {str(e)}")
21
+ return []
22
+
23
+ def clean_text(text: str):
24
+ text = text.replace("```json", "").replace("```", "")
25
+ return text
app/utils/prompt.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EXTRACT_ENTITIES_PROMPT = """
2
+ Từ mô tả bên dưới, hãy trích xuất các Thực thể được mô tả theo định dạng được chỉ định. Đảm bảo kết quả hoàn chỉnh, không thiếu thông tin.
3
+ 0. LUÔN LUÔN HOÀN TẤT KẾT QUẢ. Không gửi kết quả bị thiếu
4
+ 1. Trích xuất Thực thể (entities)
5
+ - Mỗi thực thể phải có thuộc tính `id` là chuỗi chữ và số duy nhất, định dạng **camelCase** (ví dụ: `benhDaoOn`, `laVang`). Thuộc tính `id` được sử dụng để liên kết trong mối quan hệ.
6
+ - Chỉ tạo thực thể thuộc các loại được liệt kê, không tạo loại mới.
7
+ - Đảm bảo các thuộc tính (`name`, `description`, v.v.) khớp với nội dung văn bản.
8
+
9
+ Các loại thực thể:
10
+ - **Disease**: Tình trạng cây bị hại bởi vi sinh vật, nấm, hoặc yếu tố môi trường. Đảm bảo trong ngữ cảnh đầu vào chỉ đang nói đến một bệnh duy nhất.
11
+ - `id`: Tên bệnh ở dạng camelCase, bao gồm thông tin cây trồng nếu bệnh xuất hiện trên cây cụ thể (ví dụ: `benhDomNauSan` cho bệnh đốm nâu trên sắn, `benhDomNauCaChua` cho bệnh đốm nâu trên cà chua). Không có các giới từ như "trên".
12
+ - `name`: Tên bệnh trong văn bản, bao gồm thông tin cây trồng nếu có (ví dụ: "Bệnh đốm nâu trên sắn"). Nếu không có thông tin cây trồng, sử dụng tên bệnh chung (ví dụ: "Bệnh đốm nâu").
13
+ - `description`: Mô tả tình trạng bệnh, ưu tiên đề cập cây trồng nếu có (ví dụ: "Bệnh đốm nâu trên cây sắn do nấm gây ra"). Nếu không có thông tin, dùng "Không có mô tả cụ thể".
14
+
15
+ - **Symptom**: Dấu hiệu bất thường trên cây (lá vàng, héo, đốm).
16
+ - `id`: Tên triệu chứng ở dạng camelCase (ví dụ: `laVang`).
17
+ - `name`: Tên triệu chứng trong văn bản (ví dụ: "Lá vàng").
18
+ - `description`: Mô tả triệu chứng.
19
+
20
+ - **Treatment**: Biện pháp kiểm soát bệnh/sâu hại (thuốc, sinh học).
21
+ - `id`: Tên biện pháp ở dạng camelCase, gắn với hoạt chất hoặc loại thuốc cụ thể (ví dụ: `thuocDietNamThiophanate`).
22
+ - `name`: Tên biện pháp trong văn bản, phản ánh hoạt chất hoặc loại thuốc (ví dụ: "Thuốc Diệt Nấm chứa Thiophanate").
23
+ - `method`: Cách thực hiện biện pháp. (ví dụ: "Phun thuốc lên lá")
24
+ - `activeIngredient` (tùy chọn): Tên hoạt chất chính, bao gồm nồng độ nếu có (ví dụ: "Thiophanate 0.20%"). Nếu không xác định, để trống.
25
+
26
+ - **Cause**: Tác nhân gây bệnh/sâu hại (nấm, virus, côn trùng).
27
+ - `id`: Tên tác nhân ở dạng camelCase, gắn liền với tên của tác nhân viết gọn (ví dụ: `namMHenningsii`).
28
+ - `name`: Tên tác nhân trong văn bản (ví dụ: "Nấm Mycosphaerella henningsi")
29
+ - `type`: Loại tác nhân (nấm, virus, côn trùng, vi khuẩn, ...).
30
+
31
+ - **Effect**: Tác động của bệnh/sâu hại (giảm năng suất, cây chết).
32
+ - `id`: Tên tác động ở dạng camelCase, sử dụng dạng ngắn gọn và chung nhất (ví dụ: `giamNangSuat` cho mọi trường hợp liên quan đến giảm năng suất, thay vì `nangSuatGiamDangKe`).
33
+ - `name`: Tên tác động được chuẩn hóa, sử dụng dạng chung nhất từ văn bản (ví dụ: "Giảm năng suất" thay vì "Giảm năng suất đáng kể"). Loại bỏ các từ ngữ bổ nghĩa như "đáng kể", "nghiêm trọng".
34
+ - `impact`: Mô tả ngắn gọn mức độ ảnh hưởng, ưu tiên sử dụng cụm từ chung (ví dụ: "Ảnh hưởng đến sản lượng" thay vì sao chép toàn bộ mô tả chi tiết từ văn bản).
35
+
36
+ - **Prevention**: Biện pháp ngăn ngừa bệnh/sâu hại (luân canh, giống kháng).
37
+ - `id`: Tên biện pháp ở dạng camelCase (ví dụ: `luanCanh`).
38
+ - `name`: Tên biện pháp trong văn bản.
39
+ - `method`: Cách thực hiện biện pháp.
40
+
41
+ - **EnvironmentalFactor**: Yếu tố tự nhiên ảnh hưởng cây (nhiệt độ, độ ẩm, ...).
42
+ - `id`: Tên yếu tố ở dạng camelCase (ví dụ: `doAmCao`).
43
+ - `name`: Tên yếu tố trong văn bản.
44
+ - `description`: Mô tả yếu tố.
45
+
46
+ - **Stage**: Giai đoạn phát triển của cây.
47
+ - `id`: Tên giai đoạn ở dạng camelCase (ví dụ: `giaiDoanRaHoa`).
48
+ - `start`: Thời gian bắt đầu (tháng, kiểu float).
49
+ - `end`: Thời gian kết thúc (tháng, kiểu float).
50
+
51
+ - **Crop**: Cây trồng, không được tạo ngoài danh sách: "Lúa", "Sắn", "Cà chua", "Ngô"
52
+ - `id`: Tên cây ở dạng camelCase, chỉ nằm trong danh sách: [lua,san,caChua,ngo].
53
+ - `name`: Tên cây trong văn bản (ví dụ: "Sắn").
54
+
55
+ - **CropType**: Phân loại cây (lương thực, ăn quả, công nghiệp).
56
+ - `id`: Tên loại cây ở dạng camelCase (ví dụ: `luongThuc`).
57
+ - `name`: Tên loại cây trong văn bản.
58
+
59
+ - **PlantPart**: Phần cây bị ảnh hưởng (lá, thân, rễ, quả).
60
+ - `id`: Tên phần cây ở dạng camelCase (ví dụ: `la`).
61
+ - `name`: Tên phần cây trong văn bản.
62
+
63
+ - **SoilType**: Loại đất trồng cây.
64
+ - `id`: Tên loại đất ở dạng camelCase (ví dụ: `datPhuSa`).
65
+ - `name`: Tên loại đất trong văn bản.
66
+
67
+ - **DiagnosisMethod**: Kỹ thuật xác định bệnh/sâu hại (quan sát, xét nghiệm).
68
+ - `id`: Tên kỹ thuật ở dạng camelCase (ví dụ: `quanSat`).
69
+ - `name`: Tên kỹ thuật trong văn bản.
70
+ - `technique`: Cách thực hiện kỹ thuật.
71
+
72
+ 2. Trả về kết quả dưới dạng JSON:
73
+ - Trả về JSON với một trường duy nhất là `entities`
74
+ - `entities`: Danh sách các thực thể, mỗi thực thể là một object với các thuộc tính theo loại.
75
+
76
+ Ví dụ:
77
+ ```json
78
+ {
79
+ "entities": [{"label":"Disease","id":string,"name":string,"description":string}]
80
+ }
81
+ ```
82
+
83
+ Ngữ cảnh:
84
+ $ctext
85
+ """
environment.yml ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: graduated2
2
+ channels:
3
+ - defaults
4
+ - https://repo.anaconda.com/pkgs/main
5
+ - https://repo.anaconda.com/pkgs/r
6
+ dependencies:
7
+ - ca-certificates=2025.2.25=hecd8cb5_0
8
+ - libcxx=14.0.6=h9765a3e_0
9
+ - libffi=3.4.4=hecd8cb5_1
10
+ - ncurses=6.4=hcec6c5f_0
11
+ - openssl=3.0.16=h184c1cd_0
12
+ - pip=25.1=pyhc872135_2
13
+ - python=3.9.21=hce00570_1
14
+ - readline=8.2=hca72f7f_0
15
+ - setuptools=78.1.1=py39hecd8cb5_0
16
+ - sqlite=3.45.3=h6c40b1e_0
17
+ - tk=8.6.14=h4d00af3_0
18
+ - tzdata=2025b=h04d1e81_0
19
+ - wheel=0.45.1=py39hecd8cb5_0
20
+ - xz=5.6.4=h46256e1_1
21
+ - zlib=1.2.13=h4b97444_1
22
+ - pip:
23
+ - aiohappyeyeballs==2.6.1
24
+ - aiohttp==3.11.18
25
+ - aiosignal==1.3.2
26
+ - annotated-types==0.7.0
27
+ - anyio==4.9.0
28
+ - asgiref==3.8.1
29
+ - async-timeout==4.0.3
30
+ - attrs==25.3.0
31
+ - backoff==2.2.1
32
+ - bcrypt==4.3.0
33
+ - build==1.2.2.post1
34
+ - cachetools==5.5.2
35
+ - certifi==2025.4.26
36
+ - charset-normalizer==3.4.2
37
+ - chromadb==1.0.8
38
+ - click==8.1.8
39
+ - coloredlogs==15.0.1
40
+ - dataclasses-json==0.6.7
41
+ - deprecated==1.2.18
42
+ - distro==1.9.0
43
+ - durationpy==0.9
44
+ - exceptiongroup==1.3.0
45
+ - fastapi==0.115.9
46
+ - filelock==3.18.0
47
+ - filetype==1.2.0
48
+ - flatbuffers==25.2.10
49
+ - frozenlist==1.6.0
50
+ - fsspec==2024.12.0
51
+ - google-ai-generativelanguage==0.6.18
52
+ - google-api-core==2.24.2
53
+ - google-auth==2.40.1
54
+ - googleapis-common-protos==1.70.0
55
+ - greenlet==3.2.2
56
+ - grpcio==1.72.0rc1
57
+ - grpcio-status==1.72.0rc1
58
+ - h11==0.16.0
59
+ - hf-xet==1.1.0
60
+ - httpcore==1.0.9
61
+ - httptools==0.6.4
62
+ - httpx==0.28.1
63
+ - httpx-sse==0.4.0
64
+ - huggingface-hub==0.31.1
65
+ - humanfriendly==10.0
66
+ - idna==3.10
67
+ - importlib-metadata==8.6.1
68
+ - importlib-resources==6.5.2
69
+ - jinja2==3.1.6
70
+ - jiter==0.10.0
71
+ - json-repair==0.39.1
72
+ - jsonpatch==1.33
73
+ - jsonpointer==3.0.0
74
+ - jsonschema==4.23.0
75
+ - jsonschema-specifications==2025.4.1
76
+ - kubernetes==32.0.1
77
+ - langchain==0.3.25
78
+ - langchain-community==0.3.24
79
+ - langchain-core==0.3.60
80
+ - langchain-google-genai==2.1.4
81
+ - langchain-neo4j==0.4.0
82
+ - langchain-text-splitters==0.3.8
83
+ - langsmith==0.3.42
84
+ - markdown-it-py==3.0.0
85
+ - markupsafe==3.0.2
86
+ - marshmallow==3.26.1
87
+ - mdurl==0.1.2
88
+ - mmh3==5.1.0
89
+ - mpmath==1.3.0
90
+ - multidict==6.4.4
91
+ - mypy-extensions==1.1.0
92
+ - neo4j==5.28.1
93
+ - neo4j-graphrag==1.7.0
94
+ - networkx==3.2.1
95
+ - numpy==2.0.2
96
+ - oauthlib==3.2.2
97
+ - onnxruntime==1.19.2
98
+ - openai==1.79.0
99
+ - opentelemetry-api==1.33.0
100
+ - opentelemetry-exporter-otlp-proto-common==1.33.0
101
+ - opentelemetry-exporter-otlp-proto-grpc==1.33.0
102
+ - opentelemetry-instrumentation==0.54b0
103
+ - opentelemetry-instrumentation-asgi==0.54b0
104
+ - opentelemetry-instrumentation-fastapi==0.54b0
105
+ - opentelemetry-proto==1.33.0
106
+ - opentelemetry-sdk==1.33.0
107
+ - opentelemetry-semantic-conventions==0.54b0
108
+ - opentelemetry-util-http==0.54b0
109
+ - orjson==3.10.18
110
+ - overrides==7.7.0
111
+ - packaging==24.2
112
+ - pillow==11.2.1
113
+ - posthog==4.0.1
114
+ - propcache==0.3.1
115
+ - proto-plus==1.26.1
116
+ - protobuf==6.31.0
117
+ - pyasn1==0.6.1
118
+ - pyasn1-modules==0.4.2
119
+ - pydantic==2.11.4
120
+ - pydantic-core==2.33.2
121
+ - pydantic-settings==2.9.1
122
+ - pygments==2.19.1
123
+ - pypdf==5.5.0
124
+ - pypika==0.48.9
125
+ - pyproject-hooks==1.2.0
126
+ - python-dateutil==2.9.0.post0
127
+ - python-dotenv==1.1.0
128
+ - pytz==2025.2
129
+ - pyyaml==6.0.2
130
+ - referencing==0.36.2
131
+ - regex==2024.11.6
132
+ - requests==2.32.3
133
+ - requests-oauthlib==2.0.0
134
+ - requests-toolbelt==1.0.0
135
+ - rich==14.0.0
136
+ - rpds-py==0.24.0
137
+ - rsa==4.9.1
138
+ - safetensors==0.5.3
139
+ - shellingham==1.5.4
140
+ - six==1.17.0
141
+ - sniffio==1.3.1
142
+ - sqlalchemy==2.0.41
143
+ - starlette==0.45.3
144
+ - sympy==1.14.0
145
+ - tenacity==9.1.2
146
+ - tokenizers==0.21.1
147
+ - tomli==2.2.1
148
+ - torch==2.2.2
149
+ - torchvision==0.17.2
150
+ - tqdm==4.67.1
151
+ - transformers==4.51.3
152
+ - typer==0.15.3
153
+ - types-pyyaml==6.0.12.20250516
154
+ - typing-extensions==4.13.2
155
+ - typing-inspect==0.9.0
156
+ - typing-inspection==0.4.0
157
+ - urllib3==2.4.0
158
+ - uvicorn==0.34.2
159
+ - uvloop==0.21.0
160
+ - watchfiles==1.0.5
161
+ - websocket-client==1.8.0
162
+ - websockets==15.0.1
163
+ - wrapt==1.17.2
164
+ - yarl==1.20.0
165
+ - zipp==3.21.0
166
+ - zstandard==0.23.0
167
+ prefix: /Users/artteiv/miniconda3/envs/graduated2
prepare_script/image_caption_embeddings.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sqlite3
3
+ import os
4
+ import sys
5
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
+ import torch
7
+ from PIL import Image
8
+ import clip
9
+ import faiss
10
+ import numpy as np
11
+ import glob
12
+
13
+ # Đường dẫn lưu trữ
14
+ VECTOR_EMBEDDINGS_DB_PATH = 'app/data/vector_embeddings.db'
15
+ IMAGE_FAISS_INDEX_PATH = 'app/data/image_faiss_index.index'
16
+ TEXT_FAISS_INDEX_PATH = 'app/data/text_faiss_index.index'
17
+
18
+ # Đường dẫn dữ liệu
19
+ DATA_ROOT = '/Users/artteiv/Desktop/Graduated/chore-graduated/Data'
20
+ MAIN_DATA_PATH = os.path.join(DATA_ROOT, 'main_data')
21
+ CAPTIONS_PATH = os.path.join(DATA_ROOT, 'captions')
22
+
23
+ # Kết nối SQLite
24
+ conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH)
25
+ cursor = conn.cursor()
26
+
27
+ # Tạo bảng embeddings cho ảnh và văn bản
28
+ cursor.execute('''
29
+ CREATE TABLE IF NOT EXISTS image_embeddings (
30
+ e_index INTEGER PRIMARY KEY,
31
+ image_path TEXT NOT NULL,
32
+ caption TEXT NOT NULL,
33
+ category TEXT NOT NULL,
34
+ subcategory TEXT NOT NULL
35
+ )
36
+ ''')
37
+
38
+ cursor.execute('''
39
+ CREATE TABLE IF NOT EXISTS text_embeddings (
40
+ e_index INTEGER PRIMARY KEY,
41
+ text TEXT NOT NULL,
42
+ category TEXT NOT NULL,
43
+ subcategory TEXT NOT NULL
44
+ )
45
+ ''')
46
+
47
+ def insert_image_embedding(e_index, image_path, caption, category, subcategory):
48
+ """Thêm embedding ảnh vào SQLite."""
49
+ cursor.execute('''
50
+ INSERT INTO image_embeddings (e_index, image_path, caption, category, subcategory)
51
+ VALUES (?, ?, ?, ?, ?)
52
+ ''', (e_index, image_path, caption, category, subcategory))
53
+ conn.commit()
54
+ print(f"Đã thêm embedding ảnh: {image_path}")
55
+
56
+ def insert_text_embedding(e_index, text, category, subcategory):
57
+ """Thêm embedding văn bản vào SQLite."""
58
+ cursor.execute('''
59
+ INSERT INTO text_embeddings (e_index, text, category, subcategory)
60
+ VALUES (?, ?, ?, ?)
61
+ ''', (e_index, text, category, subcategory))
62
+ conn.commit()
63
+ print(f"Đã thêm embedding văn bản: {text[:50]}...")
64
+
65
+ def save_faiss_index(index, index_file):
66
+ """Lưu FAISS index vào file."""
67
+ faiss.write_index(index, index_file)
68
+ print(f"Đã lưu FAISS index vào {index_file}")
69
+
70
+ def load_faiss_index(index_file):
71
+ """Nạp FAISS index từ file."""
72
+ if os.path.exists(index_file):
73
+ index = faiss.read_index(index_file)
74
+ print(f"Đã nạp FAISS index từ {index_file}")
75
+ return index
76
+ return None
77
+
78
+ def compute_embeddings():
79
+ """Tính toán embeddings cho ảnh và văn bản sử dụng CLIP."""
80
+ print("Loading CLIP model...")
81
+ device = "cuda" if torch.cuda.is_available() else "cpu"
82
+ model, preprocess = clip.load("ViT-B/32", device=device)
83
+ print("Model loaded")
84
+
85
+ # Lấy danh sách các thư mục con (categories)
86
+ categories = [d for d in os.listdir(MAIN_DATA_PATH) if os.path.isdir(os.path.join(MAIN_DATA_PATH, d))]
87
+
88
+ image_paths = []
89
+ captions = []
90
+ texts = []
91
+ categories_list = []
92
+ subcategories_list = []
93
+
94
+ # Chuẩn bị dữ liệu
95
+ print("Processing data from directories...")
96
+ for category in categories:
97
+ # Đường dẫn đến thư mục category
98
+ category_path = os.path.join(MAIN_DATA_PATH, category)
99
+
100
+ # Lấy danh sách các subcategories
101
+ subcategories = [d for d in os.listdir(category_path) if os.path.isdir(os.path.join(category_path, d))]
102
+
103
+ for subcategory in subcategories:
104
+ # Đường dẫn đến thư mục ảnh và caption của subcategory
105
+ subcategory_image_path = os.path.join(category_path, subcategory)
106
+ subcategory_caption_path = os.path.join(CAPTIONS_PATH, category, subcategory)
107
+
108
+
109
+ # Lấy danh sách ảnh
110
+ image_files = glob.glob(os.path.join(subcategory_image_path, '*.*'))
111
+
112
+ for img_path in image_files:
113
+ # Lấy tên file không có phần mở rộng
114
+ base_name = os.path.splitext(os.path.basename(img_path))[0]
115
+ caption_file = os.path.join(subcategory_caption_path, f"{base_name}.txt")
116
+
117
+ if os.path.exists(caption_file):
118
+ try:
119
+ # Đọc caption
120
+ with open(caption_file, 'r', encoding='utf-8') as f:
121
+ caption = f.read().strip()
122
+
123
+ # Thêm vào danh sách
124
+ image_paths.append(img_path)
125
+ captions.append(caption)
126
+ texts.append(caption) # Sử dụng caption làm text
127
+ categories_list.append(category)
128
+ subcategories_list.append(subcategory)
129
+
130
+ except Exception as e:
131
+ print(f"Error processing {img_path}: {e}")
132
+ continue
133
+
134
+ # Tính toán embeddings cho ảnh
135
+ # if image_paths:
136
+ # print("Computing image embeddings...")
137
+ # image_embeddings = []
138
+ # for idx, img_path in enumerate(image_paths):
139
+ # try:
140
+ # image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
141
+ # with torch.no_grad():
142
+ # image_features = model.encode_image(image)
143
+ # image_features = image_features.cpu().numpy()
144
+ # faiss.normalize_L2(image_features)
145
+ # image_embeddings.append(image_features[0])
146
+ # insert_image_embedding(idx, img_path, captions[idx], categories_list[idx], subcategories_list[idx])
147
+ # except Exception as e:
148
+ # print(f"Error processing image {img_path}: {e}")
149
+ # continue
150
+
151
+ # if image_embeddings:
152
+ # image_embeddings = np.array(image_embeddings)
153
+ # d = image_embeddings.shape[1]
154
+ # image_index = faiss.IndexFlatIP(d)
155
+ # image_index.add(image_embeddings)
156
+ # save_faiss_index(image_index, IMAGE_FAISS_INDEX_PATH)
157
+
158
+ # Tính toán embeddings cho văn bản
159
+ if texts:
160
+ print("Computing text embeddings...")
161
+ text_tokens = clip.tokenize(texts, truncate=True).to(device)
162
+ print("Kích thước của text_tokens:", text_tokens.shape)
163
+ with torch.no_grad():
164
+ text_features = model.encode_text(text_tokens)
165
+ text_features = text_features.cpu().numpy()
166
+ faiss.normalize_L2(text_features)
167
+
168
+ d = text_features.shape[1]
169
+ text_index = faiss.IndexFlatIP(d)
170
+ text_index.add(text_features)
171
+
172
+ # Lưu text embeddings vào SQLite
173
+ for idx, (text, category, subcategory) in enumerate(zip(texts, categories_list, subcategories_list)):
174
+ insert_text_embedding(idx, text, category, subcategory)
175
+
176
+ save_faiss_index(text_index, TEXT_FAISS_INDEX_PATH)
177
+
178
+ print("Processing completed")
179
+ return image_index if image_paths else None, text_index if texts else None
180
+
181
+ def predict_image(image_path):
182
+ device = "cuda" if torch.cuda.is_available() else "cpu"
183
+ model, preprocess = clip.load("ViT-B/32", device=device)
184
+
185
+ image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
186
+ with torch.no_grad():
187
+ image_features = model.encode_image(image)
188
+ image_features = image_features.cpu().numpy()
189
+ faiss.normalize_L2(image_features)
190
+
191
+ index = load_faiss_index(IMAGE_FAISS_INDEX_PATH)
192
+ distances, indices = index.search(image_features, k=10)
193
+
194
+ return distances, indices
195
+
196
+ if __name__ == '__main__':
197
+ ## Predict
198
+
199
+ try:
200
+ image_index, text_index = compute_embeddings()
201
+ if image_index:
202
+ print(f"Image index ready with {image_index.ntotal} embeddings")
203
+ if text_index:
204
+ print(f"Text index ready with {text_index.ntotal} embeddings")
205
+ finally:
206
+ conn.close()
207
+ print("SQLite connection closed")
prepare_script/sync_neo4j_node.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sqlite3
3
+ import os
4
+ import sys
5
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
+ from app.models.knowledge_graph import Neo4jConnection
7
+ from sentence_transformers import SentenceTransformer
8
+ from pyvi.ViTokenizer import tokenize
9
+ import faiss
10
+ import numpy as np
11
+
12
+ """
13
+ Script này thực hiện lấy các entity từ neo4j từ xa về và tạo ra các data lưu trong sqlite, đồng thời tạo các embeddings
14
+ dựa trên từng row.
15
+ """
16
+
17
+ # Kết nối SQLite
18
+ VECTOR_EMBEDDINGS_DB_PATH = 'app/data/vector_embeddings.db'
19
+ FAISS_INDEX_PATH = 'app/data/faiss_index.index'
20
+
21
+ conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH)
22
+ cursor = conn.cursor()
23
+
24
+ # Tạo bảng embeddings nếu chưa tồn tại
25
+ cursor.execute('''
26
+ CREATE TABLE IF NOT EXISTS embeddings (
27
+ e_index INTEGER PRIMARY KEY,
28
+ id TEXT NOT NULL,
29
+ name TEXT NOT NULL,
30
+ label TEXT NOT NULL,
31
+ properties TEXT NOT NULL
32
+ )
33
+ ''')
34
+
35
+ def insert_embedding(e_index, id, name, label, properties):
36
+ """Thêm embedding vào SQLite."""
37
+ cursor.execute('''
38
+ INSERT INTO embeddings (e_index, id, name, label, properties)
39
+ VALUES (?, ?, ?, ?, ?)
40
+ ''', (e_index, id, name, label, json.dumps(properties)))
41
+ conn.commit()
42
+ print(f"Đã thêm embedding: {name}")
43
+
44
+ def update_embedding(embedding_id, id, name, label, properties):
45
+ """Cập nhật embedding trong SQLite."""
46
+ cursor.execute('''
47
+ UPDATE embeddings
48
+ SET id = ?, name = ?, label = ?, properties = ?
49
+ WHERE e_index = ?
50
+ ''', (id, name, label, json.dumps(properties), embedding_id))
51
+ conn.commit()
52
+ print(f"Đã cập nhật embedding ID: {embedding_id}")
53
+
54
+ def get_all_embeddings():
55
+ """Lấy tất cả embeddings từ SQLite."""
56
+ cursor.execute('SELECT * FROM embeddings')
57
+ return cursor.fetchall()
58
+
59
+ def get_embedding_by_id(embedding_id):
60
+ """Lấy embedding theo e_index từ SQLite."""
61
+ cursor.execute('SELECT * FROM embeddings WHERE e_index = ?', (embedding_id,))
62
+ return cursor.fetchone()
63
+
64
+ def save_faiss_index(index, index_file=FAISS_INDEX_PATH):
65
+ """Lưu FAISS index vào file."""
66
+ faiss.write_index(index, index_file)
67
+ print(f"Đã lưu FAISS index vào {index_file}")
68
+
69
+ def load_faiss_index(index_file=FAISS_INDEX_PATH):
70
+ """Nạp FAISS index từ file."""
71
+ if os.path.exists(index_file):
72
+ index = faiss.read_index(index_file)
73
+ print(f"Đã nạp FAISS index từ {index_file}")
74
+ return index
75
+ return None
76
+
77
+ def compute_and_save_embeddings(index_file=FAISS_INDEX_PATH):
78
+ """Tính toán embeddings, lưu vào FAISS và đồng bộ metadata vào SQLite."""
79
+ print("Loading model...")
80
+ model = SentenceTransformer('dangvantuan/vietnamese-embedding')
81
+ print("Model loaded")
82
+
83
+ # Lấy dữ liệu từ Neo4j
84
+ neo4j = Neo4jConnection()
85
+ result = neo4j.execute_query("MATCH (n) RETURN n")
86
+ corpus = []
87
+
88
+ # Chuẩn bị corpus và lưu metadata vào SQLite
89
+ print("Processing Neo4j data and saving to SQLite...")
90
+ for index, record in enumerate(result):
91
+ print(record)
92
+ label = list(record["n"].labels)[0]
93
+ print(label)
94
+ embedding = dict(record["n"])
95
+ id = embedding.pop('id')
96
+ name = embedding.pop('name') if 'name' in embedding else id
97
+ properties = embedding
98
+ corpus.append(name)
99
+
100
+ # Kiểm tra và cập nhật/thêm vào SQLite
101
+ cursor.execute('SELECT e_index FROM embeddings WHERE e_index = ?', (index,))
102
+ existing = cursor.fetchone()
103
+ if existing:
104
+ update_embedding(index, id, name, label, properties)
105
+ else:
106
+ insert_embedding(index, id, name, label, properties)
107
+
108
+ # Tính toán embeddings
109
+ print("Tokenizing and encoding...")
110
+ tokenized = [tokenize(s) for s in corpus]
111
+ embeddings = model.encode(tokenized, show_progress_bar=False)
112
+ print("Encoding done")
113
+
114
+ # Chuẩn hóa embeddings
115
+ print("Normalizing...")
116
+ faiss.normalize_L2(embeddings)
117
+ print("Normalized")
118
+
119
+ # Tạo và lưu FAISS index
120
+ d = embeddings.shape[1]
121
+ index = faiss.IndexFlatIP(d)
122
+ index.add(embeddings)
123
+ save_faiss_index(index, index_file)
124
+
125
+ print("Processing completed")
126
+ return index, corpus, embeddings
127
+
128
+ def load_or_compute_embeddings(index_file=FAISS_INDEX_PATH):
129
+ """Nạp hoặc tính toán embeddings và FAISS index."""
130
+ # Thử nạp FAISS index
131
+ index = load_faiss_index(index_file)
132
+
133
+ # Lấy corpus từ SQLite
134
+ embeddings_data = get_all_embeddings()
135
+ corpus = [row[2] for row in embeddings_data] # Lấy cột name
136
+
137
+ if index is None or not corpus:
138
+ print("No saved index or corpus found, computing new ones...")
139
+ index, corpus, embeddings = compute_and_save_embeddings(index_file)
140
+ else:
141
+ print("Loaded existing index and corpus")
142
+
143
+ return index, corpus
144
+
145
+ def get_qvec_by_text(model, text):
146
+ q_token = tokenize(text)
147
+ q_vec = model.encode([q_token])
148
+ faiss.normalize_L2(q_vec)
149
+ return q_vec
150
+
151
+ if __name__ == "__main__":
152
+ try:
153
+ index, corpus = load_or_compute_embeddings()
154
+ print(f"Index ready with {index.ntotal} embeddings, corpus size: {len(corpus)}")
155
+ model = SentenceTransformer('dangvantuan/vietnamese-embedding')
156
+ while True:
157
+ try:
158
+ query = input("Nhập câu truy vấn (nhấn Ctrl+C để thoát): ")
159
+ q_vec = get_qvec_by_text(model, query)
160
+ k = 1 # số kết quả cần lấy
161
+ D, I = index.search(q_vec, k)
162
+ print("Câu truy vấn:", query)
163
+ print(I[0][0])
164
+ print(type(I[0][0]))
165
+ print("Câu gần nhất:", get_embedding_by_id(int(I[0][0])), "(khoảng cách:", D[0][0], ")")
166
+ print("-" * 50)
167
+ except KeyboardInterrupt:
168
+ print("\nĐã dừng chương trình!")
169
+ break
170
+ finally:
171
+ conn.close()
172
+ print("SQLite connection closed")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ langchain>=0.3.23
2
+ neo4j>=5.28.1
3
+ python-dotenv>=1.0.1
4
+ fastapi>=0.115.12
5
+ uvicorn>=0.34.2
6
+ pydantic>=2.10.6
7
+ faiss-cpu>=1.11.0
8
+