from bytedance.jsutils.db.base_dbo import DBO
from bytedance.jsutils.db.db_manager import DBManager
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column
from datetime import datetime
from typing import Any, Dict, List


Base = declarative_base()


class ToDictMixin:
    _default_select_columns: List[str]

    def to_dict(self) -> Dict[str, Any]:
        d = {}
        for col in self._default_select_columns:
            obj = getattr(self, col)
            if isinstance(obj, datetime):
                obj = obj.isoformat()
            d[col] = obj
        return d


class PCCHostedModel(Base, ToDictMixin):
    __tablename__ = 'pcc_hosted_models'
    _default_select_columns = ["id", "model_name", "tos_url", "baseline", "created_at"]

    id: Mapped[int] = mapped_column(
        primary_key=True, autoincrement=True, comment="ID")
    model_name: Mapped[str] = mapped_column(comment="模型名称", default="")
    tos_url: Mapped[str] = mapped_column(comment="模型存放的TOS地址", default="")
    baseline: Mapped[str] = mapped_column(comment="模型加密的密钥环ID", default="")
    created_at: Mapped[str] = mapped_column(comment="模型发布时间 YYYY-MM-DD HHmmSS", default="")


class PCCHostedModelDBO(DBO[PCCHostedModel]):
    _table = PCCHostedModel


def register_to_rds(rds_write_url: str = "", model_meta_data: Dict = {}):
    if rds_write_url == "":
        print("rds_write_url is missing, please pass rds_write_url to register_to_rds() function")
        raise Exception("rds_write_url is missing, please pass rds_write_url to register_to_rds() function")
    DBManager.init("mysql", {"PCCHostedModelTable": PCCHostedModelDBO}, rds_write_url)

    db_manager = DBManager.get_instance()

    model_name = model_meta_data.get("model_name", "")
    tos_url = model_meta_data.get("tos_url", "")
    baseline = model_meta_data.get("baseline", "")
    assert model_name != "", "model_name is missing"
    assert tos_url != "", "tos_url is missing"
    assert baseline != "", "baseline is missing"

    db_manager.PCCHostedModelTable.insert(
        model_name=model_name,
        tos_url=tos_url,
        baseline=baseline,
        created_at=str(datetime.now())
    )
    print(f"inserting model_name={model_name}, tos_url={tos_url}, baseline={baseline} succeeded.")
    print("pcc_hosted_models: ", db_manager.PCCHostedModelTable.select(where={}))