Skip to content

Commit

Permalink
feat(client): add transaction isolation level
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed Jan 26, 2024
1 parent 7e21552 commit 41b8671
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 17 deletions.
121 changes: 121 additions & 0 deletions databases/sync_tests/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,124 @@ def test_transaction_already_closed(client: Prisma) -> None:
transaction.user.delete_many()

assert exc.match('Transaction already closed')


@pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available')
def test_read_uncommited_isolation_level(client: Prisma) -> None:
"""A transaction isolation level is set to `READ_UNCOMMITED`"""
client2 = Prisma()
client2.connect()

user = client.user.create(data={'name': 'Robert'})

with client.tx(isolation_level=prisma.TransactionIsolationLevel.READ_UNCOMMITED) as tx1:
tx1_user = tx1.user.find_first_or_raise(where={'id': user.id})
tx1_count = tx1.user.count()

with client2.tx() as tx2:
tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
tx2.user.create(data={'name': 'Bobby'})

dirty_user = tx1.user.find_first_or_raise(where={'id': user.id})

non_repeatable_user = tx1.user.find_first_or_raise(where={'id': user.id})
phantom_count = tx1.user.count()

# Have dirty read
assert tx1_user.name != dirty_user.name
# Have non-repeatable read
assert tx1_user.name != non_repeatable_user.name
# Have phantom read
assert tx1_count != phantom_count


@pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available')
def test_read_commited_isolation_level(client: Prisma) -> None:
"""A transaction isolation level is set to `READ_COMMITED`"""
client2 = Prisma()
client2.connect()

user = client.user.create(data={'name': 'Robert'})

with client.tx(isolation_level=prisma.TransactionIsolationLevel.READ_COMMITED) as tx1:
tx1_user = tx1.user.find_first_or_raise(where={'id': user.id})
tx1_count = tx1.user.count()

with client2.tx() as tx2:
tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
tx2.user.create(data={'name': 'Bobby'})

dirty_user = tx1.user.find_first_or_raise(where={'id': user.id})

non_repeatable_user = tx1.user.find_first_or_raise(where={'id': user.id})
phantom_count = tx1.user.count()

# No dirty read
assert tx1_user.name == dirty_user.name
# Have non-repeatable read
assert tx1_user.name != non_repeatable_user.name
# Have phantom read
assert tx1_count != phantom_count


@pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available')
def test_repeatable_read_isolation_level(client: Prisma) -> None:
"""A transaction isolation level is set to `REPEATABLE_READ`"""
client2 = Prisma()
client2.connect()

user = client.user.create(data={'name': 'Robert'})

with client.tx(isolation_level=prisma.TransactionIsolationLevel.REPEATABLE_READ) as tx1:
tx1_user = tx1.user.find_first_or_raise(where={'id': user.id})
tx1_count = tx1.user.count()

with client2.tx() as tx2:
tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
tx2.user.create(data={'name': 'Bobby'})

dirty_user = tx1.user.find_first_or_raise(where={'id': user.id})

non_repeatable_user = tx1.user.find_first_or_raise(where={'id': user.id})
phantom_count = tx1.user.count()

# No dirty read
assert tx1_user.name == dirty_user.name
# No non-repeatable read
assert tx1_user.name == non_repeatable_user.name
# Have phantom read
assert tx1_count != phantom_count


@pytest.mark.skipif(True, reason='Available for SQL Server only')
def test_snapshot_isolation_level() -> None:
"""A transaction isolation level is set to `SNAPSHOT`"""
raise NotImplementedError


def test_serializable_isolation_level(client: Prisma) -> None:
"""A transaction isolation level is set to `SERIALIZABLE`"""
client2 = Prisma()
client2.connect()

user = client.user.create(data={'name': 'Robert'})

with client.tx(isolation_level=prisma.TransactionIsolationLevel.SERIALIZABLE) as tx1:
tx1_user = tx1.user.find_first_or_raise(where={'id': user.id})
tx1_count = tx1.user.count()

with client2.tx() as tx2:
tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
tx2.user.create(data={'name': 'Bobby'})

dirty_user = tx1.user.find_first_or_raise(where={'id': user.id})

non_repeatable_user = tx1.user.find_first_or_raise(where={'id': user.id})
phantom_count = tx1.user.count()

# No dirty read
assert tx1_user.name == dirty_user.name
# No non-repeatable read
assert tx1_user.name == non_repeatable_user.name
# No phantom read
assert tx1_count == phantom_count
126 changes: 126 additions & 0 deletions databases/tests/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,129 @@ async def test_transaction_already_closed(client: Prisma) -> None:
await transaction.user.delete_many()

