fs-lawrisk/lawrisk/utils/ingest_lawrisk.py

101 lines
2.7 KiB
Python
Raw Permalink Normal View History

2025-10-22 19:59:48 +08:00
"""Ingest subjects and permit mappings into PostgreSQL.
Reads risk_tables_export.json from repo root, embeds each subject name using
Aliyun DashScope embeddings (OpenAI-compatible), and stores into:
- law_sub(id TEXT PK, name TEXT, vector JSONB)
- law_sub_per(subject_id TEXT PK, permit_ids JSONB)
Usage:
python ingest_lawrisk.py
Ensure env var DASHSCOPE_API_KEY is set.
Optionally set PG_* vars (PG_HOST, PG_PORT, PG_USER, PG_PASSWORD, PG_DATABASE).
"""
from __future__ import annotations
import json
import os
from typing import List
from env_loader import load_env
from lawrisk_service import (
ensure_database,
ensure_schema,
EmbeddingClient,
upsert_subjects,
upsert_subject_permits,
upsert_permits,
)
REPO_JSON = os.getenv("LAWRISK_JSON", "risk_tables_export.json")
def main() -> None:
# Load .env first so service reads correct env
load_env()
ensure_database()
ensure_schema()
with open(REPO_JSON, "r", encoding="utf-8") as f:
data = json.load(f)
subjects = data.get("risk_subject", [])
permits = data.get("risk_permit", [])
# Prepare embeddings in small batches to avoid large payloads
client = EmbeddingClient()
batched_rows = []
BATCH = 32
names: List[str] = []
metas = [] # (id, name)
for s in subjects:
sid = s.get("id")
name = s.get("name")
if not sid or not name:
continue
names.append(name)
metas.append((sid, name))
if len(names) >= BATCH:
vecs = client.embed_texts(names)
for (sid, name), vec in zip(metas, vecs):
batched_rows.append((sid, name, vec))
names.clear()
metas.clear()
if names:
vecs = client.embed_texts(names)
for (sid, name), vec in zip(metas, vecs):
batched_rows.append((sid, name, vec))
upsert_subjects(batched_rows)
# Build subject->permit_ids mapping
per_rows = []
for s in subjects:
sid = s.get("id")
pids = s.get("permit_ids", [])
if sid and isinstance(pids, list):
# ensure strings
per_rows.append((sid, [str(x) for x in pids]))
if per_rows:
upsert_subject_permits(per_rows)
# Upsert permit catalog (id -> name)
per_catalog = []
for p in permits:
pid = p.get("id")
pname = p.get("name")
if pid and pname:
per_catalog.append((pid, pname))
if per_catalog:
upsert_permits(per_catalog)
print(
f"Ingested {len(batched_rows)} subjects, {len(per_rows)} subject-permit mappings, and {len(per_catalog)} permits into PostgreSQL."
)
if __name__ == "__main__":
main()