1
0
Fork 0
mirror of https://github.com/archtechx/tenancy.git synced 2026-02-05 16:54:05 +00:00
tenancy/src/Commands/CreateRLSPoliciesForTenantTables.php
2023-04-28 14:27:55 +02:00

67 lines
2.5 KiB
PHP

<?php
declare(strict_types=1);
namespace Stancl\Tenancy\Commands;
use Illuminate\Console\Command;
use Illuminate\Support\Facades\DB;
use Illuminate\Support\Facades\Schema;
use Stancl\Tenancy\Database\Concerns\BelongsToPrimaryModel;
class CreateRLSPoliciesForTenantTables extends Command
{
protected $signature = 'tenants:create-rls-policies';
public function handle(): int
{
$tenantModels = $this->getTenantModels();
$tenantKey = tenancy()->tenantKeyColumn();
foreach ($tenantModels as $model) {
$table = $model->getTable();
DB::statement("DROP POLICY IF EXISTS {$table}_rls_policy ON {$table}");
if (! Schema::hasColumn($table, $tenantKey)) {
// Table is not directly related to tenant
if (in_array(BelongsToPrimaryModel::class, class_uses_recursive($model::class))) {
$parentName = $model->getRelationshipToPrimaryModel();
$parentKey = $model->$parentName()->getForeignKeyName();
$parentModel = $model->$parentName()->make();
$parentTable = str($parentModel->getTable())->toString();
DB::statement("CREATE POLICY {$table}_rls_policy ON {$table} USING (
{$parentKey} IN (
SELECT id
FROM {$parentTable}
WHERE ({$tenantKey}::TEXT = (
SELECT {$tenantKey}
FROM {$parentTable}
WHERE id = {$parentKey}
))
)
)");
DB::statement("ALTER TABLE {$table} FORCE ROW LEVEL SECURITY");
} else {
$modelName = $model::class;
$this->components->info("Table '$table' is not related to tenant. Make sure $modelName uses the BelongsToPrimaryModel trait.");
}
} else {
DB::statement("CREATE POLICY {$table}_rls_policy ON {$table} USING ({$tenantKey}::TEXT = current_user);");
DB::statement("ALTER TABLE {$table} FORCE ROW LEVEL SECURITY");
$this->components->info("Created RLS policy for table '$table'");
}
}
return Command::SUCCESS;
}
public function getTenantModels(): array
{
return array_map(fn (string $modelName) => (new $modelName), config('tenancy.models.rls'));
}
}