a cabal implementation in erlang
% SPDX-FileCopyrightText: 2023 Henry Bubert
%
% SPDX-License-Identifier: LGPL-2.1-or-later

-module(enoise_cable).
-behavior(gen_server).

%% Public API
-export([
    listen/1, listen/2,
    accept/1, accept/2,
    connect/3,
    send/2,
    close/1,
    controlling_process/2,
    port/1,
    peername/1
]).

%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).

%% Default Noise protocol configuration for Cable
-define(DEFAULT_PROTOCOL, "Noise_XXpsk0_25519_ChaChaPoly_BLAKE2b").
-define(DEFAULT_PROLOGUE, <<"CABLE1.0">>).
-define(DEFAULT_PSK,
    hex:hexstr_to_bin("0808080808080808080808080808080808080808080808080808080808080808")
).

%% Connection state (for active connections)
-record(state, {
    % 'listener' or 'connection'
    mode,
    % TCP socket (listen socket for listener, connection socket for connection)
    socket,
    % enoise connection state (undefined for listener)
    enoise_conn,
    % Process to send messages to (undefined for listener)
    handler_pid
}).

%% @doc Create a TCP listener on the specified port
%%
%% Returns a gen_server pid that manages the listen socket internally.
%% Use this pid with accept/1 or accept/2 to accept connections.
%%
%% TcpOpts: List of gen_tcp options (e.g., [{reuseaddr, true}])
%%
%% Returns: {ok, ListenerPid} | {error, Reason}
-spec listen(inet:port_number()) -> {ok, pid()} | {error, term()}.
listen(Port) ->
    listen(Port, []).

-spec listen(inet:port_number(), proplists:proplist()) -> {ok, pid()} | {error, term()}.
listen(Port, TcpOpts) ->
    gen_server:start_link(?MODULE, {listen, Port, TcpOpts}, []).

%% @doc Accept a connection from a listener
%%
%% Accepts an incoming connection through the listener and performs the Noise handshake.
%% A new gen_server is spawned for the connection, ensuring all socket messages are
%% delivered to the correct mailbox.
%%
%% Options:
%%   - {keypair, KeyPair} - required
%%   - {psk, Binary} - optional, 32-byte pre-shared key
%%   - {protocol, String} - optional, Noise protocol name
%%   - {prologue, Binary} - optional, protocol prologue
%%
%% Returns: {ok, ConnPid} | {error, Reason}
-spec accept(pid()) -> {ok, pid()} | {error, term()}.
accept(ListenerPid) ->
    accept(ListenerPid, []).

-spec accept(pid(), proplists:proplist()) -> {ok, pid()} | {error, term()}.
accept(ListenerPid, Opts) ->
    gen_server:call(ListenerPid, {accept_connection, Opts, self()}, infinity).

%% @doc Establish a client connection (does handshake synchronously)
%%
%% The gen_server will perform the TCP connect and handshake on its own process,
%% ensuring all socket messages are delivered to the correct mailbox.
%%
%% Options:
%%   - {keypair, KeyPair} - required
%%   - {psk, Binary} - optional, 32-byte pre-shared key
%%   - {protocol, String} - optional, Noise protocol name
%%   - {prologue, Binary} - optional, protocol prologue
%%
%% Returns: {ok, Pid} | {error, Reason}
-spec connect(string() | inet:ip_address(), inet:port_number(), proplists:proplist()) ->
    {ok, pid()} | {error, term()}.
connect(Host, Port, Opts) ->
    gen_server:start_link(?MODULE, {connect, Host, Port, Opts, self()}, []).

%% @doc Send a message
%%
%% The message will be automatically framed and segmented by the enoise layer
%% according to the cable protocol (length prefix + segments).
-spec send(pid(), binary()) -> ok | {error, term()}.
send(Pid, Message) when is_binary(Message) ->
    gen_server:call(Pid, {send, Message}).

%% @doc Close the connection
-spec close(pid()) -> ok.
close(Pid) ->
    gen_server:call(Pid, close).

