Beatrust techBlog

Beatrust 株式会社の公式開発者ブログです。

Custom Prisma Client for RLS

Author: Neo Chiu

Background

We use Prisma + Postgres from prototype and start to migrate all data with RLS ( Row Level Security ) last year. We are managing multi-tenants data in one database, and we don't want data be accessed cross tenants. RLS restricts data with security policies at database engine level to prevent any unexpected access from client side.

You can find more details in Japanese at "NestJS + Prisma2 で作る RLS の世界".

After a few months, we started to face some issues.

  • increased latency even without a high rate of simultaneous requests.
  • memory usage becomes unreasonably high.

Where is the problem ?

Prisma client is a javascript interface for Prisma engine. Prisma engine is implemented in Rust and pre-built in binary executable for npm install. When you start a new Prisma client, it starts Prisma engine to manage queries and connections.

Here comes the problem: to prevent pollution between requests, we ask NestJS to handle the incoming request in Scope.REQUEST. It creates new instances of all providers for each request to make sure it is exclusive. Every request needs to wait for Prisma engine to start and a new connection open to the DB. If the endpoint logic calls Prisma client multiple times in parallel, Prisma client will automatically increase the number of connections, which is num_physical_cpus * 2 + 1 by default.

So the problem is that:

  • Prisma engine start and connection open takes time.
  • If Prisma client is called in parallel, queries need to wait for new connections to open.
  • Prisma engine consumes memory but we create it for every request.
  • Prisma engine is started as child process, it can't be terminated as fast as the request end.

Share one global client

To prevent the need to keep creating new Prisma engine instances, we should start just one Prisma client.

Originally we did the same as official example but this way is affected by Scope.REQUEST.

@Injectable({ scope: Scope.REQUEST })
export class PrismaService extends PrismaClient

So instead of extend, we create an empty class and use the global client internally. In this case, we have to manually bind methods between custom instance and global client.

const globalClient = new PrismaClient();

// we don't directly extend PrismaClient so we need to extend it in TS for typing
export interface PrismaRlsClient extends PrismaClient {}
export class PrismaRlsClient {
  constructor() {
    this._reflectProto();
  }

  private _reflectProto() {
    Object.keys(globalClient).forEach((k) => {
      this[k] = globalClient[k];
    });
  }
}

Now we can use PrismaRlsClient to create a NestJS provider which shares the same global client.

Handle RLS without breaking changes

Before, we created a whole new client / engine for each request. This guarantees that the client is always under the same context for RLS. Since we are sharing one global client, we have to solve RLS at the query level. For that, we need to use transactions.

prisma.$transaction([
  // set variables for RLS
  prisma.$executeRawUnsafe(`SET orgId='${value}'`),
  // target query
  prisma.user.findUnique(),
  // clean variables of current connection to prevent pollution
  prisma.$executeRawUnsafe(`SET orgId=''`),
]);

So... are we going to re-write all existing code to above format? We definitely don't want to refactor all of our existing code by hand. Since we are already reflecting all methods under the hood, then why not leverage this to refactor all methods at the same time? To do that, we need to find out where the model delegates are. Obviously, hard coding all models is not that smart because the model name is decided by the schema and it may increase or change. Instead, we iterate over all properties to find out where the model delegate is.

class PrismaRlsClient {
  private _reflectModel() {
    Object.keys(globalClient).forEach((name) => {
      const delegate = globalClient[name];
      if (typeof delegate?.findUnique !== "function") {
        // model delegate should have `findUnique` method
        return;
      }

      // create delegate object on this instance
      this[name] = {};

      Object.keys(delegate).forEach((k) => {
        this[name][k] = this._rebuildQueryMethod(delegate[k]);
      });
    });
  }
}

Next, let's rebuild all methods into a transaction with RLS support. We create a private method which provides the logic of SET and integrate it into the rebuild function.

type AnyFunction = (...args: any[]) => any;

class PrismaRlsClient {
  rls: {
    orgId?: string;
  } = {};

  private _getRlsPromises() {
    return {
      before: [globalClient.$executeRawUnsafe(`SET orgId='${this.rls.orgId}'`)],
      after: [globalClient.$executeRawUnsafe(`SET orgId=''`)],
    };
  }

  private _rebuildQueryMethod<Fn extends AnyFunction>(method: Fn): Fn {
    // rebuild function only
    if (typeof method !== "function") return method;
    const newFn = (...args: Parameters<Fn>) => {
      const { before, after } = this._getRlsPromises();
      return (
        globalClient
          .$transaction([
            ...before,
            Reflect.apply(method, globalClient, args),
            ...after,
          ])
          // extract real method result
          .then((results) => results[before.length])
      );
    };
    return newFn as Fn;
  }
}

Also, we need to take care of some of client methods. For raw query functions, we can directly use the above rebuild function.

class PrismaRlsClient {
  private _reflectProto() {
    this.$executeRaw = this._rebuildQueryMethod(globalClient.$executeRaw);
    this.$executeRawUnsafe = this._rebuildQueryMethod(
      globalClient.$executeRawUnsafe
    );
    this.$queryRaw = this._rebuildQueryMethod(globalClient.$queryRaw);
    this.$queryRawUnsafe = this._rebuildQueryMethod(
      globalClient.$queryRawUnsafe
    );
  }
}

For $transaction itself, we have to pass the original method from the global client because we cannot use $transaction within $transaction.

class PrismaRlsClient {
  _global = globalClient;

  private _rebuildTransaction() {
    return <P extends PrismaPromise<any>[]>(arg: [...P]) => {
      const { before, after } = this._getRlsPromises();
      return (
        client
          .$transaction([...before, ...arg, ...after] as [...P])
          // extract real method result
          .then(
            (results) =>
              results.slice(
                before.length,
                before.length + arg.length
              ) as typeof results
          )
      );
    };
  }
}

// use original method for transaction
rlsClient.$transaction([
  rlsClient._global.org.update(),
  rlsClient._global.user.update(),
]);

Try it yourself!

You can find a full example at this github repo.

Conclusion

Though there are still some issues, this approach solved our performance issues and has been working well for almost one year.

Recently, we are working on improving:

  • correct type-safety.
  • better transaction usage.
  • support for interactive transactions.
  • RLS client generator which can accept any PrismaClient.

Also, we keep watching how the community discussion evolves on "Supporting Postgres' SET across queries of a request #5128". The idea of @rls is good and we may consider to support it at Beatrust.

If you are interested in joining us, we are currently looking for Full Stack Engineers. Check Beatrust Careers for more information!