assert exc.match('Transaction already closed')


@pytest.mark.asyncio
@pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available')
async def test_read_uncommited_isolation_level(client: Prisma) -> None:
"""A transaction isolation level is set to `READ_UNCOMMITED`"""
client2 = Prisma()
await client2.connect()

user = await client.user.create(data={'name': 'Robert'})

async with client.tx(isolation_level=prisma.TransactionIsolationLevel.READ_UNCOMMITED) as tx1:
tx1_user = await tx1.user.find_first_or_raise(where={'id': user.id})
tx1_count = await tx1.user.count()

async with client2.tx() as tx2:
await tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
await tx2.user.create(data={'name': 'Bobby'})

dirty_user = await tx1.user.find_first_or_raise(where={'id': user.id})

non_repeatable_user = await tx1.user.find_first_or_raise(where={'id': user.id})
phantom_count = await tx1.user.count()

# Have dirty read
assert tx1_user.name != dirty_user.name
# Have non-repeatable read
assert tx1_user.name != non_repeatable_user.name
# Have phantom read
assert tx1_count != phantom_count


@pytest.mark.asyncio
@pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available')
async def test_read_commited_isolation_level(client: Prisma) -> None:
"""A transaction isolation level is set to `READ_COMMITED`"""
client2 = Prisma()
await client2.connect()

user = await client.user.create(data={'name': 'Robert'})

async with client.tx(isolation_level=prisma.TransactionIsolationLevel.READ_COMMITED) as tx1:
tx1_user = await tx1.user.find_first_or_raise(where={'id': user.id})
tx1_count = await tx1.user.count()

async with client2.tx() as tx2:
await tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
await tx2.user.create(data={'name': 'Bobby'})

dirty_user = await tx1.user.find_first_or_raise(where={'id': user.id})

non_repeatable_user = await tx1.user.find_first_or_raise(where={'id': user.id})
phantom_count = await tx1.user.count()

# No dirty read
assert tx1_user.name == dirty_user.name
# Have non-repeatable read
assert tx1_user.name != non_repeatable_user.name
# Have phantom read
assert tx1_count != phantom_count


@pytest.mark.asyncio
@pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available')
async def test_repeatable_read_isolation_level(client: Prisma) -> None:
"""A transaction isolation level is set to `REPEATABLE_READ`"""
client2 = Prisma()
await client2.connect()

user = await client.user.create(data={'name': 'Robert'})

async with client.tx(isolation_level=prisma.TransactionIsolationLevel.REPEATABLE_READ) as tx1:
tx1_user = await tx1.user.find_first_or_raise(where={'id': user.id})
tx1_count = await tx1.user.count()

async with client2.tx() as tx2:
await tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
await tx2.user.create(data={'name': 'Bobby'})

dirty_user = await tx1.user.find_first_or_raise(where={'id': user.id})

non_repeatable_user = await tx1.user.find_first_or_raise(where={'id': user.id})
phantom_count = await tx1.user.count()

# No dirty read
assert tx1_user.name == dirty_user.name
# No non-repeatable read
assert tx1_user.name == non_repeatable_user.name
# Have phantom read
assert tx1_count != phantom_count


@pytest.mark.asyncio
@pytest.mark.skipif(True, reason='Available for SQL Server only')
async def test_snapshot_isolation_level() -> None:
"""A transaction isolation level is set to `SNAPSHOT`"""
raise NotImplementedError


@pytest.mark.asyncio
async def test_serializable_isolation_level(client: Prisma) -> None:
"""A transaction isolation level is set to `SERIALIZABLE`"""
client2 = Prisma()
await client2.connect()

user = await client.user.create(data={'name': 'Robert'})

async with client.tx(isolation_level=prisma.TransactionIsolationLevel.SERIALIZABLE) as tx1:
tx1_user = await tx1.user.find_first_or_raise(where={'id': user.id})
tx1_count = await tx1.user.count()

async with client2.tx() as tx2:
await tx2.user.update(data={'name': 'Tegan'}, where={'id': user.id})
await tx2.user.create(data={'name': 'Bobby'})

dirty_user = await tx1.user.find_first_or_raise(where={'id': user.id})

non_repeatable_user = await tx1.user.find_first_or_raise(where={'id': user.id})
phantom_count = await tx1.user.count()

# No dirty read
assert tx1_user.name == dirty_user.name
# No non-repeatable read
assert tx1_user.name == non_repeatable_user.name
# No phantom read
assert tx1_count == phantom_count
56 changes: 39 additions & 17 deletions src/prisma/_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,39 @@
import logging
import warnings
from types import TracebackType
from typing import TYPE_CHECKING, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from datetime import timedelta