%% @doc Change the process that receives messages
-spec controlling_process(pid(), pid()) -> ok.
controlling_process(Pid, NewHandler) ->
    gen_server:call(Pid, {controlling_process, NewHandler}).

%% @doc Get the port number that a listener is bound to
-spec port(pid()) -> {ok, inet:port_number()} | {error, term()}.
port(ListenerPid) ->
    gen_server:call(ListenerPid, get_port).

%% @doc Get the peer address (IP, Port) from a connection
-spec peername(pid()) -> {ok, {inet:ip_address(), inet:port_number()}} | {error, term()}.
peername(ConnPid) ->
    gen_server:call(ConnPid, get_peername).

%%====================================================================
%% gen_server callbacks
%%====================================================================

init({listen, Port, TcpOpts}) ->
    % Note: enoise requires {active, once} or {active, true}, NOT {active, false}
    DefaultOpts = [binary, {packet, 0}, {active, once}, {reuseaddr, true}],
    case gen_tcp:listen(Port, DefaultOpts ++ TcpOpts) of
        {ok, ListenSocket} ->
            {ok, #state{
                mode = listener,
                socket = ListenSocket,
                enoise_conn = undefined,
                handler_pid = undefined
            }};
        {error, Reason} ->
            {stop, Reason}
    end;
init({connect, Host, Port, Opts, HandlerPid}) ->
    % Note: enoise requires {active, true} for handshake (multiple messages needed)
    TcpOpts = [binary, {packet, 0}, {active, true}, {nodelay, true}],
    case gen_tcp:connect(Host, Port, TcpOpts) of
        {ok, Socket} ->
            EnoiseOpts = build_enoise_opts(Opts),
            case enoise:connect(Socket, EnoiseOpts) of
                {ok, EConn, _HandshakeState} ->
                    {ok, #state{
                        mode = connection,
                        socket = Socket,
                        enoise_conn = EConn,
                        handler_pid = HandlerPid
                    }};
                {error, Reason} ->
                    gen_tcp:close(Socket),
                    {stop, Reason}
            end;
        {error, Reason} ->
            {stop, Reason}
    end;
init({do_accept, ListenSocket, Opts, HandlerPid, CallerFrom}) ->
    % Perform accept and handshake in this gen_server's process
    case gen_tcp:accept(ListenSocket) of
        {ok, Socket} ->
            % Set socket to active mode for enoise handshake (needs multiple messages)
            inet:setopts(Socket, [{active, true}]),
            EnoiseOpts = build_enoise_opts(Opts),
            case enoise:accept(Socket, EnoiseOpts) of
                {ok, EConn, _HandshakeState} ->
                    gen_server:reply(CallerFrom, {ok, self()}),
                    {ok, #state{
                        mode = connection,
                        socket = Socket,
                        enoise_conn = EConn,
                        handler_pid = HandlerPid
                    }};
                {error, Reason} ->
                    gen_tcp:close(Socket),
                    gen_server:reply(CallerFrom, {error, Reason}),
                    {stop, Reason}
            end;
        {error, Reason} ->
            io:format(user, "[enoise_cable] TCP accept failed: ~p~n", [Reason]),
            gen_server:reply(CallerFrom, {error, Reason}),
            {stop, Reason}
    end.

handle_call(
    {accept_connection, Opts, HandlerPid},
    From,
    State = #state{mode = listener, socket = ListenSocket}
) ->
    % Start a new gen_server that will do the accept/handshake in its init
    case gen_server:start_link(?MODULE, {do_accept, ListenSocket, Opts, HandlerPid, From}, []) of
        {ok, _ConnPid} ->
            % The reply will be sent from the new gen_server's init
            {noreply, State};
        {error, Reason} ->
            {reply, {error, Reason}, State}
    end;
