import React, {useEffect, useState} from 'react';
import {FinetuneRequests} from "../../api/requests/FinetuneRequests";
import {
    Alert,
    CircularProgress,
    Divider,
    Paper,
    Stack,
    SxProps,
    Table,
    TableBody,
    TableCell,
    TableContainer,
    TableHead,
    TableRow,
    Tooltip,
    Typography,
} from "@mui/material";
import {FinetunePrecheckJson} from "../../api/json/FinetunePrecheckJson";
import {ProjectJson} from "../../api/json/ProjectJson";
import {AnyFinetuneOptions} from "../../api/json/FinetuneOptionsJson";
import ApiError from "../../api/ApiError"
import {PublicStatsJson, PublicStatsRequests} from "../../api/requests/PublicStatsRequests"
import InfoOutlinedIcon from '@mui/icons-material/InfoOutlined';
import {ModelPlatform} from "../../api/enum/ModelPlatform"
import animations from "../../Animations.module.scss"
import {AnyTemplateJson} from "../../api/json/TemplateJson"
import {TemplateType} from "../../api/enum/TemplateType"
import {getTemplateTokenCount} from "../../library/TokenCountUtil"
import ErrorMessage from "../error/ErrorMessage"

interface PrecheckPanelProps {
    project: ProjectJson
    options: AnyFinetuneOptions
    platform: ModelPlatform
    template: AnyTemplateJson
    setBlocked: (blocked: boolean) => void
    sx?: SxProps
}

interface PrecheckIssue {
    content: string
    severity: "warning" | "error"
}

