import { useMutation } from "@tanstack/react-query";
import { useEffect, useMemo } from "react";
import { useForm } from "react-hook-form";
import { toast } from "react-hot-toast";
import { getTrainingInputSchema } from "../../schema";
import type { ErrorWithJSON, Prediction, Version } from "../../types";
import { route } from "../../urls";
import { createVersionTraining } from "../api-playground/api";
import { usePrediction } from "../api-playground/hooks";

export type TrainingAction = "create" | "update";

export function parseCreateTrainingError(
  error: ErrorWithJSON | null,
  fallbackMessage: string
) {
  return useMemo(() => {
    if (!error) {
      return null;
    }

    if (error.detail) {
      if (error?.status === 404) {
        return {
          message: (
            <p>
              Model not found, or you do not have permission to access it. You
              can <a href={route("model_create")}>create the model</a> if it
              doesn't exist.
            </p>
          ),
        };
      }

      return {
        message: error.detail,
      };
    }

    return {
      message: fallbackMessage,
    };
  }, [error, fallbackMessage]);
}

export function useCreateVersionTrainingMutation({
  errorMessage,
}: {
  errorMessage: string;
}) {
  const mutation = useMutation<
    Awaited<ReturnType<typeof createVersionTraining>>,
    ErrorWithJSON,
    Parameters<typeof createVersionTraining>[0]
  >({
    mutationFn: createVersionTraining,
    onError: (_e: unknown) => {
      toast.error(errorMessage);
    },
  });

  const predictionQuery = usePrediction({
    uuid: mutation.data?.id ?? null,
  });

  useEffect(() => {
    if (mutation.data && predictionQuery.isSuccess) {
      window.location.href = route("prediction_detail", {
        prediction_uuid: mutation.data.id,
      });
    }
  }, [mutation.data, predictionQuery.isSuccess]);

  return mutation;
}

export function useCreateTrainingForm({
  action,
  version,
  training,
  trainingVersion,
  defaultValues = {},
}: {
  action: TrainingAction;
  version: Version;
  training?: Prediction;
  trainingVersion?: Version;
  // Useful for pre-populating the form with default values,
  // like in the case of SDXL where we're still
  // hard-coding some values.
  defaultValues?: Record<string, any>;
}) {
  const defaultFormValues = useMemo(() => {
    const trainingSchema = getTrainingInputSchema(version);
    const defaultInput = Object.keys(trainingSchema?.properties ?? {}).reduce(
      (acc, name) => {
        return {
          ...acc,
          [name]: trainingSchema?.properties[name]?.default ?? null,
        };
      },
      {}
    );

    if (action === "create") {
      return defaultInput;
    }

    if (action === "update" && version) {
      const previousTrainingInput = training?.input ?? {};
      const destination = trainingVersion
        ? trainingVersion._extras.model.name
        : "";

      return {
        ...defaultInput,
        ...previousTrainingInput,
        destination,
      };
    }

    return {};
  }, [action, training, trainingVersion, version]);

  return useForm({
    defaultValues: {
      ...defaultValues,
      ...defaultFormValues,
    },
  });
}
