%%%-------------------------------------------------------------------
%% @doc enoise_chat public API
%% @end
%%%-------------------------------------------------------------------
-module(enoise_chat).
-export([start/0, main/1]).
start() ->
ok.
main(Args) ->
case parse_args(Args) of
#{mode := client, host := Host, port := Port, keys := KeyFile} ->
{ClientSK, ClientPK} = load_keypair(KeyFile),
{server_pk, ServerPK} = read_server_public(KeyFile),
client_connect(Host, Port, ClientSK, ClientPK, ServerPK);
#{mode := server, port := Port, keys := KeyFile} ->
{ServerSK, ServerPK} = load_keypair(KeyFile),
server_listen(Port, ServerSK, ServerPK);
_ ->
usage(),
halt(1)
end.
%% -----------------------
%% Client
%% -----------------------
client_connect(Host, Port, ClientSK, ClientPK, ServerPK) ->
TestProtocol = enoise_protocol:from_name("Noise_XK_25519_ChaChaPoly_BLAKE2b"),
TcpOpts = [binary, {active, true}, {reuseaddr, true}],
case gen_tcp:connect(Host, Port, TcpOpts, 5000) of
{ok, TcpSock} ->
io:format("Connected to ~s:~p~n", [Host, Port]),
ok = gen_tcp:send(TcpSock, <<0, 8, 0, 0, 3>>),
Opts = [
{noise, TestProtocol},
{s, enoise_keypair:new(dh25519, ClientSK, ClientPK)},
{rs, ServerPK},
{prologue, <<0, 8, 0, 0, 3>>}
],
case enoise:connect(TcpSock, Opts) of
{ok, EConn, _Hs} ->
io:format("Noise handshake complete~n"),
echo_roundtrip(EConn, 3),
enoise:close(EConn);
{error, Reason} ->
io:format("Handshake failed: ~p~n", [Reason]),
halt(1)
end;
{error, Reason} ->
io:format("TCP connect failed: ~p~n", [Reason]),
halt(1)
end.
echo_roundtrip(_EConn, 0) ->
ok;
echo_roundtrip(EConn, N) ->
Msg = <<"ok\n">>,
ok = enoise:send(EConn, Msg),
io:format("→ Sent: ~p~n", [Msg]),
receive
{noise, EConn, Msg} ->
io:format("← Got: ~p~n", [Msg])
after 3000 ->
io:format("Timeout waiting for echo #~p~n", [N]),
halt(1)
end,
timer:sleep(100),
echo_roundtrip(EConn, N - 1).
%% -----------------------
%% Server
%% -----------------------
server_listen(Port, ServerSK, ServerPK) ->
TestProtocol = enoise_protocol:from_name("Noise_XK_25519_ChaChaPoly_BLAKE2b"),
TcpOpts = [binary, {active, true}, {reuseaddr, true}, {packet, raw}],
case gen_tcp:listen(Port, TcpOpts) of
{ok, ListenSock} ->
io:format("Listening on port ~p~n", [Port]),
accept_loop(ListenSock, TestProtocol, ServerSK, ServerPK);
{error, Reason} ->
io:format("Listen failed: ~p~n", [Reason]),
halt(1)
end.
accept_loop(ListenSock, Proto, ServerSK, ServerPK) ->
case gen_tcp:accept(ListenSock, 10000) of
{ok, TcpSock} ->
io:format("Client connected~n"),
Opts = [
{noise, Proto},
{s, enoise_keypair:new(dh25519, ServerSK, ServerPK)},
{prologue, <<0, 8, 0, 0, 3>>}
],
Pid = spawn(fun() ->
case enoise:accept(TcpSock, Opts) of
{ok, EConn, _Hs} ->
io:format("Handshake successful~n"),
echo_server(EConn, 3),
enoise:close(EConn);
{error, Reason} ->
io:format("Accept failed: ~p~n", [Reason])
end
end),
% Keep listening
gen_tcp:controlling_process(TcpSock, Pid),
accept_loop(ListenSock, Proto, ServerSK, ServerPK);
{error, closed} ->
io:format("Listener closed~n");
{error, Reason} ->
io:format("Accept error: ~p~n", [Reason]),
accept_loop(ListenSock, Proto, ServerSK, ServerPK)
end.
echo_server(_EConn, 0) ->
ok;
echo_server(EConn, N) ->
receive
{noise, EConn, Data} ->
io:format("→ Echoing: ~p~n", [Data]),
ok = enoise:send(EConn, Data),
echo_server(EConn, N - 1)
after 3000 ->
io:format("Server echo timeout~n"),
halt(1)
end.
%% -----------------------
%% Key Loading
%% -----------------------
load_keypair(File) ->
case file:consult(File) of
{ok, Terms} ->
case lists:keyfind(my_keys, 1, Terms) of
{my_keys, #{priv := SK, pub := PK}} ->
{hex(SK), hex(PK)};
_ ->
error(badkey)
end;
{error, Reason} ->
io:format("Could not load keys from ~s: ~p~n", [File, Reason]),
halt(1)
end.
read_server_public(File) ->
case file:consult(File) of
{ok, Terms} ->
case lists:keyfind(server_pk, 1, Terms) of
{server_pk, SPK} -> {server_pk, hex(SPK)};
false -> error(no_server_pk)
end;
{error, _} = Err -> Err
end.
%% Support: "FFAABB.." or <<255,170,187,...>> or base64
hex(H) when is_list(H), length(H) == 64 ->
hex_binary(list_to_binary(H));
hex(B) when is_binary(B), bit_size(B) == 256 ->
B;
hex(S) when is_list(S); is_binary(S) ->
base64_decode_or_hex_binary(S).
base64_decode_or_hex_binary(S) ->
try
base64:decode(S)
catch _:_ ->
try
hex_binary(S)
catch _:_ ->
error({invalid_key_encoding, S})
end
end.
hex_binary(Bin) when is_binary(Bin) ->
hex_binary(binary_to_list(Bin));
hex_binary(HexStr) ->
list_to_binary(
[erlang:list_to_integer([A, B], 16) || [A, B] <- lists:splitlists(2, HexStr)]
).
%% -----------------------
%% Args & Usage
%% -----------------------
parse_args(Args) ->
parse_args(Args, #{}).
parse_args([], Acc) -> Acc;
parse_args(["-mode", "client" | Rest], Acc) ->
parse_args(Rest, Acc#{mode => client});
parse_args(["-mode", "server" | Rest], Acc) ->
parse_args(Rest, Acc#{mode => server});
parse_args(["-host", Host | Rest], Acc) ->
parse_args(Rest, Acc#{host => Host});
parse_args(["-port", PortStr | Rest], Acc) ->
parse_args(Rest, Acc#{port => list_to_integer(PortStr)});
parse_args(["-keys", KeyFile | Rest], Acc) ->
parse_args(Rest, Acc#{keys => KeyFile});
parse_args([_ | Rest], Acc) ->
parse_args(Rest, Acc).
usage() ->
io:format("Usage:~n"),
io:format(" -mode client|server~n"),
io:format(" -host <host> (client only)~n"),
io:format(" -port <number>~n"),
io:format(" -keys <file> (path to key config)~n"),
io:format("~nExample (server):~n"),
io:format(" rebar3 run -- -mode server -port 7891 -keys keys.txt~n"),
io:format("~nExample (client):~n"),
io:format(" rebar3 run -- -mode client -host localhost -port 7891 -keys keys.txt~n").