import React, {useEffect, useMemo, useState} from 'react';
import {
    Button,
    FormControl,
    InputLabel,
    MenuItem,
    Select,
    Stack,
    TextField,
    Tooltip,
    Typography,
} from "@mui/material";
import {FinetuneOptionsReplicate} from "../../api/json/FinetuneOptionsJson";
import {ModelSpecReplicateJson} from "../../api/json/ModelSpecJson"
import NumberInput from "../form/NumberInput"
import {ModelPlatform} from "../../api/enum/ModelPlatform"
import {IntegrationRequests} from "../../api/requests/IntegrationRequests"
import {useProjectData} from "../../hooks/useProjectData"
import SpaceBetween from "../common/SpaceBetween"
import CheckCircleOutlineIcon from "@mui/icons-material/CheckCircleOutline"
import CancelOutlinedIcon from '@mui/icons-material/CancelOutlined';
import {useToasts} from "../../hooks/useToasts"
import throttle from "lodash.throttle"
import {AnyTemplateJson} from "../../api/json/TemplateJson"

interface FinetuneReplicateFormProps {
    template: AnyTemplateJson
    numberOfTrainingExamples: number
    baseModels: ModelSpecReplicateJson[]
    setOptions: (options: FinetuneOptionsReplicate) => void
    setFormValid: (valid: boolean) => void
}

