Overgram/lib/src/main/java/io/github/lonamiwebs/overgram/network/MTProtoState.java

151 lines
5.2 KiB
Java

package io.github.lonamiwebs.overgram.network;
import io.github.lonamiwebs.overgram.crypto.AES;
import io.github.lonamiwebs.overgram.crypto.AuthKey;
import io.github.lonamiwebs.overgram.tl.TLMessage;
import io.github.lonamiwebs.overgram.tl.TLObject;
import io.github.lonamiwebs.overgram.tl.TLRequest;
import io.github.lonamiwebs.overgram.utils.BinaryReader;
import io.github.lonamiwebs.overgram.utils.BinaryWriter;
import io.github.lonamiwebs.overgram.utils.Utils;
import javafx.util.Pair;
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.util.Arrays;
public class MTProtoState {
public AuthKey authKey;
public long salt;
private final long id;
private long timeOffset;
private int sequence;
private long lastMsgId;
public MTProtoState(final AuthKey authKey) {
this.authKey = authKey;
salt = 0;
id = new SecureRandom().nextLong();
timeOffset = 0;
sequence = 0;
lastMsgId = 0;
}
public TLMessage createMessage(final TLObject object) {
return new TLMessage(getNewMsgId(), getSeqNo(object instanceof TLRequest), object);
}
private Pair<byte[], byte[]> calcKey(final byte[] msgKey, final boolean client) {
final int x = client ? 0 : 8;
final byte[] sha256a = Utils.sha256digest(msgKey, Arrays.copyOfRange(authKey.key, x, x + 36));
final byte[] sha256b = Utils.sha256digest(Arrays.copyOfRange(authKey.key, x + 40, x + 76), msgKey);
final BinaryWriter writer = new BinaryWriter(32);
writer.writeRaw(sha256a, 0, 8);
writer.writeRaw(sha256b, 8, 16);
writer.writeRaw(sha256a, 24, 8);
final byte[] key = writer.toBytes();
writer.clear();
writer.writeRaw(sha256b, 0, 8);
writer.writeRaw(sha256a, 8, 16);
writer.writeRaw(sha256b, 24, 8);
return new Pair<>(key, writer.toBytes());
}
public byte[] packMessage(final TLMessage message) {
final BinaryWriter writer = new BinaryWriter();
writer.write(salt);
writer.write(id);
writer.write(message);
int padding = writer.size() % 16;
writer.writeRaw(Utils.randomBytes(32 - padding));
final byte[] paddedData = writer.toBytes();
writer.clear();
final byte[] msgKeyLarge = Utils.sha256digest(
Arrays.copyOfRange(authKey.key, 88, 88 + 32), paddedData);
final byte[] msgKey = Arrays.copyOfRange(msgKeyLarge, 8, 24);
final Pair<byte[], byte[]> keyIv = calcKey(msgKey, true);
writer.write(authKey.keyId);
writer.writeRaw(msgKey);
writer.writeRaw(AES.encryptIge(paddedData, keyIv.getKey(), keyIv.getValue()));
return writer.toBytes();
}
public TLMessage unpackMessage(final byte[] body) throws ClassNotFoundException {
if (body.length < 8) {
if (body[0] == (byte) 0x6c
&& body[1] == (byte) 0xfe
&& body[2] == (byte) 0xff
&& body[3] == (byte) 0xff) {
// -404 as little endian, broken authorization
throw new RuntimeException();
} else {
throw new RuntimeException();
}
}
final BinaryReader reader = new BinaryReader(ByteBuffer.wrap(body));
final long keyId = reader.readLong();
if (keyId != authKey.keyId) {
throw new SecurityException("Server replied with an invalid auth key");
}
final byte[] msgKey = reader.read(16);
final Pair<byte[], byte[]> keyIv = calcKey(msgKey, false);
final byte[] plainText = AES.decryptIge(reader.read(), keyIv.getKey(), keyIv.getValue());
final byte[] ourKey = Arrays.copyOfRange(Utils.sha256digest(
Arrays.copyOfRange(authKey.key, 96, 128), plainText), 8, 24);
if (!Arrays.equals(msgKey, ourKey)) {
throw new SecurityException("Received message key doesn't match with expected one");
}
final BinaryReader tlReader = new BinaryReader(ByteBuffer.wrap(plainText));
tlReader.readLong(); // remote salt
if (tlReader.readLong() != id) {
throw new SecurityException("Server replied with a wrong session ID");
}
final long remoteMsgId = tlReader.readLong();
final int remoteSeq = tlReader.readInt();
tlReader.readInt(); // inner message length
final TLObject object = tlReader.readTl();
return new TLMessage(remoteMsgId, remoteSeq, object);
}
public long getNewMsgId() {
final long now = System.currentTimeMillis();
long newMsgId = (((now / 1000) + timeOffset) << 32) | ((now % 1000) << 2);
if (lastMsgId >= newMsgId) {
newMsgId = lastMsgId + 4;
}
lastMsgId = newMsgId;
return newMsgId;
}
public void updateTimeOffset(long correctMsgId) {
final long now = System.currentTimeMillis() / 1000L;
final long correct = correctMsgId >> 32;
timeOffset = correct - now;
lastMsgId = 0;
}
public int getSeqNo(final boolean contentRelated) {
if (contentRelated) {
return 1 + 2 * sequence++;
} else {
return 2 * sequence;
}
}
}