151 lines
5.2 KiB
Java
151 lines
5.2 KiB
Java
package io.github.lonamiwebs.overgram.network;
|
|
|
|
import io.github.lonamiwebs.overgram.crypto.AES;
|
|
import io.github.lonamiwebs.overgram.crypto.AuthKey;
|
|
import io.github.lonamiwebs.overgram.tl.TLMessage;
|
|
import io.github.lonamiwebs.overgram.tl.TLObject;
|
|
import io.github.lonamiwebs.overgram.tl.TLRequest;
|
|
import io.github.lonamiwebs.overgram.utils.BinaryReader;
|
|
import io.github.lonamiwebs.overgram.utils.BinaryWriter;
|
|
import io.github.lonamiwebs.overgram.utils.Utils;
|
|
import javafx.util.Pair;
|
|
|
|
import java.nio.ByteBuffer;
|
|
import java.security.SecureRandom;
|
|
import java.util.Arrays;
|
|
|
|
public class MTProtoState {
|
|
|
|
public AuthKey authKey;
|
|
public long salt;
|
|
private final long id;
|
|
private long timeOffset;
|
|
private int sequence;
|
|
private long lastMsgId;
|
|
|
|
public MTProtoState(final AuthKey authKey) {
|
|
this.authKey = authKey;
|
|
salt = 0;
|
|
id = new SecureRandom().nextLong();
|
|
timeOffset = 0;
|
|
sequence = 0;
|
|
lastMsgId = 0;
|
|
}
|
|
|
|
public TLMessage createMessage(final TLObject object) {
|
|
return new TLMessage(getNewMsgId(), getSeqNo(object instanceof TLRequest), object);
|
|
}
|
|
|
|
private Pair<byte[], byte[]> calcKey(final byte[] msgKey, final boolean client) {
|
|
final int x = client ? 0 : 8;
|
|
final byte[] sha256a = Utils.sha256digest(msgKey, Arrays.copyOfRange(authKey.key, x, x + 36));
|
|
final byte[] sha256b = Utils.sha256digest(Arrays.copyOfRange(authKey.key, x + 40, x + 76), msgKey);
|
|
|
|
final BinaryWriter writer = new BinaryWriter(32);
|
|
writer.writeRaw(sha256a, 0, 8);
|
|
writer.writeRaw(sha256b, 8, 16);
|
|
writer.writeRaw(sha256a, 24, 8);
|
|
final byte[] key = writer.toBytes();
|
|
writer.clear();
|
|
writer.writeRaw(sha256b, 0, 8);
|
|
writer.writeRaw(sha256a, 8, 16);
|
|
writer.writeRaw(sha256b, 24, 8);
|
|
|
|
return new Pair<>(key, writer.toBytes());
|
|
}
|
|
|
|
public byte[] packMessage(final TLMessage message) {
|
|
final BinaryWriter writer = new BinaryWriter();
|
|
writer.write(salt);
|
|
writer.write(id);
|
|
writer.write(message);
|
|
|
|
int padding = writer.size() % 16;
|
|
writer.writeRaw(Utils.randomBytes(32 - padding));
|
|
|
|
final byte[] paddedData = writer.toBytes();
|
|
writer.clear();
|
|
|
|
final byte[] msgKeyLarge = Utils.sha256digest(
|
|
Arrays.copyOfRange(authKey.key, 88, 88 + 32), paddedData);
|
|
|
|
final byte[] msgKey = Arrays.copyOfRange(msgKeyLarge, 8, 24);
|
|
final Pair<byte[], byte[]> keyIv = calcKey(msgKey, true);
|
|
|
|
writer.write(authKey.keyId);
|
|
writer.writeRaw(msgKey);
|
|
writer.writeRaw(AES.encryptIge(paddedData, keyIv.getKey(), keyIv.getValue()));
|
|
return writer.toBytes();
|
|
}
|
|
|
|
public TLMessage unpackMessage(final byte[] body) throws ClassNotFoundException {
|
|
if (body.length < 8) {
|
|
if (body[0] == (byte) 0x6c
|
|
&& body[1] == (byte) 0xfe
|
|
&& body[2] == (byte) 0xff
|
|
&& body[3] == (byte) 0xff) {
|
|
// -404 as little endian, broken authorization
|
|
throw new RuntimeException();
|
|
} else {
|
|
throw new RuntimeException();
|
|
}
|
|
}
|
|
|
|
final BinaryReader reader = new BinaryReader(ByteBuffer.wrap(body));
|
|
final long keyId = reader.readLong();
|
|
if (keyId != authKey.keyId) {
|
|
throw new SecurityException("Server replied with an invalid auth key");
|
|
}
|
|
|
|
final byte[] msgKey = reader.read(16);
|
|
final Pair<byte[], byte[]> keyIv = calcKey(msgKey, false);
|
|
final byte[] plainText = AES.decryptIge(reader.read(), keyIv.getKey(), keyIv.getValue());
|
|
|
|
final byte[] ourKey = Arrays.copyOfRange(Utils.sha256digest(
|
|
Arrays.copyOfRange(authKey.key, 96, 128), plainText), 8, 24);
|
|
|
|
if (!Arrays.equals(msgKey, ourKey)) {
|
|
throw new SecurityException("Received message key doesn't match with expected one");
|
|
}
|
|
|
|
final BinaryReader tlReader = new BinaryReader(ByteBuffer.wrap(plainText));
|
|
tlReader.readLong(); // remote salt
|
|
if (tlReader.readLong() != id) {
|
|
throw new SecurityException("Server replied with a wrong session ID");
|
|
}
|
|
|
|
final long remoteMsgId = tlReader.readLong();
|
|
final int remoteSeq = tlReader.readInt();
|
|
tlReader.readInt(); // inner message length
|
|
|
|
final TLObject object = tlReader.readTl();
|
|
return new TLMessage(remoteMsgId, remoteSeq, object);
|
|
}
|
|
|
|
public long getNewMsgId() {
|
|
final long now = System.currentTimeMillis();
|
|
long newMsgId = (((now / 1000) + timeOffset) << 32) | ((now % 1000) << 2);
|
|
if (lastMsgId >= newMsgId) {
|
|
newMsgId = lastMsgId + 4;
|
|
}
|
|
|
|
lastMsgId = newMsgId;
|
|
return newMsgId;
|
|
}
|
|
|
|
public void updateTimeOffset(long correctMsgId) {
|
|
final long now = System.currentTimeMillis() / 1000L;
|
|
final long correct = correctMsgId >> 32;
|
|
timeOffset = correct - now;
|
|
lastMsgId = 0;
|
|
}
|
|
|
|
public int getSeqNo(final boolean contentRelated) {
|
|
if (contentRelated) {
|
|
return 1 + 2 * sequence++;
|
|
} else {
|
|
return 2 * sequence;
|
|
}
|
|
}
|
|
}
|