package io.github.lonamiwebs.overgram.network; import io.github.lonamiwebs.overgram.crypto.Authenticator; import io.github.lonamiwebs.overgram.network.connection.Connection; import io.github.lonamiwebs.overgram.tl.*; import io.github.lonamiwebs.overgram.utils.BinaryReader; import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.concurrent.BlockingQueue; import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; public class MTProtoSender { public final MTProtoState state; private final Connection connection; private String ip; private int port; private boolean userConnected; private boolean reconnecting; private Thread sendHandle; private Thread recvHandle; private BlockingQueue sendQueue; private HashMap pendingMessages; private HashSet pendingAck; private TLMessage lastAck; // to acknowledge acknowledgements public int retries = 5; public MTProtoSender(final MTProtoState state, final Connection connection) { this.state = state; this.connection = connection; sendQueue = new LinkedBlockingQueue(); pendingMessages = new HashMap<>(); pendingAck = new HashSet<>(); } public void connect(final String ip, final int port) throws IOException { this.ip = ip; this.port = port; userConnected = true; doConnect(); } private void doConnect() throws IOException { boolean success = false; for (int i = 0; i < retries; ++i) { try { connection.connect(ip, port); success = true; break; } catch (IOException ignored) { } } if (!success) { throw new IOException("Failed to connect " + retries + " times"); } if (state.authKey == null) { final MTProtoPlainSender plain = new MTProtoPlainSender(connection); success = false; for (int i = 0; i < retries; ++i) { try { state.authKey = Authenticator.doAuthentication(plain); success = true; break; } catch (SecurityException | IOException | ClassNotFoundException ignored) { } } if (!success) { disconnect(); throw new IOException("Failed to generate AuthKey"); } } sendHandle = new Thread(this::sendLoop); sendHandle.setDaemon(true); sendHandle.start(); recvHandle = new Thread(this::recvLoop); recvHandle.setDaemon(true); recvHandle.start(); } public void disconnect() { if (!userConnected) { doDisconnect(); } } private void doDisconnect() { userConnected = false; connection.disconnect(); pendingMessages.clear(); pendingAck.clear(); lastAck = null; stopHandles(); } private void stopHandles() { if (sendHandle != null) { sendHandle.interrupt(); try { sendHandle.join(); } catch (InterruptedException ignored) { } } if (recvHandle != null) { recvHandle.interrupt(); try { recvHandle.join(); } catch (InterruptedException ignored) { } } } private void doReconnect() { if (userConnected) { final Thread thread = new Thread(this::reconnect); thread.setDaemon(true); thread.start(); } } private void reconnect() { reconnecting = true; stopHandles(); connection.disconnect(); reconnecting = false; try { doConnect(); } catch (IOException ignored) { doDisconnect(); } } public Future send(TLRequest request) throws IOException { if (!userConnected) { throw new IOException("Not connected"); } final TLMessage message = state.createMessage(request); pendingMessages.put(message.id, message); try { sendQueue.put(message); } catch (InterruptedException ignored) { throw new IOException("Failed to put message"); } return message.future; } private void sendLoop() { while (userConnected) { if (!pendingAck.isEmpty()) { lastAck = state.createMessage(new Types.MsgsAck().msgIds(new ArrayList<>(pendingAck))); try { sendQueue.put(lastAck); } catch (InterruptedException ignored) { doDisconnect(); return; } pendingAck.clear(); } final TLMessage message; try { message = sendQueue.poll(1, TimeUnit.SECONDS); if (message == null) { continue; } } catch (InterruptedException ignored) { doDisconnect(); return; } final byte[] body = state.packMessage(message); while (!message.future.isCancelled()) { try { connection.send(body); break; } catch (IOException ignored) { // TODO e.g. on timeout retry (loop; not done yet) try { Thread.sleep(1000); doReconnect(); return; } catch (InterruptedException ignored2) { doDisconnect(); return; } } } } } private void recvLoop() { while (userConnected) { final byte[] body; try { body = connection.recv(); } catch (IOException ignored) { doReconnect(); return; } final TLMessage message; try { message = state.unpackMessage(body); } catch (ClassNotFoundException ignored) { continue; } try { processMessage(message); } catch (InterruptedException ignored) { doDisconnect(); return; } } } private void processMessage(final TLMessage message) throws InterruptedException { pendingAck.add(message.id); if (message.object instanceof RPCResult) { handleRpcResult(message); } else if (message.object instanceof MessageContainer) { handleContainer(message); } else if (message.object instanceof GzipPacked) { handleGzipPacked(message); } else if (message.object instanceof Types.MsgsAck) { handleAck(message); } else if (message.object instanceof Abstract.Updates || message.object instanceof Abstract.Update) { handleUpdate(message); } else if (message.object instanceof Types.Pong) { handlePong(message); } else if (message.object instanceof Types.BadServerSalt) { handleBadServerSalt(message); } else if (message.object instanceof Types.BadMsgNotification) { handleBadNotification(message); } else if (message.object instanceof Types.MsgDetailedInfo) { handleDetailedInfo(message); } else if (message.object instanceof Types.MsgNewDetailedInfo) { handleNewDetailedInfo(message); } else if (message.object instanceof Types.NewSessionCreated) { handleNewSessionCreated(message); } else if (message.object instanceof Types.FutureSalts) { handleFutureSalts(message); } else if (message.object instanceof Types.MsgsStateReq) { handleStateForgotten(message); } else if (message.object instanceof Types.MsgResendReq) { handleStateForgotten(message); } else if (message.object instanceof Types.MsgsAllInfo) { handleMsgAll(message); } } private void handleRpcResult(final TLMessage message) { final RPCResult result = (RPCResult) message.object; final TLMessage replyMessage = pendingMessages.remove(result.reqMsgId); if (result.error != null) { replyMessage.future.completeExceptionally(new RPCError(result.error)); return; } final BinaryReader reader = new BinaryReader(ByteBuffer.wrap(result.result)); try { replyMessage.future.complete(((TLRequest) replyMessage.object).readResult(reader)); } catch (ClassNotFoundException e) { replyMessage.future.completeExceptionally(e); } } public void handleContainer(final TLMessage message) throws InterruptedException { final MessageContainer result = (MessageContainer) message.object; for (final TLMessage innerMessage : result.messages) { processMessage(innerMessage); } } public void handleGzipPacked(final TLMessage message) throws InterruptedException { final GzipPacked result = (GzipPacked) message.object; message.object = result.packedObject(); processMessage(message); } public void handleUpdate(final TLMessage message) { } public void handlePong(final TLMessage message) { final Types.Pong result = (Types.Pong) message.object; final TLMessage replyMessage = pendingMessages.remove(result.msgId()); if (replyMessage != null) { replyMessage.future.complete(result); } } public void handleBadServerSalt(final TLMessage message) throws InterruptedException { final Types.BadServerSalt result = (Types.BadServerSalt) message.object; state.salt = result.newServerSalt(); if (lastAck != null && result.badMsgId() == lastAck.id) { sendQueue.put(lastAck); } final TLMessage badMessage = pendingMessages.get(result.badMsgId()); if (badMessage != null) { sendQueue.put(badMessage); } } public void handleBadNotification(final TLMessage message) { final Types.BadMsgNotification result = (Types.BadMsgNotification) message.object; final TLMessage badMessage = pendingMessages.get(result.badMsgId()); if (result.errorCode() == 16 || result.errorCode() == 17) { if (badMessage != null) { // resend and update time offset throw new UnsupportedOperationException(); } return; } throw new UnsupportedOperationException(); } public void handleDetailedInfo(final TLMessage message) { final Types.MsgDetailedInfo result = (Types.MsgDetailedInfo) message.object; pendingAck.add(result.answerMsgId()); } public void handleNewDetailedInfo(final TLMessage message) { final Types.MsgNewDetailedInfo result = (Types.MsgNewDetailedInfo) message.object; pendingAck.add(result.answerMsgId()); } public void handleNewSessionCreated(final TLMessage message) { final Types.NewSessionCreated result = (Types.NewSessionCreated) message.object; state.salt = result.serverSalt(); } public void handleAck(final TLMessage message) { // check if ack-ed logout } public void handleFutureSalts(final TLMessage message) { // check if there's request } public void handleStateForgotten(final TLMessage message) { // send MsgsStateInfo(req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids) } public void handleMsgAll(final TLMessage message) { // send MsgsStateInfo(req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids) } }