How to use structured concurrency correctly with blocking IO code?

I try to use Kotlin for a P2P application and started with the basic structure for a network node with a server and a connection class for handling the sending and receiving of messages.

I need to use blocking IO as my use case it to use Tor and non-blocking IO does not support SocksProxy.

I tried to stick with the concepts of structured concurrency as far as I understand it (being a Kotlin beginner) but the only way I get my code running without getting blocked is by using CoroutineScope(Dispatchers.IO).launch {...}.

As far I understand my current version would require manual resource management, as if the parent scope terminates my standalone scopes would not get canceled.

Beside the usage in the blocking IO functions I also did not find a way how to use the existing scope for the channel receivers. Also here it only works by using a standalone scope.
I guess the problem here is that the sender at that channel is the blocking IO scope already running in a standalone scope…

Hope to deal with such a setup in the right way?
Any help/feedback very appreciated! Thanks in advance!!

My code is posted here:

https://play.kotlinlang.org/#{"version":"1.3.11","platform":"java","args":"","noneMarkers":true,"theme":"idea","foldedButton":true,"readOnly":false,"code":"import kotlinx.coroutines.*\nimport kotlinx.coroutines.channels.Channel\nimport kotlinx.coroutines.channels.consumeEach\nimport java.io.ObjectInputStream\nimport java.io.ObjectOutputStream\nimport java.net.ServerSocket\nimport java.net.Socket\n\nfun main() {\n    runBlocking {\n        val node1 = Node(SocketFactory())\n        node1.startServer(1111)\n\n        val node2 = Node(SocketFactory())\n        node2.startServer(2222)\n\n        val address = Address.localhost(1111)\n        node2.connect(address)\n        node2.send(address, \"Message from localhost:2222 to localhost:1111\")\n\n        node1.shutdown()\n        node2.shutdown()\n    }\n}\n\nclass Node(val socketFactory: SocketFactory) {\n    private val inboundConnections: MutableList<Connection> = mutableListOf()\n    private val outboundConnections: HashMap<Address, Connection> = hashMapOf()\n    private var server: Server? = null\n\n    suspend fun startServer(serverPort: Int) {\n        coroutineScope {\n            // It works when using a standalone scope with a single thread executor, but I doubt that this is the \n            // right way to doo it... but using the scope of the caller we get blocked...\n            val socketHandler = Channel<Socket>()\n            CoroutineScope(Dispatchers.Default).launch {\n                socketHandler.consumeEach { socket ->\n                    println(\"Received ${socket} at handler\")\n                    val connection = getConnectionAndStartListen(socket)\n                    inboundConnections.add(connection)\n                }\n            }\n\n            val serverSocket = socketFactory.createServerSocket(serverPort)\n            server = Server(serverSocket, socketHandler, CoroutineScope(Dispatchers.IO))\n                .also { it.listen() }\n        }\n    }\n\n    suspend fun send(address: Address, message: String) {\n        val outboundConnection = outboundConnections.getOrElse(address) { connect(address) }\n        outboundConnection.send(message)\n    }\n\n    suspend fun connect(address: Address) =\n        coroutineScope {\n            val socket = socketFactory.createSocket(address)\n            getConnectionAndStartListen(socket).also { outboundConnections[address] = it }\n        }\n\n    private suspend fun getConnectionAndStartListen(socket: Socket): Connection = coroutineScope {\n        // It works when using a standalone scope with a single thread executor, but I doubt that this is the \n        // right way to do it... but using the scope of the caller we get blocked...\n        val messageHandler = Channel<String>()\n        CoroutineScope(Dispatchers.Default).launch {\n            messageHandler.consumeEach { message ->\n                println(\"Received '$message'\")\n            }\n        }\n\n        // As we need to use blocking IO we create a standalone scope with a single thread executor\n        Connection(socket, messageHandler, CoroutineScope(Dispatchers.IO)).also { it.startListen() }\n    }\n\n    fun shutdown() {\n        server?.shutdown()\n        inboundConnections.forEach({ it.shutdown() })\n        outboundConnections.values.forEach({ it.shutdown() })\n    }\n}\n\nclass Server(\n    private val serverSocket: ServerSocket,\n    private val socketHandler: Channel<Socket>,\n    scope: CoroutineScope,\n) : CoroutineScope by scope {\n    fun listen() {\n        launch {\n            runCatching {\n                while (true) {\n                    println(\"Server listening on port ${serverSocket.localPort} for new connections\")\n                    val socket = serverSocket.accept()\n                    println(\"Server accept inbound connection on port ${serverSocket.localPort}\")\n                    socketHandler.send(socket)\n                }\n            }\n        }\n    }\n\n    fun shutdown() {\n        runCatching { serverSocket.close() }\n        socketHandler.close()\n    }\n}\n\nclass Connection(\n    private val socket: Socket,\n    private val messageHandler: Channel<String>,\n    scope: CoroutineScope,\n) : CoroutineScope by scope {\n    private val objectOutputStream: ObjectOutputStream = ObjectOutputStream(socket.getOutputStream())\n    private val objectInputStream: ObjectInputStream = ObjectInputStream(socket.getInputStream())\n\n    fun startListen() {\n        launch {\n            runCatching {\n                while (true) {\n                    println(\"Listening for messages\")\n                    val msg = objectInputStream.readObject()\n                    if (msg is String) {\n                        println(\"Message arrived '$msg'\")\n                        messageHandler.send(msg)\n                    }\n                }\n            }\n        }\n    }\n\n    suspend fun send(message: String) =\n        withContext(Dispatchers.IO) {\n            objectOutputStream.writeObject(message)\n            objectOutputStream.flush()\n        }\n\n    fun shutdown() {\n        runCatching { socket.close() }\n        messageHandler.close()\n    }\n}\n\n// We need to use blocking IO API as we use Tor and non-blocking IO API does not support SocksProxy\n// Here we use clear net, but in real app it will run over Tor and getting the ServerSocket and Socket takes\n// significant time\nclass SocketFactory() {\n    suspend fun createServerSocket(serverPort: Int) =\n        withContext(Dispatchers.IO) {\n            println(\"Create serverSocket with port $serverPort\")\n            ServerSocket(serverPort)\n        }\n\n    suspend fun createSocket(address: Address) =\n        withContext(Dispatchers.IO) {\n            println(\"Create socket to address $address\")\n            Socket(address.host, address.port)\n        }\n}\n\ndata class Address(val host: String, val port: Int) {\n    companion object {\n        fun localhost(port: Int) = Address(\"127.0.0.1\", port)\n    }\n\n    override fun toString(): String = host + \":\" + port\n}","to":"90","from":"81"}

