a cabal implementation in erlang
-module(cabal_peer_events).
-behaviour(gen_server).

-export([start_link/1]).
-export([
    register_event_handler/3,
    unregister_event_handler/2,
    notify_channels/2
]).
-export([
    init/1,
    handle_call/3,
    handle_cast/2,
    handle_continue/2,
    handle_info/2,
    terminate/2,
    code_change/3
]).

-record(state, {
    % Database handle (needed for reconnection logic)
    db,
    % Transport PID (needed for dialing)
    transport_pid,
    % Transport Supervisor PID (needed for recovery)
    transport_sup,
    % Peer server PID (to query peer list)
    peer_server_pid,
    % Timer reference for periodic reconnection attempts
    reconnect_timer = undefined,
    % Event handlers: #{HandlerPid => #{interval => Ms, timer_ref => Ref, pending => sets:set()}}
    event_handlers = #{}
}).

%% API
start_link(Args) ->
    gen_server:start_link(?MODULE, Args, []).

register_event_handler(Pid, HandlerPid, IntervalMs) ->
    gen_server:call(Pid, {register_event_handler, HandlerPid, IntervalMs}).

unregister_event_handler(Pid, HandlerPid) ->
    gen_server:call(Pid, {unregister_event_handler, HandlerPid}).

notify_channels(Pid, Channels) ->
    gen_server:cast(Pid, {notify_channels, Channels}).

%% gen_server callbacks
init(Args) ->
    process_flag(trap_exit, true),

    Db = proplists:get_value(db, Args),
    TransportSup = proplists:get_value(transport_sup, Args),
    {ok, TransportPid} = cabal_transport_sup:get_transport_pid(TransportSup),
    erlang:monitor(process, TransportPid),
    PeerServerPid = proplists:get_value(peer_server_pid, Args),

    io:format(
        "[PeerEvents] Starting with db=~p, transport=~p, peer_server=~p~n",
        [Db, TransportPid, PeerServerPid]
    ),

    State = #state{
        db = Db,
        transport_pid = TransportPid,
        transport_sup = TransportSup,
        peer_server_pid = PeerServerPid
    },

    %% Start reconnect timer using handle_continue
    {ok, State, {continue, start_reconnect_timer}}.

