001/*
002 * Copyright (C) 2026 Blackilykat and contributors
003 *
004 * This program is free software: you can redistribute it and/or modify
005 * it under the terms of the GNU General Public License as published by
006 * the Free Software Foundation, either version 3 of the License, or
007 * (at your option) any later version.
008 *
009 * This program is distributed in the hope that it will be useful,
010 * but WITHOUT ANY WARRANTY; without even the implied warranty of
011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
012 * GNU General Public License for more details.
013 *
014 * You should have received a copy of the GNU General Public License
015 * along with this program. If not, see <https://www.gnu.org/licenses/>.
016 */
017
018package dev.blackilykat.pmp;
019
020import com.fasterxml.jackson.annotation.JsonCreator;
021import com.fasterxml.jackson.core.JsonProcessingException;
022import com.fasterxml.jackson.databind.ObjectMapper;
023import com.fasterxml.jackson.module.paramnames.ParameterNamesModule;
024import dev.blackilykat.pmp.event.EventSource;
025import dev.blackilykat.pmp.event.RetroactiveEventSource;
026import dev.blackilykat.pmp.messages.DisconnectMessage;
027import dev.blackilykat.pmp.messages.Message;
028import dev.blackilykat.pmp.messages.Request;
029import dev.blackilykat.pmp.messages.Response;
030import org.apache.logging.log4j.LogManager;
031import org.apache.logging.log4j.Logger;
032
033import javax.net.ssl.SSLSocket;
034import java.io.IOException;
035import java.io.InputStream;
036import java.io.OutputStream;
037import java.net.Socket;
038import java.nio.charset.StandardCharsets;
039import java.util.ArrayDeque;
040import java.util.HashMap;
041import java.util.LinkedList;
042import java.util.List;
043import java.util.Map;
044import java.util.Queue;
045import java.util.Timer;
046import java.util.TimerTask;
047import java.util.concurrent.BlockingQueue;
048import java.util.concurrent.LinkedBlockingQueue;
049import java.util.concurrent.atomic.AtomicBoolean;
050
051/// Defines a connection which uses PMP's protocol.
052///
053/// The PMP protocol uses a TCP connection which should, but does not have to, be encrypted with SSL.
054///
055/// The connection is established once both sides send 4 bytes: "PMP\n".
056/// This is referred to as the PMP signature.
057/// This allows the server to recognize when it is accepting non-PMP clients and to terminate those connections.
058///
059/// Messages are sent serialized using JSON.
060/// The type of each message is defined by the property "messageType".
061/// Messages are separated by newlines.
062///
063/// Every {@value #KEEPALIVE_MS} milliseconds, an extra newline should be sent.
064/// This is equivalent to an empty line and will be treated as a keepalive.
065/// If a keepalive is not received within {@value #KEEPALIVE_MAX_MS} milliseconds,
066/// any side can assume the connection has silently dropped and terminate it.
067///
068/// @see dev.blackilykat.pmp.server.Encryption
069public class PMPConnection {
070        /// Emitted when any message is received on any client and allows to cancel it before it gets handled.
071        ///
072        /// @see MessageListener
073        public static final EventSource<ReceivingMessageEvent> EVENT_RECEIVING_MESSAGE = new EventSource<>();
074        /// Event emitted when any connection has disconnected. Contains the terminated connection as its data.
075        public static final EventSource<PMPConnection> EVENT_DISCONNECTED = new EventSource<>();
076
077        /// The default port used for transferring messages.
078        public static final int DEFAULT_MESSAGE_PORT = 6803;
079        /// The default port used for transferring files through HTTP.
080        ///
081        /// @see dev.blackilykat.pmp.server.TransferHandler
082        public static final int DEFAULT_FILE_PORT = 6804;
083
084        /// The amount of milliseconds between sending keepalives.
085        private static final int KEEPALIVE_MS = 10_000;
086        /// The amount of milliseconds since the last keepalive after which a connection can be considered dropped.
087        private static final int KEEPALIVE_MAX_MS = 30_000;
088        private static final Logger LOGGER = LogManager.getLogger(PMPConnection.class);
089        private static final ObjectMapper mapper = new ObjectMapper();
090        /// The underlying TCP socket of this connection.
091        public final Socket socket;
092        /// The name of this connection, used to differentiate logging on the server side.
093        public final String name;
094        /// Event emitted once this connection has been confirmed by receiving the other side's PMP signature.
095        public final RetroactiveEventSource<Void> eventConnected = new RetroactiveEventSource<>();
096        /// Event emitted once this connection has been terminated for any reason.
097        ///
098        /// After this is emitted, [#EVENT_DISCONNECTED] is always also emitted with this as its content.
099        public final RetroactiveEventSource<Void> eventDisconnected = new RetroactiveEventSource<>();
100        /// The thread which reads incoming messages and calls their listeners and handlers.
101        private final MessageReceivingThread messageReceivingThread;
102        /// The input stream the raw incoming serialized messages get read from.
103        private final InputStream inputStream;
104        /// The thread which serializes and writes outgoing messages.
105        ///
106        /// This is not the only thread which is allowed to write to [#outputStream].
107        /// Both [#keepaliveTimer] and any thread initiating a disconnect will also do so.
108        /// [#outputStreamLock] is used to keep exclusive access to the output stream.
109        private final MessageSendingThread messageSendingThread;
110        /// The output stream the raw serialized messages get written to.
111        ///
112        /// Any piece of code attempting to write to this stream should be enclosed in a synchronized block with [#outputStreamLock].
113        private final OutputStream outputStream;
114        /// The object used to keep exclusive access to [#outputStream].
115        private final Object outputStreamLock = new Object();
116        /// The queue of non-serialized messages to be sent.
117        /// Unless the message is urgent, it will be placed here for [#messageSendingThread] to take, serialize and send over the network.
118        private final BlockingQueue<Message> messageQueue = new LinkedBlockingQueue<>();
119        /// The timer used to send keepalive messages.
120        private final Timer keepaliveTimer;
121        /// All listeners registered for this connection.
122        private final List<MessageListener<?>> listeners = new LinkedList<>();
123        /// All requests and their ids which are still pending a final response.
124        ///
125        /// @see Response#isLastResponse()
126        private final Map<Integer, Request> pendingRequests = new HashMap<>();
127        /// Whether the connection has been confirmed by receiving the PMP signature from the other side.
128        public Boolean connected = false;
129        /// Unix timestamp of the last keepalive.
130        private long lastKeepalive;
131
132        /// Initiate a PMP connection:
133        /// - Sets required fields
134        /// - Writes the PMP signature
135        /// - Starts the [#messageReceivingThread]
136        /// - Starts the [#messageSendingThread]
137        /// - Schedules sending keepalives and checking the other side's keepalive timeout
138        public PMPConnection(Socket socket, String name) throws IOException {
139                if(!(socket instanceof SSLSocket)) {
140                        LOGGER.warn("PMP Connection with insecure socket");
141                }
142
143                this.socket = socket;
144                this.name = name;
145
146                this.inputStream = socket.getInputStream();
147                this.outputStream = socket.getOutputStream();
148                this.outputStream.write(new byte[]{'P', 'M', 'P', '\n'});
149
150                messageReceivingThread = new MessageReceivingThread();
151                messageReceivingThread.start();
152                messageSendingThread = new MessageSendingThread();
153                messageSendingThread.start();
154
155                keepaliveTimer = new Timer("Keepalive timer for " + name);
156                lastKeepalive = System.currentTimeMillis();
157                keepaliveTimer.schedule(new TimerTask() {
158                        @Override
159                        public void run() {
160                                // Keepalive timeout doesn't need to be exact so it's fine to send and check at the same time
161
162                                if(System.currentTimeMillis() - lastKeepalive > KEEPALIVE_MAX_MS) {
163                                        disconnect("Keepalive timeout");
164                                } else {
165                                        try {
166                                                sendKeepalive();
167                                        } catch(IOException e) {
168                                                disconnect("Failed to send keepalive");
169                                        }
170                                }
171                        }
172                }, KEEPALIVE_MS, KEEPALIVE_MS);
173        }
174
175        /// Adds a message to the message queue
176        public void send(Message message) {
177                if(message instanceof Request request) {
178                        request.setConnection(PMPConnection.this);
179                }
180                messageQueue.add(message);
181        }
182
183        /// Sends a message ignoring the message queue and writing to the socket on this thread.
184        /// Does not assign a request ID.
185        private void sendNow(Message message) throws IOException {
186                synchronized(outputStreamLock) {
187                        LOGGER.info("Sending message to {}: {}", name, mapper.writeValueAsString(message.withRedactedInfo()));
188
189                        outputStream.write((mapper.writeValueAsString(message) + '\n').getBytes(StandardCharsets.UTF_8));
190                }
191        }
192
193        /// Sends a keepalive on this thread.
194        ///
195        /// A keepalive is an extra newline, practically an "empty line" instead of containing a message.
196        private void sendKeepalive() throws IOException {
197                synchronized(outputStreamLock) {
198                        outputStream.write((int) '\n');
199                }
200        }
201
202        /// Terminate this connection.
203        ///
204        /// @param reason Human readable reason for why the connection was terminated, for logging.
205        public void disconnect(String reason) {
206                LOGGER.warn("Disconnecting {}: {}", name, reason);
207                _disconnect();
208        }
209
210        /// Disconnect on the message sending thread as soon as all currently queued messages are sent.
211        /// Useful to disconnect without performing network operations on the current thread.
212        ///
213        /// @param reason Human readable reason for why the connection was terminated, for logging.
214        public void disconnectSoon(String reason) {
215                LOGGER.warn("Disconnecting soon {}: {}", name, reason);
216                send(new DisconnectMessage());
217        }
218
219        /// Internal disconnect method to terminate the connection without logging.
220        ///
221        /// Attempts to send a disconnect message on this thread, closes the socket and terminates all threads.
222        private void _disconnect() {
223                boolean wasConnected = connected;
224                if(connected) {
225                        connected = false;
226                        try {
227                                sendNow(new DisconnectMessage());
228                        } catch(IOException ignored) {
229                        }
230                }
231                try {
232                        socket.close();
233                } catch(IOException ignored) {
234                }
235
236                messageReceivingThread.interrupt();
237                if(!messageSendingThread.equals(Thread.currentThread())) {
238                        messageSendingThread.interrupt();
239                }
240                keepaliveTimer.cancel();
241
242                if(wasConnected) {
243                        eventDisconnected.call(null);
244                        EVENT_DISCONNECTED.call(this);
245                }
246        }
247
248        /// Register a message listener for this connection.
249        ///
250        /// @see #unregisterListener(MessageListener)
251        /// @see MessageListener
252        public void registerListener(MessageListener<?> listener) {
253                if(listeners.contains(listener)) {
254                        throw new IllegalStateException("Listener already registered");
255                }
256                listeners.add(listener);
257        }
258
259        /// Unregister a message listener for this connection.
260        ///
261        /// @see #unregisterListener(MessageListener)
262        /// @see MessageListener
263        public void unregisterListener(MessageListener<?> listener) {
264                if(!listeners.remove(listener)) {
265                        throw new IllegalStateException("Listener wasn't registered");
266                }
267        }
268
269        /// The thread which serializes and writes outgoing messages.
270        private class MessageSendingThread extends Thread {
271                public MessageSendingThread() {
272                        super("Message sending thread for " + name);
273                }
274
275                @Override
276                public void run() {
277                        try {
278                                while(!Thread.interrupted()) {
279                                        Message message = messageQueue.take();
280
281                                        if(message instanceof DisconnectMessage) {
282                                                _disconnect();
283                                                return;
284                                        }
285
286                                        if(message instanceof Request request) {
287                                                if(request.requestId == null) {
288                                                        request.assignId();
289                                                }
290
291                                                assert !pendingRequests.containsKey(request.requestId);
292                                                pendingRequests.put(request.requestId, request);
293                                        }
294
295                                        sendNow(message);
296                                }
297                        } catch(IOException e) {
298                                if(!connected) {
299                                        return;
300                                }
301                                LOGGER.error("IO exception in message sending thread", e);
302                        } catch(InterruptedException ignored) {
303                        } catch(Exception e) {
304                                LOGGER.error("Unknown exception in message sending thread", e);
305                        } finally {
306                                if(connected) {
307                                        disconnect("Message sending thread terminated");
308                                }
309                        }
310                }
311        }
312
313        /// The thread which reads incoming messages and calls their listeners and handlers.
314        private class MessageReceivingThread extends Thread {
315                public MessageReceivingThread() {
316                        super("Message receiving thread for " + name);
317                }
318
319                @Override
320                public void run() {
321                        Queue<Byte> inputBuffer = new ArrayDeque<>();
322                        try {
323                                int read;
324                                while(!Thread.interrupted()) {
325                                        read = inputStream.read();
326                                        if(read == -1) {
327                                                break;
328                                        }
329                                        if(read != ((int) '\n')) {
330                                                inputBuffer.add((byte) read);
331                                        } else if(inputBuffer.isEmpty()) {
332                                                lastKeepalive = System.currentTimeMillis();
333                                        } else {
334                                                byte[] msg = new byte[inputBuffer.size()];
335                                                for(int i = 0; i < msg.length; i++) {
336                                                        //noinspection DataFlowIssue
337                                                        msg[i] = inputBuffer.poll();
338                                                }
339                                                String messageString = new String(msg, StandardCharsets.UTF_8);
340                                                if(messageString.equals("PMP")) {
341                                                        LOGGER.info("Received PMP signature from {}", name);
342                                                        connected = true;
343                                                        eventConnected.call(null);
344                                                        continue;
345                                                }
346                                                if(!connected) {
347                                                        disconnect("Did not receive PMP signature");
348                                                        break;
349                                                }
350                                                try {
351                                                        Message message = mapper.readValue(messageString, Message.class);
352
353
354                                                        Message printedMessage = message.withRedactedInfo();
355                                                        if(printedMessage == message) {
356                                                                //noinspection LoggingSimilarMessage
357                                                                LOGGER.info("Received message from {}: {}", name, messageString);
358                                                        } else {
359                                                                LOGGER.info("Received message from {} (some hidden values): {}", name,
360                                                                                mapper.writeValueAsString(printedMessage));
361                                                        }
362
363                                                        ReceivingMessageEvent evt = new ReceivingMessageEvent(message, PMPConnection.this);
364                                                        EVENT_RECEIVING_MESSAGE.call(evt);
365                                                        if(evt.isCancelled()) {
366                                                                continue;
367                                                        }
368
369                                                        if(message instanceof Response response) {
370                                                                Request request = pendingRequests.get(response.requestId);
371                                                                if(request != null) {
372                                                                        request.addResponse(response);
373
374                                                                        if(response.isLastResponse()) {
375                                                                                pendingRequests.remove(response.requestId);
376                                                                        }
377                                                                }
378                                                        }
379
380                                                        AtomicBoolean cancelled = new AtomicBoolean(false);
381
382                                                        for(MessageListener<?> listener : listeners) {
383                                                                if(!listener.type.isInstance(message)) {
384                                                                        continue;
385                                                                }
386                                                                LOGGER.debug("Found listener for {}", listener.type.getSimpleName());
387                                                                try {
388                                                                        listener.runCasting(message, cancelled);
389                                                                } catch(Exception e) {
390                                                                        LOGGER.error("Exception in message listener", e);
391                                                                }
392                                                        }
393
394                                                        if(cancelled.get()) {
395                                                                LOGGER.info("A {} message was cancelled", message.getClass().getSimpleName());
396                                                                continue;
397                                                        }
398
399                                                        boolean foundHandler = false;
400
401                                                        for(MessageHandler<?> handler : MessageHandler.registeredHandlers) {
402                                                                if(!handler.type.isInstance(message)) {
403                                                                        continue;
404                                                                }
405                                                                if(foundHandler) {
406                                                                        LOGGER.warn("Multiple handlers for message type {}",
407                                                                                        message.getClass().getSimpleName());
408                                                                }
409                                                                foundHandler = true;
410                                                                try {
411                                                                        handler.runCasting(PMPConnection.this, message);
412                                                                } catch(Exception e) {
413                                                                        LOGGER.error("Exception in message listener", e);
414                                                                }
415                                                        }
416
417                                                        // responses can have no handler but be handled through Request#takeResponse
418                                                        if(!foundHandler && !(message instanceof Response)) {
419                                                                LOGGER.warn("Unhandled message type {}", message.getClass().getSimpleName());
420                                                        }
421                                                } catch(JsonProcessingException e) {
422                                                        LOGGER.error("Invalid message format: {} (original message: '{}')", e.getMessage(),
423                                                                        inputBuffer.toString());
424                                                }
425                                        }
426                                }
427                        } catch(IOException e) {
428                                if(!connected) {
429                                        return;
430                                }
431                                LOGGER.error("IO exception in message receiving thread", e);
432                        } catch(Exception e) {
433                                LOGGER.error("Unknown exception in message receiving thread", e);
434                        } finally {
435                                if(connected) {
436                                        disconnect("Message receiving thread terminated");
437                                }
438                        }
439                }
440        }
441
442        /// Data for [PMPConnection#EVENT_RECEIVING_MESSAGE]
443        public static class ReceivingMessageEvent {
444                public final PMPConnection connection;
445                public final Message message;
446                private boolean cancelled;
447
448                public ReceivingMessageEvent(Message message, PMPConnection connection) {
449                        this.message = message;
450                        this.connection = connection;
451                }
452
453                public boolean isCancelled() {
454                        return cancelled;
455                }
456
457                public void cancel() {
458                        cancelled = true;
459                        LOGGER.info("Cancelled incoming message at {}", new Throwable().getStackTrace()[1]);
460                }
461        }
462
463        static {
464                mapper.registerModule(new ParameterNamesModule(JsonCreator.Mode.PROPERTIES));
465
466                MessageHandler.registeredHandlers.add(new MessageHandler<>(DisconnectMessage.class) {
467                        @Override
468                        public void run(PMPConnection connection, DisconnectMessage message) {
469                                connection.disconnect("Received disconnect message");
470                        }
471                });
472        }
473}