a cabal implementation in erlang
-module(cabal_transport).
-behavior(gen_server).

-export([start_link/1, stop/1, listener_port/1, get_address/1, get_listener_info/1]).
-export([register_handler/2, dial/3, send/3]).
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, code_change/3, terminate/2]).

-record(state, {
    % Listener PID from enoise_cable:listen
    listener_pid,
    % Port we're listening on
    listen_port,
    % Node's keypair for encryption
    key_pair,
    % PID of event handler to notify
    event_handler,
    % Map of ConnPid -> connection info #{addr => {IP, Port}}
    % ConnPid is the unique identifier for each connection
    connections = #{}
}).

%% Public API

%% Start the transport server
%% Args:
%%   - {listen_addr, {IP, Port}}: Address to listen on
%%   - {key_pair, KeyPair}: Node's keypair for encryption
%%   - {event_handler, HandlerPid}: PID of event handler to process events
start_link(Args) ->
    gen_server:start_link(?MODULE, Args, []).

%% Stop the transport server
stop(Pid) ->
    gen_server:call(Pid, stop).

%% Register a process to handle connection events
register_handler(Pid, HandlerPid) ->
    gen_server:call(Pid, {register_handler, HandlerPid}).

%% Connect to a remote peer
dial(Pid, Host, Port) ->
    gen_server:call(Pid, {dial, Host, Port}).

%% Send a binary message to a peer via a specific connection
%% ConnPid is the connection process ID (unique per connection)
send(Pid, ConnPid, Binary) ->
    gen_server:call(Pid, {send, ConnPid, Binary}).

listener_port(Pid) ->
    gen_server:call(Pid, {listener_port}).

get_address(Pid) ->
    gen_server:call(Pid, {get_address}).

%% Get listener info for accept_worker
get_listener_info(Pid) ->
    gen_server:call(Pid, {get_listener_info}).

%% gen_server callbacks

init(Args) ->
    % Extract configuration
    {_IP, Port} = proplists:get_value(listen_addr, Args),
    KeyPair = proplists:get_value(key_pair, Args),
    Handler = proplists:get_value(event_handler, Args, undefined),

    % Check that we have a keypair (handler can be set later via register_handler)
    if
        KeyPair =:= undefined ->
            {stop, {error, no_keypair_provided}};
        true ->
            % Start enoise_cable listener
            TcpOpts = [{reuseaddr, true}],
            case enoise_cable:listen(Port, TcpOpts) of
                {ok, ListenerPid} ->
                    {ok, ListenPort} = enoise_cable:port(ListenerPid),

                    % Capture the correct process ID for the transport server
                    TransportPid = self(),
                    io:format("[Transport] Starting encrypted server on port ~p (pid: ~p)~n", [
                        ListenPort, TransportPid
                    ]),

                    % Note: accept_worker will be started by transport_sup as a supervised child

                    % Initialize state
                    {ok, #state{
                        listener_pid = ListenerPid,
                        listen_port = ListenPort,
                        key_pair = KeyPair,
                        event_handler = Handler,
                        connections = #{}
                    }};
                {error, Reason} ->
                    {stop, Reason}
            end
    end.

