Overgram/lib/src/main/java/io/github/lonamiwebs/overgram/network/MTProtoSender.java

367 lines
12 KiB
Java

package io.github.lonamiwebs.overgram.network;
import io.github.lonamiwebs.overgram.crypto.Authenticator;
import io.github.lonamiwebs.overgram.network.connection.Connection;
import io.github.lonamiwebs.overgram.tl.*;
import io.github.lonamiwebs.overgram.utils.BinaryReader;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
public class MTProtoSender {
public final MTProtoState state;
private final Connection connection;
private String ip;
private int port;
private boolean userConnected;
private boolean reconnecting;
private Thread sendHandle;
private Thread recvHandle;
private BlockingQueue<TLMessage> sendQueue;
private HashMap<Long, TLMessage> pendingMessages;
private HashSet<Long> pendingAck;
private TLMessage lastAck; // to acknowledge acknowledgements
public int retries = 5;
public MTProtoSender(final MTProtoState state, final Connection connection) {
this.state = state;
this.connection = connection;
sendQueue = new LinkedBlockingQueue();
pendingMessages = new HashMap<>();
pendingAck = new HashSet<>();
}
public void connect(final String ip, final int port) throws IOException {
this.ip = ip;
this.port = port;
userConnected = true;
doConnect();
}
private void doConnect() throws IOException {
boolean success = false;
for (int i = 0; i < retries; ++i) {
try {
connection.connect(ip, port);
success = true;
break;
} catch (IOException ignored) {
}
}
if (!success) {
throw new IOException("Failed to connect " + retries + " times");
}
if (state.authKey == null) {
final MTProtoPlainSender plain = new MTProtoPlainSender(connection);
success = false;
for (int i = 0; i < retries; ++i) {
try {
state.authKey = Authenticator.doAuthentication(plain);
success = true;
break;
} catch (SecurityException | IOException | ClassNotFoundException ignored) {
}
}
if (!success) {
disconnect();
throw new IOException("Failed to generate AuthKey");
}
}
sendHandle = new Thread(this::sendLoop);
sendHandle.setDaemon(true);
sendHandle.start();
recvHandle = new Thread(this::recvLoop);
recvHandle.setDaemon(true);
recvHandle.start();
}
public void disconnect() {
if (!userConnected) {
doDisconnect();
}
}
private void doDisconnect() {
userConnected = false;
connection.disconnect();
pendingMessages.clear();
pendingAck.clear();
lastAck = null;
stopHandles();
}
private void stopHandles() {
if (sendHandle != null) {
sendHandle.interrupt();
try {
sendHandle.join();
} catch (InterruptedException ignored) {
}
}
if (recvHandle != null) {
recvHandle.interrupt();
try {
recvHandle.join();
} catch (InterruptedException ignored) {
}
}
}
private void doReconnect() {
if (userConnected) {
final Thread thread = new Thread(this::reconnect);
thread.setDaemon(true);
thread.start();
}
}
private void reconnect() {
reconnecting = true;
stopHandles();
connection.disconnect();
reconnecting = false;
try {
doConnect();
} catch (IOException ignored) {
doDisconnect();
}
}
public Future send(TLRequest request) throws IOException {
if (!userConnected) {
throw new IOException("Not connected");
}
final TLMessage message = state.createMessage(request);
pendingMessages.put(message.id, message);
try {
sendQueue.put(message);
} catch (InterruptedException ignored) {
throw new IOException("Failed to put message");
}
return message.future;
}
private void sendLoop() {
while (userConnected) {
if (!pendingAck.isEmpty()) {
lastAck = state.createMessage(new Types.MsgsAck().msgIds(new ArrayList<>(pendingAck)));
try {
sendQueue.put(lastAck);
} catch (InterruptedException ignored) {
doDisconnect();
return;
}
pendingAck.clear();
}
final TLMessage message;
try {
message = sendQueue.poll(1, TimeUnit.SECONDS);
if (message == null) {
continue;
}
} catch (InterruptedException ignored) {
doDisconnect();
return;
}
final byte[] body = state.packMessage(message);
while (!message.future.isCancelled()) {
try {
connection.send(body);
break;
} catch (IOException ignored) {
// TODO e.g. on timeout retry (loop; not done yet)
try {
Thread.sleep(1000);
doReconnect();
return;
} catch (InterruptedException ignored2) {
doDisconnect();
return;
}
}
}
}
}
private void recvLoop() {
while (userConnected) {
final byte[] body;
try {
body = connection.recv();
} catch (IOException ignored) {
doReconnect();
return;
}
final TLMessage message;
try {
message = state.unpackMessage(body);
} catch (ClassNotFoundException ignored) {
continue;
}
try {
processMessage(message);
} catch (InterruptedException ignored) {
doDisconnect();
return;
}
}
}
private void processMessage(final TLMessage message) throws InterruptedException {
pendingAck.add(message.id);
if (message.object instanceof RPCResult) {
handleRpcResult(message);
} else if (message.object instanceof MessageContainer) {
handleContainer(message);
} else if (message.object instanceof GzipPacked) {
handleGzipPacked(message);
} else if (message.object instanceof Types.MsgsAck) {
handleAck(message);
} else if (message.object instanceof Abstract.Updates || message.object instanceof Abstract.Update) {
handleUpdate(message);
} else if (message.object instanceof Types.Pong) {
handlePong(message);
} else if (message.object instanceof Types.BadServerSalt) {
handleBadServerSalt(message);
} else if (message.object instanceof Types.BadMsgNotification) {
handleBadNotification(message);
} else if (message.object instanceof Types.MsgDetailedInfo) {
handleDetailedInfo(message);
} else if (message.object instanceof Types.MsgNewDetailedInfo) {
handleNewDetailedInfo(message);
} else if (message.object instanceof Types.NewSessionCreated) {
handleNewSessionCreated(message);
} else if (message.object instanceof Types.FutureSalts) {
handleFutureSalts(message);
} else if (message.object instanceof Types.MsgsStateReq) {
handleStateForgotten(message);
} else if (message.object instanceof Types.MsgResendReq) {
handleStateForgotten(message);
} else if (message.object instanceof Types.MsgsAllInfo) {
handleMsgAll(message);
}
}
private void handleRpcResult(final TLMessage message) {
final RPCResult result = (RPCResult) message.object;
final TLMessage replyMessage = pendingMessages.remove(result.reqMsgId);
if (result.error != null) {
replyMessage.future.completeExceptionally(new RPCError(result.error));
return;
}
final BinaryReader reader = new BinaryReader(ByteBuffer.wrap(result.result));
try {
replyMessage.future.complete(((TLRequest) replyMessage.object).readResult(reader));
} catch (ClassNotFoundException e) {
replyMessage.future.completeExceptionally(e);
}
}
public void handleContainer(final TLMessage message) throws InterruptedException {
final MessageContainer result = (MessageContainer) message.object;
for (final TLMessage innerMessage : result.messages) {
processMessage(innerMessage);
}
}
public void handleGzipPacked(final TLMessage message) throws InterruptedException {
final GzipPacked result = (GzipPacked) message.object;
message.object = result.packedObject();
processMessage(message);
}
public void handleUpdate(final TLMessage message) {
}
public void handlePong(final TLMessage message) {
final Types.Pong result = (Types.Pong) message.object;
final TLMessage replyMessage = pendingMessages.remove(result.msgId());
if (replyMessage != null) {
replyMessage.future.complete(result);
}
}
public void handleBadServerSalt(final TLMessage message) throws InterruptedException {
final Types.BadServerSalt result = (Types.BadServerSalt) message.object;
state.salt = result.newServerSalt();
if (lastAck != null && result.badMsgId() == lastAck.id) {
sendQueue.put(lastAck);
}
final TLMessage badMessage = pendingMessages.get(result.badMsgId());
if (badMessage != null) {
sendQueue.put(badMessage);
}
}
public void handleBadNotification(final TLMessage message) {
final Types.BadMsgNotification result = (Types.BadMsgNotification) message.object;
final TLMessage badMessage = pendingMessages.get(result.badMsgId());
if (result.errorCode() == 16 || result.errorCode() == 17) {
if (badMessage != null) {
// resend and update time offset
throw new UnsupportedOperationException();
}
return;
}
throw new UnsupportedOperationException();
}
public void handleDetailedInfo(final TLMessage message) {
final Types.MsgDetailedInfo result = (Types.MsgDetailedInfo) message.object;
pendingAck.add(result.answerMsgId());
}
public void handleNewDetailedInfo(final TLMessage message) {
final Types.MsgNewDetailedInfo result = (Types.MsgNewDetailedInfo) message.object;
pendingAck.add(result.answerMsgId());
}
public void handleNewSessionCreated(final TLMessage message) {
final Types.NewSessionCreated result = (Types.NewSessionCreated) message.object;
state.salt = result.serverSalt();
}
public void handleAck(final TLMessage message) {
// check if ack-ed logout
}
public void handleFutureSalts(final TLMessage message) {
// check if there's request
}
public void handleStateForgotten(final TLMessage message) {
// send MsgsStateInfo(req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)
}
public void handleMsgAll(final TLMessage message) {
// send MsgsStateInfo(req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)
}
}