Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Sontranwakumo
commited on
Commit
·
88cc76c
1
Parent(s):
77d75db
init: move from github
Browse files- .DS_Store +0 -0
- .env.example +5 -0
- .gitattributes +2 -0
- .gitignore +114 -0
- README.md +128 -1
- app/.DS_Store +0 -0
- app/__init__.py +0 -0
- app/api/__init__.py +0 -0
- app/api/dto/kg_query.py +23 -0
- app/api/routes.py +40 -0
- app/core/__init__.py +0 -0
- app/core/config.py +29 -0
- app/core/dependencies.py +28 -0
- app/core/type.py +46 -0
- app/data/faiss_index.index +3 -0
- app/data/image_faiss_index.index +3 -0
- app/data/vector_embeddings.db +3 -0
- app/main.py +84 -0
- app/models/__init__.py +0 -0
- app/models/crop_clip.py +98 -0
- app/models/gemini_caller.py +41 -0
- app/models/knowledge_graph.py +126 -0
- app/services/__init__.py +0 -0
- app/services/predict.py +60 -0
- app/utils/constant.py +2 -0
- app/utils/data_mapping.py +65 -0
- app/utils/extract_entity.py +25 -0
- app/utils/prompt.py +85 -0
- environment.yml +167 -0
- prepare_script/image_caption_embeddings.py +207 -0
- prepare_script/sync_neo4j_node.py +172 -0
- requirements.txt +8 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.env.example
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
NEO4J_URI=
|
| 2 |
+
NEO4J_USER=neo4j
|
| 3 |
+
NEO4J_PASSWORD=
|
| 4 |
+
OPENAI_API_KEY=
|
| 5 |
+
GEMINI_API_KEY=
|
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.index filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.db filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
weights/
|
| 7 |
+
|
| 8 |
+
## Big data
|
| 9 |
+
/Data
|
| 10 |
+
|
| 11 |
+
# C extensions
|
| 12 |
+
*.so
|
| 13 |
+
|
| 14 |
+
# Distribution / packaging
|
| 15 |
+
.Python
|
| 16 |
+
build/
|
| 17 |
+
develop-eggs/
|
| 18 |
+
dist/
|
| 19 |
+
downloads/
|
| 20 |
+
eggs/
|
| 21 |
+
.eggs/
|
| 22 |
+
lib/
|
| 23 |
+
lib64/
|
| 24 |
+
parts/
|
| 25 |
+
sdist/
|
| 26 |
+
var/
|
| 27 |
+
wheels/
|
| 28 |
+
*.egg-info/
|
| 29 |
+
.installed.cfg
|
| 30 |
+
*.egg
|
| 31 |
+
|
| 32 |
+
# PyInstaller
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
.hypothesis/
|
| 50 |
+
|
| 51 |
+
# Translations
|
| 52 |
+
*.mo
|
| 53 |
+
*.pot
|
| 54 |
+
|
| 55 |
+
# Django stuff:
|
| 56 |
+
*.log
|
| 57 |
+
local_settings.py
|
| 58 |
+
db.sqlite3
|
| 59 |
+
|
| 60 |
+
# Flask stuff:
|
| 61 |
+
instance/
|
| 62 |
+
.webassets-cache
|
| 63 |
+
|
| 64 |
+
# Scrapy stuff:
|
| 65 |
+
.scrapy
|
| 66 |
+
|
| 67 |
+
# Sphinx documentation
|
| 68 |
+
docs/_build/
|
| 69 |
+
|
| 70 |
+
# PyBuilder
|
| 71 |
+
target/
|
| 72 |
+
|
| 73 |
+
# Jupyter Notebook
|
| 74 |
+
.ipynb_checkpoints
|
| 75 |
+
|
| 76 |
+
# pyenv
|
| 77 |
+
.python-version
|
| 78 |
+
|
| 79 |
+
# celery beat schedule file
|
| 80 |
+
celerybeat-schedule
|
| 81 |
+
|
| 82 |
+
# SageMath parsed files
|
| 83 |
+
*.sage.py
|
| 84 |
+
|
| 85 |
+
# Environments
|
| 86 |
+
.env
|
| 87 |
+
.venv
|
| 88 |
+
env/
|
| 89 |
+
venv/
|
| 90 |
+
ENV/
|
| 91 |
+
env.bak/
|
| 92 |
+
venv.bak/
|
| 93 |
+
|
| 94 |
+
# Spyder project settings
|
| 95 |
+
.spyderproject
|
| 96 |
+
.spyproject
|
| 97 |
+
|
| 98 |
+
# Rope project settings
|
| 99 |
+
.ropeproject
|
| 100 |
+
|
| 101 |
+
# mkdocs documentation
|
| 102 |
+
/site
|
| 103 |
+
|
| 104 |
+
# mypy
|
| 105 |
+
.mypy_cache/
|
| 106 |
+
|
| 107 |
+
# IDE specific files
|
| 108 |
+
.idea/
|
| 109 |
+
.vscode/
|
| 110 |
+
*.swp
|
| 111 |
+
*.swo
|
| 112 |
+
|
| 113 |
+
# FastAPI specific
|
| 114 |
+
.pytest_cache/
|
README.md
CHANGED
|
@@ -8,4 +8,131 @@ pinned: false
|
|
| 8 |
short_description: Crop diagnosis module
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
short_description: Crop diagnosis module
|
| 9 |
---
|
| 10 |
|
| 11 |
+
|
| 12 |
+
# Crop Diagnosis Knowledge Graph Module
|
| 13 |
+
|
| 14 |
+
A powerful tool for querying crop diagnosis knowledge graphs using LangChain and Neo4j. This module provides an API interface to interact with a knowledge graph containing agricultural and crop disease information.
|
| 15 |
+
|
| 16 |
+
## Features
|
| 17 |
+
|
| 18 |
+
- Natural language querying of crop diagnosis knowledge graph
|
| 19 |
+
- Integration with LangChain for intelligent query processing
|
| 20 |
+
- Neo4j database backend for efficient graph operations
|
| 21 |
+
- RESTful API interface
|
| 22 |
+
- Environment-based configuration
|
| 23 |
+
|
| 24 |
+
## Prerequisites
|
| 25 |
+
|
| 26 |
+
- Python 3.8+
|
| 27 |
+
- Neo4j Database (version 5.x)
|
| 28 |
+
- OpenAI API key (for LangChain integration)
|
| 29 |
+
|
| 30 |
+
## Installation
|
| 31 |
+
|
| 32 |
+
1. Clone the repository:
|
| 33 |
+
```bash
|
| 34 |
+
git clone [repository-url]
|
| 35 |
+
cd crop-diag-module
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
2. Create and activate a virtual environment:
|
| 39 |
+
```bash
|
| 40 |
+
python -m venv venv
|
| 41 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
3. Install dependencies:
|
| 45 |
+
```bash
|
| 46 |
+
pip install -r requirements.txt
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## Configuration
|
| 50 |
+
|
| 51 |
+
1. Create a `.env` file in the project root:
|
| 52 |
+
```bash
|
| 53 |
+
cp .env.example .env
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
2. Edit the `.env` file with your configuration:
|
| 57 |
+
```env
|
| 58 |
+
# Neo4j Configuration
|
| 59 |
+
NEO4J_URI=bolt://localhost:7687
|
| 60 |
+
NEO4J_USER=neo4j
|
| 61 |
+
NEO4J_PASSWORD=your_password
|
| 62 |
+
|
| 63 |
+
# API Configuration
|
| 64 |
+
API_HOST=0.0.0.0
|
| 65 |
+
API_PORT=8000
|
| 66 |
+
DEBUG=True
|
| 67 |
+
|
| 68 |
+
# LangChain Configuration
|
| 69 |
+
OPENAI_API_KEY=your_openai_api_key
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
Replace the following values:
|
| 73 |
+
- `NEO4J_URI`: Your Neo4j database URI
|
| 74 |
+
- `NEO4J_USER`: Neo4j username
|
| 75 |
+
- `NEO4J_PASSWORD`: Neo4j password
|
| 76 |
+
- `OPENAI_API_KEY`: Your OpenAI API key
|
| 77 |
+
|
| 78 |
+
## Running the Application
|
| 79 |
+
|
| 80 |
+
1. Start the FastAPI server:
|
| 81 |
+
```bash
|
| 82 |
+
uvicorn app.main:app --reload
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
2. Access the API documentation:
|
| 86 |
+
- Swagger UI: http://localhost:8000/docs
|
| 87 |
+
- ReDoc: http://localhost:8000/redoc
|
| 88 |
+
|
| 89 |
+
## API Usage
|
| 90 |
+
|
| 91 |
+
### Query the Knowledge Graph
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
curl -X POST "http://localhost:8000/api/query" \
|
| 95 |
+
-H "Content-Type: application/json" \
|
| 96 |
+
-d '{"question": "What are the symptoms of rice blast disease?"}'
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
## Project Structure
|
| 100 |
+
|
| 101 |
+
```
|
| 102 |
+
crop-diag-module/
|
| 103 |
+
├── app/
|
| 104 |
+
│ ├── api/ # API routes and endpoints
|
| 105 |
+
│ ├── core/ # Core functionality
|
| 106 |
+
│ ├── models/ # Data models
|
| 107 |
+
│ └── utils/ # Utility functions
|
| 108 |
+
├── KG/ # Knowledge Graph data
|
| 109 |
+
├── tests/ # Test cases
|
| 110 |
+
├── requirements.txt # Project dependencies
|
| 111 |
+
└── .env # Environment configuration
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
## Development
|
| 115 |
+
|
| 116 |
+
### Running Tests
|
| 117 |
+
|
| 118 |
+
```bash
|
| 119 |
+
pytest tests/
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
### Code Style
|
| 123 |
+
|
| 124 |
+
This project follows PEP 8 style guidelines. Use the following command to check code style:
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
| 128 |
+
|
| 129 |
+
uvicorn app.main:app --reload
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
## License
|
| 133 |
+
|
| 134 |
+
[Add your license information here]
|
| 135 |
+
|
| 136 |
+
## Contributing
|
| 137 |
+
|
| 138 |
+
[Add contribution guidelines here]
|
app/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
app/__init__.py
ADDED
|
File without changes
|
app/api/__init__.py
ADDED
|
File without changes
|
app/api/dto/kg_query.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
from app.core.type import Node
|
| 5 |
+
|
| 6 |
+
class QueryContext(BaseModel):
|
| 7 |
+
crop_id: Optional[str] = None
|
| 8 |
+
nodes: Optional[List[Node]] = None
|
| 9 |
+
predicted_labels: Optional[List[str]] = None
|
| 10 |
+
|
| 11 |
+
class PredictedLabel(BaseModel):
|
| 12 |
+
crop_name: str
|
| 13 |
+
label: str
|
| 14 |
+
confidence: float
|
| 15 |
+
|
| 16 |
+
class KGQueryRequest(BaseModel):
|
| 17 |
+
context: Optional[QueryContext] = None
|
| 18 |
+
crop_id: Optional[str] = None
|
| 19 |
+
additional_info: Optional[str] = None
|
| 20 |
+
|
| 21 |
+
class KGQueryResponse(BaseModel):
|
| 22 |
+
answer: str
|
| 23 |
+
sources: List[str]
|
app/api/routes.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, File
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
from app.api.dto.kg_query import KGQueryRequest, KGQueryResponse, PredictedLabel, QueryContext
|
| 5 |
+
from app.core.dependencies import get_all_models, get_clip_model, get_data_mapper
|
| 6 |
+
from app.core.type import Node
|
| 7 |
+
from app.models.crop_clip import CLIPModule
|
| 8 |
+
from app.services.predict import PredictService, get_predict_service
|
| 9 |
+
from app.utils.extract_entity import extract_entities
|
| 10 |
+
|
| 11 |
+
router = APIRouter()
|
| 12 |
+
|
| 13 |
+
class QueryRequest(BaseModel):
|
| 14 |
+
question: str
|
| 15 |
+
context: Optional[List[str]] = None
|
| 16 |
+
|
| 17 |
+
class QueryResponse(BaseModel):
|
| 18 |
+
answer: str
|
| 19 |
+
sources: List[str]
|
| 20 |
+
|
| 21 |
+
@router.post("/analyze")
|
| 22 |
+
async def analyze(
|
| 23 |
+
image: UploadFile = File(None),
|
| 24 |
+
predict_service: PredictService = Depends(get_predict_service)
|
| 25 |
+
):
|
| 26 |
+
predicted_label = predict_service.predict_image(image)
|
| 27 |
+
|
| 28 |
+
return {
|
| 29 |
+
"crop_id": predicted_label[0].crop_name,
|
| 30 |
+
"predicted_labels": predicted_label,
|
| 31 |
+
"nodes": []
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@router.post("/kg-query")
|
| 36 |
+
async def query_kg(
|
| 37 |
+
request: KGQueryRequest,
|
| 38 |
+
predict_service: PredictService = Depends(get_predict_service),
|
| 39 |
+
):
|
| 40 |
+
return predict_service.retrieve_kg(request)
|
app/core/__init__.py
ADDED
|
File without changes
|
app/core/config.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic_settings import BaseSettings
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
|
| 4 |
+
class Settings(BaseSettings):
|
| 5 |
+
# Neo4j Configuration
|
| 6 |
+
neo4j_uri: str
|
| 7 |
+
neo4j_user: str
|
| 8 |
+
neo4j_password: str
|
| 9 |
+
neo4j_database: str = "neo4j"
|
| 10 |
+
|
| 11 |
+
# API Configuration
|
| 12 |
+
api_host: str = "0.0.0.0"
|
| 13 |
+
api_port: int = 8000
|
| 14 |
+
debug: bool = True
|
| 15 |
+
|
| 16 |
+
openai_api_key: str
|
| 17 |
+
gemini_api_key: str
|
| 18 |
+
|
| 19 |
+
load_clip_model: bool = True
|
| 20 |
+
load_gemini_model: bool = True
|
| 21 |
+
load_data_mapper: bool = True
|
| 22 |
+
load_knowledge_graph: bool = True
|
| 23 |
+
|
| 24 |
+
class Config:
|
| 25 |
+
env_file = ".env"
|
| 26 |
+
|
| 27 |
+
@lru_cache()
|
| 28 |
+
def get_settings() -> Settings:
|
| 29 |
+
return Settings()
|
app/core/dependencies.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Depends, Request
|
| 2 |
+
from app.models.crop_clip import CLIPModule
|
| 3 |
+
from app.utils.data_mapping import DataMapping
|
| 4 |
+
from app.models.knowledge_graph import KnowledgeGraphUtils, Neo4jConnection
|
| 5 |
+
|
| 6 |
+
def get_clip_model(request: Request) -> CLIPModule:
|
| 7 |
+
"""Lấy CLIP model từ app.state"""
|
| 8 |
+
return request.app.state.model_loader.clip_model
|
| 9 |
+
|
| 10 |
+
def get_data_mapper(request: Request) -> DataMapping:
|
| 11 |
+
"""Lấy DataMapper từ app.state"""
|
| 12 |
+
return request.app.state.model_loader.data_mapper
|
| 13 |
+
|
| 14 |
+
def get_knowledge_graph(request: Request) -> KnowledgeGraphUtils:
|
| 15 |
+
"""Lấy KnowledgeGraph từ app.state"""
|
| 16 |
+
return request.app.state.model_loader.knowledge_graph
|
| 17 |
+
|
| 18 |
+
def get_all_models(
|
| 19 |
+
clip_model: CLIPModule = Depends(get_clip_model),
|
| 20 |
+
data_mapper: DataMapping = Depends(get_data_mapper),
|
| 21 |
+
knowledge_graph: KnowledgeGraphUtils = Depends(get_knowledge_graph)
|
| 22 |
+
):
|
| 23 |
+
"""Lấy tất cả các model từ app.state"""
|
| 24 |
+
return {
|
| 25 |
+
"clip_model": clip_model,
|
| 26 |
+
"data_mapper": data_mapper,
|
| 27 |
+
"knowledge_graph": knowledge_graph
|
| 28 |
+
}
|
app/core/type.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
class Node(BaseModel):
|
| 6 |
+
id: str
|
| 7 |
+
label: str
|
| 8 |
+
name: str
|
| 9 |
+
properties: dict
|
| 10 |
+
score: Optional[float] = None
|
| 11 |
+
|
| 12 |
+
@staticmethod
|
| 13 |
+
def map_json_to_node(json_data: dict, label: str = None) -> 'Node':
|
| 14 |
+
node_data = {
|
| 15 |
+
"name": json_data.pop("name") if "name" in json_data else json_data["id"],
|
| 16 |
+
"id": json_data.pop("id"),
|
| 17 |
+
"label": label if label else json_data.pop("label"),
|
| 18 |
+
"properties": json_data
|
| 19 |
+
}
|
| 20 |
+
return Node(**node_data)
|
| 21 |
+
|
| 22 |
+
@staticmethod
|
| 23 |
+
def data_row_to_node(data_row: list[str], score = None) -> 'Node':
|
| 24 |
+
return Node(
|
| 25 |
+
id=data_row[1],
|
| 26 |
+
name=data_row[2],
|
| 27 |
+
label=data_row[3],
|
| 28 |
+
properties=json.loads(data_row[4]),
|
| 29 |
+
score=score
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
class Relationship(BaseModel):
|
| 33 |
+
source: str
|
| 34 |
+
target: str
|
| 35 |
+
type: str
|
| 36 |
+
properties: dict
|
| 37 |
+
|
| 38 |
+
class KnowledgeGraph(BaseModel):
|
| 39 |
+
nodes: List[Node]
|
| 40 |
+
relationships: List[Relationship]
|
| 41 |
+
|
| 42 |
+
class GraphQuery(BaseModel):
|
| 43 |
+
key: str
|
| 44 |
+
cypher: str
|
| 45 |
+
parameters: Optional[dict] = None
|
| 46 |
+
description: Optional[str] = None
|
app/data/faiss_index.index
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a1af6778f8fb10ee5a2d44f2815bc288c0ebac355c74a2b144d95720af5b8171
|
| 3 |
+
size 1188909
|
app/data/image_faiss_index.index
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a1f42c9ecb4da4a64cf34d90dc887564accc467233a8bfde986e6fa02b788b10
|
| 3 |
+
size 41824301
|
app/data/vector_embeddings.db
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b9b2d7cab4f196a4a739830702944fedfe10f123e2c8ca3f4285857f56e7996
|
| 3 |
+
size 5394432
|
app/main.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from contextlib import asynccontextmanager
|
| 3 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 4 |
+
from app.core.config import get_settings
|
| 5 |
+
from app.api.routes import router as api_router
|
| 6 |
+
from app.models.crop_clip import CLIPModule
|
| 7 |
+
from app.models.gemini_caller import GeminiGenerator
|
| 8 |
+
from app.utils.data_mapping import DataMapping, SingletonModel
|
| 9 |
+
from app.models.knowledge_graph import KnowledgeGraphUtils, Neo4jConnection
|
| 10 |
+
import asyncio
|
| 11 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 12 |
+
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
settings = get_settings()
|
| 16 |
+
|
| 17 |
+
class ModelLoader:
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.clip_model = None
|
| 20 |
+
self.gemini_model = None
|
| 21 |
+
self.sentence_transformer = None
|
| 22 |
+
self.neo4j_connection = None
|
| 23 |
+
|
| 24 |
+
def load_models(self):
|
| 25 |
+
try:
|
| 26 |
+
if settings.load_clip_model:
|
| 27 |
+
logger.info("Loading CLIP model...")
|
| 28 |
+
self.clip_model = CLIPModule()
|
| 29 |
+
logger.info("CLIP model loaded successfully")
|
| 30 |
+
|
| 31 |
+
if settings.load_gemini_model:
|
| 32 |
+
logger.info("Loading Gemini model...")
|
| 33 |
+
self.gemini_model = GeminiGenerator()
|
| 34 |
+
logger.info("Gemini model loaded successfully")
|
| 35 |
+
|
| 36 |
+
if settings.load_data_mapper:
|
| 37 |
+
logger.info("Loading DataMapper model...")
|
| 38 |
+
self.data_mapper = DataMapping()
|
| 39 |
+
logger.info("DataMapper model loaded successfully")
|
| 40 |
+
|
| 41 |
+
if settings.load_knowledge_graph:
|
| 42 |
+
logger.info("Connecting to Knowledge Graph...")
|
| 43 |
+
self.knowledge_graph = KnowledgeGraphUtils()
|
| 44 |
+
logger.info("Knowledge Graph connection established")
|
| 45 |
+
except Exception as e:
|
| 46 |
+
logger.error(f"Failed to load models: {e}")
|
| 47 |
+
raise
|
| 48 |
+
|
| 49 |
+
def close(self):
|
| 50 |
+
if self.neo4j_connection:
|
| 51 |
+
logger.info("Closing Neo4j connection...")
|
| 52 |
+
self.neo4j_connection.close()
|
| 53 |
+
self.clip_model = None
|
| 54 |
+
self.gemini_model = None
|
| 55 |
+
self.sentence_transformer = None
|
| 56 |
+
logger.info("Models released")
|
| 57 |
+
|
| 58 |
+
# Lifespan event handler
|
| 59 |
+
@asynccontextmanager
|
| 60 |
+
async def lifespan(app: FastAPI):
|
| 61 |
+
loop = asyncio.get_event_loop()
|
| 62 |
+
with ThreadPoolExecutor() as pool:
|
| 63 |
+
await loop.run_in_executor(pool, app.state.model_loader.load_models)
|
| 64 |
+
logger.info("Application startup complete")
|
| 65 |
+
yield
|
| 66 |
+
app.state.model_loader.close()
|
| 67 |
+
logger.info("Application shutdown complete")
|
| 68 |
+
|
| 69 |
+
app = FastAPI(
|
| 70 |
+
title="Crop Diagnosis Knowledge Graph API",
|
| 71 |
+
description="API for querying crop diagnosis knowledge graph using LangChain",
|
| 72 |
+
version="1.0.0",
|
| 73 |
+
debug=settings.debug,
|
| 74 |
+
lifespan=lifespan
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
app.state.model_loader = ModelLoader()
|
| 78 |
+
|
| 79 |
+
app.include_router(api_router, prefix="/api")
|
| 80 |
+
|
| 81 |
+
@app.get("/")
|
| 82 |
+
async def root():
|
| 83 |
+
return {"message": "Welcome to Crop Diagnosis Knowledge Graph API"}
|
| 84 |
+
|
app/models/__init__.py
ADDED
|
File without changes
|
app/models/crop_clip.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
import clip
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
from app.api.dto.kg_query import PredictedLabel
|
| 10 |
+
|
| 11 |
+
CLASS_NAMES = ['benhVerticilliumWiltCaChua', 'benhChayLaCaChua', 'benhXoanLaCaChua', 'benhDomLaCaChua',
|
| 12 |
+
'benhNhenXanhSan', 'benhKhamLaSan', 'cassava healthy', 'benhDomNau',
|
| 13 |
+
'boCanhCungHaiLaNgo', 'corn healthy', 'benhChayLaNgo', 'benhRiSatNgo', 'benhSocLaNgo',
|
| 14 |
+
'benhDomLaNgo', 'benhBacLaLua', 'benhDaoOnLua', 'benhDomNauLuaNuoc']
|
| 15 |
+
|
| 16 |
+
CROP_NAMES = ['caChua', 'caChua', 'caChua', 'caChua', 'san', 'san', 'san', 'san',
|
| 17 |
+
'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'luaNuoc', 'luaNuoc', 'luaNuoc']
|
| 18 |
+
|
| 19 |
+
WEIGHTS_PATH = os.path.join(os.path.dirname(__file__), 'weights', 'clip_finetuned.pth')
|
| 20 |
+
|
| 21 |
+
class CLIPFineTuner(nn.Module):
|
| 22 |
+
def __init__(self, model, num_classes):
|
| 23 |
+
super(CLIPFineTuner, self).__init__()
|
| 24 |
+
self.model = model
|
| 25 |
+
self.classifier = nn.Linear(model.visual.output_dim, num_classes)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
with torch.no_grad():
|
| 29 |
+
features = self.model.encode_image(x).float() # Convert to float32
|
| 30 |
+
return self.classifier(features)
|
| 31 |
+
|
| 32 |
+
class CLIPModule:
|
| 33 |
+
def __init__(self):
|
| 34 |
+
model, preprocess = clip.load("ViT-B/32", jit=False)
|
| 35 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
self.model = CLIPFineTuner(model, 17)
|
| 37 |
+
self.model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=self.device))
|
| 38 |
+
self.model.to(self.device)
|
| 39 |
+
self.model.eval()
|
| 40 |
+
self.classes = CLASS_NAMES
|
| 41 |
+
self.transform = preprocess
|
| 42 |
+
|
| 43 |
+
def predict_image(self, image: Image.Image):
|
| 44 |
+
output = self.__predict(image)
|
| 45 |
+
probabilities = torch.nn.functional.softmax(output, dim=1)[0]
|
| 46 |
+
|
| 47 |
+
predictions: List[PredictedLabel] = []
|
| 48 |
+
for idx, prob in enumerate(probabilities):
|
| 49 |
+
predictions.append(PredictedLabel(
|
| 50 |
+
crop_name=CROP_NAMES[idx],
|
| 51 |
+
label=self.classes[idx],
|
| 52 |
+
confidence=float(prob)
|
| 53 |
+
))
|
| 54 |
+
|
| 55 |
+
# Sắp xếp giảm dần theo xác suất
|
| 56 |
+
predictions.sort(key=lambda x: x.confidence, reverse=True)
|
| 57 |
+
|
| 58 |
+
return predictions
|
| 59 |
+
|
| 60 |
+
def __predict(self, image_input):
|
| 61 |
+
"""
|
| 62 |
+
Dự đoán nhãn cho một ảnh.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
image_input: Đường dẫn file ảnh (str) hoặc đối tượng PIL.Image
|
| 66 |
+
device: Thiết bị chạy mô hình ('cuda' hoặc 'cpu').
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
str: Nhãn dự đoán (e.g., "cassava_leaf beetle").
|
| 70 |
+
"""
|
| 71 |
+
try:
|
| 72 |
+
image = self.__handle_image(image_input)
|
| 73 |
+
image_tensor = self.transform(image)
|
| 74 |
+
except ValueError as e:
|
| 75 |
+
raise e
|
| 76 |
+
except Exception as e:
|
| 77 |
+
raise ValueError(f"Không thể xử lý ảnh đầu vào: {str(e)}")
|
| 78 |
+
|
| 79 |
+
if image_tensor.dim() == 3:
|
| 80 |
+
image_tensor = image_tensor.unsqueeze(0)
|
| 81 |
+
|
| 82 |
+
print(image_tensor.shape)
|
| 83 |
+
|
| 84 |
+
image_tensor = image_tensor.to(self.device)
|
| 85 |
+
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
output = self.model(image_tensor)
|
| 88 |
+
|
| 89 |
+
return output ## an array of 17 values, no softmax
|
| 90 |
+
|
| 91 |
+
def __handle_image(self, image_input):
|
| 92 |
+
if isinstance(image_input, str):
|
| 93 |
+
image = Image.open(image_input).convert('RGB')
|
| 94 |
+
elif isinstance(image_input, Image.Image):
|
| 95 |
+
image = image_input
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError("Invalid image input")
|
| 98 |
+
return image
|
app/models/gemini_caller.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import google.generativeai as genai
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
load_dotenv()
|
| 6 |
+
|
| 7 |
+
class GeminiGenerator:
|
| 8 |
+
def __init__(self, model_name="gemini-2.0-flash", temperature=0):
|
| 9 |
+
self.key = os.environ.get("GEMINI_API_KEY")
|
| 10 |
+
genai.configure(api_key=self.key)
|
| 11 |
+
|
| 12 |
+
# Cấu hình generation config
|
| 13 |
+
self.generation_config = {
|
| 14 |
+
"temperature": temperature
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
# Hệ thống prompt mặc định
|
| 18 |
+
self.system_prompt = "Bạn là một trợ lý AI hữu ích cho các dự án IT về cây trồng và bệnh cây trồng. Bạn có khả năng trích xuất thông tin từ văn bản được cung cấp và trả dữ liệu bằng tiếng Việt theo yêu cầu."
|
| 19 |
+
|
| 20 |
+
# Khởi tạo model
|
| 21 |
+
self.model = genai.GenerativeModel(
|
| 22 |
+
model_name=model_name,
|
| 23 |
+
generation_config=self.generation_config,
|
| 24 |
+
system_instruction=self.system_prompt
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def generate(self, prompt="Hello, world!", system_prompt=None):
|
| 28 |
+
# Sử dụng system prompt tùy chỉnh nếu được cung cấp
|
| 29 |
+
if system_prompt:
|
| 30 |
+
model = genai.GenerativeModel(
|
| 31 |
+
model_name=self.model.model_name,
|
| 32 |
+
generation_config=self.generation_config,
|
| 33 |
+
system_instruction=system_prompt
|
| 34 |
+
)
|
| 35 |
+
response = model.generate_content(prompt)
|
| 36 |
+
else:
|
| 37 |
+
response = self.model.generate_content(prompt)
|
| 38 |
+
return response
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
generator = GeminiGenerator()
|
app/models/knowledge_graph.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
from fastapi import Depends
|
| 5 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 6 |
+
from app.core.config import Settings, get_settings
|
| 7 |
+
from utils.data_mapping import DataMapping
|
| 8 |
+
from utils.extract_entity import extract_entities
|
| 9 |
+
from core.type import Node
|
| 10 |
+
from neo4j import GraphDatabase
|
| 11 |
+
from utils.constant import NEO4J_LABELS, NEO4J_RELATIONS
|
| 12 |
+
|
| 13 |
+
NEO4J_URI = os.getenv("NEO4J_URI", "neo4j://localhost:7687")
|
| 14 |
+
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
|
| 15 |
+
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
|
| 16 |
+
NEO4J_DATABASE = os.getenv("NEO4J_DATABASE", "neo4j")
|
| 17 |
+
|
| 18 |
+
class Neo4jConnection:
|
| 19 |
+
def __init__(self):
|
| 20 |
+
"""Khởi tạo kết nối tới Neo4j"""
|
| 21 |
+
self.uri = NEO4J_URI
|
| 22 |
+
self.user = NEO4J_USER
|
| 23 |
+
self.password = NEO4J_PASSWORD
|
| 24 |
+
self.database = NEO4J_DATABASE
|
| 25 |
+
|
| 26 |
+
self.driver = GraphDatabase.driver(
|
| 27 |
+
self.uri,
|
| 28 |
+
auth=(self.user, self.password),
|
| 29 |
+
database=self.database
|
| 30 |
+
)
|
| 31 |
+
self.entity_types = []
|
| 32 |
+
self.relations = []
|
| 33 |
+
|
| 34 |
+
with self.driver.session() as session:
|
| 35 |
+
result = session.run("CALL db.info()")
|
| 36 |
+
self.database_info = result.single().data()
|
| 37 |
+
self.entity_types = NEO4J_LABELS
|
| 38 |
+
self.relations = NEO4J_RELATIONS
|
| 39 |
+
|
| 40 |
+
def get_database_info(self):
|
| 41 |
+
"""Trả về thông tin về database đang kết nối"""
|
| 42 |
+
return self.database_info
|
| 43 |
+
|
| 44 |
+
def close(self):
|
| 45 |
+
"""Đóng kết nối tới Neo4j"""
|
| 46 |
+
if self.driver is not None:
|
| 47 |
+
self.driver.close()
|
| 48 |
+
|
| 49 |
+
def execute_query(self, query, parameters=None):
|
| 50 |
+
"""Thực thi một truy vấn Cypher bất kỳ"""
|
| 51 |
+
with self.driver.session() as session:
|
| 52 |
+
result = session.run(query, parameters)
|
| 53 |
+
return [record for record in result]
|
| 54 |
+
|
| 55 |
+
class KnowledgeGraphUtils:
|
| 56 |
+
def get_disease_from_env_factors(self, crop_id: str, params: list[Node]):
|
| 57 |
+
envFactors = [param.id for param in params if param.label == "EnvironmentalFactor"]
|
| 58 |
+
query = f"""
|
| 59 |
+
MATCH (c:Crop {{id: "{crop_id}"}})
|
| 60 |
+
WITH c
|
| 61 |
+
MATCH (d:Disease)-[:AFFECTS]-(c)
|
| 62 |
+
OPTIONAL MATCH (ef:EnvironmentalFactor)-[:FAVORS]-(d)
|
| 63 |
+
WHERE ef.id IN {envFactors}
|
| 64 |
+
OPTIONAL MATCH (ef2:EnvironmentalFactor)-[:FAVORS]-(cause:Cause)-[:CAUSES|AFFECTS]-(d)
|
| 65 |
+
WHERE ef2.id IN {envFactors}
|
| 66 |
+
WITH d, COLLECT(DISTINCT ef.id) AS direct_env, COLLECT(DISTINCT ef2.id) AS indirect_env
|
| 67 |
+
WHERE SIZE(direct_env) > 0 OR SIZE(indirect_env) > 0
|
| 68 |
+
RETURN DISTINCT d, direct_env, indirect_env
|
| 69 |
+
"""
|
| 70 |
+
kg = Neo4jConnection()
|
| 71 |
+
result = kg.execute_query(query)
|
| 72 |
+
print(result)
|
| 73 |
+
final_result = []
|
| 74 |
+
for record in result:
|
| 75 |
+
record_dict = dict(record)
|
| 76 |
+
disease = Node.map_json_to_node(dict(record_dict["d"]), "Disease")
|
| 77 |
+
env_ids = list(record_dict["direct_env"]) + list(record_dict["indirect_env"])
|
| 78 |
+
print(env_ids)
|
| 79 |
+
score = 0
|
| 80 |
+
for env_id in env_ids:
|
| 81 |
+
for param in params:
|
| 82 |
+
if param.id == env_id:
|
| 83 |
+
score = max(score, param.score)
|
| 84 |
+
disease.score = score
|
| 85 |
+
final_result.append({
|
| 86 |
+
"disease": disease,
|
| 87 |
+
"env_ids": env_ids
|
| 88 |
+
})
|
| 89 |
+
final_result.sort(key=lambda x: x["disease"].score, reverse=True)
|
| 90 |
+
|
| 91 |
+
return final_result
|
| 92 |
+
|
| 93 |
+
def get_disease_from_symptoms(self, crop_id: str, params: list[Node]) -> list:
|
| 94 |
+
symptoms = [param.id for param in params if param.label == "Symptom"]
|
| 95 |
+
query = f"""
|
| 96 |
+
MATCH (c:Crop {{id: "{crop_id}"}})
|
| 97 |
+
WITH c
|
| 98 |
+
MATCH (d:Disease)-[:AFFECTS]-(c)
|
| 99 |
+
OPTIONAL MATCH (sym1:Symptom)-[:HAS_SYMPTOM]-(d)
|
| 100 |
+
WHERE sym1.id IN {symptoms}
|
| 101 |
+
OPTIONAL MATCH (sym2:Symptom)-[:HAS_SYMPTOM|LOCATED_ON]-(p:PlantPart)-[:CONTAINS]-(d)
|
| 102 |
+
WHERE sym2.id IN {symptoms}
|
| 103 |
+
WITH d, p, c, sym1, sym2, COLLECT(DISTINCT sym1.id) AS direct_env, COLLECT(DISTINCT sym2.id) AS indirect_env
|
| 104 |
+
WHERE SIZE(direct_env) > 0 OR SIZE(indirect_env) > 0
|
| 105 |
+
RETURN d, c, p, sym1, sym2
|
| 106 |
+
"""
|
| 107 |
+
kg = Neo4jConnection()
|
| 108 |
+
result = kg.execute_query(query)
|
| 109 |
+
final_result = []
|
| 110 |
+
for record in result:
|
| 111 |
+
record_dict = dict(record)
|
| 112 |
+
disease = Node.map_json_to_node(dict(record_dict["d"]), "Disease")
|
| 113 |
+
symptom_ids = list(record_dict["sym1"]) + list(record_dict["sym2"])
|
| 114 |
+
score = 0
|
| 115 |
+
for symptom_id in symptom_ids:
|
| 116 |
+
for param in params:
|
| 117 |
+
if param.id == symptom_id:
|
| 118 |
+
score = max(score, param.score)
|
| 119 |
+
disease.score = score
|
| 120 |
+
final_result.append({
|
| 121 |
+
"disease": disease,
|
| 122 |
+
"symptom_ids": symptom_ids
|
| 123 |
+
})
|
| 124 |
+
final_result.sort(key=lambda x: x["disease"].score, reverse=True)
|
| 125 |
+
|
| 126 |
+
return final_result
|
app/services/__init__.py
ADDED
|
File without changes
|
app/services/predict.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Depends, UploadFile
|
| 2 |
+
import torch
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
from app.api.dto.kg_query import KGQueryRequest, QueryContext
|
| 7 |
+
from app.core.dependencies import get_all_models
|
| 8 |
+
from app.core.type import Node
|
| 9 |
+
from app.models.crop_clip import CLIPModule
|
| 10 |
+
from app.models.knowledge_graph import KnowledgeGraphUtils
|
| 11 |
+
from app.utils.data_mapping import DataMapping
|
| 12 |
+
from app.utils.extract_entity import extract_entities
|
| 13 |
+
|
| 14 |
+
class PredictService:
|
| 15 |
+
def __init__(self, models):
|
| 16 |
+
self.models = models
|
| 17 |
+
|
| 18 |
+
def predict_image(self, image: UploadFile):
|
| 19 |
+
clip_model: CLIPModule = self.models["clip_model"]
|
| 20 |
+
image_content = image.file.read()
|
| 21 |
+
pil_image = Image.open(Image.io.BytesIO(image_content)).convert('RGB')
|
| 22 |
+
return clip_model.predict_image(pil_image)
|
| 23 |
+
|
| 24 |
+
def retrieve_kg(self, request: KGQueryRequest):
|
| 25 |
+
try:
|
| 26 |
+
kg: KnowledgeGraphUtils = self.models["knowledge_graph"]
|
| 27 |
+
if not request.context:
|
| 28 |
+
request.context = QueryContext()
|
| 29 |
+
if request.crop_id:
|
| 30 |
+
request.context.crop_id = request.crop_id
|
| 31 |
+
if request.additional_info:
|
| 32 |
+
request.context.nodes = self.__get_nodes_from_additional_info(request.additional_info, self.models["data_mapper"])
|
| 33 |
+
env_result = kg.get_disease_from_env_factors(request.context.crop_id, request.context.nodes)
|
| 34 |
+
symptom_result = kg.get_disease_from_symptoms(request.context.crop_id, request.context.nodes)
|
| 35 |
+
context = request.context
|
| 36 |
+
context.nodes.extend([env_result["disease"] for env_result in env_result])
|
| 37 |
+
context.nodes.extend([symptom_result["disease"] for symptom_result in symptom_result])
|
| 38 |
+
context.nodes.sort(key=lambda x: x.score, reverse=True)
|
| 39 |
+
return {
|
| 40 |
+
"context": context,
|
| 41 |
+
"env_result": env_result,
|
| 42 |
+
"symptom_result": symptom_result
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(e)
|
| 47 |
+
raise e
|
| 48 |
+
|
| 49 |
+
def __get_nodes_from_additional_info(self, additional_info: str, data_mapper: DataMapping):
|
| 50 |
+
entities = extract_entities(additional_info)
|
| 51 |
+
top_results: list[Node] = []
|
| 52 |
+
for entity in entities:
|
| 53 |
+
top_result = data_mapper.get_top_result_by_text(entity.name, 3)
|
| 54 |
+
print([result.name for result in top_result])
|
| 55 |
+
for result in top_result:
|
| 56 |
+
top_results.append(result)
|
| 57 |
+
return top_results
|
| 58 |
+
|
| 59 |
+
def get_predict_service(models = Depends(get_all_models)):
|
| 60 |
+
return PredictService(models)
|
app/utils/constant.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
NEO4J_LABELS =['Disease', 'Symptom', 'Treatment', 'Cause', 'Effect', 'Prevention', 'EnvironmentalFactor', 'Stage', 'Crop', 'CropType', 'PlantPart', 'SoilType', 'DiagnosisMethod']
|
| 2 |
+
NEO4J_RELATIONS = ['CAUSES', 'HAS_SYMPTOM', 'PRODUCES', 'FAVORS', 'IS_TREATED_BY', 'PREVENTS', 'OCCURS_AT', 'BELONGS_TO', 'CONTAINS', 'LOCATED_ON', 'AFFECTS', 'IS_APPLIED_TO']
|
app/utils/data_mapping.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from sentence_transformers import SentenceTransformer
|
| 3 |
+
import faiss
|
| 4 |
+
from pyvi.ViTokenizer import tokenize
|
| 5 |
+
import sqlite3
|
| 6 |
+
from app.core.type import Node
|
| 7 |
+
|
| 8 |
+
FAISS_INDEX_PATH = 'app/data/faiss_index.index'
|
| 9 |
+
VECTOR_EMBEDDINGS_DB_PATH = 'app/data/vector_embeddings.db'
|
| 10 |
+
|
| 11 |
+
class SingletonModel:
|
| 12 |
+
_instance = None
|
| 13 |
+
|
| 14 |
+
def __new__(cls):
|
| 15 |
+
if cls._instance is None:
|
| 16 |
+
cls._instance = super(SingletonModel, cls).__new__(cls)
|
| 17 |
+
cls._instance.model = SentenceTransformer('dangvantuan/vietnamese-embedding')
|
| 18 |
+
return cls._instance
|
| 19 |
+
|
| 20 |
+
class DataMapping:
|
| 21 |
+
def __init__(self):
|
| 22 |
+
try:
|
| 23 |
+
self.model: SentenceTransformer = SingletonModel().model
|
| 24 |
+
self.index: faiss.IndexFlatL2 = self.__load_faiss_index()
|
| 25 |
+
self.conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH, check_same_thread=False)
|
| 26 |
+
self.cursor = self.conn.cursor()
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Error while initializing DataMapping: {e}")
|
| 29 |
+
raise
|
| 30 |
+
|
| 31 |
+
def __del__(self):
|
| 32 |
+
self.cursor.close()
|
| 33 |
+
self.conn.close()
|
| 34 |
+
|
| 35 |
+
def __load_faiss_index(self, index_file = FAISS_INDEX_PATH):
|
| 36 |
+
if os.path.exists(index_file):
|
| 37 |
+
index = faiss.read_index(index_file)
|
| 38 |
+
print(f"Đã nạp FAISS index từ {index_file}")
|
| 39 |
+
return index
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
def get_top_index_by_text(self, text, top_k=1, distance_threshold=float(0.6)):
|
| 43 |
+
if not text or top_k < 1:
|
| 44 |
+
raise ValueError("Invalid input: text cannot be empty and top_k must be positive")
|
| 45 |
+
|
| 46 |
+
q_token = tokenize(text)
|
| 47 |
+
q_vec = self.model.encode([q_token])
|
| 48 |
+
faiss.normalize_L2(q_vec)
|
| 49 |
+
D, I = self.index.search(q_vec, top_k)
|
| 50 |
+
mask = D[0] >= distance_threshold
|
| 51 |
+
filtered_indices = I[0][mask].tolist()
|
| 52 |
+
distances = D[0][mask].tolist()
|
| 53 |
+
return filtered_indices, distances
|
| 54 |
+
|
| 55 |
+
def get_embedding_by_id(self, id):
|
| 56 |
+
self.cursor.execute("SELECT * FROM embeddings WHERE e_index = ?", (id,))
|
| 57 |
+
return self.cursor.fetchone()
|
| 58 |
+
|
| 59 |
+
def get_top_result_by_text(self, text, top_k = 1, type = None) -> list[Node]:
|
| 60 |
+
top_index, distances = self.get_top_index_by_text(text, top_k)
|
| 61 |
+
results = [self.get_embedding_by_id(int(index)) for index in top_index]
|
| 62 |
+
if type:
|
| 63 |
+
results = [result for result in results if result[3] == type]
|
| 64 |
+
return [Node.data_row_to_node(result, distance) for result, distance in zip(results, distances)]
|
| 65 |
+
|
app/utils/extract_entity.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from string import Template
|
| 3 |
+
from fastapi import Depends
|
| 4 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 5 |
+
import dotenv
|
| 6 |
+
import os
|
| 7 |
+
from app.models.gemini_caller import GeminiGenerator
|
| 8 |
+
from app.core.type import Node
|
| 9 |
+
from app.utils.prompt import EXTRACT_ENTITIES_PROMPT
|
| 10 |
+
dotenv.load_dotenv()
|
| 11 |
+
|
| 12 |
+
def extract_entities(text: str) -> list[Node]:
|
| 13 |
+
try:
|
| 14 |
+
gemini = GeminiGenerator()
|
| 15 |
+
prompt = Template(EXTRACT_ENTITIES_PROMPT).substitute(ctext=text)
|
| 16 |
+
entities = gemini.generate(prompt)
|
| 17 |
+
entities = (json.loads(clean_text(entities.text)))["entities"]
|
| 18 |
+
return [Node.map_json_to_node(entity) for entity in entities]
|
| 19 |
+
except Exception as e:
|
| 20 |
+
print(f"Error while extract knowledge entities: {str(e)}")
|
| 21 |
+
return []
|
| 22 |
+
|
| 23 |
+
def clean_text(text: str):
|
| 24 |
+
text = text.replace("```json", "").replace("```", "")
|
| 25 |
+
return text
|
app/utils/prompt.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
EXTRACT_ENTITIES_PROMPT = """
|
| 2 |
+
Từ mô tả bên dưới, hãy trích xuất các Thực thể được mô tả theo định dạng được chỉ định. Đảm bảo kết quả hoàn chỉnh, không thiếu thông tin.
|
| 3 |
+
0. LUÔN LUÔN HOÀN TẤT KẾT QUẢ. Không gửi kết quả bị thiếu
|
| 4 |
+
1. Trích xuất Thực thể (entities)
|
| 5 |
+
- Mỗi thực thể phải có thuộc tính `id` là chuỗi chữ và số duy nhất, định dạng **camelCase** (ví dụ: `benhDaoOn`, `laVang`). Thuộc tính `id` được sử dụng để liên kết trong mối quan hệ.
|
| 6 |
+
- Chỉ tạo thực thể thuộc các loại được liệt kê, không tạo loại mới.
|
| 7 |
+
- Đảm bảo các thuộc tính (`name`, `description`, v.v.) khớp với nội dung văn bản.
|
| 8 |
+
|
| 9 |
+
Các loại thực thể:
|
| 10 |
+
- **Disease**: Tình trạng cây bị hại bởi vi sinh vật, nấm, hoặc yếu tố môi trường. Đảm bảo trong ngữ cảnh đầu vào chỉ đang nói đến một bệnh duy nhất.
|
| 11 |
+
- `id`: Tên bệnh ở dạng camelCase, bao gồm thông tin cây trồng nếu bệnh xuất hiện trên cây cụ thể (ví dụ: `benhDomNauSan` cho bệnh đốm nâu trên sắn, `benhDomNauCaChua` cho bệnh đốm nâu trên cà chua). Không có các giới từ như "trên".
|
| 12 |
+
- `name`: Tên bệnh trong văn bản, bao gồm thông tin cây trồng nếu có (ví dụ: "Bệnh đốm nâu trên sắn"). Nếu không có thông tin cây trồng, sử dụng tên bệnh chung (ví dụ: "Bệnh đốm nâu").
|
| 13 |
+
- `description`: Mô tả tình trạng bệnh, ưu tiên đề cập cây trồng nếu có (ví dụ: "Bệnh đốm nâu trên cây sắn do nấm gây ra"). Nếu không có thông tin, dùng "Không có mô tả cụ thể".
|
| 14 |
+
|
| 15 |
+
- **Symptom**: Dấu hiệu bất thường trên cây (lá vàng, héo, đốm).
|
| 16 |
+
- `id`: Tên triệu chứng ở dạng camelCase (ví dụ: `laVang`).
|
| 17 |
+
- `name`: Tên triệu chứng trong văn bản (ví dụ: "Lá vàng").
|
| 18 |
+
- `description`: Mô tả triệu chứng.
|
| 19 |
+
|
| 20 |
+
- **Treatment**: Biện pháp kiểm soát bệnh/sâu hại (thuốc, sinh học).
|
| 21 |
+
- `id`: Tên biện pháp ở dạng camelCase, gắn với hoạt chất hoặc loại thuốc cụ thể (ví dụ: `thuocDietNamThiophanate`).
|
| 22 |
+
- `name`: Tên biện pháp trong văn bản, phản ánh hoạt chất hoặc loại thuốc (ví dụ: "Thuốc Diệt Nấm chứa Thiophanate").
|
| 23 |
+
- `method`: Cách thực hiện biện pháp. (ví dụ: "Phun thuốc lên lá")
|
| 24 |
+
- `activeIngredient` (tùy chọn): Tên hoạt chất chính, bao gồm nồng độ nếu có (ví dụ: "Thiophanate 0.20%"). Nếu không xác định, để trống.
|
| 25 |
+
|
| 26 |
+
- **Cause**: Tác nhân gây bệnh/sâu hại (nấm, virus, côn trùng).
|
| 27 |
+
- `id`: Tên tác nhân ở dạng camelCase, gắn liền với tên của tác nhân viết gọn (ví dụ: `namMHenningsii`).
|
| 28 |
+
- `name`: Tên tác nhân trong văn bản (ví dụ: "Nấm Mycosphaerella henningsi")
|
| 29 |
+
- `type`: Loại tác nhân (nấm, virus, côn trùng, vi khuẩn, ...).
|
| 30 |
+
|
| 31 |
+
- **Effect**: Tác động của bệnh/sâu hại (giảm năng suất, cây chết).
|
| 32 |
+
- `id`: Tên tác động ở dạng camelCase, sử dụng dạng ngắn gọn và chung nhất (ví dụ: `giamNangSuat` cho mọi trường hợp liên quan đến giảm năng suất, thay vì `nangSuatGiamDangKe`).
|
| 33 |
+
- `name`: Tên tác động được chuẩn hóa, sử dụng dạng chung nhất từ văn bản (ví dụ: "Giảm năng suất" thay vì "Giảm năng suất đáng kể"). Loại bỏ các từ ngữ bổ nghĩa như "đáng kể", "nghiêm trọng".
|
| 34 |
+
- `impact`: Mô tả ngắn gọn mức độ ảnh hưởng, ưu tiên sử dụng cụm từ chung (ví dụ: "Ảnh hưởng đến sản lượng" thay vì sao chép toàn bộ mô tả chi tiết từ văn bản).
|
| 35 |
+
|
| 36 |
+
- **Prevention**: Biện pháp ngăn ngừa bệnh/sâu hại (luân canh, giống kháng).
|
| 37 |
+
- `id`: Tên biện pháp ở dạng camelCase (ví dụ: `luanCanh`).
|
| 38 |
+
- `name`: Tên biện pháp trong văn bản.
|
| 39 |
+
- `method`: Cách thực hiện biện pháp.
|
| 40 |
+
|
| 41 |
+
- **EnvironmentalFactor**: Yếu tố tự nhiên ảnh hưởng cây (nhiệt độ, độ ẩm, ...).
|
| 42 |
+
- `id`: Tên yếu tố ở dạng camelCase (ví dụ: `doAmCao`).
|
| 43 |
+
- `name`: Tên yếu tố trong văn bản.
|
| 44 |
+
- `description`: Mô tả yếu tố.
|
| 45 |
+
|
| 46 |
+
- **Stage**: Giai đoạn phát triển của cây.
|
| 47 |
+
- `id`: Tên giai đoạn ở dạng camelCase (ví dụ: `giaiDoanRaHoa`).
|
| 48 |
+
- `start`: Thời gian bắt đầu (tháng, kiểu float).
|
| 49 |
+
- `end`: Thời gian kết thúc (tháng, kiểu float).
|
| 50 |
+
|
| 51 |
+
- **Crop**: Cây trồng, không được tạo ngoài danh sách: "Lúa", "Sắn", "Cà chua", "Ngô"
|
| 52 |
+
- `id`: Tên cây ở dạng camelCase, chỉ nằm trong danh sách: [lua,san,caChua,ngo].
|
| 53 |
+
- `name`: Tên cây trong văn bản (ví dụ: "Sắn").
|
| 54 |
+
|
| 55 |
+
- **CropType**: Phân loại cây (lương thực, ăn quả, công nghiệp).
|
| 56 |
+
- `id`: Tên loại cây ở dạng camelCase (ví dụ: `luongThuc`).
|
| 57 |
+
- `name`: Tên loại cây trong văn bản.
|
| 58 |
+
|
| 59 |
+
- **PlantPart**: Phần cây bị ảnh hưởng (lá, thân, rễ, quả).
|
| 60 |
+
- `id`: Tên phần cây ở dạng camelCase (ví dụ: `la`).
|
| 61 |
+
- `name`: Tên phần cây trong văn bản.
|
| 62 |
+
|
| 63 |
+
- **SoilType**: Loại đất trồng cây.
|
| 64 |
+
- `id`: Tên loại đất ở dạng camelCase (ví dụ: `datPhuSa`).
|
| 65 |
+
- `name`: Tên loại đất trong văn bản.
|
| 66 |
+
|
| 67 |
+
- **DiagnosisMethod**: Kỹ thuật xác định bệnh/sâu hại (quan sát, xét nghiệm).
|
| 68 |
+
- `id`: Tên kỹ thuật ở dạng camelCase (ví dụ: `quanSat`).
|
| 69 |
+
- `name`: Tên kỹ thuật trong văn bản.
|
| 70 |
+
- `technique`: Cách thực hiện kỹ thuật.
|
| 71 |
+
|
| 72 |
+
2. Trả về kết quả dưới dạng JSON:
|
| 73 |
+
- Trả về JSON với một trường duy nhất là `entities`
|
| 74 |
+
- `entities`: Danh sách các thực thể, mỗi thực thể là một object với các thuộc tính theo loại.
|
| 75 |
+
|
| 76 |
+
Ví dụ:
|
| 77 |
+
```json
|
| 78 |
+
{
|
| 79 |
+
"entities": [{"label":"Disease","id":string,"name":string,"description":string}]
|
| 80 |
+
}
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
Ngữ cảnh:
|
| 84 |
+
$ctext
|
| 85 |
+
"""
|
environment.yml
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: graduated2
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
- https://repo.anaconda.com/pkgs/main
|
| 5 |
+
- https://repo.anaconda.com/pkgs/r
|
| 6 |
+
dependencies:
|
| 7 |
+
- ca-certificates=2025.2.25=hecd8cb5_0
|
| 8 |
+
- libcxx=14.0.6=h9765a3e_0
|
| 9 |
+
- libffi=3.4.4=hecd8cb5_1
|
| 10 |
+
- ncurses=6.4=hcec6c5f_0
|
| 11 |
+
- openssl=3.0.16=h184c1cd_0
|
| 12 |
+
- pip=25.1=pyhc872135_2
|
| 13 |
+
- python=3.9.21=hce00570_1
|
| 14 |
+
- readline=8.2=hca72f7f_0
|
| 15 |
+
- setuptools=78.1.1=py39hecd8cb5_0
|
| 16 |
+
- sqlite=3.45.3=h6c40b1e_0
|
| 17 |
+
- tk=8.6.14=h4d00af3_0
|
| 18 |
+
- tzdata=2025b=h04d1e81_0
|
| 19 |
+
- wheel=0.45.1=py39hecd8cb5_0
|
| 20 |
+
- xz=5.6.4=h46256e1_1
|
| 21 |
+
- zlib=1.2.13=h4b97444_1
|
| 22 |
+
- pip:
|
| 23 |
+
- aiohappyeyeballs==2.6.1
|
| 24 |
+
- aiohttp==3.11.18
|
| 25 |
+
- aiosignal==1.3.2
|
| 26 |
+
- annotated-types==0.7.0
|
| 27 |
+
- anyio==4.9.0
|
| 28 |
+
- asgiref==3.8.1
|
| 29 |
+
- async-timeout==4.0.3
|
| 30 |
+
- attrs==25.3.0
|
| 31 |
+
- backoff==2.2.1
|
| 32 |
+
- bcrypt==4.3.0
|
| 33 |
+
- build==1.2.2.post1
|
| 34 |
+
- cachetools==5.5.2
|
| 35 |
+
- certifi==2025.4.26
|
| 36 |
+
- charset-normalizer==3.4.2
|
| 37 |
+
- chromadb==1.0.8
|
| 38 |
+
- click==8.1.8
|
| 39 |
+
- coloredlogs==15.0.1
|
| 40 |
+
- dataclasses-json==0.6.7
|
| 41 |
+
- deprecated==1.2.18
|
| 42 |
+
- distro==1.9.0
|
| 43 |
+
- durationpy==0.9
|
| 44 |
+
- exceptiongroup==1.3.0
|
| 45 |
+
- fastapi==0.115.9
|
| 46 |
+
- filelock==3.18.0
|
| 47 |
+
- filetype==1.2.0
|
| 48 |
+
- flatbuffers==25.2.10
|
| 49 |
+
- frozenlist==1.6.0
|
| 50 |
+
- fsspec==2024.12.0
|
| 51 |
+
- google-ai-generativelanguage==0.6.18
|
| 52 |
+
- google-api-core==2.24.2
|
| 53 |
+
- google-auth==2.40.1
|
| 54 |
+
- googleapis-common-protos==1.70.0
|
| 55 |
+
- greenlet==3.2.2
|
| 56 |
+
- grpcio==1.72.0rc1
|
| 57 |
+
- grpcio-status==1.72.0rc1
|
| 58 |
+
- h11==0.16.0
|
| 59 |
+
- hf-xet==1.1.0
|
| 60 |
+
- httpcore==1.0.9
|
| 61 |
+
- httptools==0.6.4
|
| 62 |
+
- httpx==0.28.1
|
| 63 |
+
- httpx-sse==0.4.0
|
| 64 |
+
- huggingface-hub==0.31.1
|
| 65 |
+
- humanfriendly==10.0
|
| 66 |
+
- idna==3.10
|
| 67 |
+
- importlib-metadata==8.6.1
|
| 68 |
+
- importlib-resources==6.5.2
|
| 69 |
+
- jinja2==3.1.6
|
| 70 |
+
- jiter==0.10.0
|
| 71 |
+
- json-repair==0.39.1
|
| 72 |
+
- jsonpatch==1.33
|
| 73 |
+
- jsonpointer==3.0.0
|
| 74 |
+
- jsonschema==4.23.0
|
| 75 |
+
- jsonschema-specifications==2025.4.1
|
| 76 |
+
- kubernetes==32.0.1
|
| 77 |
+
- langchain==0.3.25
|
| 78 |
+
- langchain-community==0.3.24
|
| 79 |
+
- langchain-core==0.3.60
|
| 80 |
+
- langchain-google-genai==2.1.4
|
| 81 |
+
- langchain-neo4j==0.4.0
|
| 82 |
+
- langchain-text-splitters==0.3.8
|
| 83 |
+
- langsmith==0.3.42
|
| 84 |
+
- markdown-it-py==3.0.0
|
| 85 |
+
- markupsafe==3.0.2
|
| 86 |
+
- marshmallow==3.26.1
|
| 87 |
+
- mdurl==0.1.2
|
| 88 |
+
- mmh3==5.1.0
|
| 89 |
+
- mpmath==1.3.0
|
| 90 |
+
- multidict==6.4.4
|
| 91 |
+
- mypy-extensions==1.1.0
|
| 92 |
+
- neo4j==5.28.1
|
| 93 |
+
- neo4j-graphrag==1.7.0
|
| 94 |
+
- networkx==3.2.1
|
| 95 |
+
- numpy==2.0.2
|
| 96 |
+
- oauthlib==3.2.2
|
| 97 |
+
- onnxruntime==1.19.2
|
| 98 |
+
- openai==1.79.0
|
| 99 |
+
- opentelemetry-api==1.33.0
|
| 100 |
+
- opentelemetry-exporter-otlp-proto-common==1.33.0
|
| 101 |
+
- opentelemetry-exporter-otlp-proto-grpc==1.33.0
|
| 102 |
+
- opentelemetry-instrumentation==0.54b0
|
| 103 |
+
- opentelemetry-instrumentation-asgi==0.54b0
|
| 104 |
+
- opentelemetry-instrumentation-fastapi==0.54b0
|
| 105 |
+
- opentelemetry-proto==1.33.0
|
| 106 |
+
- opentelemetry-sdk==1.33.0
|
| 107 |
+
- opentelemetry-semantic-conventions==0.54b0
|
| 108 |
+
- opentelemetry-util-http==0.54b0
|
| 109 |
+
- orjson==3.10.18
|
| 110 |
+
- overrides==7.7.0
|
| 111 |
+
- packaging==24.2
|
| 112 |
+
- pillow==11.2.1
|
| 113 |
+
- posthog==4.0.1
|
| 114 |
+
- propcache==0.3.1
|
| 115 |
+
- proto-plus==1.26.1
|
| 116 |
+
- protobuf==6.31.0
|
| 117 |
+
- pyasn1==0.6.1
|
| 118 |
+
- pyasn1-modules==0.4.2
|
| 119 |
+
- pydantic==2.11.4
|
| 120 |
+
- pydantic-core==2.33.2
|
| 121 |
+
- pydantic-settings==2.9.1
|
| 122 |
+
- pygments==2.19.1
|
| 123 |
+
- pypdf==5.5.0
|
| 124 |
+
- pypika==0.48.9
|
| 125 |
+
- pyproject-hooks==1.2.0
|
| 126 |
+
- python-dateutil==2.9.0.post0
|
| 127 |
+
- python-dotenv==1.1.0
|
| 128 |
+
- pytz==2025.2
|
| 129 |
+
- pyyaml==6.0.2
|
| 130 |
+
- referencing==0.36.2
|
| 131 |
+
- regex==2024.11.6
|
| 132 |
+
- requests==2.32.3
|
| 133 |
+
- requests-oauthlib==2.0.0
|
| 134 |
+
- requests-toolbelt==1.0.0
|
| 135 |
+
- rich==14.0.0
|
| 136 |
+
- rpds-py==0.24.0
|
| 137 |
+
- rsa==4.9.1
|
| 138 |
+
- safetensors==0.5.3
|
| 139 |
+
- shellingham==1.5.4
|
| 140 |
+
- six==1.17.0
|
| 141 |
+
- sniffio==1.3.1
|
| 142 |
+
- sqlalchemy==2.0.41
|
| 143 |
+
- starlette==0.45.3
|
| 144 |
+
- sympy==1.14.0
|
| 145 |
+
- tenacity==9.1.2
|
| 146 |
+
- tokenizers==0.21.1
|
| 147 |
+
- tomli==2.2.1
|
| 148 |
+
- torch==2.2.2
|
| 149 |
+
- torchvision==0.17.2
|
| 150 |
+
- tqdm==4.67.1
|
| 151 |
+
- transformers==4.51.3
|
| 152 |
+
- typer==0.15.3
|
| 153 |
+
- types-pyyaml==6.0.12.20250516
|
| 154 |
+
- typing-extensions==4.13.2
|
| 155 |
+
- typing-inspect==0.9.0
|
| 156 |
+
- typing-inspection==0.4.0
|
| 157 |
+
- urllib3==2.4.0
|
| 158 |
+
- uvicorn==0.34.2
|
| 159 |
+
- uvloop==0.21.0
|
| 160 |
+
- watchfiles==1.0.5
|
| 161 |
+
- websocket-client==1.8.0
|
| 162 |
+
- websockets==15.0.1
|
| 163 |
+
- wrapt==1.17.2
|
| 164 |
+
- yarl==1.20.0
|
| 165 |
+
- zipp==3.21.0
|
| 166 |
+
- zstandard==0.23.0
|
| 167 |
+
prefix: /Users/artteiv/miniconda3/envs/graduated2
|
prepare_script/image_caption_embeddings.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sqlite3
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import clip
|
| 9 |
+
import faiss
|
| 10 |
+
import numpy as np
|
| 11 |
+
import glob
|
| 12 |
+
|
| 13 |
+
# Đường dẫn lưu trữ
|
| 14 |
+
VECTOR_EMBEDDINGS_DB_PATH = 'app/data/vector_embeddings.db'
|
| 15 |
+
IMAGE_FAISS_INDEX_PATH = 'app/data/image_faiss_index.index'
|
| 16 |
+
TEXT_FAISS_INDEX_PATH = 'app/data/text_faiss_index.index'
|
| 17 |
+
|
| 18 |
+
# Đường dẫn dữ liệu
|
| 19 |
+
DATA_ROOT = '/Users/artteiv/Desktop/Graduated/chore-graduated/Data'
|
| 20 |
+
MAIN_DATA_PATH = os.path.join(DATA_ROOT, 'main_data')
|
| 21 |
+
CAPTIONS_PATH = os.path.join(DATA_ROOT, 'captions')
|
| 22 |
+
|
| 23 |
+
# Kết nối SQLite
|
| 24 |
+
conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH)
|
| 25 |
+
cursor = conn.cursor()
|
| 26 |
+
|
| 27 |
+
# Tạo bảng embeddings cho ảnh và văn bản
|
| 28 |
+
cursor.execute('''
|
| 29 |
+
CREATE TABLE IF NOT EXISTS image_embeddings (
|
| 30 |
+
e_index INTEGER PRIMARY KEY,
|
| 31 |
+
image_path TEXT NOT NULL,
|
| 32 |
+
caption TEXT NOT NULL,
|
| 33 |
+
category TEXT NOT NULL,
|
| 34 |
+
subcategory TEXT NOT NULL
|
| 35 |
+
)
|
| 36 |
+
''')
|
| 37 |
+
|
| 38 |
+
cursor.execute('''
|
| 39 |
+
CREATE TABLE IF NOT EXISTS text_embeddings (
|
| 40 |
+
e_index INTEGER PRIMARY KEY,
|
| 41 |
+
text TEXT NOT NULL,
|
| 42 |
+
category TEXT NOT NULL,
|
| 43 |
+
subcategory TEXT NOT NULL
|
| 44 |
+
)
|
| 45 |
+
''')
|
| 46 |
+
|
| 47 |
+
def insert_image_embedding(e_index, image_path, caption, category, subcategory):
|
| 48 |
+
"""Thêm embedding ảnh vào SQLite."""
|
| 49 |
+
cursor.execute('''
|
| 50 |
+
INSERT INTO image_embeddings (e_index, image_path, caption, category, subcategory)
|
| 51 |
+
VALUES (?, ?, ?, ?, ?)
|
| 52 |
+
''', (e_index, image_path, caption, category, subcategory))
|
| 53 |
+
conn.commit()
|
| 54 |
+
print(f"Đã thêm embedding ảnh: {image_path}")
|
| 55 |
+
|
| 56 |
+
def insert_text_embedding(e_index, text, category, subcategory):
|
| 57 |
+
"""Thêm embedding văn bản vào SQLite."""
|
| 58 |
+
cursor.execute('''
|
| 59 |
+
INSERT INTO text_embeddings (e_index, text, category, subcategory)
|
| 60 |
+
VALUES (?, ?, ?, ?)
|
| 61 |
+
''', (e_index, text, category, subcategory))
|
| 62 |
+
conn.commit()
|
| 63 |
+
print(f"Đã thêm embedding văn bản: {text[:50]}...")
|
| 64 |
+
|
| 65 |
+
def save_faiss_index(index, index_file):
|
| 66 |
+
"""Lưu FAISS index vào file."""
|
| 67 |
+
faiss.write_index(index, index_file)
|
| 68 |
+
print(f"Đã lưu FAISS index vào {index_file}")
|
| 69 |
+
|
| 70 |
+
def load_faiss_index(index_file):
|
| 71 |
+
"""Nạp FAISS index từ file."""
|
| 72 |
+
if os.path.exists(index_file):
|
| 73 |
+
index = faiss.read_index(index_file)
|
| 74 |
+
print(f"Đã nạp FAISS index từ {index_file}")
|
| 75 |
+
return index
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
def compute_embeddings():
|
| 79 |
+
"""Tính toán embeddings cho ảnh và văn bản sử dụng CLIP."""
|
| 80 |
+
print("Loading CLIP model...")
|
| 81 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 82 |
+
model, preprocess = clip.load("ViT-B/32", device=device)
|
| 83 |
+
print("Model loaded")
|
| 84 |
+
|
| 85 |
+
# Lấy danh sách các thư mục con (categories)
|
| 86 |
+
categories = [d for d in os.listdir(MAIN_DATA_PATH) if os.path.isdir(os.path.join(MAIN_DATA_PATH, d))]
|
| 87 |
+
|
| 88 |
+
image_paths = []
|
| 89 |
+
captions = []
|
| 90 |
+
texts = []
|
| 91 |
+
categories_list = []
|
| 92 |
+
subcategories_list = []
|
| 93 |
+
|
| 94 |
+
# Chuẩn bị dữ liệu
|
| 95 |
+
print("Processing data from directories...")
|
| 96 |
+
for category in categories:
|
| 97 |
+
# Đường dẫn đến thư mục category
|
| 98 |
+
category_path = os.path.join(MAIN_DATA_PATH, category)
|
| 99 |
+
|
| 100 |
+
# Lấy danh sách các subcategories
|
| 101 |
+
subcategories = [d for d in os.listdir(category_path) if os.path.isdir(os.path.join(category_path, d))]
|
| 102 |
+
|
| 103 |
+
for subcategory in subcategories:
|
| 104 |
+
# Đường dẫn đến thư mục ảnh và caption của subcategory
|
| 105 |
+
subcategory_image_path = os.path.join(category_path, subcategory)
|
| 106 |
+
subcategory_caption_path = os.path.join(CAPTIONS_PATH, category, subcategory)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# Lấy danh sách ảnh
|
| 110 |
+
image_files = glob.glob(os.path.join(subcategory_image_path, '*.*'))
|
| 111 |
+
|
| 112 |
+
for img_path in image_files:
|
| 113 |
+
# Lấy tên file không có phần mở rộng
|
| 114 |
+
base_name = os.path.splitext(os.path.basename(img_path))[0]
|
| 115 |
+
caption_file = os.path.join(subcategory_caption_path, f"{base_name}.txt")
|
| 116 |
+
|
| 117 |
+
if os.path.exists(caption_file):
|
| 118 |
+
try:
|
| 119 |
+
# Đọc caption
|
| 120 |
+
with open(caption_file, 'r', encoding='utf-8') as f:
|
| 121 |
+
caption = f.read().strip()
|
| 122 |
+
|
| 123 |
+
# Thêm vào danh sách
|
| 124 |
+
image_paths.append(img_path)
|
| 125 |
+
captions.append(caption)
|
| 126 |
+
texts.append(caption) # Sử dụng caption làm text
|
| 127 |
+
categories_list.append(category)
|
| 128 |
+
subcategories_list.append(subcategory)
|
| 129 |
+
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f"Error processing {img_path}: {e}")
|
| 132 |
+
continue
|
| 133 |
+
|
| 134 |
+
# Tính toán embeddings cho ảnh
|
| 135 |
+
# if image_paths:
|
| 136 |
+
# print("Computing image embeddings...")
|
| 137 |
+
# image_embeddings = []
|
| 138 |
+
# for idx, img_path in enumerate(image_paths):
|
| 139 |
+
# try:
|
| 140 |
+
# image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
|
| 141 |
+
# with torch.no_grad():
|
| 142 |
+
# image_features = model.encode_image(image)
|
| 143 |
+
# image_features = image_features.cpu().numpy()
|
| 144 |
+
# faiss.normalize_L2(image_features)
|
| 145 |
+
# image_embeddings.append(image_features[0])
|
| 146 |
+
# insert_image_embedding(idx, img_path, captions[idx], categories_list[idx], subcategories_list[idx])
|
| 147 |
+
# except Exception as e:
|
| 148 |
+
# print(f"Error processing image {img_path}: {e}")
|
| 149 |
+
# continue
|
| 150 |
+
|
| 151 |
+
# if image_embeddings:
|
| 152 |
+
# image_embeddings = np.array(image_embeddings)
|
| 153 |
+
# d = image_embeddings.shape[1]
|
| 154 |
+
# image_index = faiss.IndexFlatIP(d)
|
| 155 |
+
# image_index.add(image_embeddings)
|
| 156 |
+
# save_faiss_index(image_index, IMAGE_FAISS_INDEX_PATH)
|
| 157 |
+
|
| 158 |
+
# Tính toán embeddings cho văn bản
|
| 159 |
+
if texts:
|
| 160 |
+
print("Computing text embeddings...")
|
| 161 |
+
text_tokens = clip.tokenize(texts, truncate=True).to(device)
|
| 162 |
+
print("Kích thước của text_tokens:", text_tokens.shape)
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
text_features = model.encode_text(text_tokens)
|
| 165 |
+
text_features = text_features.cpu().numpy()
|
| 166 |
+
faiss.normalize_L2(text_features)
|
| 167 |
+
|
| 168 |
+
d = text_features.shape[1]
|
| 169 |
+
text_index = faiss.IndexFlatIP(d)
|
| 170 |
+
text_index.add(text_features)
|
| 171 |
+
|
| 172 |
+
# Lưu text embeddings vào SQLite
|
| 173 |
+
for idx, (text, category, subcategory) in enumerate(zip(texts, categories_list, subcategories_list)):
|
| 174 |
+
insert_text_embedding(idx, text, category, subcategory)
|
| 175 |
+
|
| 176 |
+
save_faiss_index(text_index, TEXT_FAISS_INDEX_PATH)
|
| 177 |
+
|
| 178 |
+
print("Processing completed")
|
| 179 |
+
return image_index if image_paths else None, text_index if texts else None
|
| 180 |
+
|
| 181 |
+
def predict_image(image_path):
|
| 182 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 183 |
+
model, preprocess = clip.load("ViT-B/32", device=device)
|
| 184 |
+
|
| 185 |
+
image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
|
| 186 |
+
with torch.no_grad():
|
| 187 |
+
image_features = model.encode_image(image)
|
| 188 |
+
image_features = image_features.cpu().numpy()
|
| 189 |
+
faiss.normalize_L2(image_features)
|
| 190 |
+
|
| 191 |
+
index = load_faiss_index(IMAGE_FAISS_INDEX_PATH)
|
| 192 |
+
distances, indices = index.search(image_features, k=10)
|
| 193 |
+
|
| 194 |
+
return distances, indices
|
| 195 |
+
|
| 196 |
+
if __name__ == '__main__':
|
| 197 |
+
## Predict
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
image_index, text_index = compute_embeddings()
|
| 201 |
+
if image_index:
|
| 202 |
+
print(f"Image index ready with {image_index.ntotal} embeddings")
|
| 203 |
+
if text_index:
|
| 204 |
+
print(f"Text index ready with {text_index.ntotal} embeddings")
|
| 205 |
+
finally:
|
| 206 |
+
conn.close()
|
| 207 |
+
print("SQLite connection closed")
|
prepare_script/sync_neo4j_node.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sqlite3
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 6 |
+
from app.models.knowledge_graph import Neo4jConnection
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
from pyvi.ViTokenizer import tokenize
|
| 9 |
+
import faiss
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
Script này thực hiện lấy các entity từ neo4j từ xa về và tạo ra các data lưu trong sqlite, đồng thời tạo các embeddings
|
| 14 |
+
dựa trên từng row.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
# Kết nối SQLite
|
| 18 |
+
VECTOR_EMBEDDINGS_DB_PATH = 'app/data/vector_embeddings.db'
|
| 19 |
+
FAISS_INDEX_PATH = 'app/data/faiss_index.index'
|
| 20 |
+
|
| 21 |
+
conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH)
|
| 22 |
+
cursor = conn.cursor()
|
| 23 |
+
|
| 24 |
+
# Tạo bảng embeddings nếu chưa tồn tại
|
| 25 |
+
cursor.execute('''
|
| 26 |
+
CREATE TABLE IF NOT EXISTS embeddings (
|
| 27 |
+
e_index INTEGER PRIMARY KEY,
|
| 28 |
+
id TEXT NOT NULL,
|
| 29 |
+
name TEXT NOT NULL,
|
| 30 |
+
label TEXT NOT NULL,
|
| 31 |
+
properties TEXT NOT NULL
|
| 32 |
+
)
|
| 33 |
+
''')
|
| 34 |
+
|
| 35 |
+
def insert_embedding(e_index, id, name, label, properties):
|
| 36 |
+
"""Thêm embedding vào SQLite."""
|
| 37 |
+
cursor.execute('''
|
| 38 |
+
INSERT INTO embeddings (e_index, id, name, label, properties)
|
| 39 |
+
VALUES (?, ?, ?, ?, ?)
|
| 40 |
+
''', (e_index, id, name, label, json.dumps(properties)))
|
| 41 |
+
conn.commit()
|
| 42 |
+
print(f"Đã thêm embedding: {name}")
|
| 43 |
+
|
| 44 |
+
def update_embedding(embedding_id, id, name, label, properties):
|
| 45 |
+
"""Cập nhật embedding trong SQLite."""
|
| 46 |
+
cursor.execute('''
|
| 47 |
+
UPDATE embeddings
|
| 48 |
+
SET id = ?, name = ?, label = ?, properties = ?
|
| 49 |
+
WHERE e_index = ?
|
| 50 |
+
''', (id, name, label, json.dumps(properties), embedding_id))
|
| 51 |
+
conn.commit()
|
| 52 |
+
print(f"Đã cập nhật embedding ID: {embedding_id}")
|
| 53 |
+
|
| 54 |
+
def get_all_embeddings():
|
| 55 |
+
"""Lấy tất cả embeddings từ SQLite."""
|
| 56 |
+
cursor.execute('SELECT * FROM embeddings')
|
| 57 |
+
return cursor.fetchall()
|
| 58 |
+
|
| 59 |
+
def get_embedding_by_id(embedding_id):
|
| 60 |
+
"""Lấy embedding theo e_index từ SQLite."""
|
| 61 |
+
cursor.execute('SELECT * FROM embeddings WHERE e_index = ?', (embedding_id,))
|
| 62 |
+
return cursor.fetchone()
|
| 63 |
+
|
| 64 |
+
def save_faiss_index(index, index_file=FAISS_INDEX_PATH):
|
| 65 |
+
"""Lưu FAISS index vào file."""
|
| 66 |
+
faiss.write_index(index, index_file)
|
| 67 |
+
print(f"Đã lưu FAISS index vào {index_file}")
|
| 68 |
+
|
| 69 |
+
def load_faiss_index(index_file=FAISS_INDEX_PATH):
|
| 70 |
+
"""Nạp FAISS index từ file."""
|
| 71 |
+
if os.path.exists(index_file):
|
| 72 |
+
index = faiss.read_index(index_file)
|
| 73 |
+
print(f"Đã nạp FAISS index từ {index_file}")
|
| 74 |
+
return index
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
def compute_and_save_embeddings(index_file=FAISS_INDEX_PATH):
|
| 78 |
+
"""Tính toán embeddings, lưu vào FAISS và đồng bộ metadata vào SQLite."""
|
| 79 |
+
print("Loading model...")
|
| 80 |
+
model = SentenceTransformer('dangvantuan/vietnamese-embedding')
|
| 81 |
+
print("Model loaded")
|
| 82 |
+
|
| 83 |
+
# Lấy dữ liệu từ Neo4j
|
| 84 |
+
neo4j = Neo4jConnection()
|
| 85 |
+
result = neo4j.execute_query("MATCH (n) RETURN n")
|
| 86 |
+
corpus = []
|
| 87 |
+
|
| 88 |
+
# Chuẩn bị corpus và lưu metadata vào SQLite
|
| 89 |
+
print("Processing Neo4j data and saving to SQLite...")
|
| 90 |
+
for index, record in enumerate(result):
|
| 91 |
+
print(record)
|
| 92 |
+
label = list(record["n"].labels)[0]
|
| 93 |
+
print(label)
|
| 94 |
+
embedding = dict(record["n"])
|
| 95 |
+
id = embedding.pop('id')
|
| 96 |
+
name = embedding.pop('name') if 'name' in embedding else id
|
| 97 |
+
properties = embedding
|
| 98 |
+
corpus.append(name)
|
| 99 |
+
|
| 100 |
+
# Kiểm tra và cập nhật/thêm vào SQLite
|
| 101 |
+
cursor.execute('SELECT e_index FROM embeddings WHERE e_index = ?', (index,))
|
| 102 |
+
existing = cursor.fetchone()
|
| 103 |
+
if existing:
|
| 104 |
+
update_embedding(index, id, name, label, properties)
|
| 105 |
+
else:
|
| 106 |
+
insert_embedding(index, id, name, label, properties)
|
| 107 |
+
|
| 108 |
+
# Tính toán embeddings
|
| 109 |
+
print("Tokenizing and encoding...")
|
| 110 |
+
tokenized = [tokenize(s) for s in corpus]
|
| 111 |
+
embeddings = model.encode(tokenized, show_progress_bar=False)
|
| 112 |
+
print("Encoding done")
|
| 113 |
+
|
| 114 |
+
# Chuẩn hóa embeddings
|
| 115 |
+
print("Normalizing...")
|
| 116 |
+
faiss.normalize_L2(embeddings)
|
| 117 |
+
print("Normalized")
|
| 118 |
+
|
| 119 |
+
# Tạo và lưu FAISS index
|
| 120 |
+
d = embeddings.shape[1]
|
| 121 |
+
index = faiss.IndexFlatIP(d)
|
| 122 |
+
index.add(embeddings)
|
| 123 |
+
save_faiss_index(index, index_file)
|
| 124 |
+
|
| 125 |
+
print("Processing completed")
|
| 126 |
+
return index, corpus, embeddings
|
| 127 |
+
|
| 128 |
+
def load_or_compute_embeddings(index_file=FAISS_INDEX_PATH):
|
| 129 |
+
"""Nạp hoặc tính toán embeddings và FAISS index."""
|
| 130 |
+
# Thử nạp FAISS index
|
| 131 |
+
index = load_faiss_index(index_file)
|
| 132 |
+
|
| 133 |
+
# Lấy corpus từ SQLite
|
| 134 |
+
embeddings_data = get_all_embeddings()
|
| 135 |
+
corpus = [row[2] for row in embeddings_data] # Lấy cột name
|
| 136 |
+
|
| 137 |
+
if index is None or not corpus:
|
| 138 |
+
print("No saved index or corpus found, computing new ones...")
|
| 139 |
+
index, corpus, embeddings = compute_and_save_embeddings(index_file)
|
| 140 |
+
else:
|
| 141 |
+
print("Loaded existing index and corpus")
|
| 142 |
+
|
| 143 |
+
return index, corpus
|
| 144 |
+
|
| 145 |
+
def get_qvec_by_text(model, text):
|
| 146 |
+
q_token = tokenize(text)
|
| 147 |
+
q_vec = model.encode([q_token])
|
| 148 |
+
faiss.normalize_L2(q_vec)
|
| 149 |
+
return q_vec
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
try:
|
| 153 |
+
index, corpus = load_or_compute_embeddings()
|
| 154 |
+
print(f"Index ready with {index.ntotal} embeddings, corpus size: {len(corpus)}")
|
| 155 |
+
model = SentenceTransformer('dangvantuan/vietnamese-embedding')
|
| 156 |
+
while True:
|
| 157 |
+
try:
|
| 158 |
+
query = input("Nhập câu truy vấn (nhấn Ctrl+C để thoát): ")
|
| 159 |
+
q_vec = get_qvec_by_text(model, query)
|
| 160 |
+
k = 1 # số kết quả cần lấy
|
| 161 |
+
D, I = index.search(q_vec, k)
|
| 162 |
+
print("Câu truy vấn:", query)
|
| 163 |
+
print(I[0][0])
|
| 164 |
+
print(type(I[0][0]))
|
| 165 |
+
print("Câu gần nhất:", get_embedding_by_id(int(I[0][0])), "(khoảng cách:", D[0][0], ")")
|
| 166 |
+
print("-" * 50)
|
| 167 |
+
except KeyboardInterrupt:
|
| 168 |
+
print("\nĐã dừng chương trình!")
|
| 169 |
+
break
|
| 170 |
+
finally:
|
| 171 |
+
conn.close()
|
| 172 |
+
print("SQLite connection closed")
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
langchain>=0.3.23
|
| 2 |
+
neo4j>=5.28.1
|
| 3 |
+
python-dotenv>=1.0.1
|
| 4 |
+
fastapi>=0.115.12
|
| 5 |
+
uvicorn>=0.34.2
|
| 6 |
+
pydantic>=2.10.6
|
| 7 |
+
faiss-cpu>=1.11.0
|
| 8 |
+
|