a cabal implementation in erlang
% SPDX-FileCopyrightText: 2023 Henry Bubert
%
% SPDX-License-Identifier: LGPL-2.1-or-later

-module(cabal_wire).

-export([decode/1]).
-export([
    encode_post_request/2,
    encode_cancel_request/2,
    encode_channel_time_range_request/5,
    encode_channel_state_request/3,
    encode_channel_list_request/3,
    encode_hash_response/2,
    encode_post_response/2,
    encode_channel_list_response/2
]).

% low-level
-export([split_messages/1]).
-export([decode_header/1, decode_varints/2, decode_list_of_binaries/1]).
-export([decode_varint/1, encode_varint/1]).

%%%%%%%%%%%%
% Encoders %
%%%%%%%%%%%%

encode_hash_response(Header, Hashes) when is_list(Hashes) ->
    [
        {requestId, RequestId},
        {circuitId, CircuitId}
    ] = Header,
    length_encode_fields([
        <<0>>,
        CircuitId,
        RequestId,
        encode_varint(length(Hashes)),
        Hashes
    ]).

encode_post_response(Header, Posts) when is_list(Posts) ->
    [
        {requestId, RequestId},
        {circuitId, CircuitId}
    ] = Header,
    PostWithLen = fun(Post) ->
        PostLen = encode_varint(byte_size(Post)),
        [PostLen, Post]
    end,
    PostsBin = lists:map(PostWithLen, Posts),
    length_encode_fields([<<1>>, CircuitId, RequestId, PostsBin, 0]).

encode_post_request(Header, Hashes) when is_list(Hashes) ->
    [
        {requestId, RequestId},
        {circuitId, CircuitId},
        {ttl, TTL}
    ] = Header,
    length_encode_fields([
        <<2>>,
        CircuitId,
        RequestId,
        encode_varint(TTL),
        encode_varint(length(Hashes)),
        Hashes
    ]).

encode_cancel_request(Header, CancelId) when is_binary(CancelId) ->
    [
        {requestId, RequestId},
        {circuitId, CircuitId},
        {ttl, TTL}
    ] = Header,
    length_encode_fields([
        <<3>>,
        CircuitId,
        RequestId,
        encode_varint(TTL),
        CancelId
    ]).

encode_channel_time_range_request(Header, Channel, Start, End, Limit) ->
    [
        {requestId, RequestId},
        {circuitId, CircuitId},
        {ttl, TTL}
    ] = Header,
    length_encode_fields([
        <<4>>,
        CircuitId,
        RequestId,
        encode_varint(TTL),
        encode_varint(length(Channel)),
        Channel,
        encode_varint(Start),
        encode_varint(End),
        encode_varint(Limit)
    ]).

encode_channel_state_request(Header, Channel, Future) ->
    [
        {requestId, RequestId},
        {circuitId, CircuitId},
        {ttl, TTL}
    ] = Header,
    ChannelBin = list_to_binary(Channel),
    length_encode_fields([
        <<5>>,
        CircuitId,
        RequestId,
        encode_varint(TTL),
        encode_varint(byte_size(ChannelBin)),
        ChannelBin,
        encode_bool(Future)
    ]).

encode_channel_list_request(Header, Offset, Limit) ->
    [
        {requestId, RequestId},
        {circuitId, CircuitId},
        {ttl, TTL}
    ] = Header,
    length_encode_fields([
        <<6>>,
        CircuitId,
        RequestId,
        encode_varint(TTL),
        encode_varint(Offset),
        encode_varint(Limit)
    ]).

encode_channel_list_response(Header, Channels) when is_list(Channels) ->
    [
        {requestId, RequestId},
        {circuitId, CircuitId}
    ] = Header,
    WithLen = fun(Chan) ->
        BinChan = unicode:characters_to_binary(Chan),
        ChanLen = encode_varint(byte_size(BinChan)),
        [ChanLen, BinChan]
    end,
    ChannelsBin = lists:map(WithLen, Channels),
    length_encode_fields([<<7>>, CircuitId, RequestId, ChannelsBin, 0]).

%%%%%%%%%%%%
% Decoders %
%%%%%%%%%%%%

decode(Data) when is_binary(Data) ->
    {ok, [Header, Payload]} = decode_header(Data),
    Body =
        case proplists:get_value(msgType, Header) of
            0 -> decode_hash_response(Payload);
            1 -> decode_post_response(Payload);
            2 -> decode_post_request(Payload);
            3 -> decode_cancel_request(Payload);
            4 -> decode_channel_time_range_request(Payload);
            5 -> decode_channel_state_request(Payload);
            6 -> decode_channel_list_request(Payload);
            7 -> decode_channel_list_response(Payload);
            Unknown -> erlang:error(io_lib:format("Unknown message type: ~p", [Unknown]))
        end,
    {ok, [Header, Body]}.

decode_header(Data) ->
    {MsgType, Rest2} = decode_varint(Data),
    <<CircuitId:4/binary, ReqId:4/binary, Rest3/binary>> = Rest2,
    Header = [
        {msgType, MsgType},
        {circuitId, CircuitId},
        {requestId, ReqId}
    ],
    {ok, [Header, Rest3]}.

decode_hash_response(Payload) ->
    {HashCount, Rest} = decode_varint(Payload),
    Hashes = [Hash || <<Hash:32/binary>> <= Rest],
    case length(Hashes) =:= HashCount of
        false ->
            ErrMsg = io_lib:format("invalid hash_count - stated ~p but got ~p", [
                HashCount, length(Hashes)
            ]),
            erlang:error(lists:flatten(ErrMsg));
        true ->
            [{hashes, Hashes}]
    end.