handle_call({register_handler, HandlerPid}, _From, State) ->
    io:format("[Transport] Registering handler: ~p~n", [HandlerPid]),
    {reply, ok, State#state{event_handler = HandlerPid}};
handle_call({dial, Host, Port}, From, State = #state{key_pair = KeyPair}) ->
    % For "unnamed" literal strings, use loopback address instead
    ActualHost =
        case Host of
            "unnamed" -> {127, 0, 0, 1};
            _ -> Host
        end,
    io:format("[Transport] Attempting encrypted connection to ~p:~p~n", [ActualHost, Port]),
    TransportPid = self(),
    % Use spawn_link - if transport crashes, linked dial workers die automatically (no orphans)
    % If dial worker crashes unexpectedly, transport gets EXIT signal and can log/handle it
    % Reply will be sent from do_dial once connection succeeds or fails
    spawn_link(fun() -> do_dial(ActualHost, Port, KeyPair, TransportPid, From) end),
    {noreply, State};
handle_call({send, ConnPid, Binary}, _From, State = #state{connections = Conns}) ->
    % Verify the connection exists
    case maps:get(ConnPid, Conns, undefined) of
        undefined ->
            io:format("[Transport] Cannot send to unknown connection: ~p~n", [ConnPid]),
            {reply, {error, unknown_connection}, State};
        _ConnInfo ->
            %% PeerAddr = maps:get(addr, ConnInfo),
            %% io:format(
            %%     "[Transport] Sending ~p bytes to ~p via encrypted connection ~p~n",
            %%     [byte_size(Binary), PeerAddr, ConnPid]
            %% ),
            Result = enoise_cable:send(ConnPid, Binary),
            {reply, Result, State}
    end;
handle_call(stop, _From, State) ->
    {stop, normal, ok, State};
handle_call({listener_port}, _From, State = #state{listen_port = Port}) ->
    {reply, {ok, Port}, State};
handle_call({get_address}, _From, State = #state{listen_port = Port}) ->
    % Return IP address as tuple for direct connection
    {reply, {ok, {{127, 0, 0, 1}, Port}}, State};
handle_call(
    {get_listener_info}, _From, State = #state{listener_pid = ListenerPid, key_pair = KeyPair}
) ->
    {reply, {ok, {ListenerPid, KeyPair}}, State};
handle_call(Request, _From, State) ->
    io:format("Unhandled call: ~p~n", [Request]),
    {reply, {error, unknown_call}, State}.

handle_cast(Msg, State) ->
    io:format("Unhandled cast: ~p~n", [Msg]),
    {noreply, State}.

%% Handle new connection from dial operation or accept
handle_info(
    {new_connection, ConnPid, PeerAddr},
    State = #state{event_handler = Handler, connections = Conns}
) ->
    io:format("[Transport] New encrypted connection established with ~p (conn: ~p)~n", [
        PeerAddr, ConnPid
    ]),

    % Monitor the connection process so we know when it dies
    erlang:monitor(process, ConnPid),

    % Store connection info using ConnPid as the unique key
    ConnInfo = #{addr => PeerAddr},
    NewConns = maps:put(ConnPid, ConnInfo, Conns),

    % Notify the peer handler of the new connection (ConnPid is the unique ID)
    Handler ! {peerNew, ConnPid, PeerAddr},

    {noreply, State#state{connections = NewConns}};
%% Handle incoming encrypted message from enoise_cable
handle_info(
    {cable_transport, ConnPid, Data}, State = #state{event_handler = Handler, connections = Conns}
) ->
    case maps:get(ConnPid, Conns, undefined) of
        undefined ->
            io:format("[Transport] Received data from unknown connection ~p~n", [ConnPid]),
            {noreply, State};
        _ConnInfo ->
            %% PeerAddr = maps:get(addr, ConnInfo),
            %% io:format("[Transport] Received ~p bytes from ~p (conn: ~p)~n", [
            %%     byte_size(Data), PeerAddr, ConnPid
            %% ]),
            % Forward data to handler with ConnPid as the unique identifier
            Handler ! {peerData, ConnPid, Data},
            {noreply, State}
    end;
%% Handle connection process termination
handle_info(
    {'DOWN', _Ref, process, ConnPid, _Reason},
    State = #state{
        event_handler = Handler,
        connections = Conns
    }
) ->
    case maps:get(ConnPid, Conns, undefined) of
        undefined ->
            io:format("[Transport] Unknown connection ~p terminated~n", [ConnPid]),
            {noreply, State};
        ConnInfo ->
            PeerAddr = maps:get(addr, ConnInfo),
            io:format("[Transport] Connection ~p closed with peer ~p~n", [ConnPid, PeerAddr]),
            % Clean up connection
            NewConns = maps:remove(ConnPid, Conns),

            % Notify handler of lost connection (ConnPid is the unique ID)
            Handler ! {peerLost, ConnPid, PeerAddr},

            {noreply, State#state{connections = NewConns}}
    end;
handle_info(Info, State) ->
    io:format("Unhandled info: ~p~n", [Info]),
    {noreply, State}.

code_change(_OldVsn, State, _Extra) ->
    {ok, State}.

terminate(_Reason, #state{listener_pid = ListenerPid, connections = Connections}) ->
    %% Note: Don't use io:format here - group leader may be dead during test cleanup
    %% Close all active connections
    maps:foreach(
        fun(ConnPid, _ConnInfo) ->
            catch enoise_cable:close(ConnPid)
        end,
        Connections
    ),
    %% Close the listener
    catch enoise_cable:close(ListenerPid),
    ok.

%% Private functions

%% Establish outgoing connection
do_dial(Host, Port, KeyPair, TransportPid, From) ->
    Opts = [{keypair, KeyPair}],
    case enoise_cable:connect(Host, Port, Opts) of
        {ok, ConnPid} ->
            io:format("[Transport] Encrypted connection established to ~p:~p (conn: ~p)~n", [
                Host, Port, ConnPid
            ]),
            % Set the controlling process to the transport gen_server
            % so it receives {cable_transport, ConnPid, Data} messages
            ok = enoise_cable:controlling_process(ConnPid, TransportPid),
            % Use Host/Port as the peer address since we don't have the actual remote address
            PeerAddr = {Host, Port},
            TransportPid ! {new_connection, ConnPid, PeerAddr},
            % Reply to the caller with success
            gen_server:reply(From, {ok, ConnPid});
        {error, Reason} ->
            io:format("[Transport] Connection error to ~p:~p: ~p~n", [Host, Port, Reason]),
            % Reply to the caller with error
            gen_server:reply(From, {error, Reason})
    end.