Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(generator): support custom default casing for client properties #877

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/prisma/generator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from pydantic.fields import PrivateAttr

from .. import config
from .utils import Faker, Sampler, clean_multiline
from .utils import Faker, Sampler, clean_multiline, to_camel_case, to_pascal_case, to_snake_case
from ..utils import DEBUG_GENERATOR, assert_never
from ..errors import UnsupportedListTypeError
from .._compat import (
Expand Down Expand Up @@ -248,6 +248,14 @@ def __str__(self) -> str:
return self.value


class ClientCasing(str, enum.Enum):
snake_case = 'snake_case'
camel_case = 'camel_case'
lower_case = 'lower_case'
upper_case = 'upper_case'
pascal_case = 'pascal_case'


class Module(BaseModel):
if TYPE_CHECKING:
spec: machinery.ModuleSpec
Expand Down Expand Up @@ -496,6 +504,7 @@ class Config(BaseSettings):
env='PRISMA_PY_CONFIG_RECURSIVE_TYPE_DEPTH',
)
engine_type: EngineType = FieldInfo(default=EngineType.binary, env='PRISMA_PY_CONFIG_ENGINE_TYPE')
client_casing: ClientCasing = FieldInfo(default=ClientCasing.lower_case)

# this should be a list of experimental features
# https://github.com/prisma/prisma/issues/12442
Expand Down Expand Up @@ -684,7 +693,8 @@ def name_validator(cls, name: str) -> str:
f'use a different model name with \'@@map("{name}")\'.'
)

if iskeyword(name.lower()):
config = get_config()
if isinstance(config, Config) and config.client_casing == ClientCasing.lower_case and iskeyword(name.lower()):
raise ValueError(
f'Model name "{name}" results in a client property that shadows a Python keyword; '
f'use a different model name with \'@@map("{name}")\'.'
Expand Down Expand Up @@ -748,6 +758,13 @@ def instance_name(self) -> str:

`User` -> `Prisma().user`
"""
config = get_config()
if isinstance(config, Config) and config.client_casing == ClientCasing.camel_case:
return to_camel_case(self.name)
elif isinstance(config, Config) and config.client_casing == ClientCasing.pascal_case:
return to_pascal_case(self.name)
elif isinstance(config, Config) and config.client_casing == ClientCasing.snake_case:
return to_snake_case(self.name)
return self.name.lower()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it would be great for this to exhaustively match the enum, you can do this by using our assert_never() helper function. Here's an example from another part of the codebase

        if value == EngineType.binary:
            return value
        elif value == EngineType.dataproxy:  # pragma: no cover
            raise ValueError('Prisma Client Python does not support the Prisma Data Proxy yet.')
        elif value == EngineType.library:  # pragma: no cover
            raise ValueError('Prisma Client Python does not support native engine bindings yet.')
        else:  # pragma: no cover
            assert_never(value)


@property
Expand Down
27 changes: 27 additions & 0 deletions src/prisma/generator/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
import shutil
from typing import TYPE_CHECKING, Any, Dict, List, Union, TypeVar, Iterator
from pathlib import Path
Expand Down Expand Up @@ -122,3 +123,29 @@ def clean_multiline(string: str) -> str:
assert string, 'Expected non-empty string'
lines = string.splitlines()
return '\n'.join([dedent(lines[0]), *lines[1:]])


ACRONYM_RE = re.compile(r'([A-Z\d]+)(?=[A-Z\d]|$)')
PASCAL_RE = re.compile(r'([^\-_]+)')
SPLIT_RE = re.compile(r'([\-_]*[A-Z][^A-Z]*[\-_]*)')
UNDERSCORE_RE = re.compile(r'(?<=[^\-_])[\-_]+[^\-_]')
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would be great to add test cases for all these functions!



def to_snake_case(input_str: str) -> str:
input_str = ACRONYM_RE.sub(lambda m: m.group(0).title(), input_str)
input_str = '_'.join(s for s in SPLIT_RE.split(input_str) if s)
return input_str.lower()


def to_camel_case(input_str: str) -> str:
if len(input_str) != 0 and not input_str[:2].isupper():
input_str = input_str[0].lower() + input_str[1:]
return UNDERSCORE_RE.sub(lambda m: m.group(0)[-1].upper(), input_str)


def to_pascal_case(input_str: str) -> str:
def _replace_fn(match: re.Match[str]) -> str:
return match.group(1)[0].upper() + match.group(1)[1:]

input_str = to_camel_case(PASCAL_RE.sub(_replace_fn, input_str))
return input_str[0].upper() + input_str[1:] if len(input_str) != 0 else input_str
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: where did you get these functions from? did you come up with it yourself?

I'm only asking because, if you copied these from somewhere, it would be great to add a link to that place in case these need to be updated in the future.

Loading