diff --git a/lib/src/main/java/io/github/lonamiwebs/overgram/Overgram.java b/lib/src/main/java/io/github/lonamiwebs/overgram/Overgram.java index a48ab8a..8807b08 100644 --- a/lib/src/main/java/io/github/lonamiwebs/overgram/Overgram.java +++ b/lib/src/main/java/io/github/lonamiwebs/overgram/Overgram.java @@ -1,22 +1,18 @@ package io.github.lonamiwebs.overgram; -import io.github.lonamiwebs.overgram.crypto.Authenticator; -import io.github.lonamiwebs.overgram.network.MTProtoPlainSender; -import io.github.lonamiwebs.overgram.network.connection.Connection; +import io.github.lonamiwebs.overgram.network.MTProtoSender; +import io.github.lonamiwebs.overgram.network.MTProtoState; import io.github.lonamiwebs.overgram.network.connection.TcpFull; import java.io.IOException; public class Overgram { - public static void main(final String... args) throws IOException, ClassNotFoundException { - final Connection connection = new TcpFull(); + public static void main(final String... args) throws IOException { + final MTProtoSender sender = new MTProtoSender(new MTProtoState(), new TcpFull()); try { - connection.connect("149.154.167.91", 443); - - final MTProtoPlainSender sender = new MTProtoPlainSender(connection); - Authenticator.doAuthentication(sender); + sender.connect("149.154.167.91", 443); } finally { - connection.disconnect(); + sender.disconnect(); } } } diff --git a/lib/src/main/java/io/github/lonamiwebs/overgram/network/MTProtoSender.java b/lib/src/main/java/io/github/lonamiwebs/overgram/network/MTProtoSender.java new file mode 100644 index 0000000..105cadc --- /dev/null +++ b/lib/src/main/java/io/github/lonamiwebs/overgram/network/MTProtoSender.java @@ -0,0 +1,349 @@ +package io.github.lonamiwebs.overgram.network; + +import io.github.lonamiwebs.overgram.crypto.Authenticator; +import io.github.lonamiwebs.overgram.network.connection.Connection; +import io.github.lonamiwebs.overgram.tl.*; +import io.github.lonamiwebs.overgram.utils.BinaryReader; + +import javax.naming.OperationNotSupportedException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.concurrent.*; + +public class MTProtoSender { + private final MTProtoState state; + private final Connection connection; + private String ip; + private int port; + private boolean userConnected; + private boolean reconnecting; + + private Thread sendHandle; + private Thread recvHandle; + private BlockingQueue sendQueue; + private HashMap pendingMessages; + private HashSet pendingAck; + private TLMessage lastAck; // to acknowledge acknowledgements + + public int retries = 5; + + public MTProtoSender(final MTProtoState state, final Connection connection) { + this.state = state; + this.connection = connection; + + sendQueue = new LinkedBlockingQueue(); + pendingMessages = new HashMap<>(); + pendingAck = new HashSet<>(); + } + + public void connect(final String ip, final int port) throws IOException { + this.ip = ip; + this.port = port; + userConnected = true; + doConnect(); + } + + private void doConnect() throws IOException { + boolean success = false; + for (int i = 0; i < retries; ++i) { + try { + connection.connect(ip, port); + success = true; + break; + } catch (IOException ignored) { + } + } + if (!success) { + throw new IOException("Failed to connect " + retries + " times"); + } + + if (state.authKey == null) { + final MTProtoPlainSender plain = new MTProtoPlainSender(connection); + success = false; + for (int i = 0; i < retries; ++i) { + try { + state.authKey = Authenticator.doAuthentication(plain); + success = true; + break; + } catch (SecurityException | IOException | ClassNotFoundException ignored) { + } + } + + if (!success) { + disconnect(); + throw new IOException("Failed to generate AuthKey"); + } + } + + sendHandle = new Thread(this::sendLoop); + sendHandle.setDaemon(true); + sendHandle.start(); + + recvHandle = new Thread(this::recvLoop); + recvHandle.setDaemon(true); + recvHandle.start(); + } + + public void disconnect() { + if (!userConnected) { + doDisconnect(); + } + } + + private void doDisconnect() { + userConnected = false; + connection.disconnect(); + pendingMessages.clear(); + pendingAck.clear(); + lastAck = null; + stopHandles(); + } + + private void stopHandles() { + if (sendHandle != null) { + sendHandle.interrupt(); + try { + sendHandle.join(); + } catch (InterruptedException ignored) { + } + } + + if (recvHandle != null) { + recvHandle.interrupt(); + try { + recvHandle.join(); + } catch (InterruptedException ignored) { + } + } + } + + private void doReconnect() { + if (userConnected) { + final Thread thread = new Thread(this::reconnect); + thread.setDaemon(true); + thread.start(); + } + } + + private void reconnect() { + reconnecting = true; + stopHandles(); + connection.disconnect(); + reconnecting = false; + + try { + doConnect(); + } catch (IOException ignored) { + doDisconnect(); + } + } + + public Future send(TLRequest request) throws IOException { + if (!userConnected) { + throw new IOException("Not connected"); + } + + final TLMessage message = state.createMessage(request); + pendingMessages.put(message.id, message); + try { + sendQueue.put(message); + } catch (InterruptedException ignored) { + throw new IOException("Failed to put message"); + } + + return message.future; + } + + private void sendLoop() { + while (userConnected) { + if (!pendingAck.isEmpty()) { + lastAck = state.createMessage(new Types.MsgsAck().msgIds(new ArrayList<>(pendingAck))); + try { + sendQueue.put(lastAck); + } catch (InterruptedException ignored) { + doDisconnect(); + return; + } + } + + final TLMessage message; + try { + message = sendQueue.poll(1, TimeUnit.SECONDS); + if (message == null) { + continue; + } + } catch (InterruptedException ignored) { + doDisconnect(); + return; + } + + final byte[] body = state.packMessage(message); + while (!message.future.isCancelled()) { + try { + connection.send(body); + } catch (IOException ignored) { + try { + Thread.sleep(1000); + doReconnect(); + return; + } catch (InterruptedException ignored2) { + doDisconnect(); + return; + } + } + } + } + } + + + private void recvLoop() { + while (userConnected) { + final byte[] body; + try { + body = connection.recv(); + } catch (IOException ignored) { + doReconnect(); + return; + } + + final TLMessage message = state.unpackMessage(body); + try { + processMessage(message); + } catch (InterruptedException ignored) { + doDisconnect(); + return; + } + } + } + + private void processMessage(final TLMessage message) throws InterruptedException { + pendingAck.add(message.id); + if (message.object instanceof RpcResult) { + handleRpcResult(message); + } else if (message.object instanceof MessageContainer) { + handleContainer(message); + } else if (message.object instanceof GzipPacked) { + handleGzipPacked(message); + } else if (message.object instanceof Types.MsgsAck) { + handleAck(message); + } else if (message.object instanceof Abstract.Updates || message.object instanceof Abstract.Update) { + handleUpdate(message); + } else if (message.object instanceof Types.Pong) { + handlePong(message); + } else if (message.object instanceof Types.BadServerSalt) { + handleBadServerSalt(message); + } else if (message.object instanceof Types.BadMsgNotification) { + handleBadNotification(message); + } else if (message.object instanceof Types.MsgDetailedInfo) { + handleDetailedInfo(message); + } else if (message.object instanceof Types.MsgNewDetailedInfo) { + handleNewDetailedInfo(message); + } else if (message.object instanceof Types.NewSessionCreated) { + handleNewSessionCreated(message); + } else if (message.object instanceof Types.FutureSalts) { + handleFutureSalts(message); + } else if (message.object instanceof Types.MsgsStateReq) { + handleStateForgotten(message); + } else if (message.object instanceof Types.MsgResendReq) { + handleStateForgotten(message); + } else if (message.object instanceof Types.MsgsAllInfo) { + handleMsgAll(message); + } + } + + private void handleRpcResult(final TLMessage message) { + final RpcResult result = (RpcResult) message.object; + final TLMessage replyMessage = pendingMessages.remove(result.reqMsgId()); + + // TODO RPC error + final BinaryReader reader = new BinaryReader(ByteBuffer.wrap(result.result())); + replyMessage.future.complete(((TLRequest) replyMessage.object).readResult(reader)); + } + + public void handleContainer(final TLMessage message) throws InterruptedException { + final MessageContainer result = (MessageContainer) message.object; + for (final TLMessage innerMessage : result.messages()) { + processMessage(innerMessage); + } + } + + public void handleGzipPacked(final TLMessage message) throws InterruptedException { + final GzipPacked result = (GzipPacked) message.object; + message.object = result.packedObject(); + processMessage(message); + } + + public void handleUpdate(final TLMessage message) { + } + + public void handlePong(final TLMessage message) { + final Types.Pong result = (Types.Pong) message.object; + final TLMessage replyMessage = pendingMessages.remove(result.msgId()); + if (replyMessage != null) { + replyMessage.future.complete(result); + } + } + + public void handleBadServerSalt(final TLMessage message) throws InterruptedException { + final Types.BadServerSalt result = (Types.BadServerSalt) message.object; + state.salt = result.newServerSalt(); + if (lastAck != null && result.badMsgId() == lastAck.id) { + sendQueue.put(lastAck); + } + + final TLMessage badMessage = pendingMessages.get(result.badMsgId()); + if (badMessage != null) { + sendQueue.put(badMessage); + } + } + + public void handleBadNotification(final TLMessage message) { + final Types.BadMsgNotification result = (Types.BadMsgNotification) message.object; + final TLMessage badMessage = pendingMessages.get(result.badMsgId()); + + if (result.errorCode() == 16 || result.errorCode() == 17) { + if (badMessage != null) { + // resend and update time offset + throw new UnsupportedOperationException(); + } + return; + } + + throw new UnsupportedOperationException(); + } + + public void handleDetailedInfo(final TLMessage message) { + final Types.MsgDetailedInfo result = (Types.MsgDetailedInfo) message.object; + pendingAck.add(result.answerMsgId()); + } + + public void handleNewDetailedInfo(final TLMessage message) { + final Types.MsgNewDetailedInfo result = (Types.MsgNewDetailedInfo) message.object; + pendingAck.add(result.answerMsgId()); + } + + public void handleNewSessionCreated(final TLMessage message) { + final Types.NewSessionCreated result = (Types.NewSessionCreated) message.object; + state.salt = result.serverSalt(); + } + + public void handleAck(final TLMessage message) { + // check if ack-ed logout + } + + public void handleFutureSalts(final TLMessage message) { + // check if there's request + } + + public void handleStateForgotten(final TLMessage message) { + // send MsgsStateInfo(req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids) + } + + public void handleMsgAll(final TLMessage message) { + // send MsgsStateInfo(req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids) + } +} diff --git a/lib/src/main/java/io/github/lonamiwebs/overgram/network/MTProtoState.java b/lib/src/main/java/io/github/lonamiwebs/overgram/network/MTProtoState.java index 76fa68b..2a9ea49 100644 --- a/lib/src/main/java/io/github/lonamiwebs/overgram/network/MTProtoState.java +++ b/lib/src/main/java/io/github/lonamiwebs/overgram/network/MTProtoState.java @@ -1,11 +1,15 @@ package io.github.lonamiwebs.overgram.network; +import io.github.lonamiwebs.overgram.crypto.AuthKey; +import io.github.lonamiwebs.overgram.tl.TLMessage; import io.github.lonamiwebs.overgram.tl.TLObject; import java.security.SecureRandom; public class MTProtoState { + public AuthKey authKey; + public long salt; private final long id; private long timeOffset; private int sequence; @@ -16,11 +20,11 @@ public class MTProtoState { timeOffset = 0; sequence = 0; lastMsgId = 0; - - // TODO auth_key, salt + authKey = null; + salt = 0; } - public Object createMessage(final TLObject object, final long afterId) { + public TLMessage createMessage(final TLObject object) { throw new UnsupportedOperationException(); } @@ -28,11 +32,11 @@ public class MTProtoState { throw new UnsupportedOperationException(); } - public byte[] packMessage(final Object message) { + public byte[] packMessage(final TLMessage message) { throw new UnsupportedOperationException(); } - public Object unpackMessage(final byte[] body) { + public TLMessage unpackMessage(final byte[] body) { throw new UnsupportedOperationException(); } diff --git a/lib/src/main/java/io/github/lonamiwebs/overgram/tl/GzipPacked.java b/lib/src/main/java/io/github/lonamiwebs/overgram/tl/GzipPacked.java new file mode 100644 index 0000000..bb087a9 --- /dev/null +++ b/lib/src/main/java/io/github/lonamiwebs/overgram/tl/GzipPacked.java @@ -0,0 +1,22 @@ +package io.github.lonamiwebs.overgram.tl; + +import io.github.lonamiwebs.overgram.utils.BinaryReader; +import io.github.lonamiwebs.overgram.utils.BinaryWriter; + +public class GzipPacked extends TLObject { + public static final int CONSTRUCTOR_ID = 812830625; + + @Override + public void serialize(BinaryWriter writer) { + throw new UnsupportedOperationException(); + } + + @Override + public void deserialize(BinaryReader reader) throws ClassNotFoundException { + throw new UnsupportedOperationException(); + } + + public TLObject packedObject() { + throw new UnsupportedOperationException(); + } +} diff --git a/lib/src/main/java/io/github/lonamiwebs/overgram/tl/MessageContainer.java b/lib/src/main/java/io/github/lonamiwebs/overgram/tl/MessageContainer.java new file mode 100644 index 0000000..8507cc9 --- /dev/null +++ b/lib/src/main/java/io/github/lonamiwebs/overgram/tl/MessageContainer.java @@ -0,0 +1,24 @@ +package io.github.lonamiwebs.overgram.tl; + +import io.github.lonamiwebs.overgram.utils.BinaryReader; +import io.github.lonamiwebs.overgram.utils.BinaryWriter; + +import java.util.List; + +public class MessageContainer extends TLObject { + public static final int CONSTRUCTOR_ID = 1945237724; + + @Override + public void serialize(BinaryWriter writer) { + throw new UnsupportedOperationException(); + } + + @Override + public void deserialize(BinaryReader reader) throws ClassNotFoundException { + throw new UnsupportedOperationException(); + } + + public List messages() { + throw new UnsupportedOperationException(); + } +} diff --git a/lib/src/main/java/io/github/lonamiwebs/overgram/tl/RpcResult.java b/lib/src/main/java/io/github/lonamiwebs/overgram/tl/RpcResult.java new file mode 100644 index 0000000..1793a24 --- /dev/null +++ b/lib/src/main/java/io/github/lonamiwebs/overgram/tl/RpcResult.java @@ -0,0 +1,26 @@ +package io.github.lonamiwebs.overgram.tl; + +import io.github.lonamiwebs.overgram.utils.BinaryReader; +import io.github.lonamiwebs.overgram.utils.BinaryWriter; + +public class RpcResult extends TLObject { + public static final int CONSTRUCTOR_ID = -212046591; + + @Override + public void serialize(BinaryWriter writer) { + throw new UnsupportedOperationException(); + } + + @Override + public void deserialize(BinaryReader reader) throws ClassNotFoundException { + throw new UnsupportedOperationException(); + } + + public long reqMsgId() { + throw new UnsupportedOperationException(); + } + + public byte[] result() { + throw new UnsupportedOperationException(); + } +} diff --git a/lib/src/main/java/io/github/lonamiwebs/overgram/tl/TLMessage.java b/lib/src/main/java/io/github/lonamiwebs/overgram/tl/TLMessage.java new file mode 100644 index 0000000..3c50f44 --- /dev/null +++ b/lib/src/main/java/io/github/lonamiwebs/overgram/tl/TLMessage.java @@ -0,0 +1,10 @@ +package io.github.lonamiwebs.overgram.tl; + +import java.util.concurrent.CompletableFuture; + +public class TLMessage { + public CompletableFuture future; + public long id; + + public TLObject object; +} diff --git a/lib/src/main/java/io/github/lonamiwebs/overgram/tl/TLRequest.java b/lib/src/main/java/io/github/lonamiwebs/overgram/tl/TLRequest.java index 2a9249c..acb1787 100644 --- a/lib/src/main/java/io/github/lonamiwebs/overgram/tl/TLRequest.java +++ b/lib/src/main/java/io/github/lonamiwebs/overgram/tl/TLRequest.java @@ -1,4 +1,7 @@ package io.github.lonamiwebs.overgram.tl; -public abstract class TLRequest extends TLObject { +import io.github.lonamiwebs.overgram.utils.BinaryReader; + +public abstract class TLRequest extends TLObject { + public abstract T readResult(final BinaryReader reader); }