A KVM switch emulator using UDP/IP
use crate::io::*;

use bytes::buf::ext::{BufExt, BufMutExt};
use bytes::BytesMut;
use futures::{Sink, Stream};
use log::debug;
use tokio::net::UdpSocket;
use tokio_util::codec::{Decoder, Encoder};
use tokio_util::udp::UdpFramed;

use std::io;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::pin::Pin;
use std::task::{Context, Poll};

pub struct Config {
    pub server_addr: IpAddr,
    pub multicast_addr: IpAddr,
    pub port: u16
}

pub struct Network {
    framed: UdpFramed<Codec>,
    config: Config,
}

impl Network {
    pub async fn open(config: Config) -> io::Result<Self> {
        let addr = SocketAddr::new(config.server_addr, config.port);
        let socket = UdpSocket::bind(addr).await?;
        match config.multicast_addr {
            IpAddr::V4(addr) => {
                socket.set_multicast_loop_v4(false)?;
                socket.join_multicast_v4(addr, Ipv4Addr::UNSPECIFIED)?;
            },
            IpAddr::V6(addr) => {
                socket.set_multicast_loop_v6(false)?;
                socket.join_multicast_v6(&addr, 0)?;
            }
        }

        Ok(Self {
            framed: UdpFramed::new(socket, Codec),
            config,
        })
    }
}

impl Stream for Network {
    type Item = bincode::Result<(NetEvent, SocketAddr)>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
        let item = Pin::new(&mut self.framed).poll_next(cx);
        if let Poll::Ready(Some(Ok((event, addr)))) = &item {
            debug!("<= {} => {:#?}", addr, event);
        }
        item
    }
}

impl Sink<(NetEvent, SocketAddr)> for Network {
    type Error = bincode::Error;

    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
        Pin::new(&mut self.framed).poll_ready(cx)
    }

    fn start_send(mut self: Pin<&mut Self>, item: (NetEvent, SocketAddr)) -> Result<(), Self::Error> {
        let (event, addr) = &item;
        debug!("=> {} <= {:#?}", addr, event);
        Pin::new(&mut self.framed).start_send(item)
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
        Pin::new(&mut self.framed).poll_flush(cx)
    }

    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
        Pin::new(&mut self.framed).poll_close(cx)
    }
}

impl Sink<NetEvent> for Network {
    type Error = bincode::Error;

    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
        <Network as Sink<(NetEvent, SocketAddr)>>::poll_ready(self, cx)
    }

    fn start_send(self: Pin<&mut Self>, event: NetEvent) -> Result<(), Self::Error> {
        let addr = SocketAddr::new(self.config.multicast_addr, self.config.port);
        <Network as Sink<(NetEvent, SocketAddr)>>::start_send(self, (event, addr))
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
        <Network as Sink<(NetEvent, SocketAddr)>>::poll_flush(self, cx)
    }

    fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
        <Network as Sink<(NetEvent, SocketAddr)>>::poll_close(self, cx)
    }
}

struct Codec;

impl Decoder for Codec {
    type Item = NetEvent;
    type Error = bincode::Error;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        // Assume that an event will be received in a single packet
        Ok(Some(bincode::deserialize_from(src.reader())?))
    }
}

impl Encoder<NetEvent> for Codec {
    type Error = bincode::Error;

    fn encode(&mut self, event: NetEvent, dst: &mut BytesMut) -> Result<(), Self::Error> {
        // Assume that a serialized event won't exceed the MTU
        bincode::serialize_into(dst.writer(), &event)
    }
}

impl NetworkInterface<SocketAddr> for Network { }