function FinetuneReplicateForm({template, numberOfTrainingExamples, baseModels, setOptions, setFormValid}: FinetuneReplicateFormProps) {
    const {id} = useProjectData()

    const [baseModel, setBaseModel] = useState<ModelSpecReplicateJson>(baseModels[0])
    const [baseModelVersion, setBaseModelVersion] = useState<string>("")
    const [modelOwner, setModelOwner] = useState<string>("")
    const [modelName, setModelName] = useState<string>("")
    const [showAdvanced, setShowAdvanced] = useState<boolean>(false)
    const [modelFound, setModelFound] = useState<boolean>()

    const [numberOfEpochs, setNumberOfEpochs] = useState<number>(3)
    const [learningRate, setLearningRate] = useState<number>(0.0002)

    let defaultBatchSize = Math.max(Math.min(Math.round(numberOfTrainingExamples * 0.02), 256), 1)
    const [trainBatchSize, setTrainBatchSize] = useState<number>(defaultBatchSize)
    const [microBatchSize, setMicroBatchSize] = useState<number>(defaultBatchSize)

    // TODO: Offer QLoRA
    const [loraRank, setLoraRank] = useState<number>(32)
    const [loraAlpha, setLoraAlpha] = useState<number>(32)
    const [loraDropout, setLoraDropout] = useState<number>(0.05)

    const toasts = useToasts()

    useEffect(() => {
        if (!modelOwner || !modelName || !baseModelVersion) return

        setOptions({
            platform: ModelPlatform.REPLICATE,
            templateId: template.id,
            modelOwner: modelOwner,
            modelName: modelName,
            baseModel: baseModel,
            baseModelOwner: baseModel.name.split("/")[0],
            baseModelName: baseModel.name.split("/")[1],
            baseModelVersion: baseModelVersion,
            numberOfEpochs: numberOfEpochs,
            learningRate: learningRate,
            trainBatchSize: trainBatchSize,
            microBatchSize: microBatchSize,
            loraRank: loraRank,
            loraAlpha: loraAlpha,
            loraDropout: loraDropout,
        })
    }, [modelOwner, modelName, baseModel, baseModelVersion, numberOfEpochs, learningRate, trainBatchSize, microBatchSize, loraRank, loraAlpha, loraDropout])

    useEffect(() => {
        setFormValid(modelFound === true)
    }, [modelFound])

    useEffect(() => {
        checkBaseModel(baseModel)
    }, [baseModel])

    useEffect(() => {
        if (modelOwner || modelName) return

        IntegrationRequests.getReplicateLatestDestination(id)
            .then(destination => {
                if (destination) {
                    setModelOwner(destination.modelOwner)
                    setModelName(destination.modelName)
                    checkModel(destination.modelOwner, destination.modelName)
                }
            })
    }, [])

    useEffect(() => throttledCheckModel(), [modelOwner, modelName])

    const throttledCheckModel = useMemo(
        () => throttle(() => { checkModel(modelOwner, modelName) }, 300), [modelOwner, modelName]
    )

    const checkBaseModel = (baseModel: ModelSpecReplicateJson) => {
        let array = baseModel.name.split("/")
        let modelOwner = array[0]
        let modelName = array[1]

        if (!modelOwner || !modelName) return

        IntegrationRequests.getReplicateModel(modelOwner, modelName)
            .then(response => setBaseModelVersion(response.latestVersion))
    }

    const checkModel = (owner: string, name: string) => {
        if (!modelOwner || !modelName) return

        IntegrationRequests.getReplicateModel(owner, name)
            .then(() => setModelFound(true))
            .catch(() => setModelFound(false))
    }

    const createModel = () => {
        IntegrationRequests.createReplicateModel(modelOwner, modelName, baseModel.usageSpec.hardwareSku)
            .then(() => {
                setModelFound(true)
                toasts.showMessage("Model created successfully on Replicate.", "success")
            })
            .catch(error => {
                toasts.showError(error)
                setModelFound(false)
            })
    }

    return (
        <>
            <FormControl>
                <InputLabel id="base-model-select-label">Base Model</InputLabel>
                <Select
                    labelId="base-model-select-label"
                    value={baseModel.name}
                    label="Base Model"
                    fullWidth
                    onChange={(e) => setBaseModel(baseModels.find(it => it.name === e.target.value)!)}
                >
                    {baseModels.map(model =>
                        <MenuItem key={model.name} value={model.name}>{model.displayName}</MenuItem>,
                    )}
                </Select>
            </FormControl>

            <TextField
                label="Base Model Version"
                aria-label="base model version"
                autoComplete="off"
                helperText="Version of the model to fine-tune."
                value={baseModelVersion}
                onChange={(event) => setBaseModelVersion(event.target.value)}
            />

            <SpaceBetween>
                <Stack direction="row" spacing={1} alignItems="center">
                    <Typography variant="subtitle1">Destination</Typography>
                    {modelFound === true && <CheckCircleOutlineIcon color="success"/>}
                    {modelFound === false &&
                        <Tooltip title="Model not found on Replicate">
                            <CancelOutlinedIcon color="error"/>
                        </Tooltip>
                    }
                </Stack>
            </SpaceBetween>

            <Stack direction="row" alignItems="baseline" spacing={1}>
                <TextField
                    label="Model Owner"
                    aria-label="model owner"
                    autoComplete="off"
                    helperText="Your Replicate username."
                    value={modelOwner}
                    onChange={(event) => setModelOwner(event.target.value)}
                    onBlur={() => checkModel(modelOwner, modelName)}
                    sx={{flex:1}}
                />

                <TextField
                    label="Model Name"
                    aria-label="model name"
                    autoComplete="off"
                    helperText="Your model on Replicate."
                    value={modelName}
                    onChange={(event) => setModelName(event.target.value)}
                    onBlur={() => checkModel(modelOwner, modelName)}
                    sx={{flex:1}}
                />

                {modelFound === false && modelOwner !== "" && modelName !== "" &&
                    <Button onClick={() => createModel()}>Create</Button>
                }
            </Stack>

            {showAdvanced ?
                <>
                    <Stack direction="row" justifyContent="space-between">
                        <Typography variant="subtitle1">Advanced</Typography>
                    </Stack>

                    <NumberInput label="Number of Epochs"
                                 helperText="The number of full cycles to go through the training dataset."
                                 min={0}
                                 max={100}
                                 value={numberOfEpochs}
                                 onChange={setNumberOfEpochs}
                    />
                    <NumberInput label="Learning Rate"
                                 helperText="Sets the size of adjustment the model makes to its weights for each batch. Higher trains faster but can lead to divergence."
                                 min={0}
                                 max={0.01}
                                 value={learningRate}
                                 onChange={setLearningRate}
                    />

                    <Stack direction="row" spacing={1}>
                        <NumberInput label="Training Batch Size"
                                     min={1}
                                     max={1000}
                                     value={trainBatchSize}
                                     onChange={setTrainBatchSize}
                                     sx={{flex:1}}
                        />
                        <NumberInput label="Micro Batch Size"
                                     min={1}
                                     max={1000}
                                     value={microBatchSize}
                                     onChange={setMicroBatchSize}
                                     sx={{flex:1}}
                        />
                    </Stack>

                    <Typography variant="subtitle1">LoRA Hyperparameters</Typography>
                    <NumberInput label="Rank"
                                 min={1}
                                 max={50000}
                                 helperText="Higher values lead to more precise training of the model's weights."
                                 value={loraRank}
                                 onChange={setLoraRank}
                    />

                    <Stack direction="row" spacing={2}>
                        <NumberInput label="Alpha"
                                     min={1}
                                     max={50000}
                                     helperText="Divide by rank to get scale factor."
                                     value={loraAlpha}
                                     onChange={setLoraAlpha}
                                     sx={{flex: 1}}
                        />
                        <NumberInput label="Scale Factor"
                                     min={0}
                                     max={10}
                                     value={loraAlpha / loraRank}
                                     onChange={newValue => {
                                         setLoraAlpha(loraRank * newValue)
                                     }}
                                     helperText="Factor to multiply weight changes before adding to model weights."
                                     sx={{flex: 1}}
                        />
                    </Stack>

                    <NumberInput label="Dropout"
                                 min={0}
                                 max={1}
                                 helperText="Higher values can decrease the chance of over-fitting but may inhibit learning."
                                 value={loraDropout}
                                 onChange={setLoraDropout}
                    />
                </> : <Button sx={{width: 150}} onClick={() => setShowAdvanced(true)}>Show Advanced</Button>
            }
        </>
    );
}

export default FinetuneReplicateForm;