Race condition when using for (item in channel()) if channel() slow to return

I’m getting a race condition in my code. This test roughly replicates the issue in my code. The race condition occurs about 1 in 20 time in my real code. The time it takes the channel to become ready varies (e.g. in the code below 19 out of 20 times the getNumberChannel delay would be 5 but 1 out of 20 times it is 15).

package com.example.coroutinerace

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.produce
import kotlinx.coroutines.test.TestCoroutineDispatcher
import kotlinx.coroutines.test.TestCoroutineScope
import org.junit.Assert.assertEquals
import org.junit.Test

@ExperimentalCoroutinesApi
class ReceiveChannelTest {

private val dispatcher = TestCoroutineDispatcher()
private val testCoroutineScope = TestCoroutineScope(dispatcher)
private val receivedInts = mutableListOf<Int>()
private val libraryCodeCannotBeChanged = LibraryCodeCannotBeChanged(testCoroutineScope)

@Test
fun `test receive all numbers`() {
    testCoroutineScope.launch(Dispatchers.IO) {
        for (number in libraryCodeCannotBeChanged.getNumberChannel()) {
            receivedInts.add(number)
        }
    }

    libraryCodeCannotBeChanged.triggerNumbersTicking()

    runBlocking {
        delay(50)
        assertEquals(listOf(1, 2, 3), receivedInts.toList())
    }
}

private class LibraryCodeCannotBeChanged(val coroutineScope: CoroutineScope) {

    private var receiveChannel: ReceiveChannel<Int> = Channel()

    fun getNumberChannel(): ReceiveChannel<Int> {
        runBlocking {
            delay(15)
        }
        return receiveChannel
    }

    fun triggerNumbersTicking() {
        receiveChannel = coroutineScope.produce(Dispatchers.Default, Channel.CONFLATED) {
            (1..3).forEach {
                send(it)
                println(it)
                delay(10)
            }
        }
    }
}
}

The real code uses Tinder’s Scarlet library for WebSockets on Android. The library converts RxJava to a SubscriptionChannel using .openSubscription(). getNumberChannel() is equivalent to a method returning an @Receive ReceiveChannel from an interface built with com.tinder.scarlet:stream-adapter-coroutines. The triggerNumbersTicking is really lifecycleRegistry.onNext(Lifecycle.State.Started).

One way the test as written can be made to pass is to delay the triggerNumbersTicking call:

    testCoroutineScope.launch(Dispatchers.IO) {
        delay(10)
        libraryCodeCannotBeChanged.triggerNumbersTicking()
    }

The delays aren’t predictable so this would likely mean me setting a long delay and still not being certain that it won’t sometimes fail. Is there a way to be sure that the for loop is ready and listening? Or another way to set up the test code to make the test pass without altering the code within the LibraryCodeCannotBeChanged class?

One way to solve this would be to do something like

testCoroutineScope.launch(Dispatchers.IO) {
    val channel =  libraryCodeCannotBeChanged.getNumberChannel()
    libraryCodeCannotBeChanged.triggerNumbersTicking()
    for (number in channel) {
        receivedInts.add(number)
    }
}

This is basically the same as your solution, but it uses the same delay as getNumberChannel and therefor won’t have any problems even if this delay would be unexpectedly long.

The other problem I see is in your runBlocking block. If the delay in the getNumberChannel get’s to long you might only see the first few values returned. Maybe something like this

fun `test recieve all numbers`() = runBlocking {
    val channel =  libraryCodeCannotBeChanged.getNumberChannel()
    libraryCodeCannotBeChanged.triggerNumbersTicking()

    testCoroutineScope.launch(Dispatchers.IO) {
        for (number in channel) {
            receivedInts.add(number)
        }
    }

    // this delay is now only required to wait for the numbers to be sent
    // it does not need to wait for the channel to be created
    delay(50)  
    assertEquals(listOf(1, 2, 3), receivedInts.toList())
}
2 Likes

Thanks, that works. I did need to change the private class as well as it was not working as what object the receiveChannel was pointing too was changing internally. Using the BroadcastChannel didn’t allow using the for loop in the example. In the real version, the equivalent change fixes the issue with no need to change the for loop.

package com.example.coroutinerace

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.test.TestCoroutineDispatcher
import kotlinx.coroutines.test.TestCoroutineScope
import org.junit.Assert.assertEquals
import org.junit.Test

@ExperimentalCoroutinesApi
class ReceiveChannelTest {

    private val dispatcher = TestCoroutineDispatcher()
    private val testCoroutineScope = TestCoroutineScope(dispatcher)
    private val receivedInts = mutableListOf<Int>()
    private val libraryCodeCannotBeChanged = LibraryCodeCannotBeChanged(testCoroutineScope)

    @Test
    fun `test receive all numbers`() {
        testCoroutineScope.launch (Dispatchers.IO) {
            val numChannel = libraryCodeCannotBeChanged.getNumberChannel()

            testCoroutineScope.launch (Dispatchers.IO) {
                numChannel.consumeEach {
                    receivedInts.add(it)
                }
            }

            libraryCodeCannotBeChanged.triggerNumbersTicking()
        }

        runBlocking {
            delay(250)
            assertEquals(listOf(1, 2, 3), receivedInts.toList())
        }
    }

    private class LibraryCodeCannotBeChanged(val coroutineScope: CoroutineScope) {

        private var receiveChannel: BroadcastChannel<Int> = BroadcastChannel(Channel.CONFLATED)

        fun getNumberChannel(): BroadcastChannel<Int> {
            runBlocking {
                delay(15)
            }
            return receiveChannel
        }

        fun triggerNumbersTicking() {
            coroutineScope.launch(Dispatchers.Default) {
                (1..3).forEach {
                    receiveChannel.send(it)
                    println(it)
                    delay(10)
                }
            }
        }
    }
}