from ._types import TransactionId
from .errors import TransactionNotStartedError
from ._compat import StrEnum
from ._builder import dumps

if TYPE_CHECKING:
from ._base_client import SyncBasePrisma, AsyncBasePrisma

log: logging.Logger = logging.getLogger(__name__)

__all__ = (
'TransactionIsolationLevel',
'AsyncTransactionManager',
'SyncTransactionManager',
)


_SyncPrismaT = TypeVar('_SyncPrismaT', bound='SyncBasePrisma')
_AsyncPrismaT = TypeVar('_AsyncPrismaT', bound='AsyncBasePrisma')


# See here: https://www.prisma.io/docs/orm/prisma-client/queries/transactions#supported-isolation-levels
class TransactionIsolationLevel(StrEnum):
READ_UNCOMMITED = 'ReadUncommitted'
READ_COMMITED = 'ReadCommitted'
REPEATABLE_READ = 'RepeatableRead'
SNAPSHOT = 'Snapshot'
SERIALIZABLE = 'Serializable'


class AsyncTransactionManager(Generic[_AsyncPrismaT]):
"""Context manager for wrapping a Prisma instance within a transaction.
Expand All @@ -33,8 +49,10 @@ def __init__(
client: _AsyncPrismaT,
max_wait: int | timedelta,
timeout: int | timedelta,
isolation_level: TransactionIsolationLevel | None,
) -> None:
self.__client = client
self._isolation_level = isolation_level

if isinstance(max_wait, int):
message = (
Expand Down Expand Up @@ -71,14 +89,15 @@ async def start(self, *, _from_context: bool = False) -> _AsyncPrismaT:
stacklevel=3 if _from_context else 2,
)

tx_id = await self.__client._engine.start_transaction(
content=dumps(
{
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
),
)
content_dict: dict[str, Any] = {
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
if self._isolation_level:
content_dict['isolation_level'] = self._isolation_level.value

tx_id = await self.__client._engine.start_transaction(content=dumps(content_dict))

self._tx_id = tx_id
client = self.__client._copy()
client._tx_id = tx_id
Expand Down Expand Up @@ -135,8 +154,10 @@ def __init__(
client: _SyncPrismaT,
max_wait: int | timedelta,
timeout: int | timedelta,
isolation_level: TransactionIsolationLevel | None,
) -> None:
self.__client = client
self._isolation_level = isolation_level

if isinstance(max_wait, int):
message = (
Expand Down Expand Up @@ -173,14 +194,15 @@ def start(self, *, _from_context: bool = False) -> _SyncPrismaT:
stacklevel=3 if _from_context else 2,
)

tx_id = self.__client._engine.start_transaction(
content=dumps(
{
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
),
)
content_dict: dict[str, Any] = {
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
if self._isolation_level:
content_dict['isolation_level'] = self._isolation_level.value

tx_id = self.__client._engine.start_transaction(content=dumps(content_dict))

self._tx_id = tx_id
client = self.__client._copy()
client._tx_id = tx_id
Expand Down
5 changes: 5 additions & 0 deletions src/prisma/generator/templates/client.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ class Prisma({% if is_async %}AsyncBasePrisma{% else %}SyncBasePrisma{% endif %}
def tx(
self,
*,
isolation_level: Optional[TransactionIsolationLevel] = None,
max_wait: Union[int, timedelta] = DEFAULT_TX_MAX_WAIT,
timeout: Union[int, timedelta] = DEFAULT_TX_TIMEOUT,
) -> TransactionManager:
Expand All @@ -211,6 +212,9 @@ class Prisma({% if is_async %}AsyncBasePrisma{% else %}SyncBasePrisma{% endif %}
actions within a transaction, queries will be isolated to the Prisma instance and
will not be commited to the database until the context manager exits.

By default, Prisma sets the isolation level to the value currently configured in the database. You can modify this
default with the `isolation_level` argument (see [supported isolation levels](https://www.prisma.io/docs/orm/prisma-client/queries/transactions#supported-isolation-levels)).

By default, Prisma will wait a maximum of 2 seconds to acquire a transaction from the database. You can modify this
default with the `max_wait` argument which accepts a value in milliseconds or `datetime.timedelta`.

Expand All @@ -231,6 +235,7 @@ class Prisma({% if is_async %}AsyncBasePrisma{% else %}SyncBasePrisma{% endif %}
client=self,
max_wait=max_wait,
timeout=timeout,
isolation_level=isolation_level,
)


Expand Down
Loading

0 comments on commit 41b8671

Please sign in to comment.