handle_continue(start_reconnect_timer, State) ->
    TimerRef = erlang:send_after(10000, self(), reconnect_tick),
    io:format("[PeerEvents] Started reconnect timer~n"),
    {noreply, State#state{reconnect_timer = TimerRef}}.

handle_call({register_event_handler, HandlerPid, IntervalMs}, _From, State) ->
    NewState = do_register_event_handler(State, HandlerPid, IntervalMs),
    {reply, ok, NewState};
handle_call({unregister_event_handler, HandlerPid}, _From, State) ->
    NewState = do_unregister_event_handler(State, HandlerPid),
    {reply, ok, NewState};
handle_call(_Request, _From, State) ->
    {reply, {error, unknown_call}, State}.

handle_cast({notify_channels, Channels}, State) ->
    NewState = mark_channels_for_notification(State, Channels),
    {noreply, NewState};
handle_cast(_Msg, State) ->
    {noreply, State}.

handle_info(reconnect_tick, State) ->
    %% Attempt reconnections to persistent peers
    NewState = attempt_peer_reconnections(State),
    %% Reschedule timer
    TimerRef = erlang:send_after(10000, self(), reconnect_tick),
    {noreply, NewState#state{reconnect_timer = TimerRef}};
handle_info({notification_timer, HandlerPid}, State) ->
    NewState = handle_notification_timer(State, HandlerPid),
    {noreply, NewState};
handle_info({'DOWN', _Ref, process, Pid, _Reason}, State) when Pid =:= State#state.transport_pid ->
    io:format("[PeerEvents] Transport process died. Attempting recovery...~n"),
    %% Transport died, query supervisor for new PID
    TransportSup = State#state.transport_sup,
    %% Wait a moment for supervisor to restart it
    timer:sleep(200),
    case cabal_transport_sup:get_transport_pid(TransportSup) of
        {ok, NewTPid} ->
            io:format("[PeerEvents] Reacquired transport PID: ~p~n", [NewTPid]),
            erlang:monitor(process, NewTPid),
            {noreply, State#state{transport_pid = NewTPid}};
        {error, _} ->
            {stop, transport_died_permanently, State}
    end;
handle_info(_Info, State) ->
    {noreply, State}.

terminate(_Reason, State) ->
    %% Cancel reconnect timer
    case State#state.reconnect_timer of
        undefined -> ok;
        TimerRef -> erlang:cancel_timer(TimerRef)
    end,
    %% Cancel all event handler timers
    maps:foreach(
        fun(_Pid, Handler) ->
            case maps:get(timer_ref, Handler, undefined) of
                undefined -> ok;
                Ref -> erlang:cancel_timer(Ref)
            end
        end,
        State#state.event_handlers
    ),
    ok.

code_change(_OldVsn, State, _Extra) ->
    {ok, State}.

%%%%%%%%%%%%%
%% Private %%
%%%%%%%%%%%%%

do_register_event_handler(State = #state{event_handlers = Handlers}, HandlerPid, IntervalMs) ->
    case maps:is_key(HandlerPid, Handlers) of
        true ->
            %% Already registered, update interval
            io:format("[PeerEvents] Updating event handler ~p interval to ~pms~n", [
                HandlerPid, IntervalMs
            ]),
            OldHandler = maps:get(HandlerPid, Handlers),
            %% Cancel old timer
            case maps:get(timer_ref, OldHandler, undefined) of
                undefined -> ok;
                OldTimerRef -> erlang:cancel_timer(OldTimerRef)
            end,
            %% Start new timer
            TimerRef = erlang:send_after(IntervalMs, self(), {notification_timer, HandlerPid}),
            NewHandler = OldHandler#{interval => IntervalMs, timer_ref => TimerRef},
            State#state{event_handlers = maps:put(HandlerPid, NewHandler, Handlers)};
        false ->
            %% New registration
            io:format("[PeerEvents] Registered event handler ~p with interval ~pms~n", [
                HandlerPid, IntervalMs
            ]),
            TimerRef = erlang:send_after(IntervalMs, self(), {notification_timer, HandlerPid}),
            Handler = #{
                interval => IntervalMs,
                timer_ref => TimerRef,
                pending => sets:new()
            },
            State#state{event_handlers = maps:put(HandlerPid, Handler, Handlers)}
    end.

do_unregister_event_handler(State = #state{event_handlers = Handlers}, HandlerPid) ->
    case maps:take(HandlerPid, Handlers) of
        error ->
            io:format("[PeerEvents] Event handler ~p not found~n", [HandlerPid]),
            State;
        {Handler, NewHandlers} ->
            io:format("[PeerEvents] Unregistered event handler ~p~n", [HandlerPid]),
            %% Cancel timer
            case maps:get(timer_ref, Handler, undefined) of
                undefined -> ok;
                TimerRef -> erlang:cancel_timer(TimerRef)
            end,
            State#state{event_handlers = NewHandlers}
    end.

handle_notification_timer(State = #state{event_handlers = Handlers}, HandlerPid) ->
    case maps:get(HandlerPid, Handlers, undefined) of
        undefined ->
            %% Handler was unregistered, ignore
            State;
        Handler ->
            Pending = maps:get(pending, Handler),
            Interval = maps:get(interval, Handler),
            %% Send notifications for all pending channels
            case sets:size(Pending) of
                0 ->
                    % Nothing to notify
                    ok;
                _ ->
                    PendingList = sets:to_list(Pending),
                    lists:foreach(
                        fun(Chan) ->
                            try
                                HandlerPid ! {channel_event, {new_messages, Chan}}
                            catch
                                _:_ ->
                                    io:format("[PeerEvents] Failed to send event to handler ~p~n", [
                                        HandlerPid
                                    ])
                            end
                        end,
                        PendingList
                    ),
                    io:format("[PeerEvents] Notified handler ~p of ~p channels~n", [
                        HandlerPid, length(PendingList)
                    ])
            end,
            %% Clear pending and restart timer
            TimerRef = erlang:send_after(Interval, self(), {notification_timer, HandlerPid}),
            NewHandler = Handler#{pending => sets:new(), timer_ref => TimerRef},
            State#state{event_handlers = maps:put(HandlerPid, NewHandler, Handlers)}
    end.

mark_channels_for_notification(State = #state{event_handlers = Handlers}, Channels) ->
    case maps:size(Handlers) of
        0 ->
            % No handlers, nothing to do
            State;
        _ ->
            %% Add channels to pending set for all handlers
            NewHandlers = maps:map(
                fun(_Pid, Handler) ->
                    Pending = maps:get(pending, Handler),
                    NewPending = lists:foldl(fun sets:add_element/2, Pending, Channels),
                    Handler#{pending => NewPending}
                end,
                Handlers
            ),
            State#state{event_handlers = NewHandlers}
    end.

%% Persistent peer reconnection
attempt_peer_reconnections(
    State = #state{db = Db, transport_pid = TransportPid, peer_server_pid = PeerServerPid}
) ->
    %% Get all persistent peers
    {ok, PersistentPeers} = cabal_db:peer_list(Db),

    %% Get current connected peers from peer_server
    Peers =
        case gen_server:call(PeerServerPid, {get_peers}, 5000) of
            {ok, P} -> P;
            _ -> #{}
        end,

    %% Filter and attempt reconnections
    lists:foreach(
        fun(PeerInfo) ->
            Address = maps:get(address, PeerInfo),
            %% Only attempt if not already connected and backoff period has passed
            case is_peer_connected(Address, Peers) of
                true ->
                    % Already connected, skip
                    ok;
                false ->
                    case should_reconnect_peer(PeerInfo) of
                        true ->
                            %% Attempt reconnection
                            io:format("[PeerEvents] Attempting reconnection to ~s~n", [Address]),
                            try
                                %% Parse address and dial
                                [HostStr, PortStr] = string:split(Address, ":"),
                                {ok, Host} = inet:getaddr(HostStr, inet),
                                Port = list_to_integer(PortStr),
                                cabal_transport:dial(TransportPid, Host, Port),

                                %% Update last_attempt and increment attempt_count
                                Now = os:system_time(millisecond),
                                AttemptCount = maps:get(attempt_count, PeerInfo, 0),
                                cabal_db:peer_update(Db, Address, [
                                    {last_attempt, Now},
                                    {attempt_count, AttemptCount + 1}
                                ])
                            catch
                                _:Error ->
                                    io:format("[PeerEvents] Failed to dial ~s: ~p~n", [
                                        Address, Error
                                    ])
                            end;
                        false ->
                            % Backoff period not elapsed, skip
                            ok
                    end
            end
        end,
        PersistentPeers
    ),
    State.

%% Calculate exponential backoff delay in milliseconds
%% Base delay is 5 seconds, max is 5 minutes
calculate_backoff_delay(AttemptCount) ->
    % 5 seconds
    BaseDelayMs = 5000,
    % 5 minutes
    MaxDelayMs = 300000,
    %% Exponential backoff: base * 2^attempts, capped at max
    DelayMs = BaseDelayMs * math:pow(2, AttemptCount),
    min(trunc(DelayMs), MaxDelayMs).

%% Check if a peer should be reconnected based on last attempt and backoff
should_reconnect_peer(PeerInfo) ->
    Now = os:system_time(millisecond),
    LastAttempt = maps:get(last_attempt, PeerInfo, 0),
    AttemptCount = maps:get(attempt_count, PeerInfo, 0),
    BackoffDelay = calculate_backoff_delay(AttemptCount),

    %% Reconnect if enough time has passed since last attempt
    (Now - LastAttempt) >= BackoffDelay.

%% Check if peer address is already connected
is_peer_connected(Address, Peers) ->
    %% Address is "Host:Port" string from DB
    %% Peers values have address as {IPTuple, Port} or string

    %% 1. Parse and resolve the persistent address
    Target =
        case string:split(Address, ":") of
            [HostStr, PortStr] ->
                try
                    Port = list_to_integer(PortStr),
                    case inet:getaddr(HostStr, inet) of
                        {ok, IP} -> {IP, Port};
                        _ -> undefined
                    end
                catch
                    _:_ -> undefined
                end;
            _ ->
                undefined
        end,

    case Target of
        undefined ->
            false;
        TargetAddr ->
            %% 2. Check if any connected peer matches this resolved address
            lists:any(
                fun({_ConnPid, PeerMeta}) ->
                    case maps:get(address, PeerMeta, undefined) of
                        {PeerIP, PeerPort} -> {PeerIP, PeerPort} =:= TargetAddr;
                        AddrStr when is_list(AddrStr) -> AddrStr =:= Address;
                        _ -> false
                    end
                end,
                maps:to_list(Peers)
            )
    end.