-module(cabal_wire).
-export([decode/1]).
-export([
encode_post_request/2,
encode_cancel_request/2,
encode_channel_time_range_request/5,
encode_channel_state_request/3,
encode_channel_list_request/3,
encode_hash_response/2,
encode_post_response/2,
encode_channel_list_response/2
]).
-export([split_messages/1]).
-export([decode_header/1, decode_varints/2, decode_list_of_binaries/1]).
-export([decode_varint/1, encode_varint/1]).
encode_hash_response(Header, Hashes) when is_list(Hashes) ->
[
{requestId, RequestId},
{circuitId, CircuitId}
] = Header,
length_encode_fields([
<<0>>,
CircuitId,
RequestId,
encode_varint(length(Hashes)),
Hashes
]).
encode_post_response(Header, Posts) when is_list(Posts) ->
[
{requestId, RequestId},
{circuitId, CircuitId}
] = Header,
PostWithLen = fun(Post) ->
PostLen = encode_varint(byte_size(Post)),
[PostLen, Post]
end,
PostsBin = lists:map(PostWithLen, Posts),
length_encode_fields([<<1>>, CircuitId, RequestId, PostsBin, 0]).
encode_post_request(Header, Hashes) when is_list(Hashes) ->
[
{requestId, RequestId},
{circuitId, CircuitId},
{ttl, TTL}
] = Header,
length_encode_fields([
<<2>>,
CircuitId,
RequestId,
encode_varint(TTL),
encode_varint(length(Hashes)),
Hashes
]).
encode_cancel_request(Header, CancelId) when is_binary(CancelId) ->
[
{requestId, RequestId},
{circuitId, CircuitId},
{ttl, TTL}
] = Header,
length_encode_fields([
<<3>>,
CircuitId,
RequestId,
encode_varint(TTL),
CancelId
]).
encode_channel_time_range_request(Header, Channel, Start, End, Limit) ->
[
{requestId, RequestId},
{circuitId, CircuitId},
{ttl, TTL}
] = Header,
length_encode_fields([
<<4>>,
CircuitId,
RequestId,
encode_varint(TTL),
encode_varint(length(Channel)),
Channel,
encode_varint(Start),
encode_varint(End),
encode_varint(Limit)
]).
encode_channel_state_request(Header, Channel, Future) ->
[
{requestId, RequestId},
{circuitId, CircuitId},
{ttl, TTL}
] = Header,
ChannelBin = list_to_binary(Channel),
length_encode_fields([
<<5>>,
CircuitId,
RequestId,
encode_varint(TTL),
encode_varint(byte_size(ChannelBin)),
ChannelBin,
encode_bool(Future)
]).
encode_channel_list_request(Header, Offset, Limit) ->
[
{requestId, RequestId},
{circuitId, CircuitId},
{ttl, TTL}
] = Header,
length_encode_fields([
<<6>>,
CircuitId,
RequestId,
encode_varint(TTL),
encode_varint(Offset),
encode_varint(Limit)
]).
encode_channel_list_response(Header, Channels) when is_list(Channels) ->
[
{requestId, RequestId},
{circuitId, CircuitId}
] = Header,
WithLen = fun(Chan) ->
BinChan = unicode:characters_to_binary(Chan),
ChanLen = encode_varint(byte_size(BinChan)),
[ChanLen, BinChan]
end,
ChannelsBin = lists:map(WithLen, Channels),
length_encode_fields([<<7>>, CircuitId, RequestId, ChannelsBin, 0]).
decode(Data) when is_binary(Data) ->
{ok, [Header, Payload]} = decode_header(Data),
Body =
case proplists:get_value(msgType, Header) of
0 -> decode_hash_response(Payload);
1 -> decode_post_response(Payload);
2 -> decode_post_request(Payload);
3 -> decode_cancel_request(Payload);
4 -> decode_channel_time_range_request(Payload);
5 -> decode_channel_state_request(Payload);
6 -> decode_channel_list_request(Payload);
7 -> decode_channel_list_response(Payload);
Unknown -> erlang:error(io_lib:format("Unknown message type: ~p", [Unknown]))
end,
{ok, [Header, Body]}.
decode_header(Data) ->
{MsgType, Rest2} = decode_varint(Data),
<<CircuitId:4/binary, ReqId:4/binary, Rest3/binary>> = Rest2,
Header = [
{msgType, MsgType},
{circuitId, CircuitId},
{requestId, ReqId}
],
{ok, [Header, Rest3]}.
decode_hash_response(Payload) ->
{HashCount, Rest} = decode_varint(Payload),
Hashes = [Hash || <<Hash:32/binary>> <= Rest],
case length(Hashes) =:= HashCount of
false ->
ErrMsg = io_lib:format("invalid hash_count - stated ~p but got ~p", [
HashCount, length(Hashes)
]),
erlang:error(lists:flatten(ErrMsg));
true ->
[{hashes, Hashes}]
end.
decode_post_request(Payload) ->
[Ttl, HashCount, Rest] = decode_varints(Payload, 2),
Hashes = [Hash || <<Hash:32/binary>> <= Rest],
case length(Hashes) =:= HashCount of
false ->
ErrMsg = io_lib:format("invalid hash_count - stated ~p but got ~p", [
HashCount, length(Hashes)
]),
erlang:error(lists:flatten(ErrMsg));
true ->
[
{ttl, Ttl},
{hashes, Hashes}
]
end.
decode_cancel_request(Payload) ->
{Ttl, <<CancelId:4/binary>>} = decode_varint(Payload),
[
{ttl, Ttl},
{cancelId, CancelId}
].
decode_channel_time_range_request(Payload) ->
[Ttl, ChannelLen, Rest] = decode_varints(Payload, 2),
<<Channel:ChannelLen/binary, Rest2/binary>> = Rest,
[TimeStart, TimeEnd, Limit, <<>>] = decode_varints(Rest2, 3),
[
{ttl, Ttl},
{channel, binary_to_list(Channel)},
{timestart, TimeStart},
{timeend, TimeEnd},
{limit, Limit}
].
decode_channel_state_request(Payload) ->
[Ttl, ChannelLen, Rest] = decode_varints(Payload, 2),
<<Channel:ChannelLen/binary, Rest2/binary>> = Rest,
{Future, <<>>} = decode_varint(Rest2),
[
{ttl, Ttl},
{channel, binary_to_list(Channel)},
{future, Future}
].
decode_channel_list_request(Payload) ->
[Ttl, Offset, Limit, <<>>] = decode_varints(Payload, 3),
[
{ttl, Ttl},
{offset, Offset},
{limit, Limit}
].
decode_post_response(Payload) ->
[{posts, decode_list_of_binaries(Payload)}].
decode_channel_list_response(Payload) ->
[{channels, decode_list_of_strings(Payload)}].
length_encode_fields(Lst) ->
Msg = iolist_to_binary(Lst),
MsgLen = encode_varint(byte_size(Msg)),
<<MsgLen/binary, Msg/binary>>.
encode_bool(V) when is_boolean(V) ->
case V of
true -> <<1>>;
false -> <<0>>
end.
split_messages(Data) -> split_messages(Data, []).
split_messages(Data, Acc) ->
{Len, Rest} = decode_varint(Data),
RestSize = byte_size(Rest),
if
Len > RestSize ->
erlang:error("Not enough Data for the complete message");
Len =:= RestSize ->
{ok, Acc ++ [Rest]};
Len < RestSize ->
Msg = binary:part(Rest, 0, Len),
Next = binary:part(Rest, Len, byte_size(Rest) - Len),
split_messages(Next, Acc ++ [Msg])
end.
decode_list_of_strings(Data) ->
decode_list_of(Data, [], string).
decode_list_of_binaries(Data) ->
decode_list_of(Data, [], binary).
decode_list_of(Data, Acc, Kind) ->
{Len, Rest} = decode_varint(Data),
case Len =:= 0 of
true ->
Acc;
false ->
<<Field:Len/binary, Rest2/binary>> = Rest,
NewAcc =
Acc ++
case Kind of
string -> [binary_to_list(Field)];
binary -> [Field]
end,
decode_list_of(Rest2, NewAcc, Kind)
end.
decode_varints(Data, N) when N >= 0 ->
decode_varints(Data, N, []).
decode_varints(Data, 0, Acc) ->
Acc ++ [Data];
decode_varints(Data, N, Acc) when N >= 0 ->
{Int, Rest} = decode_varint(Data),
decode_varints(Rest, N - 1, Acc ++ [Int]).
-spec decode_varint(binary()) -> {non_neg_integer(), Rest :: binary()}.
decode_varint(Bin) -> decode_varint(Bin, 64).
-spec decode_varint(binary(), pos_integer()) -> {non_neg_integer(), binary()}.
decode_varint(Bin, MaxNumBits) -> de_vi(Bin, 0, 0, MaxNumBits).
de_vi(<<1:1, X:7, Rest/binary>>, N, Acc, MaxNumBits) when N < (64 - 7) ->
de_vi(Rest, N + 7, X bsl N + Acc, MaxNumBits);
de_vi(<<0:1, X:7, Rest/binary>>, N, Acc, MaxNumBits) ->
Mask = (1 bsl MaxNumBits) - 1,
{(X bsl N + Acc) band Mask, Rest}.
-spec encode_varint(integer()) -> binary().
encode_varint(N) -> iolist_to_binary(en_vi(N)).
en_vi(N) when N =< 127 -> <<N>>;
en_vi(N) when N >= 128 -> [<<1:1, (N band 127):7>>, en_vi(N bsr 7)].