import math
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from torch import Tensor

from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch


class ControlNetEmbedder(nn.Module):

    def __init__(
        self,
        img_size: int,
        patch_size: int,
        in_chans: int,
        attention_head_dim: int,
        num_attention_heads: int,
        adm_in_channels: int,
        num_layers: int,
        main_model_double: int,
        double_y_emb: bool,
        device: torch.device,
        dtype: torch.dtype,
        pos_embed_max_size: Optional[int] = None,
        operations = None,
    ):
        super().__init__()
        self.main_model_double = main_model_double
        self.dtype = dtype
        self.hidden_size = num_attention_heads * attention_head_dim
        self.patch_size = patch_size
        self.x_embedder = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=self.hidden_size,
            strict_img_size=pos_embed_max_size is None,
            device=device,
            dtype=dtype,
            operations=operations,
        )

        self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)

        self.double_y_emb = double_y_emb
        if self.double_y_emb:
            self.orig_y_embedder = VectorEmbedder(
                adm_in_channels, self.hidden_size, dtype, device, operations=operations
            )
            self.y_embedder = VectorEmbedder(
                self.hidden_size, self.hidden_size, dtype, device, operations=operations
            )
        else:
            self.y_embedder = VectorEmbedder(
                adm_in_channels, self.hidden_size, dtype, device, operations=operations
            )

        self.transformer_blocks = nn.ModuleList(
            DismantledBlock(
                hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
                dtype=dtype, device=device, operations=operations
            )
            for _ in range(num_layers)
        )

        # self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
        # TODO double check this logic when 8b
        self.use_y_embedder = True

        self.controlnet_blocks = nn.ModuleList([])
        for _ in range(len(self.transformer_blocks)):
            controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
            self.controlnet_blocks.append(controlnet_block)

        self.pos_embed_input = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=self.hidden_size,
            strict_img_size=False,
            device=device,
            dtype=dtype,
            operations=operations,
        )

    def forward(
        self,
        x: torch.Tensor,
        timesteps: torch.Tensor,
        y: Optional[torch.Tensor] = None,
        context: Optional[torch.Tensor] = None,
        hint = None,
    ) -> Tuple[Tensor, List[Tensor]]:
        x_shape = list(x.shape)
        x = self.x_embedder(x)
        if not self.double_y_emb:
            h = (x_shape[-2] + 1) // self.patch_size
            w = (x_shape[-1] + 1) // self.patch_size
            x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
        c = self.t_embedder(timesteps, dtype=x.dtype)
        if y is not None and self.y_embedder is not None:
            if self.double_y_emb:
                y = self.orig_y_embedder(y)
            y = self.y_embedder(y)
            c = c + y

        x = x + self.pos_embed_input(hint)

        block_out = ()

        repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
        for i in range(len(self.transformer_blocks)):
            out = self.transformer_blocks[i](x, c)
            if not self.double_y_emb:
                x = out
            block_out += (self.controlnet_blocks[i](out),) * repeat

        return {"output": block_out}