function PrecheckPanel({project, options, platform, template, setBlocked, sx}: PrecheckPanelProps) {
    const [results, setResults] = useState<FinetunePrecheckJson>()
    const [issues, setIssues] = useState<PrecheckIssue[]>()
    const [systemStats, setSystemStats] = useState<PublicStatsJson>()
    const [error, setError] = useState<ApiError>()

    useEffect(() => {
        FinetuneRequests.precheck(project.id, "", {...options, templateId: template.id})
            .then(results => {
                setResults(results)
                setError(undefined)
            })
            .catch(setError)
    }, [project, template, options])


    useEffect(() => {
        PublicStatsRequests.getStats()
            .then(setSystemStats)
            .catch(setError)
    }, [])

    useEffect(() => {
        if (!results || options.platform !== platform) return

        const newIssues: PrecheckIssue[] = []
        if (results.examplesOverTokenLimitCount > 0)
            newIssues.push({
                content: `${results.examplesOverTokenLimitCount} examples are over the token limit of ${results.modelSpec.tokenLimits.overall}`,
                severity: "warning",
            })

        if (options.platform === ModelPlatform.AI21 && project.stats.numberOfTrainingExamples < 50)
            newIssues.push({content: `AI21 requires at least 50 examples to fine-tune.`, severity: "error"})

        if (options.platform === ModelPlatform.OPEN_AI && project.stats.numberOfTrainingExamples < 10)
            newIssues.push({content: `OpenAI requires at least 10 examples to fine-tune.`, severity: "error"})

        if (results.trainingExamplesCount === 0)
            newIssues.push({content: `No new examples were added since this model was trained.`, severity: "error"})

        if (template.type === TemplateType.STANDARD && options.baseModel?.syntax !== null) {
            newIssues.push({content: `Standard template not recommended for this model`, severity: "warning"})
        }

        if (template.type === TemplateType.CHAT && options.baseModel!.syntax === null) {
            newIssues.push({content: `Chat template is incompatible with this model.`, severity: "error"})
        }

        if (getTemplateTokenCount(template, options.baseModel!.tokenizer) == null) {
            newIssues.push({content: `Template is missing token count.`, severity: "warning"})
        }

        const isBlocked = newIssues.filter(issue => issue.severity === "error").length > 0
        setBlocked(isBlocked)
        setIssues(newIssues)
    }, [results])

    if (error) {
        return <ErrorMessage error={error}/>
    }

    if (!results) return <CircularProgress/>

    const currencyFormatter = new Intl.NumberFormat('en-US', {
        style: 'currency',
        currency: 'USD',
    })

    const finetuneMinutes = systemStats?.finetuneTimes.find(it => it.platform.toString() === platform.toString())?.minutes
    let timeEstimate = ""
    if (finetuneMinutes) {
        if (finetuneMinutes > 60) {
            timeEstimate = `${Math.round(finetuneMinutes / 120) * 2} hours`
        } else {
            timeEstimate = `${finetuneMinutes} minutes`
        }
    }

    if (platform !== options.platform) {
        return <></>
    }

    return (
        <Stack spacing={2} className={animations.defaultIn} sx={sx}>
            <Paper variant="outlined" sx={{p: 2}}>
                <Stack spacing={1}>
                    <TableContainer sx={{boxShadow: "none"}}>
                        <Table size="small">
                            <TableHead>
                                <TableRow>
                                    <TableCell sx={{paddingLeft: "0 !important"}}>
                                        <Typography variant="subtitle1">
                                            Dataset
                                        </Typography>
                                    </TableCell>
                                    <TableCell>
                                        <Typography variant="subtitle1">
                                            Examples
                                        </Typography>
                                    </TableCell>
                                    <TableCell sx={{paddingRight: "0 !important"}}>
                                        <Typography variant="subtitle1">
                                            Tokens
                                        </Typography>
                                    </TableCell>
                                </TableRow>
                            </TableHead>
                            <TableBody>
                                <TableRow>
                                    <TableCell sx={{paddingLeft: "0 !important"}}>
                                        <Typography variant="body1">Training</Typography>
                                    </TableCell>
                                    <TableCell>
                                        <Typography variant="body1">{results.trainingExamplesCount.toLocaleString()}</Typography>
                                    </TableCell>
                                    <TableCell sx={{paddingRight: "0 !important"}}>
                                        <Typography variant="body1">{results.trainingTotalTokens.toLocaleString()}</Typography>
                                    </TableCell>
                                </TableRow>
                                {"baseModelName" in options &&
                                    <TableRow>
                                        <TableCell sx={{paddingLeft: "0 !important"}}>
                                            <Typography variant="body1">Validation</Typography>
                                        </TableCell>
                                        <TableCell>
                                            <Typography variant="body1">
                                                {results.validationExamplesCount.toLocaleString()}
                                            </Typography>
                                        </TableCell>
                                        <TableCell sx={{paddingRight: "0 !important"}}>
                                            <Typography variant="body1">
                                                {results.validationTotalTokens.toLocaleString()}
                                            </Typography>
                                        </TableCell>
                                    </TableRow>
                                }
                                <TableRow>
                                    <TableCell sx={{paddingLeft: "0 !important"}}>
                                        <Typography variant="body1">
                                            Total
                                        </Typography>
                                    </TableCell>
                                    <TableCell>
                                        <Typography variant="body1">{(results.trainingExamplesCount + results.validationExamplesCount).toLocaleString()}</Typography>
                                    </TableCell>
                                    <TableCell>
                                        <Typography variant="body1">{results.totalTokens.toLocaleString()}</Typography></TableCell>
                                </TableRow>
                            </TableBody>
                        </Table>
                    </TableContainer>
                    <Divider sx={{width: 60}}/>
                    <Typography variant="caption">Attributed to template: {results.templateTotalTokens.toLocaleString()} tokens or {Math.round((results.templateTotalTokens/results.totalTokens) * 100)}% of total</Typography>
                </Stack>
            </Paper>

            <Paper variant="outlined" sx={{p: 2}}>
                <Stack spacing={1}>
                    <Typography variant="subtitle1">Estimate</Typography>
                    {results.estimatedCost !== null &&
                        <Typography variant="body1">
                            Cost: {currencyFormatter.format(results.estimatedCost)}
                        </Typography>
                    }
                    {timeEstimate && (
                        <Stack direction="row" spacing={0.75} alignItems="center">
                            <Typography variant="body1">
                                Time: {timeEstimate}
                            </Typography>
                            <Tooltip title="Based on recent fine-tunes for this platform">
                                <InfoOutlinedIcon color="secondary" sx={{fontSize: 17}}/>
                            </Tooltip>
                        </Stack>
                    )}
                    <Divider sx={{width: 60}}/>
                    <Typography variant="caption">Actual costs and times may vary. You are responsible for all charges.
                        Validation examples are evaluated after fine-tuning and incur additional charges.</Typography>
                </Stack>
            </Paper>

            {issues?.map((issue, index) => <Alert key={index} severity={issue.severity}>{issue.content}</Alert>)}
        </Stack>
    );
}

export default PrecheckPanel;