handle_call({send, Message}, _From, State = #state{mode = connection, enoise_conn = EConn}) ->
    Res = enoise:send(EConn, Message),
    {reply, Res, State};
handle_call(close, _From, State = #state{mode = connection, enoise_conn = EConn}) ->
    enoise:close(EConn),
    {stop, normal, ok, State};
handle_call(close, _From, State = #state{mode = listener, socket = ListenSocket}) ->
    gen_tcp:close(ListenSocket),
    {stop, normal, ok, State};
handle_call({controlling_process, NewHandler}, _From, State) ->
    {reply, ok, State#state{handler_pid = NewHandler}};
handle_call(get_port, _From, State = #state{mode = listener, socket = ListenSocket}) ->
    case inet:port(ListenSocket) of
        {ok, Port} ->
            {reply, {ok, Port}, State};
        {error, Reason} ->
            {reply, {error, Reason}, State}
    end;
handle_call(get_peername, _From, State = #state{mode = connection, socket = Socket}) ->
    case inet:peername(Socket) of
        {ok, {IP, Port}} ->
            {reply, {ok, {IP, Port}}, State};
        {error, Reason} ->
            {reply, {error, Reason}, State}
    end;
handle_call(_Request, _From, State) ->
    {reply, {error, unknown_call}, State}.

handle_cast(_Msg, State) ->
    {noreply, State}.

%% Handle incoming messages from enoise
%% Each {noise, EConn, Data} is one complete cable message (enoise handles framing)
handle_info(
    {noise, ReceivedEConn, Data},
    State = #state{
        enoise_conn = StateEConn,
        handler_pid = Handler
    }
) ->
    case ReceivedEConn =:= StateEConn of
        false ->
            {noreply, State};
        true ->
            % Deliver complete message to handler
            Handler ! {cable_transport, self(), Data},
            {noreply, State}
    end;
%% Handle TCP connection closure
%% Note: Socket here is a TCP port handle (e.g., #Port<0.15>), not a {Host, Port} tuple
handle_info({tcp_closed, Socket}, State = #state{mode = connection, socket = StateSocket}) ->
    case Socket =:= StateSocket of
        true ->
            io:format(user, "[enoise_cable connection ~p] TCP connection closed by peer~n", [self()]),
            {stop, normal, State};
        false ->
            io:format(
                user, "[enoise_cable connection ~p] Received tcp_closed for unknown socket~n", [
                    self()
                ]
            ),
            {noreply, State}
    end;
%% Handle TCP errors
handle_info({tcp_error, Socket, Reason}, State = #state{mode = connection, socket = StateSocket}) ->
    case Socket =:= StateSocket of
        true ->
            io:format(user, "[enoise_cable connection ~p] TCP error: ~p~n", [self(), Reason]),
            {stop, {tcp_error, Reason}, State};
        false ->
            io:format(
                user, "[enoise_cable connection ~p] Received tcp_error for unknown socket~n", [
                    self()
                ]
            ),
            {noreply, State}
    end;
handle_info(Info, State = #state{mode = Mode}) ->
    io:format(user, "[enoise_cable ~p ~p] Unexpected message: ~p~n", [Mode, self(), Info]),
    {noreply, State}.

terminate(_Reason, #state{mode = connection, enoise_conn = EConn}) ->
    % Enoise connection might already be dead, so catch any errors
    catch enoise:close(EConn),
    ok;
terminate(_Reason, #state{mode = listener, socket = ListenSocket}) ->
    catch gen_tcp:close(ListenSocket),
    ok.

code_change(_OldVsn, State, _Extra) ->
    {ok, State}.

%%====================================================================
%% Private functions
%%====================================================================

build_enoise_opts(Opts) ->
    KeyPair = proplists:get_value(keypair, Opts),
    if
        KeyPair =:= undefined ->
            erlang:error({missing_required_option, keypair});
        true ->
            ok
    end,

    Protocol = proplists:get_value(protocol, Opts, ?DEFAULT_PROTOCOL),
    PSK = proplists:get_value(psk, Opts, ?DEFAULT_PSK),
    Prologue = proplists:get_value(prologue, Opts, ?DEFAULT_PROLOGUE),

    ProtocolRecord = enoise_protocol:from_name(Protocol),

    [
        {noise, ProtocolRecord},
        {s, KeyPair},
        {psks, [PSK]},
        {prologue, Prologue}
    ].