367 lines
12 KiB
Java
367 lines
12 KiB
Java
package io.github.lonamiwebs.overgram.network;
|
|
|
|
import io.github.lonamiwebs.overgram.crypto.Authenticator;
|
|
import io.github.lonamiwebs.overgram.network.connection.Connection;
|
|
import io.github.lonamiwebs.overgram.tl.*;
|
|
import io.github.lonamiwebs.overgram.utils.BinaryReader;
|
|
|
|
import java.io.IOException;
|
|
import java.nio.ByteBuffer;
|
|
import java.util.ArrayList;
|
|
import java.util.HashMap;
|
|
import java.util.HashSet;
|
|
import java.util.concurrent.BlockingQueue;
|
|
import java.util.concurrent.Future;
|
|
import java.util.concurrent.LinkedBlockingQueue;
|
|
import java.util.concurrent.TimeUnit;
|
|
|
|
public class MTProtoSender {
|
|
public final MTProtoState state;
|
|
private final Connection connection;
|
|
private String ip;
|
|
private int port;
|
|
private boolean userConnected;
|
|
private boolean reconnecting;
|
|
|
|
private Thread sendHandle;
|
|
private Thread recvHandle;
|
|
private BlockingQueue<TLMessage> sendQueue;
|
|
private HashMap<Long, TLMessage> pendingMessages;
|
|
private HashSet<Long> pendingAck;
|
|
private TLMessage lastAck; // to acknowledge acknowledgements
|
|
|
|
public int retries = 5;
|
|
|
|
public MTProtoSender(final MTProtoState state, final Connection connection) {
|
|
this.state = state;
|
|
this.connection = connection;
|
|
|
|
sendQueue = new LinkedBlockingQueue();
|
|
pendingMessages = new HashMap<>();
|
|
pendingAck = new HashSet<>();
|
|
}
|
|
|
|
public void connect(final String ip, final int port) throws IOException {
|
|
this.ip = ip;
|
|
this.port = port;
|
|
userConnected = true;
|
|
doConnect();
|
|
}
|
|
|
|
private void doConnect() throws IOException {
|
|
boolean success = false;
|
|
for (int i = 0; i < retries; ++i) {
|
|
try {
|
|
connection.connect(ip, port);
|
|
success = true;
|
|
break;
|
|
} catch (IOException ignored) {
|
|
}
|
|
}
|
|
if (!success) {
|
|
throw new IOException("Failed to connect " + retries + " times");
|
|
}
|
|
|
|
if (state.authKey == null) {
|
|
final MTProtoPlainSender plain = new MTProtoPlainSender(connection);
|
|
success = false;
|
|
for (int i = 0; i < retries; ++i) {
|
|
try {
|
|
state.authKey = Authenticator.doAuthentication(plain);
|
|
success = true;
|
|
break;
|
|
} catch (SecurityException | IOException | ClassNotFoundException ignored) {
|
|
}
|
|
}
|
|
|
|
if (!success) {
|
|
disconnect();
|
|
throw new IOException("Failed to generate AuthKey");
|
|
}
|
|
}
|
|
|
|
sendHandle = new Thread(this::sendLoop);
|
|
sendHandle.setDaemon(true);
|
|
sendHandle.start();
|
|
|
|
recvHandle = new Thread(this::recvLoop);
|
|
recvHandle.setDaemon(true);
|
|
recvHandle.start();
|
|
}
|
|
|
|
public void disconnect() {
|
|
if (!userConnected) {
|
|
doDisconnect();
|
|
}
|
|
}
|
|
|
|
private void doDisconnect() {
|
|
userConnected = false;
|
|
connection.disconnect();
|
|
pendingMessages.clear();
|
|
pendingAck.clear();
|
|
lastAck = null;
|
|
stopHandles();
|
|
}
|
|
|
|
private void stopHandles() {
|
|
if (sendHandle != null) {
|
|
sendHandle.interrupt();
|
|
try {
|
|
sendHandle.join();
|
|
} catch (InterruptedException ignored) {
|
|
}
|
|
}
|
|
|
|
if (recvHandle != null) {
|
|
recvHandle.interrupt();
|
|
try {
|
|
recvHandle.join();
|
|
} catch (InterruptedException ignored) {
|
|
}
|
|
}
|
|
}
|
|
|
|
private void doReconnect() {
|
|
if (userConnected) {
|
|
final Thread thread = new Thread(this::reconnect);
|
|
thread.setDaemon(true);
|
|
thread.start();
|
|
}
|
|
}
|
|
|
|
private void reconnect() {
|
|
reconnecting = true;
|
|
stopHandles();
|
|
connection.disconnect();
|
|
reconnecting = false;
|
|
|
|
try {
|
|
doConnect();
|
|
} catch (IOException ignored) {
|
|
doDisconnect();
|
|
}
|
|
}
|
|
|
|
public Future send(TLRequest request) throws IOException {
|
|
if (!userConnected) {
|
|
throw new IOException("Not connected");
|
|
}
|
|
|
|
final TLMessage message = state.createMessage(request);
|
|
pendingMessages.put(message.id, message);
|
|
try {
|
|
sendQueue.put(message);
|
|
} catch (InterruptedException ignored) {
|
|
throw new IOException("Failed to put message");
|
|
}
|
|
|
|
return message.future;
|
|
}
|
|
|
|
private void sendLoop() {
|
|
while (userConnected) {
|
|
if (!pendingAck.isEmpty()) {
|
|
lastAck = state.createMessage(new Types.MsgsAck().msgIds(new ArrayList<>(pendingAck)));
|
|
try {
|
|
sendQueue.put(lastAck);
|
|
} catch (InterruptedException ignored) {
|
|
doDisconnect();
|
|
return;
|
|
}
|
|
pendingAck.clear();
|
|
}
|
|
|
|
final TLMessage message;
|
|
try {
|
|
message = sendQueue.poll(1, TimeUnit.SECONDS);
|
|
if (message == null) {
|
|
continue;
|
|
}
|
|
} catch (InterruptedException ignored) {
|
|
doDisconnect();
|
|
return;
|
|
}
|
|
|
|
final byte[] body = state.packMessage(message);
|
|
while (!message.future.isCancelled()) {
|
|
try {
|
|
connection.send(body);
|
|
break;
|
|
} catch (IOException ignored) {
|
|
// TODO e.g. on timeout retry (loop; not done yet)
|
|
try {
|
|
Thread.sleep(1000);
|
|
doReconnect();
|
|
return;
|
|
} catch (InterruptedException ignored2) {
|
|
doDisconnect();
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
private void recvLoop() {
|
|
while (userConnected) {
|
|
final byte[] body;
|
|
try {
|
|
body = connection.recv();
|
|
} catch (IOException ignored) {
|
|
doReconnect();
|
|
return;
|
|
}
|
|
|
|
final TLMessage message;
|
|
try {
|
|
message = state.unpackMessage(body);
|
|
} catch (ClassNotFoundException ignored) {
|
|
continue;
|
|
}
|
|
try {
|
|
processMessage(message);
|
|
} catch (InterruptedException ignored) {
|
|
doDisconnect();
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
private void processMessage(final TLMessage message) throws InterruptedException {
|
|
pendingAck.add(message.id);
|
|
if (message.object instanceof RPCResult) {
|
|
handleRpcResult(message);
|
|
} else if (message.object instanceof MessageContainer) {
|
|
handleContainer(message);
|
|
} else if (message.object instanceof GzipPacked) {
|
|
handleGzipPacked(message);
|
|
} else if (message.object instanceof Types.MsgsAck) {
|
|
handleAck(message);
|
|
} else if (message.object instanceof Abstract.Updates || message.object instanceof Abstract.Update) {
|
|
handleUpdate(message);
|
|
} else if (message.object instanceof Types.Pong) {
|
|
handlePong(message);
|
|
} else if (message.object instanceof Types.BadServerSalt) {
|
|
handleBadServerSalt(message);
|
|
} else if (message.object instanceof Types.BadMsgNotification) {
|
|
handleBadNotification(message);
|
|
} else if (message.object instanceof Types.MsgDetailedInfo) {
|
|
handleDetailedInfo(message);
|
|
} else if (message.object instanceof Types.MsgNewDetailedInfo) {
|
|
handleNewDetailedInfo(message);
|
|
} else if (message.object instanceof Types.NewSessionCreated) {
|
|
handleNewSessionCreated(message);
|
|
} else if (message.object instanceof Types.FutureSalts) {
|
|
handleFutureSalts(message);
|
|
} else if (message.object instanceof Types.MsgsStateReq) {
|
|
handleStateForgotten(message);
|
|
} else if (message.object instanceof Types.MsgResendReq) {
|
|
handleStateForgotten(message);
|
|
} else if (message.object instanceof Types.MsgsAllInfo) {
|
|
handleMsgAll(message);
|
|
}
|
|
}
|
|
|
|
private void handleRpcResult(final TLMessage message) {
|
|
final RPCResult result = (RPCResult) message.object;
|
|
final TLMessage replyMessage = pendingMessages.remove(result.reqMsgId);
|
|
|
|
if (result.error != null) {
|
|
replyMessage.future.completeExceptionally(new RPCError(result.error));
|
|
return;
|
|
}
|
|
|
|
final BinaryReader reader = new BinaryReader(ByteBuffer.wrap(result.result));
|
|
try {
|
|
replyMessage.future.complete(((TLRequest) replyMessage.object).readResult(reader));
|
|
} catch (ClassNotFoundException e) {
|
|
replyMessage.future.completeExceptionally(e);
|
|
}
|
|
}
|
|
|
|
public void handleContainer(final TLMessage message) throws InterruptedException {
|
|
final MessageContainer result = (MessageContainer) message.object;
|
|
for (final TLMessage innerMessage : result.messages) {
|
|
processMessage(innerMessage);
|
|
}
|
|
}
|
|
|
|
public void handleGzipPacked(final TLMessage message) throws InterruptedException {
|
|
final GzipPacked result = (GzipPacked) message.object;
|
|
message.object = result.packedObject();
|
|
processMessage(message);
|
|
}
|
|
|
|
public void handleUpdate(final TLMessage message) {
|
|
}
|
|
|
|
public void handlePong(final TLMessage message) {
|
|
final Types.Pong result = (Types.Pong) message.object;
|
|
final TLMessage replyMessage = pendingMessages.remove(result.msgId());
|
|
if (replyMessage != null) {
|
|
replyMessage.future.complete(result);
|
|
}
|
|
}
|
|
|
|
public void handleBadServerSalt(final TLMessage message) throws InterruptedException {
|
|
final Types.BadServerSalt result = (Types.BadServerSalt) message.object;
|
|
state.salt = result.newServerSalt();
|
|
if (lastAck != null && result.badMsgId() == lastAck.id) {
|
|
sendQueue.put(lastAck);
|
|
}
|
|
|
|
final TLMessage badMessage = pendingMessages.get(result.badMsgId());
|
|
if (badMessage != null) {
|
|
sendQueue.put(badMessage);
|
|
}
|
|
}
|
|
|
|
public void handleBadNotification(final TLMessage message) {
|
|
final Types.BadMsgNotification result = (Types.BadMsgNotification) message.object;
|
|
final TLMessage badMessage = pendingMessages.get(result.badMsgId());
|
|
|
|
if (result.errorCode() == 16 || result.errorCode() == 17) {
|
|
if (badMessage != null) {
|
|
// resend and update time offset
|
|
throw new UnsupportedOperationException();
|
|
}
|
|
return;
|
|
}
|
|
|
|
throw new UnsupportedOperationException();
|
|
}
|
|
|
|
public void handleDetailedInfo(final TLMessage message) {
|
|
final Types.MsgDetailedInfo result = (Types.MsgDetailedInfo) message.object;
|
|
pendingAck.add(result.answerMsgId());
|
|
}
|
|
|
|
public void handleNewDetailedInfo(final TLMessage message) {
|
|
final Types.MsgNewDetailedInfo result = (Types.MsgNewDetailedInfo) message.object;
|
|
pendingAck.add(result.answerMsgId());
|
|
}
|
|
|
|
public void handleNewSessionCreated(final TLMessage message) {
|
|
final Types.NewSessionCreated result = (Types.NewSessionCreated) message.object;
|
|
state.salt = result.serverSalt();
|
|
}
|
|
|
|
public void handleAck(final TLMessage message) {
|
|
// check if ack-ed logout
|
|
}
|
|
|
|
public void handleFutureSalts(final TLMessage message) {
|
|
// check if there's request
|
|
}
|
|
|
|
public void handleStateForgotten(final TLMessage message) {
|
|
// send MsgsStateInfo(req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)
|
|
}
|
|
|
|
public void handleMsgAll(final TLMessage message) {
|
|
// send MsgsStateInfo(req_msg_id=message.msg_id, info=chr(1) * len(message.obj.msg_ids)
|
|
}
|
|
}
|