decode_post_request(Payload) ->
    [Ttl, HashCount, Rest] = decode_varints(Payload, 2),
    Hashes = [Hash || <<Hash:32/binary>> <= Rest],
    case length(Hashes) =:= HashCount of
        false ->
            ErrMsg = io_lib:format("invalid hash_count - stated ~p but got ~p", [
                HashCount, length(Hashes)
            ]),
            erlang:error(lists:flatten(ErrMsg));
        true ->
            [
                {ttl, Ttl},
                {hashes, Hashes}
            ]
    end.

decode_cancel_request(Payload) ->
    {Ttl, <<CancelId:4/binary>>} = decode_varint(Payload),
    [
        {ttl, Ttl},
        {cancelId, CancelId}
    ].

decode_channel_time_range_request(Payload) ->
    [Ttl, ChannelLen, Rest] = decode_varints(Payload, 2),
    <<Channel:ChannelLen/binary, Rest2/binary>> = Rest,
    [TimeStart, TimeEnd, Limit, <<>>] = decode_varints(Rest2, 3),
    [
        {ttl, Ttl},
        {channel, binary_to_list(Channel)},
        {timestart, TimeStart},
        {timeend, TimeEnd},
        {limit, Limit}
    ].

decode_channel_state_request(Payload) ->
    [Ttl, ChannelLen, Rest] = decode_varints(Payload, 2),
    <<Channel:ChannelLen/binary, Rest2/binary>> = Rest,
    {Future, <<>>} = decode_varint(Rest2),
    [
        {ttl, Ttl},
        {channel, binary_to_list(Channel)},
        {future, Future}
    ].

decode_channel_list_request(Payload) ->
    [Ttl, Offset, Limit, <<>>] = decode_varints(Payload, 3),
    [
        {ttl, Ttl},
        {offset, Offset},
        {limit, Limit}
    ].

decode_post_response(Payload) ->
    [{posts, decode_list_of_binaries(Payload)}].

decode_channel_list_response(Payload) ->
    [{channels, decode_list_of_strings(Payload)}].

% Helpers

length_encode_fields(Lst) ->
    Msg = iolist_to_binary(Lst),
    MsgLen = encode_varint(byte_size(Msg)),
    <<MsgLen/binary, Msg/binary>>.

encode_bool(V) when is_boolean(V) ->
    case V of
        true -> <<1>>;
        false -> <<0>>
    end.

split_messages(Data) -> split_messages(Data, []).

split_messages(Data, Acc) ->
    {Len, Rest} = decode_varint(Data),
    RestSize = byte_size(Rest),
    if
        Len > RestSize ->
            erlang:error("Not enough Data for the complete message");
        % last call
        Len =:= RestSize ->
            {ok, Acc ++ [Rest]};
        Len < RestSize ->
            Msg = binary:part(Rest, 0, Len),
            Next = binary:part(Rest, Len, byte_size(Rest) - Len),
            split_messages(Next, Acc ++ [Msg])
    end.

decode_list_of_strings(Data) ->
    decode_list_of(Data, [], string).

decode_list_of_binaries(Data) ->
    decode_list_of(Data, [], binary).

decode_list_of(Data, Acc, Kind) ->
    {Len, Rest} = decode_varint(Data),
    case Len =:= 0 of
        true ->
            Acc;
        false ->
            <<Field:Len/binary, Rest2/binary>> = Rest,
            NewAcc =
                Acc ++
                    case Kind of
                        string -> [binary_to_list(Field)];
                        binary -> [Field]
                    end,
            decode_list_of(Rest2, NewAcc, Kind)
    end.

decode_varints(Data, N) when N >= 0 ->
    decode_varints(Data, N, []).

decode_varints(Data, 0, Acc) ->
    Acc ++ [Data];
decode_varints(Data, N, Acc) when N >= 0 ->
    {Int, Rest} = decode_varint(Data),
    decode_varints(Rest, N - 1, Acc ++ [Int]).

% stolen from https://github.com/tomas-abrahamsson/gpb/blob/edda1006d863a09509673778c455d33d88e6edbc/src/gpb.erl#L1074

%% @equiv decode_varint(Bin, 64)
-spec decode_varint(binary()) -> {non_neg_integer(), Rest :: binary()}.
decode_varint(Bin) -> decode_varint(Bin, 64).

%% @doc Decode an unsigned varint.  The decoded integer will have be at most
%% `MaxNumBits' bits long. Any higher bits will be masked away. If the binary
%% contains an overly long encoded integer, this function will fail.
-spec decode_varint(binary(), pos_integer()) -> {non_neg_integer(), binary()}.
decode_varint(Bin, MaxNumBits) -> de_vi(Bin, 0, 0, MaxNumBits).

de_vi(<<1:1, X:7, Rest/binary>>, N, Acc, MaxNumBits) when N < (64 - 7) ->
    de_vi(Rest, N + 7, X bsl N + Acc, MaxNumBits);
de_vi(<<0:1, X:7, Rest/binary>>, N, Acc, MaxNumBits) ->
    Mask = (1 bsl MaxNumBits) - 1,
    {(X bsl N + Acc) band Mask, Rest}.

%% @doc Encode an unsigned varint to a binary.
-spec encode_varint(integer()) -> binary().
encode_varint(N) -> iolist_to_binary(en_vi(N)).

en_vi(N) when N =< 127 -> <<N>>;
en_vi(N) when N >= 128 -> [<<1:1, (N band 127):7>>, en_vi(N bsr 7)].