That’s quite a lot of code to digest. If I understand your main concern correctly, then it is not really related to blocking IO. You seem to try to design your code in the way that suspend functions don’t… suspend.

I believe it is (almost?) always a bad idea to start background tasks from suspend functions. Suspend functions are marked as “suspend” for a reason. They tell their callers: “I will wait for whatever is needed to be done”. Scheduling asynchronous tasks may be confusing to the caller.

So my suggestion is to let suspend functions suspend. Then, the caller could choose whether invoke some function synchronously (by invoking it directly) or asynchronously (with launch()/async()).

So:

  1. Remove all these CoroutineScope(Dispatchers.Default).launch { } and invoke their body directly - as a result, your functions will suspend.
  2. Make Server.listen(), Connection.startListen() also synchronous, so make them suspendable and remove launch() / replace with withContext().
  3. Whenever you need to start some kind of a long running service, but you don’t want to wait for it, start it inside launch().

For example, your Node.startServer() will become something like this:

suspend fun startServer(serverPort: Int) {
    coroutineScope {
        val socketHandler = Channel<Socket>()
        val serverSocket = socketFactory.createServerSocket(serverPort)
        server = Server(serverSocket, socketHandler, CoroutineScope(Dispatchers.IO))

        launch { server.listen() }
        
        socketHandler.consumeEach { socket ->
            println("Received ${socket} at handler")
            val connection = getConnectionAndStartListen(socket)
            inboundConnections.add(connection)
        }
            
    }
}

It starts two long-running operations: listen() and consumeEach() and waits for both of them. Then use startServer() like this:

launch { node1.startServer(1111) }

As I said above, this way the caller is in control of how to invoke a long running operation. The code is synchronous by default, but you can make it asynchronous where it is needed. In your original code you made almost everything asynchronous, because you were concerned about “blocking”. If you prefer asynchronous code then you don’t really need coroutines :slight_smile:

After such redesign you have a proper structured concurrency, because all coroutines are descendants of the coroutine started by runBlocking().

Also, note that blocking IO operations are not really cancellable. You need to make sure that you close all resources, because otherwise coroutines may freeze waiting in a blocking IO even after cancellation.

Great thanks. I managed to resolved most cases already similar to what you suggested. Only the connection handler I am still making troubles… I don’t want that the outside caller has to know about the internals of Node. So when a new connection gets created from an incoming connection I need to setup the async listeners but as that code is inside the socket listener its a bit more complicated as the other cases.

Damn my reply got deleted from the spam filter…

Hope this one will make it:

Thanks @broot for your explainations! Got more clear now.

