import torch
import numpy as np
import itertools

def get_model():
    model = torch.nn.Sequential(
        torch.nn.Linear(1, 100),
        torch.nn.ReLU(),
        torch.nn.Linear(100, 100),
        torch.nn.ReLU(),
        torch.nn.Linear(100, 100),
        torch.nn.ReLU(),
        torch.nn.Linear(100, 1)
    )
    return model

def target(x):
    return x**2

def get_data_domain_1(batch_size):
    def generator():
        while True:
            x = np.random.uniform(0.0, 1.0, size=(batch_size, 1))
            y = target(x)
            yield (x, y)
    return generator

def get_data_domain_2(batch_size):
    def generator():
        while True:
            x = np.random.uniform(1.0, 2.0, size=(batch_size, 1))
            y = target(x)
            yield (x, y)
    return generator

def eval(model, get_data, n = 1000):
    data = get_data(256)
    model.eval()
    all_loss = []
    for (x, ty) in itertools.islice(data(), n):
        tensor_x = torch.tensor(x, dtype=torch.float).cuda()
        tensor_y = torch.tensor(ty, dtype=torch.float).cuda()
        y = model(tensor_x)
        loss = (y - tensor_y) ** 2
        loss = loss.mean()
        floss = loss.detach().item()
        all_loss.append(floss)

    return np.array(all_loss).mean()


def l2_loss(y, ty):
    return (y - ty)**2

def train(model, get_data, get_loss):
    # model = get_model().cuda()
    print(list(model.named_parameters()))
    opt = torch.optim.Adam(model.parameters(), lr=0.0001)

    data = get_data(256)

    epoch = range(1000)

    best_loss = 1.0
    n_not_improved = 0
    for (x, ty) in data():
        tensor_x = torch.tensor(x, dtype=torch.float).cuda()
        tensor_y = torch.tensor(ty, dtype=torch.float).cuda()
        y = model(tensor_x)
        loss = get_loss(y, tensor_y)
        loss = loss.mean()
        opt.zero_grad()
        loss.backward()
        opt.step()

        floss = loss.detach().item()
        if floss < best_loss:
            best_loss = floss
            n_not_improved = 0
            print(f"New best_loss: {best_loss:03f}")
        else:
            n_not_improved += 1

        if n_not_improved > 20:
           break


    return model

# [M, B, 1]

def get_domain_adapt_model():
    model = torch.nn.Sequential(
        torch.nn.Linear(1, 1),
    )
    return model

models = []

for i in range(5):
    model = get_model().cuda()
    m = train(model, get_data_domain_1, l2_loss)
    #m.train(False)
    #m.eval()
    for p in m.parameters():
        p.requires_grad = False
    models.append(m)

# print(entropy(models, torch.tensor([[1.0], [2.0]]).cuda()))

dam = get_domain_adapt_model().cuda()

class AdaptModel(torch.nn.Module):
    def __init__(self, dam, models):
        super().__init__()
        self.dam = dam
        self.models = models

    def forward(self, x):
        return torch.stack([ m(dam(x)) for m in models])

        #return entropy(self.dam, self.models, x)

class InferenceModel(torch.nn.Module):
    def __init__(self, multi_model):
        super().__init__()
        self.multi_model = multi_model

    def forward(self, x):
        return torch.mean(self.multi_model(x), dim=0)

class EnsembleModel(torch.nn.Module):
    def __init__(self, models):
        super().__init__()
        self.models = models

    def forward(self, x):
        return torch.mean(torch.stack([ m(x) for m in self.models]), dim=0)


class SimpleAdaptModel(torch.nn.Module):
    def __init__(self, dam, model):
        super().__init__()
        self.dam = dam
        self.model = model

    def forward(self, x):
        # for p in self.model.parameters():
        #     p.requires_grad = False
        return self.dam(self.model(x))
        #return self.model(x)


adapt_model = AdaptModel(dam, models)
inference_adapt_model = SimpleAdaptModel(dam, get_model().cuda()) #InferenceModel(adapt_model)
el1_adapt_pre_train = eval(inference_adapt_model, get_data_domain_1)
dam.train()
print("Train adaption")
#train(adapt_model, get_data_domain_1, lambda x, y: torch.std(x, dim=0))
train(inference_adapt_model, get_data_domain_1, l2_loss)
inference_model = EnsembleModel(models)

el1 = eval(inference_model, get_data_domain_1)
el2 = eval(inference_model, get_data_domain_2)
el1_adapt = eval(inference_adapt_model, get_data_domain_1)
el2_adapt = eval(inference_adapt_model, get_data_domain_2)
print(f"Loss(D1) = {el1}, Loss(D2) = {el2}")
print(f"Loss(AD1, pretrain) = {el1_adapt_pre_train}")
print(f"Loss(AD1) = {el1_adapt}, Loss(AD2) = {el2_adapt}")