package io.github.lonamiwebs.overgram.network; import io.github.lonamiwebs.overgram.crypto.AES; import io.github.lonamiwebs.overgram.crypto.AuthKey; import io.github.lonamiwebs.overgram.tl.TLMessage; import io.github.lonamiwebs.overgram.tl.TLObject; import io.github.lonamiwebs.overgram.tl.TLRequest; import io.github.lonamiwebs.overgram.utils.BinaryReader; import io.github.lonamiwebs.overgram.utils.BinaryWriter; import io.github.lonamiwebs.overgram.utils.Utils; import javafx.util.Pair; import java.nio.ByteBuffer; import java.security.SecureRandom; import java.util.Arrays; public class MTProtoState { public AuthKey authKey; public long salt; private final long id; private long timeOffset; private int sequence; private long lastMsgId; public MTProtoState(final AuthKey authKey) { this.authKey = authKey; salt = 0; id = new SecureRandom().nextLong(); timeOffset = 0; sequence = 0; lastMsgId = 0; } public TLMessage createMessage(final TLObject object) { return new TLMessage(getNewMsgId(), getSeqNo(object instanceof TLRequest), object); } private Pair calcKey(final byte[] msgKey, final boolean client) { final int x = client ? 0 : 8; final byte[] sha256a = Utils.sha256digest(msgKey, Arrays.copyOfRange(authKey.key, x, x + 36)); final byte[] sha256b = Utils.sha256digest(Arrays.copyOfRange(authKey.key, x + 40, x + 76), msgKey); final BinaryWriter writer = new BinaryWriter(32); writer.writeRaw(sha256a, 0, 8); writer.writeRaw(sha256b, 8, 16); writer.writeRaw(sha256a, 24, 8); final byte[] key = writer.toBytes(); writer.clear(); writer.writeRaw(sha256b, 0, 8); writer.writeRaw(sha256a, 8, 16); writer.writeRaw(sha256b, 24, 8); return new Pair<>(key, writer.toBytes()); } public byte[] packMessage(final TLMessage message) { final BinaryWriter writer = new BinaryWriter(); writer.write(salt); writer.write(id); writer.write(message); int padding = writer.size() % 16; writer.writeRaw(Utils.randomBytes(32 - padding)); final byte[] paddedData = writer.toBytes(); writer.clear(); final byte[] msgKeyLarge = Utils.sha256digest( Arrays.copyOfRange(authKey.key, 88, 88 + 32), paddedData); final byte[] msgKey = Arrays.copyOfRange(msgKeyLarge, 8, 24); final Pair keyIv = calcKey(msgKey, true); writer.write(authKey.keyId); writer.writeRaw(msgKey); writer.writeRaw(AES.encryptIge(paddedData, keyIv.getKey(), keyIv.getValue())); return writer.toBytes(); } public TLMessage unpackMessage(final byte[] body) throws ClassNotFoundException { if (body.length < 8) { if (body[0] == (byte) 0x6c && body[1] == (byte) 0xfe && body[2] == (byte) 0xff && body[3] == (byte) 0xff) { // -404 as little endian, broken authorization throw new RuntimeException(); } else { throw new RuntimeException(); } } final BinaryReader reader = new BinaryReader(ByteBuffer.wrap(body)); final long keyId = reader.readLong(); if (keyId != authKey.keyId) { throw new SecurityException("Server replied with an invalid auth key"); } final byte[] msgKey = reader.read(16); final Pair keyIv = calcKey(msgKey, false); final byte[] plainText = AES.decryptIge(reader.read(), keyIv.getKey(), keyIv.getValue()); final byte[] ourKey = Arrays.copyOfRange(Utils.sha256digest( Arrays.copyOfRange(authKey.key, 96, 128), plainText), 8, 24); if (!Arrays.equals(msgKey, ourKey)) { throw new SecurityException("Received message key doesn't match with expected one"); } final BinaryReader tlReader = new BinaryReader(ByteBuffer.wrap(plainText)); tlReader.readLong(); // remote salt if (tlReader.readLong() != id) { throw new SecurityException("Server replied with a wrong session ID"); } final long remoteMsgId = tlReader.readLong(); final int remoteSeq = tlReader.readInt(); tlReader.readInt(); // inner message length final TLObject object = tlReader.readTl(); return new TLMessage(remoteMsgId, remoteSeq, object); } public long getNewMsgId() { final long now = System.currentTimeMillis(); long newMsgId = (((now / 1000) + timeOffset) << 32) | ((now % 1000) << 2); if (lastMsgId >= newMsgId) { newMsgId = lastMsgId + 4; } lastMsgId = newMsgId; return newMsgId; } public void updateTimeOffset(long correctMsgId) { final long now = System.currentTimeMillis() / 1000L; final long correct = correctMsgId >> 32; timeOffset = correct - now; lastMsgId = 0; } public int getSeqNo(final boolean contentRelated) { if (contentRelated) { return 1 + 2 * sequence++; } else { return 2 * sequence; } } }