I’ve been working on a project for my company https://gitlab.com/paessler-labs/prtg-pyprobe and it’s using asyncio library for Python quite heavily. It’s a great framework and so far we’ve had very little issues implementing it but I recently implemented some new code that I wanted to mock for unit testing and it was a bit more challenging than I thought it would be.

Here is the code that I wanted to test

start = time.time()
            async with aioboto3.resource(
                "s3", aws_access_key_id=task_data["aws_access_key"], aws_secret_access_key=task_data["aws_secret_key"]
            ) as s3:
                all_buckets = s3.buckets.all()
                i = 0
                async for _ in all_buckets:
                    i += 1

                s3_bucket_data.add_channel(
                    name="Total Buckets", mode="integer", kind="Custom", customunit="buckets", value=i
                )

                s3_bucket_data.message = f"Your AWS account has {i} buckets."

                end = (time.time() - start) * 1000
                s3_bucket_data.add_channel(name="Total Query Time", mode="float", kind="TimeResponse", value=end)

And here is what the unit test ended up looking like

@pytest.mark.asyncio
class TestS3TotalWork:
    @asynctest.patch("aioboto3.resource")
    async def test_sensor_s3_total(self, aioboto_mock, s3_total_sensor):
        buckets = asynctest.MagicMock()
        buckets.__aiter__.return_value = ["bucket1", "bucket2", "bucket3"]

        aioboto_mock.return_value.__aenter__.return_value.buckets.all.return_value = buckets

        s3_total_queue = asyncio.Queue()

        await s3_total_sensor.work(task_data=task_data(), q=s3_total_queue)

        queue_result = await s3_total_queue.get()

        aioboto_mock.assert_called_once_with("s3", aws_access_key_id="1123124", aws_secret_access_key="jkh2089")
        assert queue_result["message"] == "Your AWS account has 3 buckets."
        assert {
            "customunit": "buckets",
            "kind": "Custom",
            "mode": "integer",
            "name": "Total Buckets",
            "value": 3,
        } in queue_result["channel"]

To be able to test this, I ended up using the https://pypi.org/project/asynctest/ which was really helpful for mocking async iterables and context managers.

The trick for me to figure out was what was returning what when for the call

all_buckets = s3.buckets.all()

The way that I tried to think of it is as such:

@asynctest.patch("aioboto3.resource")
...
aioboto_mock.return_value.__aenter__.return_value.buckets.all.return_value = buckets

‘aioboto_mock’ mocks the library and it’s attribute ‘resource’

The ‘return_value’ of that is the ‘s3’ context manager

Then we step into the context which is the ‘__aenter__’ method

Then we need another ‘return_value’ since this is where the context is returning the results from..

The method buckets.all()’s return_value.

buckets = asynctest.MagicMock()
buckets.__aiter__.return_value = ["bucket1", "bucket2", "bucket3"]

The buckets.all() method also returns an async iterable and so to patch the result of the s3.buckets.all() method we also have to set this to be a MagicMock that returns values for the __aiter__ method of async iterables..

Totally simple right :D. Hope this helps anyone else trying to understand how to mock async functions!