import { useEffect, useState } from "react";
import type {
  CogOutputSchema,
  Prediction,
  PredictionStatus,
} from "../../types";
import { PredictionDefaultOutput } from "./prediction-output";

export type PredictionWithStream = Prediction & {
  output: null | string | string[];
  urls: { get: string; cancel: string; stream: string };
};

/**
 * Returns true if the schema indicates that the model supports
 * streaming the file streaming API. This logic lives in
 * replicate/api.
 *
 * Callers still need to check that the Prediction has a `urls.stream`
 * property in order to actually connect to the stream.
 */
export function isFileStreamingModel(
  _prediction: Prediction,
  schema: CogOutputSchema
): _prediction is PredictionWithStream {
  if (
    "type" in schema &&
    schema.type === "string" &&
    "forma" in schema &&
    schema.format === "uri"
  ) {
    return true;
  }

  return (
    "type" in schema &&
    schema.type === "array" &&
    "items" in schema &&
    schema.items.type === "string" &&
    schema.items.format === "uri"
  );
}

export function PredictionFileStreamingOutput({
  prediction,
  schema,
}: {
  prediction: PredictionWithStream;
  schema: CogOutputSchema;
}) {
  const streamUrl = prediction.urls.stream;

  let initialFiles: string[] | null = [];
  if (prediction.status === "succeeded" && prediction.output !== null) {
    initialFiles =
      typeof prediction.output === "string"
        ? [prediction.output]
        : prediction.output;
  }

  const [files, setFiles] = useState<string[]>(initialFiles);
  const [initialStatus] = useState<PredictionStatus>(prediction.status);

  useEffect(() => {
    // Don't connect to stream if the component loads with a terminal prediction.
    if (
      initialStatus === "succeeded" ||
      initialStatus === "failed" ||
      initialStatus === "canceled"
    ) {
      return;
    }

    let controller: AbortController;
    (async () => {
      let url = streamUrl;
      while (true) {
        controller = new AbortController();
        const res = await fetch(url, { signal: controller.signal });
        if (!res.ok) {
          if (res.status === 404) {
            // End of stream;
            return;
          }
          console.error(
            new Error(`Failed to fetch ${url} response status: ${res.status}`)
          );
          return;
        }

        if (res.status === 204) {
          // Skipped upload fallback to the URL.
          const src = res.headers.get("Location");
          if (src) {
            setFiles((state) => [...state, src]);
          }
          return;
        }

        const src = URL.createObjectURL(await res.blob());
        setFiles((state) => [...state, src]);

        const { next: nextUrl } = parseLinkHeader(res.headers.get("Link"));
        if (!nextUrl) {
          console.warn("missing next header in response");
          return;
        }
        url = nextUrl;
      }
    })().catch((err) => console.error(err));

    return () => controller.abort();
  }, [streamUrl, initialStatus]);

  let output: string | string[] | null = null;
  if ("type" in schema && schema.type === "string") {
    if (files.length === 1) {
      output = files[0];
    } else if (files.length > 1) {
      // This should never happen if the model's schema is telling the truth,
      // but if it isn't, we still want to report all the output.
      output = files;
    }
  } else {
    output = files;
  }

  return (
    <PredictionDefaultOutput
      alwaysRenderURLsAsDownload={false}
      output={output}
      reportFallback
      schema={schema}
      status={prediction.status}
    />
  );
}

function parseLinkHeader(header: string | null): Record<string, string> {
  if (!header) {
    return {};
  }

  const links = header.split(",").map((part) => part.trim());
  const parsed: { [key: string]: string } = {};

  for (const link of links) {
    const segments = link.split(";").map((segment) => segment.trim());
    if (segments.length < 2) {
      continue;
    }

    const urlPart = segments[0];
    const relPart = segments.find((segment) => segment.startsWith("rel="));

    if (!urlPart.startsWith("<") || !urlPart.endsWith(">") || !relPart) {
      continue;
    }

    const url = urlPart.slice(1, -1);
    const relMatch = relPart.match(/rel="(.*?)"/);

    if (!relMatch || relMatch.length < 2) {
      continue;
    }

    const rel = relMatch[1];
    parsed[rel] = url;
  }

  return parsed;
}