Here is the fixed code:
https://play.kotlinlang.org/#{"version":"1.3.11","platform":"java","args":"","noneMarkers":true,"theme":"idea","foldedButton":true,"readOnly":false,"code":"import kotlinx.coroutines.*\nimport kotlinx.coroutines.channels.Channel\nimport kotlinx.coroutines.channels.consumeEach\nimport java.io.ObjectInputStream\nimport java.io.ObjectOutputStream\nimport java.net.ServerSocket\nimport java.net.Socket\n\nfun main() {\n    runBlocking {\n        val node1 = Node(SocketFactory())\n        node1.startServer(1111)\n        launch { node1.setupBackgroundTasks() }\n\n        node1.addListener(Channel<String>()\n            .also {\n                launch {\n                    it.consumeEach { message ->\n                        println(\"MessageListener received '$message'\")\n                    }\n                }\n            })\n\n        val node2 = Node(SocketFactory())\n        node2.startServer(2222)\n        launch { node2.setupBackgroundTasks() }\n\n        val address = Address.localhost(1111)\n        val connection = node2.connect(address)\n        launch { node2.setupConnectionHandler(connection) }\n        node2.send(address, \"Message from localhost:2222 to localhost:1111\")\n\n        node1.shutdown()\n        node2.shutdown()\n    }\n}\n\nclass Node(val socketFactory: SocketFactory) {\n    private var messageListeners = mutableListOf<Channel<String>>()\n    private val inboundConnections: MutableList<Connection> = mutableListOf()\n    private val outboundConnections: HashMap<Address, Connection> = hashMapOf()\n    private var server: Server? = null\n    private val messageHandler = Channel<String>()\n    private val socketHandler = Channel<Socket>()\n\n    suspend fun startServer(serverPort: Int) {\n        val serverSocket = socketFactory.createServerSocket(serverPort)\n        server = Server(serverSocket, socketHandler)\n    }\n\n    suspend fun setupBackgroundTasks() {\n        coroutineScope {\n            launch {\n                withContext(Dispatchers.Default) {\n                    socketHandler.consumeEach { socket ->\n                        println(\"Received ${socket} at handler\")\n                        val connection = Connection(socket, messageHandler).also { inboundConnections.add(it) }\n                        withContext(Dispatchers.IO) {\n                            launch { connection.listen() }\n                        }\n                    }\n                }\n            }\n            launch {\n                withContext(Dispatchers.Default) {\n                    messageHandler.consumeEach { message ->\n                        println(\"Received '$message'\")\n                        messageListeners.forEach { it.send(message) }\n                    }\n                }\n            }\n            launch {\n                server?.listen()\n            }\n        }\n    }\n\n    suspend fun connect(address: Address): Connection {\n        val socket = socketFactory.createSocket(address)\n        return Connection(socket, messageHandler).also { outboundConnections[address] = it }\n    }\n\n    suspend fun setupConnectionHandler(connection: Connection) {\n        coroutineScope {\n            launch {\n                connection.listen()\n            }\n        }\n    }\n\n    suspend fun send(address: Address, message: String) = coroutineScope {\n        outboundConnections.getOrElse(address) { connect(address) }.also { it.send(message) }\n    }\n\n    fun addListener(listener: Channel<String>) {\n        messageListeners.add(listener)\n    }\n\n    fun shutdown() {\n        server?.shutdown()\n        inboundConnections.forEach({ it.shutdown() })\n        outboundConnections.values.forEach({ it.shutdown() })\n        messageListeners.forEach({ it.close() })\n    }\n}\n\nclass Server(private val serverSocket: ServerSocket, private val socketHandler: Channel<Socket>) {\n    suspend fun listen() = coroutineScope {\n        withContext(Dispatchers.IO) {\n            runCatching {\n                while (true) {\n                    println(\"Server listening on port ${serverSocket.localPort} for new connections\")\n                    val socket = serverSocket.accept()\n                    println(\"Server accept inbound connection on port ${serverSocket.localPort}\")\n                    socketHandler.send(socket)\n                }\n            }\n        }\n    }\n\n    fun shutdown() {\n        runCatching { serverSocket.close() }\n        socketHandler.close()\n    }\n}\n\nclass Connection(private val socket: Socket, private val messageHandler: Channel<String>) {\n    private val objectOutputStream: ObjectOutputStream = ObjectOutputStream(socket.getOutputStream())\n    private val objectInputStream: ObjectInputStream = ObjectInputStream(socket.getInputStream())\n\n    suspend fun listen() = coroutineScope {\n        withContext(Dispatchers.IO) {\n            runCatching {\n                while (true) {\n                    println(\"Listening for messages\")\n                    val msg = objectInputStream.readObject()\n                    if (msg is String) {\n                        println(\"Message arrived '$msg'\")\n                        messageHandler.send(msg)\n                    }\n                }\n            }\n        }\n    }\n\n    suspend fun send(message: String) =\n        withContext(Dispatchers.IO) {\n            println(\"Send message $message\")\n            objectOutputStream.writeObject(message)\n            objectOutputStream.flush()\n        }\n\n    fun shutdown() {\n        runCatching { socket.close() }\n        messageHandler.close()\n    }\n}\n\nclass SocketFactory() {\n    suspend fun createServerSocket(serverPort: Int) =\n        withContext(Dispatchers.IO) {\n            println(\"Create serverSocket with port $serverPort\")\n            ServerSocket(serverPort)\n        }\n\n    suspend fun createSocket(address: Address) =\n        withContext(Dispatchers.IO) {\n            println(\"Create socket to address $address\")\n            Socket(address.host, address.port)\n        }\n}\n\ndata class Address(val host: String, val port: Int) {\n    companion object {\n        fun localhost(port: Int) = Address(\"127.0.0.1\", port)\n    }\n\n    override fun toString(): String = host + \":\" + port\n}","to":"90","from":"81"}