package run
import (
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/openfga/openfga/cmd/util"
)
// bindRunFlags binds the cobra cmd flags to the equivalent config value being managed
// by viper. This bridges the config between cobra flags and viper flags.
func bindRunFlagsFunc(flags *pflag.FlagSet) func(*cobra.Command, []string) {
return func(command *cobra.Command, args []string) {
util.MustBindPFlag("experimentals", flags.Lookup("experimentals"))
util.MustBindEnv("experimentals", "OPENFGA_EXPERIMENTALS")
util.MustBindPFlag("accessControl.enabled", flags.Lookup("access-control-enabled"))
util.MustBindEnv("accessControl.enabled", "OPENFGA_ACCESS_CONTROL_ENABLED")
util.MustBindPFlag("accessControl.storeId", flags.Lookup("access-control-store-id"))
util.MustBindEnv("accessControl.storeId", "OPENFGA_ACCESS_CONTROL_STORE_ID")
util.MustBindPFlag("accessControl.modelId", flags.Lookup("access-control-model-id"))
util.MustBindEnv("accessControl.modelId", "OPENFGA_ACCESS_CONTROL_MODEL_ID")
command.MarkFlagsRequiredTogether("access-control-enabled", "access-control-store-id", "access-control-model-id")
util.MustBindPFlag("grpc.addr", flags.Lookup("grpc-addr"))
util.MustBindEnv("grpc.addr", "OPENFGA_GRPC_ADDR")
util.MustBindPFlag("grpc.tls.enabled", flags.Lookup("grpc-tls-enabled"))
util.MustBindEnv("grpc.tls.enabled", "OPENFGA_GRPC_TLS_ENABLED")
util.MustBindPFlag("grpc.tls.cert", flags.Lookup("grpc-tls-cert"))
util.MustBindEnv("grpc.tls.cert", "OPENFGA_GRPC_TLS_CERT")
util.MustBindPFlag("grpc.tls.key", flags.Lookup("grpc-tls-key"))
util.MustBindEnv("grpc.tls.key", "OPENFGA_GRPC_TLS_KEY")
command.MarkFlagsRequiredTogether("grpc-tls-enabled", "grpc-tls-cert", "grpc-tls-key")
util.MustBindPFlag("http.enabled", flags.Lookup("http-enabled"))
util.MustBindEnv("http.enabled", "OPENFGA_HTTP_ENABLED")
util.MustBindPFlag("http.addr", flags.Lookup("http-addr"))
util.MustBindEnv("http.addr", "OPENFGA_HTTP_ADDR")
util.MustBindPFlag("http.tls.enabled", flags.Lookup("http-tls-enabled"))
util.MustBindEnv("http.tls.enabled", "OPENFGA_HTTP_TLS_ENABLED")
util.MustBindPFlag("http.tls.cert", flags.Lookup("http-tls-cert"))
util.MustBindEnv("http.tls.cert", "OPENFGA_HTTP_TLS_CERT")
util.MustBindPFlag("http.tls.key", flags.Lookup("http-tls-key"))
util.MustBindEnv("http.tls.key", "OPENFGA_HTTP_TLS_KEY")
command.MarkFlagsRequiredTogether("http-tls-enabled", "http-tls-cert", "http-tls-key")
util.MustBindPFlag("http.upstreamTimeout", flags.Lookup("http-upstream-timeout"))
util.MustBindEnv("http.upstreamTimeout", "OPENFGA_HTTP_UPSTREAM_TIMEOUT", "OPENFGA_HTTP_UPSTREAMTIMEOUT")
util.MustBindPFlag("http.corsAllowedOrigins", flags.Lookup("http-cors-allowed-origins"))
util.MustBindEnv("http.corsAllowedOrigins", "OPENFGA_HTTP_CORS_ALLOWED_ORIGINS", "OPENFGA_HTTP_CORSALLOWEDORIGINS")
util.MustBindPFlag("http.corsAllowedHeaders", flags.Lookup("http-cors-allowed-headers"))
util.MustBindEnv("http.corsAllowedHeaders", "OPENFGA_HTTP_CORS_ALLOWED_HEADERS", "OPENFGA_HTTP_CORSALLOWEDHEADERS")
util.MustBindPFlag("authn.method", flags.Lookup("authn-method"))
util.MustBindEnv("authn.method", "OPENFGA_AUTHN_METHOD")
util.MustBindPFlag("authn.preshared.keys", flags.Lookup("authn-preshared-keys"))
util.MustBindEnv("authn.preshared.keys", "OPENFGA_AUTHN_PRESHARED_KEYS")
util.MustBindPFlag("authn.oidc.audience", flags.Lookup("authn-oidc-audience"))
util.MustBindEnv("authn.oidc.audience", "OPENFGA_AUTHN_OIDC_AUDIENCE")
util.MustBindPFlag("authn.oidc.issuer", flags.Lookup("authn-oidc-issuer"))
util.MustBindEnv("authn.oidc.issuer", "OPENFGA_AUTHN_OIDC_ISSUER")
util.MustBindPFlag("authn.oidc.issuerAliases", flags.Lookup("authn-oidc-issuer-aliases"))
util.MustBindEnv("authn.oidc.issuerAliases", "OPENFGA_AUTHN_OIDC_ISSUER_ALIASES")
util.MustBindPFlag("authn.oidc.subjects", flags.Lookup("authn-oidc-subjects"))
util.MustBindEnv("authn.oidc.subjects", "OPENFGA_AUTHN_OIDC_SUBJECTS")
util.MustBindPFlag("authn.oidc.clientIdClaims", flags.Lookup("authn-oidc-client-id-claims"))
util.MustBindEnv("authn.oidc.clientIdClaims", "OPENFGA_AUTHN_OIDC_CLIENT_ID_CLAIMS")
util.MustBindPFlag("datastore.engine", flags.Lookup("datastore-engine"))
util.MustBindEnv("datastore.engine", "OPENFGA_DATASTORE_ENGINE")
util.MustBindPFlag("datastore.uri", flags.Lookup("datastore-uri"))
util.MustBindEnv("datastore.uri", "OPENFGA_DATASTORE_URI")
util.MustBindPFlag("datastore.secondaryUri", flags.Lookup("datastore-secondary-uri"))
util.MustBindEnv("datastore.secondaryUri", "OPENFGA_DATASTORE_SECONDARY_URI")
util.MustBindPFlag("datastore.secondaryUsername", flags.Lookup("datastore-secondary-username"))
util.MustBindEnv("datastore.secondaryUsername", "OPENFGA_DATASTORE_SECONDARY_USERNAME")
util.MustBindPFlag("datastore.secondaryPassword", flags.Lookup("datastore-secondary-password"))
util.MustBindEnv("datastore.secondaryPassword", "OPENFGA_DATASTORE_SECONDARY_PASSWORD")
util.MustBindPFlag("datastore.username", flags.Lookup("datastore-username"))
util.MustBindEnv("datastore.username", "OPENFGA_DATASTORE_USERNAME")
util.MustBindPFlag("datastore.password", flags.Lookup("datastore-password"))
util.MustBindEnv("datastore.password", "OPENFGA_DATASTORE_PASSWORD")
util.MustBindPFlag("datastore.maxCacheSize", flags.Lookup("datastore-max-cache-size"))
util.MustBindEnv("datastore.maxCacheSize", "OPENFGA_DATASTORE_MAX_CACHE_SIZE", "OPENFGA_DATASTORE_MAXCACHESIZE")
util.MustBindPFlag("datastore.minOpenConns", flags.Lookup("datastore-min-open-conns"))
util.MustBindEnv("datastore.minOpenConns", "OPENFGA_DATASTORE_MIN_OPEN_CONNS")
util.MustBindPFlag("datastore.maxOpenConns", flags.Lookup("datastore-max-open-conns"))
util.MustBindEnv("datastore.maxOpenConns", "OPENFGA_DATASTORE_MAX_OPEN_CONNS", "OPENFGA_DATASTORE_MAXOPENCONNS")
util.MustBindPFlag("datastore.minIdleConns", flags.Lookup("datastore-min-idle-conns"))
util.MustBindEnv("datastore.minIdleConns", "OPENFGA_DATASTORE_MIN_IDLE_CONNS")
util.MustBindPFlag("datastore.maxIdleConns", flags.Lookup("datastore-max-idle-conns"))
util.MustBindEnv("datastore.maxIdleConns", "OPENFGA_DATASTORE_MAX_IDLE_CONNS", "OPENFGA_DATASTORE_MAXIDLECONNS")
util.MustBindPFlag("datastore.connMaxIdleTime", flags.Lookup("datastore-conn-max-idle-time"))
util.MustBindEnv("datastore.connMaxIdleTime", "OPENFGA_DATASTORE_CONN_MAX_IDLE_TIME", "OPENFGA_DATASTORE_CONNMAXIDLETIME")
util.MustBindPFlag("datastore.connMaxLifetime", flags.Lookup("datastore-conn-max-lifetime"))
util.MustBindEnv("datastore.connMaxLifetime", "OPENFGA_DATASTORE_CONN_MAX_LIFETIME", "OPENFGA_DATASTORE_CONNMAXLIFETIME")
util.MustBindPFlag("datastore.metrics.enabled", flags.Lookup("datastore-metrics-enabled"))
util.MustBindEnv("datastore.metrics.enabled", "OPENFGA_DATASTORE_METRICS_ENABLED")
util.MustBindPFlag("playground.enabled", flags.Lookup("playground-enabled"))
util.MustBindEnv("playground.enabled", "OPENFGA_PLAYGROUND_ENABLED")
util.MustBindPFlag("playground.port", flags.Lookup("playground-port"))
util.MustBindEnv("playground.port", "OPENFGA_PLAYGROUND_PORT")
util.MustBindPFlag("profiler.enabled", flags.Lookup("profiler-enabled"))
util.MustBindEnv("profiler.enabled", "OPENFGA_PROFILER_ENABLED")
util.MustBindPFlag("profiler.addr", flags.Lookup("profiler-addr"))
util.MustBindEnv("profiler.addr", "OPENFGA_PROFILER_ADDRESS")
util.MustBindPFlag("log.format", flags.Lookup("log-format"))
util.MustBindEnv("log.format", "OPENFGA_LOG_FORMAT")
util.MustBindPFlag("log.level", flags.Lookup("log-level"))
util.MustBindEnv("log.level", "OPENFGA_LOG_LEVEL")
util.MustBindPFlag("log.timestampFormat", flags.Lookup("log-timestamp-format"))
util.MustBindEnv("log.timestampFormat", "OPENFGA_LOG_TIMESTAMP_FORMAT")
util.MustBindPFlag("trace.enabled", flags.Lookup("trace-enabled"))
util.MustBindEnv("trace.enabled", "OPENFGA_TRACE_ENABLED")
util.MustBindPFlag("trace.otlp.endpoint", flags.Lookup("trace-otlp-endpoint"))
util.MustBindEnv("trace.otlp.endpoint", "OPENFGA_TRACE_OTLP_ENDPOINT")
util.MustBindPFlag("trace.otlp.tls.enabled", flags.Lookup("trace-otlp-tls-enabled"))
util.MustBindEnv("trace.otlp.tls.enabled", "OPENFGA_TRACE_OTLP_TLS_ENABLED")
util.MustBindPFlag("trace.sampleRatio", flags.Lookup("trace-sample-ratio"))
util.MustBindEnv("trace.sampleRatio", "OPENFGA_TRACE_SAMPLE_RATIO")
util.MustBindPFlag("trace.serviceName", flags.Lookup("trace-service-name"))
util.MustBindEnv("trace.serviceName", "OPENFGA_TRACE_SERVICE_NAME")
util.MustBindPFlag("metrics.enabled", flags.Lookup("metrics-enabled"))
util.MustBindEnv("metrics.enabled", "OPENFGA_METRICS_ENABLED")
util.MustBindPFlag("metrics.addr", flags.Lookup("metrics-addr"))
util.MustBindEnv("metrics.addr", "OPENFGA_METRICS_ADDR")
util.MustBindPFlag("metrics.enableRPCHistograms", flags.Lookup("metrics-enable-rpc-histograms"))
util.MustBindEnv("metrics.enableRPCHistograms", "OPENFGA_METRICS_ENABLE_RPC_HISTOGRAMS")
util.MustBindPFlag("maxChecksPerBatchCheck", flags.Lookup("max-checks-per-batch-check"))
util.MustBindEnv("maxChecksPerBatchCheck", "OPENFGA_MAX_CHECKS_PER_BATCH_CHECK")
util.MustBindPFlag("maxConcurrentChecksPerBatchCheck", flags.Lookup("max-concurrent-checks-per-batch-check"))
util.MustBindEnv("maxConcurrentChecksPerBatchCheck", "OPENFGA_MAX_CONCURRENT_CHECKS_PER_BATCH_CHECK")
util.MustBindPFlag("maxTuplesPerWrite", flags.Lookup("max-tuples-per-write"))
util.MustBindEnv("maxTuplesPerWrite", "OPENFGA_MAX_TUPLES_PER_WRITE", "OPENFGA_MAXTUPLESPERWRITE")
util.MustBindPFlag("maxTypesPerAuthorizationModel", flags.Lookup("max-types-per-authorization-model"))
util.MustBindEnv("maxTypesPerAuthorizationModel", "OPENFGA_MAX_TYPES_PER_AUTHORIZATION_MODEL", "OPENFGA_MAXTYPESPERAUTHORIZATIONMODEL")
util.MustBindPFlag("maxAuthorizationModelSizeInBytes", flags.Lookup("max-authorization-model-size-in-bytes"))
util.MustBindEnv("maxAuthorizationModelSizeInBytes", "OPENFGA_MAX_AUTHORIZATION_MODEL_SIZE_IN_BYTES", "OPENFGA_MAXAUTHORIZATIONMODELSIZEINBYTES")
util.MustBindPFlag("maxConcurrentReadsForListObjects", flags.Lookup("max-concurrent-reads-for-list-objects"))
util.MustBindEnv("maxConcurrentReadsForListObjects", "OPENFGA_MAX_CONCURRENT_READS_FOR_LIST_OBJECTS", "OPENFGA_MAXCONCURRENTREADSFORLISTOBJECTS")
util.MustBindPFlag("maxConcurrentReadsForListUsers", flags.Lookup("max-concurrent-reads-for-list-users"))
util.MustBindEnv("maxConcurrentReadsForListUsers", "OPENFGA_MAX_CONCURRENT_READS_FOR_LIST_USERS", "OPENFGA_MAXCONCURRENTREADSFORLISTUSERS")
util.MustBindPFlag("maxConcurrentReadsForCheck", flags.Lookup("max-concurrent-reads-for-check"))
util.MustBindEnv("maxConcurrentReadsForCheck", "OPENFGA_MAX_CONCURRENT_READS_FOR_CHECK", "OPENFGA_MAXCONCURRENTREADSFORCHECK")
util.MustBindPFlag("maxConditionEvaluationCost", flags.Lookup("max-condition-evaluation-cost"))
util.MustBindEnv("maxConditionEvaluationCost", "OPENFGA_MAX_CONDITION_EVALUATION_COST", "OPENFGA_MAXCONDITIONEVALUATIONCOST")
util.MustBindPFlag("changelogHorizonOffset", flags.Lookup("changelog-horizon-offset"))
util.MustBindEnv("changelogHorizonOffset", "OPENFGA_CHANGELOG_HORIZON_OFFSET", "OPENFGA_CHANGELOGHORIZONOFFSET")
util.MustBindPFlag("resolveNodeLimit", flags.Lookup("resolve-node-limit"))
util.MustBindEnv("resolveNodeLimit", "OPENFGA_RESOLVE_NODE_LIMIT", "OPENFGA_RESOLVENODELIMIT")
util.MustBindPFlag("resolveNodeBreadthLimit", flags.Lookup("resolve-node-breadth-limit"))
util.MustBindEnv("resolveNodeBreadthLimit", "OPENFGA_RESOLVE_NODE_BREADTH_LIMIT", "OPENFGA_RESOLVENODEBREADTHLIMIT")
util.MustBindPFlag("listObjectsDeadline", flags.Lookup("listObjects-deadline"))
util.MustBindEnv("listObjectsDeadline", "OPENFGA_LIST_OBJECTS_DEADLINE", "OPENFGA_LISTOBJECTSDEADLINE")
util.MustBindPFlag("listObjectsMaxResults", flags.Lookup("listObjects-max-results"))
util.MustBindEnv("listObjectsMaxResults", "OPENFGA_LIST_OBJECTS_MAX_RESULTS", "OPENFGA_LISTOBJECTSMAXRESULTS")
util.MustBindPFlag("listUsersDeadline", flags.Lookup("listUsers-deadline"))
util.MustBindEnv("listUsersDeadline", "OPENFGA_LIST_USERS_DEADLINE", "OPENFGA_LISTUSERSDEADLINE")
util.MustBindPFlag("listUsersMaxResults", flags.Lookup("listUsers-max-results"))
util.MustBindEnv("listUsersMaxResults", "OPENFGA_LIST_USERS_MAX_RESULTS", "OPENFGA_LISTUSERSMAXRESULTS")
util.MustBindPFlag("checkCache.limit", flags.Lookup("check-cache-limit"))
util.MustBindEnv("checkCache.limit", "OPENFGA_CHECK_CACHE_LIMIT")
// The below configuration is deprecated in favour of OPENFGA_CHECK_CACHE_LIMIT
util.MustBindPFlag("cache.limit", flags.Lookup("check-query-cache-limit"))
util.MustBindEnv("cache.limit", "OPENFGA_CHECK_QUERY_CACHE_LIMIT")
util.MustBindPFlag("cacheController.enabled", flags.Lookup("cache-controller-enabled"))
util.MustBindEnv("cacheController.enabled", "OPENFGA_CACHE_CONTROLLER_ENABLED")
util.MustBindPFlag("cacheController.ttl", flags.Lookup("cache-controller-ttl"))
util.MustBindEnv("cacheController.ttl", "OPENFGA_CACHE_CONTROLLER_TTL")
util.MustBindPFlag("checkIteratorCache.enabled", flags.Lookup("check-iterator-cache-enabled"))
util.MustBindEnv("checkIteratorCache.enabled", "OPENFGA_CHECK_ITERATOR_CACHE_ENABLED")
util.MustBindPFlag("checkIteratorCache.maxResults", flags.Lookup("check-iterator-cache-max-results"))
util.MustBindEnv("checkIteratorCache.maxResults", "OPENFGA_CHECK_ITERATOR_CACHE_MAX_RESULTS")
util.MustBindPFlag("checkIteratorCache.ttl", flags.Lookup("check-iterator-cache-ttl"))
util.MustBindEnv("checkIteratorCache.ttl", "OPENFGA_CHECK_ITERATOR_CACHE_TTL")
util.MustBindPFlag("checkQueryCache.enabled", flags.Lookup("check-query-cache-enabled"))
util.MustBindEnv("checkQueryCache.enabled", "OPENFGA_CHECK_QUERY_CACHE_ENABLED")
util.MustBindPFlag("checkQueryCache.ttl", flags.Lookup("check-query-cache-ttl"))
util.MustBindEnv("checkQueryCache.ttl", "OPENFGA_CHECK_QUERY_CACHE_TTL")
util.MustBindPFlag("listObjectsIteratorCache.enabled", flags.Lookup("list-objects-iterator-cache-enabled"))
util.MustBindEnv("listObjectsIteratorCache.enabled", "OPENFGA_LIST_OBJECTS_ITERATOR_CACHE_ENABLED")
util.MustBindPFlag("listObjectsIteratorCache.maxResults", flags.Lookup("list-objects-iterator-cache-max-results"))
util.MustBindEnv("listObjectsIteratorCache.maxResults", "OPENFGA_LIST_OBJECTS_ITERATOR_CACHE_MAX_RESULTS")
util.MustBindPFlag("listObjectsIteratorCache.ttl", flags.Lookup("list-objects-iterator-cache-ttl"))
util.MustBindEnv("listObjectsIteratorCache.ttl", "OPENFGA_LIST_OBJECTS_ITERATOR_CACHE_TTL")
util.MustBindPFlag("sharedIterator.enabled", flags.Lookup("shared-iterator-enabled"))
util.MustBindEnv("sharedIterator.enabled", "OPENFGA_SHARED_ITERATOR_ENABLED")
util.MustBindPFlag("sharedIterator.limit", flags.Lookup("shared-iterator-limit"))
util.MustBindEnv("sharedIterator.limit", "OPENFGA_SHARED_ITERATOR_LIMIT")
util.MustBindPFlag("requestDurationDatastoreQueryCountBuckets", flags.Lookup("request-duration-datastore-query-count-buckets"))
util.MustBindEnv("requestDurationDatastoreQueryCountBuckets", "OPENFGA_REQUEST_DURATION_DATASTORE_QUERY_COUNT_BUCKETS")
util.MustBindPFlag("requestDurationDispatchCountBuckets", flags.Lookup("request-duration-dispatch-count-buckets"))
util.MustBindEnv("requestDurationDispatchCountBuckets", "OPENFGA_REQUEST_DURATION_DISPATCH_COUNT_BUCKETS")
util.MustBindPFlag("contextPropagationToDatastore", flags.Lookup("context-propagation-to-datastore"))
util.MustBindEnv("contextPropagationToDatastore", "OPENFGA_CONTEXT_PROPAGATION_TO_DATASTORE")
util.MustBindPFlag("checkDispatchThrottling.enabled", flags.Lookup("check-dispatch-throttling-enabled"))
util.MustBindEnv("checkDispatchThrottling.enabled", "OPENFGA_CHECK_DISPATCH_THROTTLING_ENABLED")
util.MustBindPFlag("checkDispatchThrottling.frequency", flags.Lookup("check-dispatch-throttling-frequency"))
util.MustBindEnv("checkDispatchThrottling.frequency", "OPENFGA_CHECK_DISPATCH_THROTTLING_FREQUENCY")
util.MustBindPFlag("checkDispatchThrottling.threshold", flags.Lookup("check-dispatch-throttling-threshold"))
util.MustBindEnv("checkDispatchThrottling.threshold", "OPENFGA_CHECK_DISPATCH_THROTTLING_THRESHOLD")
util.MustBindPFlag("checkDispatchThrottling.maxThreshold", flags.Lookup("check-dispatch-throttling-max-threshold"))
util.MustBindEnv("checkDispatchThrottling.maxThreshold", "OPENFGA_CHECK_DISPATCH_THROTTLING_MAX_THRESHOLD")
util.MustBindPFlag("listObjectsDispatchThrottling.enabled", flags.Lookup("listObjects-dispatch-throttling-enabled"))
util.MustBindEnv("listObjectsDispatchThrottling.enabled", "OPENFGA_LIST_OBJECTS_DISPATCH_THROTTLING_ENABLED")
util.MustBindPFlag("listObjectsDispatchThrottling.frequency", flags.Lookup("listObjects-dispatch-throttling-frequency"))
util.MustBindEnv("listObjectsDispatchThrottling.frequency", "OPENFGA_LIST_OBJECTS_DISPATCH_THROTTLING_FREQUENCY")
util.MustBindPFlag("listObjectsDispatchThrottling.threshold", flags.Lookup("listObjects-dispatch-throttling-threshold"))
util.MustBindEnv("listObjectsDispatchThrottling.threshold", "OPENFGA_LIST_OBJECTS_DISPATCH_THROTTLING_THRESHOLD")
util.MustBindPFlag("listObjectsDispatchThrottling.maxThreshold", flags.Lookup("listObjects-dispatch-throttling-max-threshold"))
util.MustBindEnv("listObjectsDispatchThrottling.maxThreshold", "OPENFGA_LIST_OBJECTS_DISPATCH_THROTTLING_MAX_THRESHOLD")
util.MustBindPFlag("listUsersDispatchThrottling.enabled", flags.Lookup("listUsers-dispatch-throttling-enabled"))
util.MustBindEnv("listUsersDispatchThrottling.enabled", "OPENFGA_LIST_USERS_DISPATCH_THROTTLING_ENABLED")
util.MustBindPFlag("listUsersDispatchThrottling.frequency", flags.Lookup("listUsers-dispatch-throttling-frequency"))
util.MustBindEnv("listUsersDispatchThrottling.frequency", "OPENFGA_LIST_USERS_DISPATCH_THROTTLING_FREQUENCY")
util.MustBindPFlag("listUsersDispatchThrottling.threshold", flags.Lookup("listUsers-dispatch-throttling-threshold"))
util.MustBindEnv("listUsersDispatchThrottling.threshold", "OPENFGA_LIST_USERS_DISPATCH_THROTTLING_THRESHOLD")
util.MustBindPFlag("listUsersDispatchThrottling.maxThreshold", flags.Lookup("listUsers-dispatch-throttling-max-threshold"))
util.MustBindEnv("listUsersDispatchThrottling.maxThreshold", "OPENFGA_LIST_USERS_DISPATCH_THROTTLING_MAX_THRESHOLD")
util.MustBindPFlag("checkDatastoreThrottle.threshold", flags.Lookup("check-datastore-throttle-threshold"))
util.MustBindEnv("checkDatastoreThrottle.threshold", "OPENFGA_CHECK_DATASTORE_THROTTLE_THRESHOLD")
util.MustBindPFlag("checkDatastoreThrottle.duration", flags.Lookup("check-datastore-throttle-duration"))
util.MustBindEnv("checkDatastoreThrottle.duration", "OPENFGA_CHECK_DATASTORE_THROTTLE_DURATION")
util.MustBindPFlag("listObjectsDatastoreThrottle.threshold", flags.Lookup("listObjects-datastore-throttle-threshold"))
util.MustBindEnv("listObjectsDatastoreThrottle.threshold", "OPENFGA_LIST_OBJECTS_DATASTORE_THROTTLE_THRESHOLD")
util.MustBindPFlag("listObjectsDatastoreThrottle.duration", flags.Lookup("listObjects-datastore-throttle-duration"))
util.MustBindEnv("listObjectsDatastoreThrottle.duration", "OPENFGA_LIST_OBJECTS_DATASTORE_THROTTLE_DURATION")
util.MustBindPFlag("listUsersDatastoreThrottle.threshold", flags.Lookup("listUsers-datastore-throttle-threshold"))
util.MustBindEnv("listUsersDatastoreThrottle.threshold", "OPENFGA_LIST_USERS_DATASTORE_THROTTLE_THRESHOLD")
util.MustBindPFlag("listUsersDatastoreThrottle.duration", flags.Lookup("listUsers-datastore-throttle-duration"))
util.MustBindEnv("listUsersDatastoreThrottle.duration", "OPENFGA_LIST_USERS_DATASTORE_THROTTLE_DURATION")
util.MustBindPFlag("requestTimeout", flags.Lookup("request-timeout"))
util.MustBindEnv("requestTimeout", "OPENFGA_REQUEST_TIMEOUT")
// these are irrelevant unless the check-experimental flag is enabled at the current time
util.MustBindPFlag("planner.evictionThreshold", flags.Lookup("planner-eviction-threshold"))
util.MustBindEnv("planner.evictionThreshold", "OPENFGA_PLANNER_EVICTION_THRESHOLD")
util.MustBindPFlag("planner.cleanupInterval", flags.Lookup("planner-cleanup-interval"))
util.MustBindEnv("planner.cleanupInterval", "OPENFGA_PLANNER_CLEANUP_INTERVAL")
}
}
// Package run contains the command to run an OpenFGA server.
package run
import (
"context"
"crypto/tls"
"errors"
"fmt"
"html/template"
"net"
"net/http"
"net/http/pprof"
"os"
"os/signal"
goruntime "runtime"
"strconv"
"strings"
"syscall"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/go-logr/logr"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/recovery"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
grpc_prometheus "github.com/jon-whit/go-grpc-prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/cors"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel"
semconv "go.opentelemetry.io/otel/semconv/v1.12.0"
"go.opentelemetry.io/otel/trace/noop"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
healthv1pb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/reflection"
"google.golang.org/grpc/status"
"sigs.k8s.io/controller-runtime/pkg/certwatcher"
"sigs.k8s.io/controller-runtime/pkg/log"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/assets"
"github.com/openfga/openfga/internal/authn"
"github.com/openfga/openfga/internal/authn/oidc"
"github.com/openfga/openfga/internal/authn/presharedkey"
"github.com/openfga/openfga/internal/build"
authnmw "github.com/openfga/openfga/internal/middleware/authn"
"github.com/openfga/openfga/internal/planner"
"github.com/openfga/openfga/pkg/encoder"
"github.com/openfga/openfga/pkg/gateway"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/middleware"
httpmiddleware "github.com/openfga/openfga/pkg/middleware/http"
"github.com/openfga/openfga/pkg/middleware/logging"
"github.com/openfga/openfga/pkg/middleware/recovery"
"github.com/openfga/openfga/pkg/middleware/requestid"
"github.com/openfga/openfga/pkg/middleware/storeid"
"github.com/openfga/openfga/pkg/middleware/validator"
"github.com/openfga/openfga/pkg/server"
serverconfig "github.com/openfga/openfga/pkg/server/config"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/server/health"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/memory"
"github.com/openfga/openfga/pkg/storage/mysql"
"github.com/openfga/openfga/pkg/storage/postgres"
"github.com/openfga/openfga/pkg/storage/sqlcommon"
"github.com/openfga/openfga/pkg/storage/sqlite"
"github.com/openfga/openfga/pkg/telemetry"
)
const (
datastoreEngineFlag = "datastore-engine"
datastoreURIFlag = "datastore-uri"
)
func NewRunCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "run",
Short: "Run the OpenFGA server",
Long: "Run the OpenFGA server.",
Run: run,
Args: cobra.NoArgs,
}
defaultConfig := serverconfig.DefaultConfig()
flags := cmd.Flags()
flags.StringSlice("experimentals", defaultConfig.Experimentals, "a list of experimental features to enable. Allowed values: `enable-consistency-params`, `enable-check-optimizations`, `enable-list-objects-optimizations`, `enable-access-control`")
flags.Bool("access-control-enabled", defaultConfig.AccessControl.Enabled, "enable/disable the access control feature")
flags.String("access-control-store-id", defaultConfig.AccessControl.StoreID, "the store ID of the OpenFGA store that will be used to access the access control store")
flags.String("access-control-model-id", defaultConfig.AccessControl.ModelID, "the model ID of the OpenFGA store that will be used to access the access control store")
cmd.MarkFlagsRequiredTogether("access-control-enabled", "access-control-store-id", "access-control-model-id")
flags.String("grpc-addr", defaultConfig.GRPC.Addr, "the host:port address to serve the grpc server on")
flags.Bool("grpc-tls-enabled", defaultConfig.GRPC.TLS.Enabled, "enable/disable transport layer security (TLS)")
flags.String("grpc-tls-cert", defaultConfig.GRPC.TLS.CertPath, "the (absolute) file path of the certificate to use for the TLS connection")
flags.String("grpc-tls-key", defaultConfig.GRPC.TLS.KeyPath, "the (absolute) file path of the TLS key that should be used for the TLS connection")
cmd.MarkFlagsRequiredTogether("grpc-tls-enabled", "grpc-tls-cert", "grpc-tls-key")
flags.Bool("http-enabled", defaultConfig.HTTP.Enabled, "enable/disable the OpenFGA HTTP server")
flags.String("http-addr", defaultConfig.HTTP.Addr, "the host:port address to serve the HTTP server on")
flags.Bool("http-tls-enabled", defaultConfig.HTTP.TLS.Enabled, "enable/disable transport layer security (TLS)")
flags.String("http-tls-cert", defaultConfig.HTTP.TLS.CertPath, "the (absolute) file path of the certificate to use for the TLS connection")
flags.String("http-tls-key", defaultConfig.HTTP.TLS.KeyPath, "the (absolute) file path of the TLS key that should be used for the TLS connection")
cmd.MarkFlagsRequiredTogether("http-tls-enabled", "http-tls-cert", "http-tls-key")
flags.Duration("http-upstream-timeout", defaultConfig.HTTP.UpstreamTimeout, "the timeout duration for proxying HTTP requests upstream to the grpc endpoint")
flags.StringSlice("http-cors-allowed-origins", defaultConfig.HTTP.CORSAllowedOrigins, "specifies the CORS allowed origins")
flags.StringSlice("http-cors-allowed-headers", defaultConfig.HTTP.CORSAllowedHeaders, "specifies the CORS allowed headers")
flags.String("authn-method", defaultConfig.Authn.Method, "the authentication method to use")
flags.StringSlice("authn-preshared-keys", defaultConfig.Authn.Keys, "one or more preshared keys to use for authentication")
flags.String("authn-oidc-audience", defaultConfig.Authn.Audience, "the OIDC audience of the tokens being signed by the authorization server")
flags.String("authn-oidc-issuer", defaultConfig.Authn.Issuer, "the OIDC issuer (authorization server) signing the tokens, and where the keys will be fetched from")
flags.StringSlice("authn-oidc-issuer-aliases", defaultConfig.Authn.IssuerAliases, "the OIDC issuer DNS aliases that will be accepted as valid when verifying the `iss` field of the JWTs.")
flags.StringSlice("authn-oidc-subjects", defaultConfig.Authn.Subjects, "the OIDC subject names that will be accepted as valid when verifying the `sub` field of the JWTs. If empty, every `sub` will be allowed")
flags.StringSlice("authn-oidc-client-id-claims", defaultConfig.Authn.ClientIDClaims, "the ClientID claims that will be used to parse the clientID - configure in order of priority (first is highest). Defaults to [`azp`, `client_id`]")
flags.String("datastore-engine", defaultConfig.Datastore.Engine, "the datastore engine that will be used for persistence")
flags.String("datastore-uri", defaultConfig.Datastore.URI, "the connection uri to use to connect to the datastore (for any engine other than 'memory')")
flags.String("datastore-secondary-uri", defaultConfig.Datastore.SecondaryURI, "the connection uri to use to connect to the secondary datastore (for postgres only)")
flags.String("datastore-username", "", "the connection username to use to connect to the datastore (overwrites any username provided in the connection uri)")
flags.String("datastore-password", "", "the connection password to use to connect to the datastore (overwrites any password provided in the connection uri)")
flags.String("datastore-secondary-username", "", "the connection username to use to connect to the secondary datastore (overwrites any username provided in the connection uri)")
flags.String("datastore-secondary-password", "", "the connection password to use to connect to the secondary datastore (overwrites any password provided in the connection uri)")
flags.Int("datastore-max-cache-size", defaultConfig.Datastore.MaxCacheSize, "the maximum number of authorization models that will be cached in memory")
flags.Int("datastore-min-open-conns", defaultConfig.Datastore.MinOpenConns, "the minimum number of open connections to the datastore")
flags.Int("datastore-max-open-conns", defaultConfig.Datastore.MaxOpenConns, "the maximum number of open connections to the datastore")
flags.Int("datastore-min-idle-conns", defaultConfig.Datastore.MinIdleConns, "the minimum number of connections to the datastore in the idle connection pool")
flags.Int("datastore-max-idle-conns", defaultConfig.Datastore.MaxIdleConns, "the maximum number of connections to the datastore in the idle connection pool")
flags.Duration("datastore-conn-max-idle-time", defaultConfig.Datastore.ConnMaxIdleTime, "the maximum amount of time a connection to the datastore may be idle")
flags.Duration("datastore-conn-max-lifetime", defaultConfig.Datastore.ConnMaxLifetime, "the maximum amount of time a connection to the datastore may be reused")
flags.Bool("datastore-metrics-enabled", defaultConfig.Datastore.Metrics.Enabled, "enable/disable sql metrics")
flags.Bool("playground-enabled", defaultConfig.Playground.Enabled, "enable/disable the OpenFGA Playground")
flags.Int("playground-port", defaultConfig.Playground.Port, "the port to serve the local OpenFGA Playground on")
flags.Bool("profiler-enabled", defaultConfig.Profiler.Enabled, "enable/disable pprof profiling")
flags.String("profiler-addr", defaultConfig.Profiler.Addr, "the host:port address to serve the pprof profiler server on")
flags.String("log-format", defaultConfig.Log.Format, "the log format to output logs in")
flags.String("log-level", defaultConfig.Log.Level, "the log level to use")
flags.String("log-timestamp-format", defaultConfig.Log.TimestampFormat, "the timestamp format to use for log messages")
flags.Bool("trace-enabled", defaultConfig.Trace.Enabled, "enable tracing")
flags.String("trace-otlp-endpoint", defaultConfig.Trace.OTLP.Endpoint, "the endpoint of the trace collector")
flags.Bool("trace-otlp-tls-enabled", defaultConfig.Trace.OTLP.TLS.Enabled, "use TLS connection for trace collector")
flags.Float64("trace-sample-ratio", defaultConfig.Trace.SampleRatio, "the fraction of traces to sample. 1 means all, 0 means none.")
flags.String("trace-service-name", defaultConfig.Trace.ServiceName, "the service name included in sampled traces.")
flags.Bool("metrics-enabled", defaultConfig.Metrics.Enabled, "enable/disable prometheus metrics on the '/metrics' endpoint")
flags.String("metrics-addr", defaultConfig.Metrics.Addr, "the host:port address to serve the prometheus metrics server on")
flags.Bool("metrics-enable-rpc-histograms", defaultConfig.Metrics.EnableRPCHistograms, "enables prometheus histogram metrics for RPC latency distributions")
flags.Uint32("max-concurrent-checks-per-batch-check", defaultConfig.MaxConcurrentChecksPerBatchCheck, "the maximum number of checks that can be processed concurrently in a batch check request")
flags.Uint32("max-checks-per-batch-check", defaultConfig.MaxChecksPerBatchCheck, "the maximum number of tuples allowed in a BatchCheck request")
flags.Int("max-tuples-per-write", defaultConfig.MaxTuplesPerWrite, "the maximum allowed number of tuples per Write transaction")
flags.Int("max-types-per-authorization-model", defaultConfig.MaxTypesPerAuthorizationModel, "the maximum allowed number of type definitions per authorization model")
flags.Int("max-authorization-model-size-in-bytes", defaultConfig.MaxAuthorizationModelSizeInBytes, "the maximum size in bytes allowed for persisting an Authorization Model.")
flags.Uint32("max-concurrent-reads-for-list-users", defaultConfig.MaxConcurrentReadsForListUsers, "the maximum allowed number of concurrent datastore reads in a single ListUsers query. A high number will consume more connections from the datastore pool and will attempt to prioritize performance for the request at the expense of other queries performance.")
flags.Uint32("max-concurrent-reads-for-list-objects", defaultConfig.MaxConcurrentReadsForListObjects, "the maximum allowed number of concurrent datastore reads in a single ListObjects or StreamedListObjects query. A high number will consume more connections from the datastore pool and will attempt to prioritize performance for the request at the expense of other queries performance.")
flags.Uint32("max-concurrent-reads-for-check", defaultConfig.MaxConcurrentReadsForCheck, "the maximum allowed number of concurrent datastore reads in a single Check query. A high number will consume more connections from the datastore pool and will attempt to prioritize performance for the request at the expense of other queries performance.")
flags.Uint64("max-condition-evaluation-cost", defaultConfig.MaxConditionEvaluationCost, "the maximum cost for CEL condition evaluation before a request returns an error")
flags.Int("changelog-horizon-offset", defaultConfig.ChangelogHorizonOffset, "the offset (in minutes) from the current time. Changes that occur after this offset will not be included in the response of ReadChanges")
flags.Uint32("resolve-node-limit", defaultConfig.ResolveNodeLimit, "maximum resolution depth to attempt before throwing an error (defines how deeply nested an authorization model can be before a query errors out).")
flags.Uint32("resolve-node-breadth-limit", defaultConfig.ResolveNodeBreadthLimit, "defines how many nodes on a given level can be evaluated concurrently in a Check resolution tree")
flags.Duration("listObjects-deadline", defaultConfig.ListObjectsDeadline, "the timeout deadline for serving ListObjects and StreamedListObjects requests")
flags.Uint32("listObjects-max-results", defaultConfig.ListObjectsMaxResults, "the maximum results to return in non-streaming ListObjects API responses. If 0, all results can be returned")
flags.Duration("listUsers-deadline", defaultConfig.ListUsersDeadline, "the timeout deadline for serving ListUsers requests. If 0, there is no deadline")
flags.Uint32("listUsers-max-results", defaultConfig.ListUsersMaxResults, "the maximum results to return in ListUsers API responses. If 0, all results can be returned")
flags.Uint32("check-cache-limit", defaultConfig.CheckCache.Limit, "if check-query-cache-enabled or check-iterator-cache-enabled, this is the size limit of the cache")
flags.Bool("shared-iterator-enabled", defaultConfig.SharedIterator.Enabled, "enabling sharing of datastore iterators with different consumers. Each iterator is the result of a database query, for example usersets related to a specific object, or objects related to a specific user, up to a certain number of tuples per iterator.")
flags.Uint32("shared-iterator-limit", defaultConfig.SharedIterator.Limit, "if shared-iterator-enabled is enabled, this is the limit of the number of iterators that can be shared.")
flags.Bool("check-iterator-cache-enabled", defaultConfig.CheckIteratorCache.Enabled, "enable caching of datastore iterators. The key is a string representing a database query, and the value is a list of tuples. Each iterator is the result of a database query, for example usersets related to a specific object, or objects related to a specific user, up to a certain number of tuples per iterator. If the request's consistency is HIGHER_CONSISTENCY, this cache is not used.")
flags.Uint32("check-iterator-cache-max-results", defaultConfig.CheckIteratorCache.MaxResults, "if caching of datastore iterators of Check requests is enabled, this is the limit of tuples to cache per key.")
flags.Duration("check-iterator-cache-ttl", defaultConfig.CheckIteratorCache.TTL, "if caching of datastore iterators of Check requests is enabled, this is the TTL of each value")
flags.Bool("list-objects-iterator-cache-enabled", defaultConfig.ListObjectsIteratorCache.Enabled, "enable caching of datastore iterators for ListObjects. The key is a string representing a database query, and the value is a list of tuples. Each iterator is the result of a database query, for example usersets related to a specific object, or objects related to a specific user, up to a certain number of tuples per iterator. If the request's consistency is HIGHER_CONSISTENCY, this cache is not used.")
flags.Uint32("list-objects-iterator-cache-max-results", defaultConfig.ListObjectsIteratorCache.MaxResults, "if caching of datastore iterators of ListObjects requests is enabled, this is the limit of tuples to cache per key.")
flags.Duration("list-objects-iterator-cache-ttl", defaultConfig.ListObjectsIteratorCache.TTL, "if caching of datastore iterators of ListObjects requests is enabled, this is the TTL of each value")
flags.Bool("check-query-cache-enabled", defaultConfig.CheckQueryCache.Enabled, "enable caching of Check requests. For example, if you have a relation define viewer: owner or editor, and the query is Check(user:anne, viewer, doc:1), we'll evaluate the owner relation and the editor relation and cache both results: (user:anne, viewer, doc:1) -> allowed=true and (user:anne, owner, doc:1) -> allowed=true. The cache is stored in-memory; the cached values are overwritten on every change in the result, and cleared after the configured TTL. This flag improves latency, but turns Check and ListObjects into eventually consistent APIs. If the request's consistency is HIGHER_CONSISTENCY, this cache is not used.")
flags.Uint32("check-query-cache-limit", defaultConfig.CheckCache.Limit, "DEPRECATED: Use check-cache-limit instead. If caching of Check and ListObjects calls is enabled, this is the size limit of the cache")
flags.Duration("check-query-cache-ttl", defaultConfig.CheckQueryCache.TTL, "if check-query-cache-enabled, this is the TTL of each value")
flags.Bool("cache-controller-enabled", defaultConfig.CacheController.Enabled, "enabling dynamic invalidation of check query cache and check iterator cache based on whether there are recent tuple writes. If enabled, cache will be invalidated when either 1) there are tuples written to the store OR 2) the check query cache or check iterator cache TTL has expired.")
flags.Duration("cache-controller-ttl", defaultConfig.CacheController.TTL, "if cache controller is enabled, control how frequent read changes are invoked internally to query for recent tuple writes to the store.")
// Unfortunately UintSlice/IntSlice does not work well when used as environment variable, we need to stick with string slice and convert back to integer
flags.StringSlice("request-duration-datastore-query-count-buckets", defaultConfig.RequestDurationDatastoreQueryCountBuckets, "datastore query count buckets used in labelling request_duration_ms.")
flags.StringSlice("request-duration-dispatch-count-buckets", defaultConfig.RequestDurationDispatchCountBuckets, "dispatch count (i.e number of concurrent traversals to resolve a query) buckets used in labelling request_duration_ms.")
flags.Bool("context-propagation-to-datastore", defaultConfig.ContextPropagationToDatastore, "enable propagation of a request's context to the datastore")
flags.Bool("check-dispatch-throttling-enabled", defaultConfig.CheckDispatchThrottling.Enabled, "enable throttling for Check requests when the request's number of dispatches is high. Enabling this feature will prioritize dispatched requests requiring less than the configured dispatch threshold over requests whose dispatch count exceeds the configured threshold.")
flags.Duration("check-dispatch-throttling-frequency", defaultConfig.CheckDispatchThrottling.Frequency, "defines how frequent Check dispatch throttling will be evaluated. This controls how frequently throttled dispatch Check requests are dispatched.")
flags.Uint32("check-dispatch-throttling-threshold", defaultConfig.CheckDispatchThrottling.Threshold, "define the number of dispatches above which Check requests will be throttled.")
flags.Uint32("check-dispatch-throttling-max-threshold", defaultConfig.CheckDispatchThrottling.MaxThreshold, "define the maximum dispatch threshold beyond which a Check requests will be throttled. 0 will use the 'check-dispatch-throttling-threshold' value as maximum")
flags.Bool("listObjects-dispatch-throttling-enabled", defaultConfig.ListObjectsDispatchThrottling.Enabled, "enable throttling when a ListObjects request's number of dispatches is high. Enabling this feature will prioritize dispatched requests requiring less than the configured dispatch threshold over requests whose dispatch count exceeds the configured threshold.")
flags.Duration("listObjects-dispatch-throttling-frequency", defaultConfig.ListObjectsDispatchThrottling.Frequency, "defines how frequent ListObjects dispatch throttling will be evaluated. Frequency controls how frequently throttled dispatch ListObjects requests are dispatched.")
flags.Uint32("listObjects-dispatch-throttling-threshold", defaultConfig.ListObjectsDispatchThrottling.Threshold, "defines the number of dispatches above which ListObjects requests will be throttled.")
flags.Uint32("listObjects-dispatch-throttling-max-threshold", defaultConfig.ListObjectsDispatchThrottling.MaxThreshold, "define the maximum dispatch threshold beyond which a ListObjects requests will be throttled. 0 will use the 'listObjects-dispatch-throttling-threshold' value as maximum")
flags.Bool("listUsers-dispatch-throttling-enabled", defaultConfig.ListUsersDispatchThrottling.Enabled, "enable throttling when a ListUsers request's number of dispatches is high. Enabling this feature will prioritize dispatched requests requiring less than the configured dispatch threshold over requests whose dispatch count exceeds the configured threshold.")
flags.Duration("listUsers-dispatch-throttling-frequency", defaultConfig.ListUsersDispatchThrottling.Frequency, "defines how frequent ListUsers dispatch throttling will be evaluated. Frequency controls how frequently throttled dispatch ListUsers requests are dispatched.")
flags.Uint32("listUsers-dispatch-throttling-threshold", defaultConfig.ListUsersDispatchThrottling.Threshold, "defines the number of dispatches above which ListUsers requests will be throttled.")
flags.Uint32("listUsers-dispatch-throttling-max-threshold", defaultConfig.ListUsersDispatchThrottling.MaxThreshold, "define the maximum dispatch threshold beyond which a list users requests will be throttled. 0 will use the 'listUsers-dispatch-throttling-threshold' value as maximum")
flags.Int("check-datastore-throttle-threshold", defaultConfig.CheckDatastoreThrottle.Threshold, "define the number of datastore requests allowed before being throttled.")
flags.Duration("check-datastore-throttle-duration", defaultConfig.CheckDatastoreThrottle.Duration, "defines the time for which the datastore request will be suspended for being throttled.")
flags.Int("listObjects-datastore-throttle-threshold", defaultConfig.ListObjectsDatastoreThrottle.Threshold, "define the number of datastore requests allowed before being throttled.")
flags.Duration("listObjects-datastore-throttle-duration", defaultConfig.ListObjectsDatastoreThrottle.Duration, "defines the time for which the datastore request will be suspended for being throttled.")
flags.Int("listUsers-datastore-throttle-threshold", defaultConfig.ListUsersDatastoreThrottle.Threshold, "define the number of datastore requests allowed before being throttled.")
flags.Duration("listUsers-datastore-throttle-duration", defaultConfig.ListUsersDatastoreThrottle.Duration, "defines the time for which the datastore request will be suspended for being throttled.")
flags.Duration("request-timeout", defaultConfig.RequestTimeout, "configures request timeout. If both HTTP upstream timeout and request timeout are specified, request timeout will be used.")
flags.Duration("planner-eviction-threshold", defaultConfig.Planner.EvictionThreshold, "how long a planner key can be unused before being evicted")
flags.Duration("planner-cleanup-interval", defaultConfig.Planner.CleanupInterval, "how often the planner checks for stale keys")
// NOTE: if you add a new flag here, update the function below, too
cmd.PreRun = bindRunFlagsFunc(flags)
return cmd
}
// ReadConfig returns the OpenFGA server configuration based on the values provided in the server's 'config.yaml' file.
// The 'config.yaml' file is loaded from '/etc/openfga', '$HOME/.openfga', or the current working directory. If no configuration
// file is present, the default values are returned.
func ReadConfig() (*serverconfig.Config, error) {
config := serverconfig.DefaultConfig()
viper.SetTypeByDefaultValue(true)
err := viper.ReadInConfig()
if err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("failed to load server config: %w", err)
}
}
if err := viper.Unmarshal(config); err != nil {
return nil, fmt.Errorf("failed to unmarshal server config: %w", err)
}
return config, nil
}
func run(_ *cobra.Command, _ []string) {
config, err := ReadConfig()
if err != nil {
panic(err)
}
if err := config.Verify(); err != nil {
panic(err)
}
logger := logger.MustNewLogger(config.Log.Format, config.Log.Level, config.Log.TimestampFormat)
serverCtx := &ServerContext{Logger: logger}
if err := serverCtx.Run(context.Background(), config); err != nil {
panic(err)
}
}
type ServerContext struct {
Logger logger.Logger
}
func convertStringArrayToUintArray(stringArray []string) []uint {
uintArray := []uint{}
for _, val := range stringArray {
// note that we have already validated whether the array item is non-negative integer
valInt, err := strconv.Atoi(val)
if err == nil {
uintArray = append(uintArray, uint(valInt))
}
}
return uintArray
}
// telemetryConfig returns the function that must be called to shut down tracing.
// The context provided to this function should be error-free, or shut down will be incomplete.
func (s *ServerContext) telemetryConfig(config *serverconfig.Config) func() error {
if config.Trace.Enabled {
s.Logger.Info(fmt.Sprintf("🕵 tracing enabled: sampling ratio is %v and sending traces to '%s', tls: %t", config.Trace.SampleRatio, config.Trace.OTLP.Endpoint, config.Trace.OTLP.TLS.Enabled))
options := []telemetry.TracerOption{
telemetry.WithOTLPEndpoint(
config.Trace.OTLP.Endpoint,
),
telemetry.WithAttributes(
semconv.ServiceNameKey.String(config.Trace.ServiceName),
semconv.ServiceVersionKey.String(build.Version),
),
telemetry.WithSamplingRatio(config.Trace.SampleRatio),
}
if !config.Trace.OTLP.TLS.Enabled {
options = append(options, telemetry.WithOTLPInsecure())
}
tp := telemetry.MustNewTracerProvider(options...)
return func() error {
// can take up to 5 seconds to complete (https://github.com/open-telemetry/opentelemetry-go/blob/aebcbfcbc2962957a578e9cb3e25dc834125e318/sdk/trace/batch_span_processor.go#L97)
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Second)
defer cancel()
return errors.Join(tp.ForceFlush(ctx), tp.Shutdown(ctx))
}
}
otel.SetTracerProvider(noop.NewTracerProvider())
return func() error {
return nil
}
}
func (s *ServerContext) datastoreConfig(config *serverconfig.Config) (storage.OpenFGADatastore, encoder.ContinuationTokenSerializer, error) {
// SQL Token Serializer by default
tokenSerializer := sqlcommon.NewSQLContinuationTokenSerializer()
datastoreOptions := []sqlcommon.DatastoreOption{
sqlcommon.WithSecondaryURI(config.Datastore.SecondaryURI),
sqlcommon.WithUsername(config.Datastore.Username),
sqlcommon.WithPassword(config.Datastore.Password),
sqlcommon.WithSecondaryUsername(config.Datastore.SecondaryUsername),
sqlcommon.WithSecondaryPassword(config.Datastore.SecondaryPassword),
sqlcommon.WithLogger(s.Logger),
sqlcommon.WithMaxTuplesPerWrite(config.MaxTuplesPerWrite),
sqlcommon.WithMaxTypesPerAuthorizationModel(config.MaxTypesPerAuthorizationModel),
sqlcommon.WithMaxOpenConns(config.Datastore.MaxOpenConns),
sqlcommon.WithMinOpenConns(config.Datastore.MinOpenConns),
sqlcommon.WithMaxIdleConns(config.Datastore.MaxIdleConns),
sqlcommon.WithMinIdleConns(config.Datastore.MinIdleConns),
sqlcommon.WithConnMaxIdleTime(config.Datastore.ConnMaxIdleTime),
sqlcommon.WithConnMaxLifetime(config.Datastore.ConnMaxLifetime),
}
if config.Datastore.Metrics.Enabled {
datastoreOptions = append(datastoreOptions, sqlcommon.WithMetrics())
}
dsCfg := sqlcommon.NewConfig(datastoreOptions...)
var datastore storage.OpenFGADatastore
var err error
switch config.Datastore.Engine {
case "memory":
// override for "memory" datastore
tokenSerializer = encoder.NewStringContinuationTokenSerializer()
opts := []memory.StorageOption{
memory.WithMaxTypesPerAuthorizationModel(config.MaxTypesPerAuthorizationModel),
memory.WithMaxTuplesPerWrite(config.MaxTuplesPerWrite),
}
datastore = memory.New(opts...)
case "mysql":
datastore, err = mysql.New(config.Datastore.URI, dsCfg)
if err != nil {
return nil, nil, fmt.Errorf("initialize mysql datastore: %w", err)
}
case "postgres":
datastore, err = postgres.New(config.Datastore.URI, dsCfg)
if err != nil {
return nil, nil, fmt.Errorf("initialize postgres datastore: %w", err)
}
case "sqlite":
datastore, err = sqlite.New(config.Datastore.URI, dsCfg)
if err != nil {
return nil, nil, fmt.Errorf("initialize sqlite datastore: %w", err)
}
default:
return nil, nil, fmt.Errorf("storage engine '%s' is unsupported", config.Datastore.Engine)
}
s.Logger.Info(fmt.Sprintf("using '%v' storage engine", config.Datastore.Engine))
return datastore, tokenSerializer, nil
}
func (s *ServerContext) authenticatorConfig(config *serverconfig.Config) (authn.Authenticator, error) {
var authenticator authn.Authenticator
var err error
switch config.Authn.Method {
case "none":
s.Logger.Warn("authentication is disabled")
authenticator = authn.NoopAuthenticator{}
case "preshared":
s.Logger.Info("using 'preshared' authentication")
authenticator, err = presharedkey.NewPresharedKeyAuthenticator(config.Authn.Keys)
case "oidc":
s.Logger.Info("using 'oidc' authentication")
authenticator, err = oidc.NewRemoteOidcAuthenticator(config.Authn.Issuer, config.Authn.IssuerAliases, config.Authn.Audience, config.Authn.Subjects, config.Authn.ClientIDClaims)
default:
return nil, fmt.Errorf("unsupported authentication method '%v'", config.Authn.Method)
}
if err != nil {
return nil, fmt.Errorf("failed to initialize authenticator: %w", err)
}
return authenticator, nil
}
// Run returns an error if the server was unable to start successfully.
// If it started and terminated successfully, it returns a nil error.
func (s *ServerContext) Run(ctx context.Context, config *serverconfig.Config) error {
ctx, stop := signal.NotifyContext(ctx, os.Interrupt, os.Kill, syscall.SIGTERM)
defer stop()
tracerProviderCloser := s.telemetryConfig(config)
if len(config.Experimentals) > 0 {
s.Logger.Info(fmt.Sprintf("🧪 experimental features enabled: %v", config.Experimentals))
}
var experimentals []string
experimentals = append(experimentals, config.Experimentals...)
datastore, continuationTokenSerializer, err := s.datastoreConfig(config)
if err != nil {
return err
}
authenticator, err := s.authenticatorConfig(config)
if err != nil {
return err
}
serverOpts := []grpc.ServerOption{
grpc.MaxRecvMsgSize(serverconfig.DefaultMaxRPCMessageSizeInBytes),
grpc.ChainUnaryInterceptor(
[]grpc.UnaryServerInterceptor{
grpc_recovery.UnaryServerInterceptor( // panic middleware must be 1st in chain
grpc_recovery.WithRecoveryHandlerContext(
recovery.PanicRecoveryHandler(s.Logger),
),
),
grpc_ctxtags.UnaryServerInterceptor(), // needed for logging
requestid.NewUnaryInterceptor(), // add request_id to ctxtags
}...,
),
grpc.ChainStreamInterceptor(
[]grpc.StreamServerInterceptor{
grpc_recovery.StreamServerInterceptor( // panic middleware must be 1st in chain
grpc_recovery.WithRecoveryHandlerContext(
recovery.PanicRecoveryHandler(s.Logger),
),
),
grpc_ctxtags.StreamServerInterceptor(), // needed for logging
requestid.NewStreamingInterceptor(), // add request_id to ctxtags
}...,
),
}
if config.RequestTimeout > 0 {
timeoutMiddleware := middleware.NewTimeoutInterceptor(config.RequestTimeout, s.Logger)
serverOpts = append(serverOpts, grpc.ChainUnaryInterceptor(timeoutMiddleware.NewUnaryTimeoutInterceptor()))
serverOpts = append(serverOpts, grpc.ChainStreamInterceptor(timeoutMiddleware.NewStreamTimeoutInterceptor()))
}
serverOpts = append(serverOpts,
grpc.ChainUnaryInterceptor(
[]grpc.UnaryServerInterceptor{
storeid.NewUnaryInterceptor(), // if available, add store_id to ctxtags
logging.NewLoggingInterceptor(s.Logger), // needed to log invalid requests
validator.UnaryServerInterceptor(),
}...,
),
grpc.ChainStreamInterceptor(
[]grpc.StreamServerInterceptor{
validator.StreamServerInterceptor(),
}...,
),
)
if config.Metrics.Enabled {
serverOpts = append(serverOpts,
grpc.ChainUnaryInterceptor(grpc_prometheus.UnaryServerInterceptor),
grpc.ChainStreamInterceptor(grpc_prometheus.StreamServerInterceptor))
if config.Metrics.EnableRPCHistograms {
grpc_prometheus.EnableHandlingTimeHistogram()
}
}
if config.Trace.Enabled {
serverOpts = append(serverOpts, grpc.StatsHandler(otelgrpc.NewServerHandler()))
}
serverOpts = append(serverOpts, grpc.ChainUnaryInterceptor(
[]grpc.UnaryServerInterceptor{
grpcauth.UnaryServerInterceptor(authnmw.AuthFunc(authenticator)),
}...),
grpc.ChainStreamInterceptor(
[]grpc.StreamServerInterceptor{
grpcauth.StreamServerInterceptor(authnmw.AuthFunc(authenticator)),
// The following interceptors wrap the server stream with our own
// wrapper and must come last.
storeid.NewStreamingInterceptor(),
logging.NewStreamingLoggingInterceptor(s.Logger),
}...,
),
)
if config.GRPC.TLS.Enabled {
if config.GRPC.TLS.CertPath == "" || config.GRPC.TLS.KeyPath == "" {
return errors.New("'grpc.tls.cert' and 'grpc.tls.key' configs must be set")
}
grpcGetCertificate, err := watchAndLoadCertificateWithCertWatcher(ctx, config.GRPC.TLS.CertPath, config.GRPC.TLS.KeyPath, s.Logger)
if err != nil {
return err
}
creds := credentials.NewTLS(&tls.Config{
GetCertificate: grpcGetCertificate,
})
serverOpts = append(serverOpts, grpc.Creds(creds))
s.Logger.Info("gRPC TLS is enabled, serving connections using the provided certificate")
} else {
s.Logger.Warn("gRPC TLS is disabled, serving connections using insecure plaintext")
}
var profilerServer *http.Server
if config.Profiler.Enabled {
mux := http.NewServeMux()
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
profilerServer = &http.Server{Addr: config.Profiler.Addr, Handler: mux}
go func() {
s.Logger.Info(fmt.Sprintf("🔬 starting pprof profiler on '%s'", config.Profiler.Addr))
if err := profilerServer.ListenAndServe(); err != nil {
if !errors.Is(err, http.ErrServerClosed) {
s.Logger.Fatal("failed to start pprof profiler", zap.Error(err))
}
}
s.Logger.Info("profiler shut down.")
}()
}
var metricsServer *http.Server
if config.Metrics.Enabled {
mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.Handler())
metricsServer = &http.Server{Addr: config.Metrics.Addr, Handler: mux}
go func() {
s.Logger.Info(fmt.Sprintf("📈 starting prometheus metrics server on '%s'", config.Metrics.Addr))
if err := metricsServer.ListenAndServe(); err != nil {
if !errors.Is(err, http.ErrServerClosed) {
s.Logger.Fatal("failed to start prometheus metrics server", zap.Error(err))
}
}
s.Logger.Info("metrics server shut down.")
}()
}
svr := server.MustNewServerWithOpts(
server.WithDatastore(datastore),
server.WithContinuationTokenSerializer(continuationTokenSerializer),
server.WithAuthorizationModelCacheSize(config.Datastore.MaxCacheSize),
server.WithLogger(s.Logger),
server.WithTransport(gateway.NewRPCTransport(s.Logger)),
server.WithResolveNodeLimit(config.ResolveNodeLimit),
server.WithResolveNodeBreadthLimit(config.ResolveNodeBreadthLimit),
server.WithChangelogHorizonOffset(config.ChangelogHorizonOffset),
server.WithListObjectsDeadline(config.ListObjectsDeadline),
server.WithListObjectsMaxResults(config.ListObjectsMaxResults),
server.WithListUsersDeadline(config.ListUsersDeadline),
server.WithListUsersMaxResults(config.ListUsersMaxResults),
server.WithMaxConcurrentReadsForListObjects(config.MaxConcurrentReadsForListObjects),
server.WithMaxConcurrentReadsForCheck(config.MaxConcurrentReadsForCheck),
server.WithMaxConcurrentReadsForListUsers(config.MaxConcurrentReadsForListUsers),
server.WithCacheControllerEnabled(config.CacheController.Enabled),
server.WithCacheControllerTTL(config.CacheController.TTL),
server.WithCheckCacheLimit(config.CheckCache.Limit),
server.WithCheckIteratorCacheEnabled(config.CheckIteratorCache.Enabled),
server.WithCheckIteratorCacheMaxResults(config.CheckIteratorCache.MaxResults),
server.WithCheckIteratorCacheTTL(config.CheckIteratorCache.TTL),
server.WithCheckQueryCacheEnabled(config.CheckQueryCache.Enabled),
server.WithCheckQueryCacheTTL(config.CheckQueryCache.TTL),
server.WithRequestDurationByQueryHistogramBuckets(convertStringArrayToUintArray(config.RequestDurationDatastoreQueryCountBuckets)),
server.WithRequestDurationByDispatchCountHistogramBuckets(convertStringArrayToUintArray(config.RequestDurationDispatchCountBuckets)),
server.WithMaxAuthorizationModelSizeInBytes(config.MaxAuthorizationModelSizeInBytes),
server.WithContextPropagationToDatastore(config.ContextPropagationToDatastore),
server.WithDispatchThrottlingCheckResolverEnabled(config.CheckDispatchThrottling.Enabled),
server.WithDispatchThrottlingCheckResolverFrequency(config.CheckDispatchThrottling.Frequency),
server.WithDispatchThrottlingCheckResolverThreshold(config.CheckDispatchThrottling.Threshold),
server.WithDispatchThrottlingCheckResolverMaxThreshold(config.CheckDispatchThrottling.MaxThreshold),
server.WithListObjectsDispatchThrottlingEnabled(config.ListObjectsDispatchThrottling.Enabled),
server.WithListObjectsDispatchThrottlingFrequency(config.ListObjectsDispatchThrottling.Frequency),
server.WithListObjectsDispatchThrottlingThreshold(config.ListObjectsDispatchThrottling.Threshold),
server.WithListObjectsDispatchThrottlingMaxThreshold(config.ListObjectsDispatchThrottling.MaxThreshold),
server.WithListUsersDispatchThrottlingEnabled(config.ListUsersDispatchThrottling.Enabled),
server.WithListUsersDispatchThrottlingFrequency(config.ListUsersDispatchThrottling.Frequency),
server.WithListUsersDispatchThrottlingThreshold(config.ListUsersDispatchThrottling.Threshold),
server.WithListUsersDispatchThrottlingMaxThreshold(config.ListUsersDispatchThrottling.MaxThreshold),
server.WithCheckDatabaseThrottle(config.CheckDatastoreThrottle.Threshold, config.CheckDatastoreThrottle.Duration),
server.WithListObjectsDatabaseThrottle(config.ListObjectsDatastoreThrottle.Threshold, config.ListObjectsDatastoreThrottle.Duration),
server.WithListUsersDatabaseThrottle(config.ListUsersDatastoreThrottle.Threshold, config.ListUsersDatastoreThrottle.Duration),
server.WithListObjectsIteratorCacheEnabled(config.ListObjectsIteratorCache.Enabled),
server.WithListObjectsIteratorCacheMaxResults(config.ListObjectsIteratorCache.MaxResults),
server.WithListObjectsIteratorCacheTTL(config.ListObjectsIteratorCache.TTL),
server.WithMaxChecksPerBatchCheck(config.MaxChecksPerBatchCheck),
server.WithMaxConcurrentChecksPerBatchCheck(config.MaxConcurrentChecksPerBatchCheck),
server.WithSharedIteratorEnabled(config.SharedIterator.Enabled),
server.WithSharedIteratorLimit(config.SharedIterator.Limit),
server.WithPlanner(planner.New(&planner.Config{
EvictionThreshold: config.Planner.EvictionThreshold,
CleanupInterval: config.Planner.CleanupInterval,
})),
// The shared iterator watchdog timeout is set to config.RequestTimeout + 2 seconds
// to provide a small buffer for operations that might slightly exceed the request timeout.
server.WithSharedIteratorTTL(config.RequestTimeout+2*time.Second),
server.WithExperimentals(experimentals...),
server.WithAccessControlParams(config.AccessControl.Enabled, config.AccessControl.StoreID, config.AccessControl.ModelID, config.Authn.Method),
server.WithContext(ctx),
)
s.Logger.Info(
"starting openfga service...",
zap.String("version", build.Version),
zap.String("date", build.Date),
zap.String("commit", build.Commit),
zap.String("go-version", goruntime.Version()),
zap.Any("config", config),
)
// nosemgrep: grpc-server-insecure-connection
grpcServer := grpc.NewServer(serverOpts...)
openfgav1.RegisterOpenFGAServiceServer(grpcServer, svr)
healthServer := &health.Checker{TargetService: svr, TargetServiceName: openfgav1.OpenFGAService_ServiceDesc.ServiceName}
healthv1pb.RegisterHealthServer(grpcServer, healthServer)
reflection.Register(grpcServer)
lis, err := net.Listen("tcp", config.GRPC.Addr)
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}
go func() {
s.Logger.Info(fmt.Sprintf("🚀 starting gRPC server on '%s'...", lis.Addr().String()))
if err := grpcServer.Serve(lis); err != nil {
if !errors.Is(err, grpc.ErrServerStopped) {
s.Logger.Fatal("failed to start gRPC server", zap.Error(err))
}
}
s.Logger.Info("gRPC server shut down.")
}()
var httpServer *http.Server
if config.HTTP.Enabled {
runtime.DefaultContextTimeout = serverconfig.DefaultContextTimeout(config)
dialOpts := []grpc.DialOption{
// nolint:staticcheck // ignoring gRPC deprecations
grpc.WithBlock(),
}
if config.Trace.Enabled {
dialOpts = append(dialOpts, grpc.WithStatsHandler(otelgrpc.NewClientHandler()))
}
if config.GRPC.TLS.Enabled {
creds, err := credentials.NewClientTLSFromFile(config.GRPC.TLS.CertPath, "")
if err != nil {
s.Logger.Fatal("failed to load gRPC credentials", zap.Error(err))
}
dialOpts = append(dialOpts, grpc.WithTransportCredentials(creds))
} else {
dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
timeoutCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
// nolint:staticcheck // ignoring gRPC deprecations
conn, err := grpc.DialContext(timeoutCtx, config.GRPC.Addr, dialOpts...)
if err != nil {
s.Logger.Fatal("failed to connect to gRPC server", zap.Error(err))
}
defer conn.Close()
muxOpts := []runtime.ServeMuxOption{
runtime.WithForwardResponseOption(httpmiddleware.HTTPResponseModifier),
runtime.WithErrorHandler(func(c context.Context, sr *runtime.ServeMux, mm runtime.Marshaler, w http.ResponseWriter, r *http.Request, e error) {
intCode := serverErrors.ConvertToEncodedErrorCode(status.Convert(e))
httpmiddleware.CustomHTTPErrorHandler(c, w, r, serverErrors.NewEncodedError(intCode, e.Error()))
}),
runtime.WithStreamErrorHandler(func(ctx context.Context, e error) *status.Status {
intCode := serverErrors.ConvertToEncodedErrorCode(status.Convert(e))
encodedErr := serverErrors.NewEncodedError(intCode, e.Error())
return status.Convert(encodedErr)
}),
runtime.WithHealthzEndpoint(healthv1pb.NewHealthClient(conn)),
runtime.WithOutgoingHeaderMatcher(func(s string) (string, bool) { return s, true }),
}
mux := runtime.NewServeMux(muxOpts...)
if err := openfgav1.RegisterOpenFGAServiceHandler(ctx, mux, conn); err != nil {
return err
}
handler := http.Handler(mux)
if config.Trace.Enabled {
handler = otelhttp.NewHandler(handler, "grpc-gateway")
}
httpServer = &http.Server{
Addr: config.HTTP.Addr,
Handler: recovery.HTTPPanicRecoveryHandler(cors.New(cors.Options{
AllowedOrigins: config.HTTP.CORSAllowedOrigins,
AllowCredentials: true,
AllowedHeaders: config.HTTP.CORSAllowedHeaders,
AllowedMethods: []string{http.MethodGet, http.MethodPost,
http.MethodHead, http.MethodPatch, http.MethodDelete, http.MethodPut},
}).Handler(handler), s.Logger),
}
listener, err := net.Listen("tcp", config.HTTP.Addr)
if err != nil {
return err
}
if config.HTTP.TLS.Enabled {
if config.HTTP.TLS.CertPath == "" || config.HTTP.TLS.KeyPath == "" {
s.Logger.Fatal("'http.tls.cert' and 'http.tls.key' configs must be set")
}
httpGetCertificate, err := watchAndLoadCertificateWithCertWatcher(ctx, config.HTTP.TLS.CertPath, config.HTTP.TLS.KeyPath, s.Logger)
if err != nil {
return err
}
listener = tls.NewListener(listener, &tls.Config{
GetCertificate: httpGetCertificate,
})
s.Logger.Info("HTTP TLS is enabled, serving connections using the provided certificate")
} else {
s.Logger.Warn("HTTP TLS is disabled, serving connections using insecure plaintext")
}
go func() {
s.Logger.Info(fmt.Sprintf("🚀 starting HTTP server on '%s'...", httpServer.Addr))
if err := httpServer.Serve(listener); err != nil {
if !errors.Is(err, http.ErrServerClosed) {
s.Logger.Fatal("HTTP server closed with unexpected error", zap.Error(err))
}
}
s.Logger.Info("HTTP server shut down.")
}()
}
var playground *http.Server
if config.Playground.Enabled {
if !config.HTTP.Enabled {
return errors.New("the HTTP server must be enabled to run the openfga playground")
}
authMethod := config.Authn.Method
if authMethod != "none" && authMethod != "preshared" {
return errors.New("the playground only supports authn methods 'none' and 'preshared'")
}
playgroundAddr := fmt.Sprintf(":%d", config.Playground.Port)
s.Logger.Info(fmt.Sprintf("🛝 starting openfga playground on http://localhost%s/playground", playgroundAddr))
tmpl, err := template.ParseFS(assets.EmbedPlayground, "playground/index.html")
if err != nil {
return fmt.Errorf("failed to parse playground index.html as Go template: %w", err)
}
fileServer := http.FileServer(http.FS(assets.EmbedPlayground))
policy := backoff.NewExponentialBackOff()
policy.MaxElapsedTime = 3 * time.Second
var conn net.Conn
err = backoff.Retry(
func() error {
conn, err = net.Dial("tcp", config.HTTP.Addr)
return err
},
policy,
)
if err != nil {
return fmt.Errorf("failed to establish playground connection to HTTP server: %w", err)
}
playgroundAPIToken := ""
if authMethod == "preshared" {
playgroundAPIToken = config.Authn.Keys[0]
}
mux := http.NewServeMux()
mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/playground") {
if r.URL.Path == "/playground" || r.URL.Path == "/playground/index.html" {
err = tmpl.Execute(w, struct {
HTTPServerURL string
PlaygroundAPIToken string
}{
HTTPServerURL: conn.RemoteAddr().String(),
PlaygroundAPIToken: playgroundAPIToken,
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
s.Logger.Error("failed to execute/render the playground web template", zap.Error(err))
}
return
}
fileServer.ServeHTTP(w, r)
return
}
http.NotFound(w, r)
}))
playground = &http.Server{Addr: playgroundAddr, Handler: mux}
go func() {
err = playground.ListenAndServe()
if err != http.ErrServerClosed {
s.Logger.Fatal("failed to start the openfga playground server", zap.Error(err))
}
s.Logger.Info("playground shut down.")
}()
}
// wait for cancellation signal
<-ctx.Done()
s.Logger.Info("attempting to shutdown gracefully...")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if playground != nil {
if err := playground.Shutdown(ctx); err != nil {
s.Logger.Info("failed to shutdown the playground", zap.Error(err))
}
}
if httpServer != nil {
if err := httpServer.Shutdown(ctx); err != nil {
s.Logger.Info("failed to shutdown the http server", zap.Error(err))
}
}
if profilerServer != nil {
if err := profilerServer.Shutdown(ctx); err != nil {
s.Logger.Info("failed to shutdown the profiler", zap.Error(err))
}
}
if metricsServer != nil {
if err := metricsServer.Shutdown(ctx); err != nil {
s.Logger.Info("failed to shutdown the prometheus metrics server", zap.Error(err))
}
}
grpcServer.GracefulStop()
svr.Close()
authenticator.Close()
if err := tracerProviderCloser(); err != nil {
s.Logger.Error("failed to shutdown tracing", zap.Error(err))
}
s.Logger.Info("server exited. goodbye 👋")
return nil
}
func watchAndLoadCertificateWithCertWatcher(ctx context.Context, certPath, keyPath string, logger logger.Logger) (func(*tls.ClientHelloInfo) (*tls.Certificate, error), error) {
log.SetLogger(logr.New(nil))
// Create a certificate watcher
watcher, err := certwatcher.New(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("failed to create certwatcher: %w", err)
}
// Load the initial certificate
if err := watcher.ReadCertificate(); err != nil {
return nil, fmt.Errorf("failed to load initial certificate: %w", err)
}
logger.Info("Initial TLS certificate loaded.", zap.String("certPath", certPath), zap.String("keyPath", keyPath))
// Start watching for certificate changes
go func() {
logger.Info("Starting certificate watcher...", zap.String("certPath", certPath), zap.String("keyPath", keyPath))
if err := watcher.Start(ctx); err != nil {
logger.Error("Certwatcher encountered an error", zap.Error(err))
}
}()
// Return a function that retrieves the updated certificate
getCertificate := func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return watcher.GetCertificate(nil)
}
return getCertificate, nil
}
// Package util provides common utilities for spf13/cobra CLI utilities
// that can be used for various commands within this project.
package util
import (
"os"
"path/filepath"
"testing"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"github.com/stretchr/testify/require"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/memory"
"github.com/openfga/openfga/pkg/storage/mysql"
"github.com/openfga/openfga/pkg/storage/postgres"
"github.com/openfga/openfga/pkg/storage/sqlcommon"
"github.com/openfga/openfga/pkg/storage/sqlite"
storagefixtures "github.com/openfga/openfga/pkg/testfixtures/storage"
)
// MustBindPFlag attempts to bind a specific key to a pflag (as used by cobra) and panics
// if the binding fails with a non-nil error.
func MustBindPFlag(key string, flag *pflag.Flag) {
if err := viper.BindPFlag(key, flag); err != nil {
panic("failed to bind pflag: " + err.Error())
}
}
func MustBindEnv(input ...string) {
if err := viper.BindEnv(input...); err != nil {
panic("failed to bind env key: " + err.Error())
}
}
func Contains[E comparable](s []E, v E) bool {
return Index(s, v) >= 0
}
func Index[E comparable](s []E, v E) int {
for i, vs := range s {
if v == vs {
return i
}
}
return -1
}
// MustBootstrapDatastore returns the datastore's container, the datastore, and the URI to connect to it.
// It automatically cleans up the container after the test finishes.
func MustBootstrapDatastore(t testing.TB, engine string) (storagefixtures.DatastoreTestContainer, storage.OpenFGADatastore, string) {
container := storagefixtures.RunDatastoreTestContainer(t, engine)
uri := container.GetConnectionURI(true)
var ds storage.OpenFGADatastore
var err error
cfg := sqlcommon.NewConfig()
switch engine {
case "memory":
ds = memory.New()
case "postgres":
ds, err = postgres.New(uri, cfg)
case "mysql":
ds, err = mysql.New(uri, cfg)
case "sqlite":
ds, err = sqlite.New(uri, cfg)
default:
t.Fatalf("unsupported datastore engine: %q", engine)
}
require.NoError(t, err)
t.Cleanup(ds.Close)
return container, ds, uri
}
func PrepareTempConfigDir(t *testing.T) string {
_, err := os.Stat("/etc/openfga/config.yaml")
require.ErrorIs(t, err, os.ErrNotExist, "Config file at /etc/openfga/config.yaml would disturb test result.")
homedir := t.TempDir()
t.Setenv("HOME", homedir)
confdir := filepath.Join(homedir, ".openfga")
require.NoError(t, os.Mkdir(confdir, 0750))
return confdir
}
func PrepareTempConfigFile(t *testing.T, config string) {
confdir := PrepareTempConfigDir(t)
confFile, err := os.Create(filepath.Join(confdir, "config.yaml"))
require.NoError(t, err)
_, err = confFile.WriteString(config)
require.NoError(t, err)
require.NoError(t, confFile.Close())
}
package authn
import (
"context"
"github.com/MicahParks/keyfunc/v2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/authclaims"
)
var (
ErrUnauthenticated = status.Error(codes.Code(openfgav1.AuthErrorCode_unauthenticated), "unauthenticated")
ErrMissingBearerToken = status.Error(codes.Code(openfgav1.AuthErrorCode_bearer_token_missing), "missing bearer token")
)
type Authenticator interface {
// Authenticate returns a nil error and the AuthClaims info (if available) if the subject is authenticated or a
// non-nil error with an appropriate error cause otherwise.
Authenticate(requestContext context.Context) (*authclaims.AuthClaims, error)
// Close Cleans up the authenticator.
Close()
}
type NoopAuthenticator struct{}
var _ Authenticator = (*NoopAuthenticator)(nil)
func (n NoopAuthenticator) Authenticate(requestContext context.Context) (*authclaims.AuthClaims, error) {
return &authclaims.AuthClaims{
Subject: "",
Scopes: nil,
}, nil
}
func (n NoopAuthenticator) Close() {}
// OidcConfig contains authorization server metadata. See https://datatracker.ietf.org/doc/html/rfc8414#section-2
type OidcConfig struct {
Issuer string `json:"issuer"`
JWKsURI string `json:"jwks_uri"`
}
type OIDCAuthenticator interface {
GetConfiguration() (*OidcConfig, error)
GetKeys() (*keyfunc.JWKS, error)
}
package oidc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"slices"
"strings"
"time"
"github.com/MicahParks/keyfunc/v2"
jwt "github.com/golang-jwt/jwt/v5"
grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
"github.com/hashicorp/go-retryablehttp"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/authn"
"github.com/openfga/openfga/pkg/authclaims"
)
type RemoteOidcAuthenticator struct {
MainIssuer string
IssuerAliases []string
Audience string
Subjects []string
ClientIDClaims []string
JwksURI string
JWKs *keyfunc.JWKS
httpClient *http.Client
}
var (
jwkRefreshInterval = 48 * time.Hour
errInvalidClaims = status.Error(codes.Code(openfgav1.AuthErrorCode_invalid_claims), "invalid claims")
fetchJWKs = fetchJWK
)
var _ authn.Authenticator = (*RemoteOidcAuthenticator)(nil)
var _ authn.OIDCAuthenticator = (*RemoteOidcAuthenticator)(nil)
func NewRemoteOidcAuthenticator(mainIssuer string, issuerAliases []string, audience string, subjects []string, clientIDClaims []string) (*RemoteOidcAuthenticator, error) {
client := retryablehttp.NewClient()
client.Logger = nil
oidc := &RemoteOidcAuthenticator{
MainIssuer: mainIssuer,
IssuerAliases: issuerAliases,
Audience: audience,
Subjects: subjects,
httpClient: client.StandardClient(),
ClientIDClaims: clientIDClaims,
}
// Client ID is:
// 1. If the user has set it in configuration, use that
// 2, If the user has not set it in configuration, use the following as default:
// 2.a. Use `azp`: the OpenID standard https://openid.net/specs/openid-connect-core-1_0.html#IDToken
// 3.b. Use `client_id` in RFC9068 https://www.rfc-editor.org/rfc/rfc9068.html#name-data-structure
if len(oidc.ClientIDClaims) == 0 {
oidc.ClientIDClaims = []string{"azp", "client_id"}
}
err := fetchJWKs(oidc)
if err != nil {
return nil, err
}
return oidc, nil
}
func (oidc *RemoteOidcAuthenticator) Authenticate(requestContext context.Context) (*authclaims.AuthClaims, error) {
authHeader, err := grpcauth.AuthFromMD(requestContext, "Bearer")
if err != nil {
return nil, authn.ErrMissingBearerToken
}
options := []jwt.ParserOption{
jwt.WithValidMethods([]string{"RS256"}),
jwt.WithIssuedAt(),
jwt.WithExpirationRequired(),
}
if strings.TrimSpace(oidc.Audience) != "" {
options = append(options, jwt.WithAudience(oidc.Audience))
}
jwtParser := jwt.NewParser(options...)
token, err := jwtParser.Parse(authHeader, func(token *jwt.Token) (any, error) {
return oidc.JWKs.Keyfunc(token)
})
if err != nil || !token.Valid {
return nil, errInvalidClaims
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, errInvalidClaims
}
validIssuers := []string{
oidc.MainIssuer,
}
validIssuers = append(validIssuers, oidc.IssuerAliases...)
ok = slices.ContainsFunc(validIssuers, func(issuer string) bool {
v := jwt.NewValidator(jwt.WithIssuer(issuer))
err := v.Validate(claims)
return err == nil
})
if !ok {
return nil, errInvalidClaims
}
if len(oidc.Subjects) > 0 {
ok = slices.ContainsFunc(oidc.Subjects, func(subject string) bool {
v := jwt.NewValidator(jwt.WithSubject(subject))
err := v.Validate(claims)
return err == nil
})
if !ok {
return nil, errInvalidClaims
}
}
// optional subject
var subject = ""
if subjectClaim, ok := claims["sub"]; ok {
if subject, ok = subjectClaim.(string); !ok {
return nil, errInvalidClaims
}
}
clientID := ""
for _, claimString := range oidc.ClientIDClaims {
clientID, ok = claims[claimString].(string)
if ok {
break
}
}
principal := &authclaims.AuthClaims{
Subject: subject,
Scopes: make(map[string]bool),
ClientID: clientID,
}
// optional scopes
if scopeKey, ok := claims["scope"]; ok {
if scope, ok := scopeKey.(string); ok {
scopes := strings.Split(scope, " ")
for _, s := range scopes {
principal.Scopes[s] = true
}
}
}
return principal, nil
}
func fetchJWK(oidc *RemoteOidcAuthenticator) error {
oidcConfig, err := oidc.GetConfiguration()
if err != nil {
return fmt.Errorf("error fetching OIDC configuration: %w", err)
}
oidc.JwksURI = oidcConfig.JWKsURI
jwks, err := oidc.GetKeys()
if err != nil {
return fmt.Errorf("error fetching OIDC keys: %w", err)
}
oidc.JWKs = jwks
return nil
}
func (oidc *RemoteOidcAuthenticator) GetKeys() (*keyfunc.JWKS, error) {
jwks, err := keyfunc.Get(oidc.JwksURI, keyfunc.Options{
Client: oidc.httpClient,
RefreshInterval: jwkRefreshInterval,
})
if err != nil {
return nil, fmt.Errorf("error fetching keys from %v: %w", oidc.JwksURI, err)
}
return jwks, nil
}
func (oidc *RemoteOidcAuthenticator) GetConfiguration() (*authn.OidcConfig, error) {
wellKnown := strings.TrimSuffix(oidc.MainIssuer, "/") + "/.well-known/openid-configuration"
req, err := http.NewRequest("GET", wellKnown, nil)
if err != nil {
return nil, fmt.Errorf("error forming request to get OIDC: %w", err)
}
res, err := oidc.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("error getting OIDC: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code getting OIDC: %v", res.StatusCode)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
oidcConfig := &authn.OidcConfig{}
if err := json.Unmarshal(body, oidcConfig); err != nil {
return nil, fmt.Errorf("failed parsing document: %w", err)
}
if oidcConfig.Issuer == "" {
return nil, errors.New("missing issuer value")
}
if oidcConfig.JWKsURI == "" {
return nil, errors.New("missing jwks_uri value")
}
return oidcConfig, nil
}
func (oidc *RemoteOidcAuthenticator) Close() {
oidc.JWKs.EndBackground()
}
package presharedkey
import (
"context"
"errors"
grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
"github.com/openfga/openfga/internal/authn"
"github.com/openfga/openfga/pkg/authclaims"
)
type PresharedKeyAuthenticator struct {
ValidKeys map[string]struct{}
}
var _ authn.Authenticator = (*PresharedKeyAuthenticator)(nil)
func NewPresharedKeyAuthenticator(validKeys []string) (*PresharedKeyAuthenticator, error) {
if len(validKeys) < 1 {
return nil, errors.New("invalid auth configuration, please specify at least one key")
}
vKeys := make(map[string]struct{})
for _, k := range validKeys {
vKeys[k] = struct{}{}
}
return &PresharedKeyAuthenticator{ValidKeys: vKeys}, nil
}
func (pka *PresharedKeyAuthenticator) Authenticate(ctx context.Context) (*authclaims.AuthClaims, error) {
authHeader, err := grpcauth.AuthFromMD(ctx, "Bearer")
if err != nil {
return nil, authn.ErrMissingBearerToken
}
if _, found := pka.ValidKeys[authHeader]; found {
return &authclaims.AuthClaims{
Subject: "", // no user information in this auth method
}, nil
}
return nil, authn.ErrUnauthenticated
}
func (pka *PresharedKeyAuthenticator) Close() {}
package authz
import (
"context"
"fmt"
"strings"
"sync"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
parser "github.com/openfga/language/pkg/go/utils"
"github.com/openfga/openfga/internal/utils/apimethod"
"github.com/openfga/openfga/pkg/authclaims"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
const (
accessControlKey = "access_control"
// MaxModulesInRequest Max number of modules a user is allowed to write in a single request if they do not have write permissions to the store.
MaxModulesInRequest = 1
// Relations.
CanCallReadAuthorizationModels = "can_call_read_authorization_models"
CanCallRead = "can_call_read"
CanCallWrite = "can_call_write"
CanCallListObjects = "can_call_list_objects"
CanCallCheck = "can_call_check"
CanCallListUsers = "can_call_list_users"
CanCallWriteAssertions = "can_call_write_assertions"
CanCallReadAssertions = "can_call_read_assertions"
CanCallWriteAuthorizationModels = "can_call_write_authorization_models"
CanCallListStores = "can_call_list_stores"
CanCallCreateStore = "can_call_create_stores"
CanCallGetStore = "can_call_get_store"
CanCallDeleteStore = "can_call_delete_store"
CanCallExpand = "can_call_expand"
CanCallReadChanges = "can_call_read_changes"
StoreType = "store"
ModuleType = "module"
ApplicationType = "application"
SystemType = "system"
SystemRelationOnStore = "system"
RootSystemID = "fga"
)
var (
ErrUnauthorizedResponse = status.Error(codes.Code(openfgav1.AuthErrorCode_forbidden), "the principal is not authorized to perform the action")
SystemObjectID = fmt.Sprintf("%s:%s", SystemType, RootSystemID)
tracer = otel.Tracer("internal/authz")
)
type StoreIDType string
func (s StoreIDType) String() string {
return fmt.Sprintf("%s:%s", StoreType, string(s))
}
type ClientIDType string
func (c ClientIDType) String() string {
return fmt.Sprintf("%s:%s", ApplicationType, string(c))
}
type ModuleIDType string
func (m ModuleIDType) String(module string) string {
return fmt.Sprintf(`%s:%s|%s`, ModuleType, string(m), module)
}
type Config struct {
StoreID string
ModelID string
}
type AuthorizerInterface interface {
Authorize(ctx context.Context, storeID string, apiMethod apimethod.APIMethod, modules ...string) error
AuthorizeCreateStore(ctx context.Context) error
AuthorizeListStores(ctx context.Context) error
ListAuthorizedStores(ctx context.Context) ([]string, error)
GetModulesForWriteRequest(ctx context.Context, req *openfgav1.WriteRequest, typesys *typesystem.TypeSystem) ([]string, error)
AccessControlStoreID() string
}
type NoopAuthorizer struct{}
func NewAuthorizerNoop() *NoopAuthorizer {
return &NoopAuthorizer{}
}
func (a *NoopAuthorizer) Authorize(ctx context.Context, storeID string, apiMethod apimethod.APIMethod, modules ...string) error {
return nil
}
func (a *NoopAuthorizer) AuthorizeCreateStore(ctx context.Context) error {
return nil
}
func (a *NoopAuthorizer) AuthorizeListStores(ctx context.Context) error {
return nil
}
func (a *NoopAuthorizer) ListAuthorizedStores(ctx context.Context) ([]string, error) {
return nil, nil
}
func (a *NoopAuthorizer) GetModulesForWriteRequest(ctx context.Context, req *openfgav1.WriteRequest, typesys *typesystem.TypeSystem) ([]string, error) {
return nil, nil
}
func (a *NoopAuthorizer) AccessControlStoreID() string {
return ""
}
type Authorizer struct {
config *Config
server ServerInterface
logger logger.Logger
}
func NewAuthorizer(config *Config, server ServerInterface, logger logger.Logger) *Authorizer {
return &Authorizer{
config: config,
server: server,
logger: logger,
}
}
type authorizationError struct {
Cause string
}
func (e *authorizationError) Error() string {
return e.Cause
}
func (a *Authorizer) getRelation(apiMethod apimethod.APIMethod) (string, error) {
// TODO: Add a golangci-linter rule to ensure all the possible cases are handled.
switch apiMethod {
case apimethod.ReadAuthorizationModel, apimethod.ReadAuthorizationModels:
return CanCallReadAuthorizationModels, nil
case apimethod.Read:
return CanCallRead, nil
case apimethod.Write:
return CanCallWrite, nil
case apimethod.ListObjects, apimethod.StreamedListObjects:
return CanCallListObjects, nil
case apimethod.Check, apimethod.BatchCheck:
return CanCallCheck, nil
case apimethod.ListUsers:
return CanCallListUsers, nil
case apimethod.WriteAssertions:
return CanCallWriteAssertions, nil
case apimethod.ReadAssertions:
return CanCallReadAssertions, nil
case apimethod.WriteAuthorizationModel:
return CanCallWriteAuthorizationModels, nil
case apimethod.ListStores:
return CanCallListStores, nil
case apimethod.CreateStore:
return CanCallCreateStore, nil
case apimethod.GetStore:
return CanCallGetStore, nil
case apimethod.DeleteStore:
return CanCallDeleteStore, nil
case apimethod.Expand:
return CanCallExpand, nil
case apimethod.ReadChanges:
return CanCallReadChanges, nil
default:
return "", fmt.Errorf("unknown API method: %s", apiMethod)
}
}
func (a *Authorizer) AccessControlStoreID() string {
if a.config != nil {
return a.config.StoreID
}
return ""
}
// Authorize checks if the user has access to the resource.
func (a *Authorizer) Authorize(ctx context.Context, storeID string, apiMethod apimethod.APIMethod, modules ...string) error {
methodName := "Authorize"
ctx, span := tracer.Start(ctx, methodName, trace.WithAttributes(
attribute.String("storeID", storeID),
attribute.String("apiMethod", apiMethod.String()),
attribute.String("modules", strings.Join(modules, ",")),
))
defer span.End()
grpc_ctxtags.Extract(ctx).Set(accessControlKey, methodName)
claims, err := checkAuthClaims(ctx)
if err != nil {
return err
}
relation, err := a.getRelation(apiMethod)
if err != nil {
return &authorizationError{Cause: fmt.Sprintf("error getting relation: %v", err)}
}
contextualTuples := openfgav1.ContextualTupleKeys{
TupleKeys: []*openfgav1.TupleKey{
getSystemAccessTuple(storeID),
},
}
// Check if there is top-level authorization first, before checking modules
err = a.individualAuthorize(ctx, claims.ClientID, relation, StoreIDType(storeID).String(), &contextualTuples)
if err == nil {
return nil
}
if len(modules) > 0 {
// If there is no top level authorization, but the max modules limit is exceeded, return an error regarding that limit
// Having a limit helps ensure we do not run too many checks on every write when there are modules
if len(modules) > MaxModulesInRequest {
return &authorizationError{Cause: fmt.Sprintf("the principal cannot write tuples of more than %v module(s) in a single request (modules in request: %v)", MaxModulesInRequest, len(modules))}
}
return a.moduleAuthorize(ctx, claims.ClientID, relation, storeID, modules)
}
// If there are no modules to check, return the top-level authorization error
return err
}
// AuthorizeCreateStore checks if the user has access to create a store.
func (a *Authorizer) AuthorizeCreateStore(ctx context.Context) error {
methodName := "AuthorizeCreateStore"
ctx, span := tracer.Start(ctx, methodName)
defer span.End()
grpc_ctxtags.Extract(ctx).Set(accessControlKey, methodName)
claims, err := checkAuthClaims(ctx)
if err != nil {
return err
}
relation, err := a.getRelation(apimethod.CreateStore)
if err != nil {
return err
}
return a.individualAuthorize(ctx, claims.ClientID, relation, SystemObjectID, &openfgav1.ContextualTupleKeys{})
}
// AuthorizeListStores checks if the user has access to list stores.
func (a *Authorizer) AuthorizeListStores(ctx context.Context) error {
methodName := "AuthorizeListStores"
ctx, span := tracer.Start(ctx, methodName)
defer span.End()
grpc_ctxtags.Extract(ctx).Set(accessControlKey, methodName)
claims, err := checkAuthClaims(ctx)
if err != nil {
return err
}
relation, err := a.getRelation(apimethod.ListStores)
if err != nil {
return err
}
return a.individualAuthorize(ctx, claims.ClientID, relation, SystemObjectID, &openfgav1.ContextualTupleKeys{})
}
// ListAuthorizedStores returns the list of store IDs that the user has access to.
func (a *Authorizer) ListAuthorizedStores(ctx context.Context) ([]string, error) {
methodName := "ListAuthorizedStores"
ctx, span := tracer.Start(ctx, methodName)
defer span.End()
grpc_ctxtags.Extract(ctx).Set(accessControlKey, methodName)
claims, err := checkAuthClaims(ctx)
if err != nil {
return nil, err
}
req := &openfgav1.ListObjectsRequest{
StoreId: a.config.StoreID,
AuthorizationModelId: a.config.ModelID,
User: ClientIDType(claims.ClientID).String(),
Relation: CanCallGetStore,
Type: StoreType,
}
// Disable authz check for the list objects request.
ctx = authclaims.ContextWithSkipAuthzCheck(ctx, true)
resp, err := a.server.ListObjects(ctx, req)
if err != nil {
return nil, &authorizationError{Cause: fmt.Sprintf("list objects returned error: %v", err)}
}
storeIDs := make([]string, len(resp.GetObjects()))
storePrefix := StoreType + ":"
for i, store := range resp.GetObjects() {
storeIDs[i] = strings.TrimPrefix(store, storePrefix)
}
return storeIDs, nil
}
// GetModulesForWriteRequest returns the modules that should be checked for the write request.
// If we encounter a type with no attached module, we should break and return no modules so that the authz check will be against the store
// Otherwise we return a list of unique modules encountered so that FGA on FGA can check them after.
func (a *Authorizer) GetModulesForWriteRequest(ctx context.Context, req *openfgav1.WriteRequest, typesys *typesystem.TypeSystem) ([]string, error) {
methodName := "GetModulesForWriteRequest"
ctx, span := tracer.Start(ctx, methodName)
defer span.End()
grpc_ctxtags.Extract(ctx).Set(accessControlKey, methodName)
tuples := make([]TupleKeyInterface, len(req.GetWrites().GetTupleKeys())+len(req.GetDeletes().GetTupleKeys()))
var index int
for _, tuple := range req.GetWrites().GetTupleKeys() {
tuples[index] = tuple
index++
}
for _, tuple := range req.GetDeletes().GetTupleKeys() {
tuples[index] = tuple
index++
}
modulesMap, err := extractModulesFromTuples(tuples, typesys)
if err != nil {
return nil, err
}
modules := make([]string, len(modulesMap))
var i int
for module := range modulesMap {
modules[i] = module
i++
}
return modules, nil
}
// TupleKeyInterface is an interface that both TupleKeyWithoutCondition and TupleKey implement.
type TupleKeyInterface interface {
GetObject() string
GetRelation() string
}
// extractModulesFromTuples extracts the modules from the tuples. If a type has no module, we
// return an empty map so that the caller can handle authorization for tuples without modules.
func extractModulesFromTuples[T TupleKeyInterface](tupleKeys []T, typesys *typesystem.TypeSystem) (map[string]struct{}, error) {
modulesMap := make(map[string]struct{})
for _, tupleKey := range tupleKeys {
objType, _ := tuple.SplitObject(tupleKey.GetObject())
objectType, ok := typesys.GetTypeDefinition(objType)
if !ok {
return nil, &authorizationError{Cause: fmt.Sprintf("type '%s' not found", objType)}
}
module, err := parser.GetModuleForObjectTypeRelation(objectType, tupleKey.GetRelation())
if err != nil {
return nil, err
}
if module == "" {
return nil, nil
}
modulesMap[module] = struct{}{}
}
return modulesMap, nil
}
func (a *Authorizer) individualAuthorize(ctx context.Context, clientID, relation, object string, contextualTuples *openfgav1.ContextualTupleKeys) error {
ctx, span := tracer.Start(ctx, "individualAuthorize", trace.WithAttributes(
attribute.String("clientID", clientID),
attribute.String("relation", relation),
attribute.String("object", object),
))
defer span.End()
req := &openfgav1.CheckRequest{
StoreId: a.config.StoreID,
AuthorizationModelId: a.config.ModelID,
TupleKey: &openfgav1.CheckRequestTupleKey{
User: ClientIDType(clientID).String(),
Relation: relation,
Object: object,
},
ContextualTuples: contextualTuples,
}
// Disable authz check for the check request.
ctx = authclaims.ContextWithSkipAuthzCheck(ctx, true)
resp, err := a.server.Check(ctx, req)
if err != nil {
return &authorizationError{Cause: fmt.Sprintf("check returned error: %v", err)}
}
if !resp.GetAllowed() {
return &authorizationError{Cause: "check returned not allowed"}
}
return nil
}
// moduleAuthorize checks if the user has access to each of the modules, and exits if an error is encountered.
func (a *Authorizer) moduleAuthorize(ctx context.Context, clientID, relation, storeID string, modules []string) error {
ctx, span := tracer.Start(ctx, "moduleAuthorize", trace.WithAttributes(
attribute.String("clientID", clientID),
attribute.String("relation", relation),
attribute.String("storeID", storeID),
attribute.String("modules", strings.Join(modules, ",")),
))
defer span.End()
var wg sync.WaitGroup
errorChannel := make(chan error, len(modules))
for _, module := range modules {
wg.Add(1)
go func(module string) {
defer wg.Done()
contextualTuples := openfgav1.ContextualTupleKeys{
TupleKeys: []*openfgav1.TupleKey{
{
User: StoreIDType(storeID).String(),
Relation: StoreType,
Object: ModuleIDType(storeID).String(module),
},
getSystemAccessTuple(storeID),
},
}
err := a.individualAuthorize(ctx, clientID, relation, ModuleIDType(storeID).String(module), &contextualTuples)
if err != nil {
errorChannel <- err
}
}(module)
}
wg.Wait()
close(errorChannel)
for err := range errorChannel {
if err != nil {
return err
}
}
return nil
}
// checkAuthClaims checks the auth claims in the context.
func checkAuthClaims(ctx context.Context) (*authclaims.AuthClaims, error) {
claims, found := authclaims.AuthClaimsFromContext(ctx)
if !found || claims.ClientID == "" {
return nil, &authorizationError{Cause: "client ID not found in context or is empty"}
}
return claims, nil
}
func getSystemAccessTuple(storeID string) *openfgav1.TupleKey {
return &openfgav1.TupleKey{
User: SystemObjectID,
Relation: SystemRelationOnStore,
Object: StoreIDType(storeID).String(),
}
}
package cachecontroller
import (
"context"
"math"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/build"
"github.com/openfga/openfga/internal/concurrency"
"github.com/openfga/openfga/internal/utils"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/telemetry"
"github.com/openfga/openfga/pkg/tuple"
)
var (
tracer = otel.Tracer("internal/cachecontroller")
cacheTotalCounter = promauto.NewCounter(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: "cachecontroller_cache_total_count",
Help: "The total number of cachecontroller requests.",
})
cacheHitCounter = promauto.NewCounter(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: "cachecontroller_cache_hit_count",
Help: "The total number of cache hits from cachecontroller requests.",
})
cacheInvalidationCounter = promauto.NewCounter(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: "cachecontroller_cache_invalidation_count",
Help: "The total number of invalidations performed by the cache controller.",
})
findChangesAndInvalidateHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: "cachecontroller_invalidation_duration_ms",
Help: "The duration (in ms) required for cache controller to find changes and invalidate labeled by whether invalidation is required and buckets of changes size.",
Buckets: []float64{5, 10, 25, 50, 100, 200, 500, 1000, 5000},
NativeHistogramBucketFactor: 1.1,
NativeHistogramMaxBucketNumber: 100,
NativeHistogramMinResetDuration: time.Hour,
}, []string{"invalidation_type"})
)
type CacheController interface {
// DetermineInvalidationTime returns the timestamp of the last write for the specified store if it was in cache,
// Else it returns Zero time and triggers InvalidateIfNeeded().
DetermineInvalidationTime(context.Context, string) time.Time
// InvalidateIfNeeded checks to see if an invalidation is currently in progress for a store,
// and if not it will spawn a goroutine to invalidate cached records conditionally
// based on timestamp. It may invalidate all cache records, some, or none.
InvalidateIfNeeded(context.Context, string)
}
type NoopCacheController struct{}
func (c *NoopCacheController) DetermineInvalidationTime(_ context.Context, _ string) time.Time {
return time.Time{}
}
func (c *NoopCacheController) InvalidateIfNeeded(_ context.Context, _ string) {
}
func NewNoopCacheController() CacheController {
return &NoopCacheController{}
}
// InMemoryCacheControllerOpt defines an option that can be used to change the behavior of InMemoryCacheController
// instance.
type InMemoryCacheControllerOpt func(*InMemoryCacheController)
// WithLogger sets the logger for InMemoryCacheController.
func WithLogger(logger logger.Logger) InMemoryCacheControllerOpt {
return func(inm *InMemoryCacheController) {
inm.logger = logger
}
}
// InMemoryCacheController will invalidate cache iterator (InMemoryCache) and sub problem cache (CachedCheckResolver) entries
// that are more recent than the last write for the specified store.
// Note that the invalidation is done asynchronously, and only after a Check request is received.
// It will be eventually consistent.
type InMemoryCacheController struct {
ds storage.OpenFGADatastore
cache storage.InMemoryCache[any]
// ttl for the entry that keeps the last timestamp for a Write for a storeID.
ttl time.Duration
iteratorCacheTTL time.Duration
inflightInvalidations sync.Map
logger logger.Logger
// for testing purposes
wg sync.WaitGroup
}
func NewCacheController(ds storage.OpenFGADatastore, cache storage.InMemoryCache[any], ttl time.Duration, iteratorCacheTTL time.Duration, opts ...InMemoryCacheControllerOpt) CacheController {
c := &InMemoryCacheController{
ds: ds,
cache: cache,
ttl: ttl,
iteratorCacheTTL: iteratorCacheTTL,
inflightInvalidations: sync.Map{},
logger: logger.NewNoopLogger(),
}
for _, opt := range opts {
opt(c)
}
return c
}
// DetermineInvalidationTime returns the timestamp of the last write for the specified store if it was in cache,
// Else it returns Zero time and triggers InvalidateIfNeeded().
func (c *InMemoryCacheController) DetermineInvalidationTime(
ctx context.Context,
storeID string,
) time.Time {
ctx, span := tracer.Start(ctx, "cacheController.DetermineInvalidationTime", trace.WithAttributes(attribute.Bool("cached", false)))
defer span.End()
cacheTotalCounter.Inc()
cacheKey := storage.GetChangelogCacheKey(storeID)
cacheResp := c.cache.Get(cacheKey)
c.logger.Debug("InMemoryCacheController DetermineInvalidationTime cache attempt",
zap.String("store_id", storeID),
zap.Bool("hit", cacheResp != nil),
)
if cacheResp != nil {
if entry, ok := cacheResp.(*storage.ChangelogCacheEntry); ok {
// the TTL grace period hasn't been breached
if entry.LastModified.Add(c.ttl).After(time.Now()) {
cacheHitCounter.Inc()
span.SetAttributes(attribute.Bool("cached", true))
return entry.LastModified
}
}
}
c.InvalidateIfNeeded(ctx, storeID)
return time.Time{}
}
// findChangesDescending is a wrapper on ReadChanges. If there are 0 changes to be returned, ReadChanges will actually return an error.
func (c *InMemoryCacheController) findChangesDescending(ctx context.Context, storeID string) ([]*openfgav1.TupleChange, string, error) {
opts := storage.ReadChangesOptions{
SortDesc: true,
Pagination: storage.PaginationOptions{
PageSize: storage.DefaultPageSize,
From: "",
}}
return c.ds.ReadChanges(ctx, storeID, storage.ReadChangesFilter{}, opts)
}
func (c *InMemoryCacheController) InvalidateIfNeeded(ctx context.Context, storeID string) {
span := trace.SpanFromContext(ctx)
_, present := c.inflightInvalidations.LoadOrStore(storeID, struct{}{})
if present {
span.SetAttributes(attribute.Bool("cache_controller_invalidation", false))
// If invalidation is already in process, abort.
return
}
span.SetAttributes(attribute.Bool("cache_controller_invalidation", true))
c.wg.Add(1)
go func() {
// we do not want to propagate context to avoid early cancellation
// and pollute span.
c.findChangesAndInvalidateIfNecessary(ctx, storeID)
c.inflightInvalidations.Delete(storeID)
c.wg.Done()
}()
}
type changelogResultMsg struct {
err error
changes []*openfgav1.TupleChange
}
// findChangesAndInvalidateIfNecessary checks the most recent entry in this store's changelog against the most
// recent cached changelog entry. If the most recent changelog entry is older than the cached changelog timestamp,
// no invalidation is necessary and we return. If not, we locate changelog records that have been around for longer
// than the cache's TTL and invalidate them.
func (c *InMemoryCacheController) findChangesAndInvalidateIfNecessary(parentCtx context.Context, storeID string) {
start := time.Now()
ctx, span := tracer.Start(context.Background(), "cacheController.findChangesAndInvalidateIfNecessary")
defer span.End()
link := trace.LinkFromContext(ctx)
trace.SpanFromContext(parentCtx).AddLink(link)
cacheKey := storage.GetChangelogCacheKey(storeID)
lastCacheRecord := c.cache.Get(cacheKey)
lastInvalidation := time.Time{}
if lastCacheRecord != nil {
if decodedRecord, ok := lastCacheRecord.(*storage.ChangelogCacheEntry); ok {
// if the change log cache is available and valid, use the last modified
// time to have better consistency. Otherwise, the lastInvalidation will
// be the beginning of time which imply the need to invalidate all records.
lastInvalidation = decodedRecord.LastModified
} else {
c.logger.Error("Unable to cast lastCacheRecord properly", zap.String("cacheKey", cacheKey))
}
}
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
done := make(chan changelogResultMsg, 1)
c.wg.Add(1)
go func() {
changes, _, err := c.findChangesDescending(ctx, storeID)
concurrency.TrySendThroughChannel(ctx, changelogResultMsg{err: err, changes: changes}, done)
c.wg.Done()
}()
var changes []*openfgav1.TupleChange
select {
case <-ctx.Done():
// no need to modify cacheKey as a new attempt will be done once the inflight validation is cleared
return
case msg := <-done:
if msg.err != nil {
telemetry.TraceError(span, msg.err)
// do not allow any cache read until next refresh
c.invalidateIteratorCache(storeID)
return
}
changes = msg.changes
}
lastChangelog := changes[0]
entry := &storage.ChangelogCacheEntry{
LastModified: lastChangelog.GetTimestamp().AsTime(),
}
defer c.cache.Set(cacheKey, entry, utils.JitterDuration(c.ttl, time.Minute)) // add buffer between checks
invalidationType := "none"
if !lastChangelog.GetTimestamp().AsTime().After(lastInvalidation) {
// no new changes, no need to perform invalidations
span.SetAttributes(attribute.String("invalidationType", invalidationType))
c.logger.Debug("InMemoryCacheController findChangesAndInvalidateIfNecessary no invalidation as entry.LastModified before last verified",
zap.String("store_id", storeID),
zap.Time("entry.LastModified", entry.LastModified),
zap.Time("timestampOfLastInvalidation", lastInvalidation))
findChangesAndInvalidateHistogram.WithLabelValues(invalidationType).Observe(float64(time.Since(start).Milliseconds()))
return
}
lastIteratorInvalidation := time.Now().Add(-c.iteratorCacheTTL)
// need to consider there might just be 1 change
// iterate from the oldest to most recent to determine if the last change is part of the current batch
// Remember that idx[0] is the most recent change while idx[len(changes)-1] is the oldest change because
// changes is ordered from most recent to oldest.
idx := len(changes) - 1
for ; idx >= 0; idx-- {
// idx marks the changes the first change after the timestampOfLastIteratorInvalidation.
// therefore, we want to use the changes happens at/after this time to invalidate cache.
//
// Note that we only want to add invalidation entries for changes with timestamp >= now - iterator cache's TTL
// because anything older than that time would not live in the iterator cache anyway.
if changes[idx].GetTimestamp().AsTime().After(lastIteratorInvalidation) {
break
}
}
// all changes happened after the last invalidation, thus we should revoke all the cached iterators for the store.
if idx == len(changes)-1 {
invalidationType = "full"
c.invalidateIteratorCache(storeID)
} else {
// only a subset of changes are new, revoke the respective ones.
lastModified := time.Now()
if idx >= 0 {
invalidationType = "partial"
}
for ; idx >= 0; idx-- {
t := changes[idx].GetTupleKey()
c.invalidateIteratorCacheByObjectRelation(storeID, t.GetObject(), t.GetRelation(), lastModified)
// We invalidate all iterators for the tuple's user and object type, regardless of the relation.
c.invalidateIteratorCacheByUserAndObjectType(storeID, t.GetUser(), tuple.GetType(t.GetObject()), lastModified)
}
}
if invalidationType != "none" {
cacheInvalidationCounter.Inc()
}
c.logger.Debug("InMemoryCacheController findChangesAndInvalidateIfNecessary invalidation",
zap.String("store_id", storeID),
zap.Time("entry.LastModified", entry.LastModified),
zap.Time("timestampOfLastIteratorInvalidation", lastIteratorInvalidation),
zap.String("invalidationType", invalidationType))
span.SetAttributes(attribute.String("invalidationType", invalidationType))
findChangesAndInvalidateHistogram.WithLabelValues(invalidationType).Observe(float64(time.Since(start).Milliseconds()))
}
// invalidateIteratorCache writes a new key to the cache with a very long TTL.
// An alternative implementation could delete invalid keys, but this approach is faster (see storagewrappers.findInCache).
func (c *InMemoryCacheController) invalidateIteratorCache(storeID string) {
c.cache.Set(storage.GetInvalidIteratorCacheKey(storeID), &storage.InvalidEntityCacheEntry{LastModified: time.Now()}, math.MaxInt)
}
// invalidateIteratorCacheByObjectRelation writes a new key to the cache.
// An alternative implementation could delete invalid keys, but this approach is faster (see storagewrappers.findInCache).
func (c *InMemoryCacheController) invalidateIteratorCacheByObjectRelation(storeID, object, relation string, ts time.Time) {
c.cache.Set(storage.GetInvalidIteratorByObjectRelationCacheKey(storeID, object, relation), &storage.InvalidEntityCacheEntry{LastModified: ts}, c.iteratorCacheTTL)
}
// invalidateIteratorCacheByUserAndObjectType writes a new key to the cache.
// An alternative implementation could delete invalid keys, but this approach is faster (see storagewrappers.findInCache).
func (c *InMemoryCacheController) invalidateIteratorCacheByUserAndObjectType(storeID, user, objectType string, ts time.Time) {
c.cache.Set(storage.GetInvalidIteratorByUserObjectTypeCacheKeys(storeID, []string{user}, objectType)[0], &storage.InvalidEntityCacheEntry{LastModified: ts}, c.iteratorCacheTTL)
}
package checkutil
import (
"context"
"google.golang.org/protobuf/types/known/structpb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/condition/eval"
"github.com/openfga/openfga/internal/validation"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
// BuildTupleKeyConditionFilter returns the TupleKeyConditionFilterFunc for which, together with the tuple key,
// evaluates whether condition is met.
func BuildTupleKeyConditionFilter(ctx context.Context, reqCtx *structpb.Struct, typesys *typesystem.TypeSystem) storage.TupleKeyConditionFilterFunc {
return func(t *openfgav1.TupleKey) (bool, error) {
// no condition on tuple or not found gets handled by eval.EvaluateTupleCondition
cond, _ := typesys.GetCondition(t.GetCondition().GetName())
return eval.EvaluateTupleCondition(ctx, t, cond, reqCtx)
}
}
// userFilter returns the ObjectRelation where the object is the specified user.
// If the specified type is publicly assigned type, the object will also include
// publicly wildcard.
func userFilter(hasPubliclyAssignedType bool,
user,
userType string) []*openfgav1.ObjectRelation {
if !hasPubliclyAssignedType || user == tuple.TypedPublicWildcard(userType) {
return []*openfgav1.ObjectRelation{{
Object: user,
}}
}
return []*openfgav1.ObjectRelation{
{Object: user},
{Object: tuple.TypedPublicWildcard(userType)},
}
}
// TODO: These (graph.ResolveCheckRequest, graph.ResolveCheckResponse) should be moved to a shared package to avoid having
// to duplicate across, and have better composition.
type resolveCheckRequest interface {
GetStoreID() string
GetTupleKey() *openfgav1.TupleKey
GetConsistency() openfgav1.ConsistencyPreference
GetContext() *structpb.Struct
}
func IteratorReadUsersetTuples(ctx context.Context,
req resolveCheckRequest,
allowedUserTypeRestrictions []*openfgav1.RelationReference) (storage.TupleKeyIterator, error) {
opts := storage.ReadUsersetTuplesOptions{
Consistency: storage.ConsistencyOptions{
Preference: req.GetConsistency(),
},
}
typesys, _ := typesystem.TypesystemFromContext(ctx)
ds, _ := storage.RelationshipTupleReaderFromContext(ctx)
iter, err := ds.ReadUsersetTuples(ctx, req.GetStoreID(), storage.ReadUsersetTuplesFilter{
Object: req.GetTupleKey().GetObject(),
Relation: req.GetTupleKey().GetRelation(),
AllowedUserTypeRestrictions: allowedUserTypeRestrictions,
}, opts)
if err != nil {
return nil, err
}
return storage.NewConditionsFilteredTupleKeyIterator(
storage.NewFilteredTupleKeyIterator(
storage.NewTupleKeyIteratorFromTupleIterator(iter),
validation.FilterInvalidTuples(typesys),
),
BuildTupleKeyConditionFilter(ctx, req.GetContext(), typesys),
), nil
}
// IteratorReadStartingFromUser returns storage iterator for
// user with request's type and relation with specified objectIDs as
// filter.
func IteratorReadStartingFromUser(ctx context.Context,
typesys *typesystem.TypeSystem,
ds storage.RelationshipTupleReader,
req resolveCheckRequest,
objectRel string,
objectIDs storage.SortedSet,
sortContextualTuples bool) (storage.TupleKeyIterator, error) {
storeID := req.GetStoreID()
reqTupleKey := req.GetTupleKey()
opts := storage.ReadStartingWithUserOptions{
WithResultsSortedAscending: sortContextualTuples,
Consistency: storage.ConsistencyOptions{
Preference: req.GetConsistency(),
},
}
user := reqTupleKey.GetUser()
userType := tuple.GetType(user)
objectType, relation := tuple.SplitObjectRelation(objectRel)
// TODO: add in optimization to filter out user not matching the type
relationReference := typesystem.DirectRelationReference(objectType, relation)
hasPubliclyAssignedType, _ := typesys.IsPubliclyAssignable(relationReference, userType)
iter, err := ds.ReadStartingWithUser(ctx, storeID,
storage.ReadStartingWithUserFilter{
ObjectType: objectType,
Relation: relation,
UserFilter: userFilter(hasPubliclyAssignedType, user, userType),
ObjectIDs: objectIDs,
}, opts)
if err != nil {
return nil, err
}
return storage.NewConditionsFilteredTupleKeyIterator(
storage.NewFilteredTupleKeyIterator(
storage.NewTupleKeyIteratorFromTupleIterator(iter),
validation.FilterInvalidTuples(typesys),
),
BuildTupleKeyConditionFilter(ctx, req.GetContext(), typesys),
), nil
}
type V2RelationFunc func(*openfgav1.RelationReference) string
// BuildUsersetV2RelationFunc returns the reference's relation.
func BuildUsersetV2RelationFunc() V2RelationFunc {
return func(ref *openfgav1.RelationReference) string {
return ref.GetRelation()
}
}
// BuildTTUV2RelationFunc will always return the computedRelation regardless of the reference.
func BuildTTUV2RelationFunc(computedRelation string) V2RelationFunc {
return func(_ *openfgav1.RelationReference) string {
return computedRelation
}
}
package concurrency
import (
"context"
"github.com/sourcegraph/conc/pool"
)
type Pool = pool.ContextPool
// NewPool returns a new pool where each task respects context cancellation.
// Wait() will only return the first error seen.
func NewPool(ctx context.Context, maxGoroutines int) *Pool {
return pool.New().
WithContext(ctx).
WithCancelOnError().
WithFirstError().
WithMaxGoroutines(maxGoroutines)
}
// TrySendThroughChannel attempts to send an object through a channel.
// If the context is canceled, it will not send the object.
func TrySendThroughChannel[T any](ctx context.Context, msg T, channel chan<- T) bool {
select {
case <-ctx.Done():
return false
case channel <- msg:
return true
}
}
package condition
import (
"context"
"fmt"
"reflect"
"sync"
"time"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common"
celtypes "github.com/google/cel-go/common/types"
"go.opentelemetry.io/otel"
"golang.org/x/exp/maps"
"google.golang.org/protobuf/types/known/structpb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/condition/metrics"
"github.com/openfga/openfga/internal/condition/types"
)
var tracer = otel.Tracer("openfga/internal/condition")
var celBaseEnv *cel.Env
func init() {
var envOpts []cel.EnvOption
for _, customTypeOpts := range types.CustomParamTypes {
envOpts = append(envOpts, customTypeOpts...)
}
envOpts = append(envOpts, types.IPAddressEnvOption(), cel.EagerlyValidateDeclarations(true))
env, err := cel.NewEnv(envOpts...)
if err != nil {
panic(fmt.Sprintf("failed to construct CEL base env: %v", err))
}
celBaseEnv = env
}
var emptyEvaluationResult = EvaluationResult{}
type EvaluationResult struct {
Cost uint64
ConditionMet bool
MissingParameters []string
}
// EvaluableCondition represents a condition that can eventually be evaluated
// given a CEL expression and a set of parameters. Calling .Evaluate() will
// optionally call .Compile() which validates and compiles the expression and
// parameter type definitions if it hasn't been done already.
type EvaluableCondition struct {
*openfgav1.Condition
celProgramOpts []cel.ProgramOption
celEnv *cel.Env
celProgram cel.Program
compileOnce sync.Once
}
// Compile compiles a condition expression with a CEL environment
// constructed from the condition's parameter type definitions into a valid
// AST that can be evaluated at a later time.
func (e *EvaluableCondition) Compile() error {
var compileErr error
e.compileOnce.Do(func() {
if err := e.compile(); err != nil {
compileErr = err
return
}
})
return compileErr
}
func (e *EvaluableCondition) compile() error {
start := time.Now()
var err error
defer func() {
if err == nil {
metrics.Metrics.ObserveCompilationDuration(time.Since(start))
}
}()
var envOpts []cel.EnvOption
conditionParamTypes := map[string]*types.ParameterType{}
for paramName, paramTypeRef := range e.GetParameters() {
paramType, err := types.DecodeParameterType(paramTypeRef)
if err != nil {
return &CompilationError{
Condition: e.Name,
Cause: fmt.Errorf("failed to decode parameter type for parameter '%s': %v", paramName, err),
}
}
conditionParamTypes[paramName] = paramType
}
for paramName, paramType := range conditionParamTypes {
envOpts = append(envOpts, cel.Variable(paramName, paramType.CelType()))
}
env, err := celBaseEnv.Extend(envOpts...)
if err != nil {
return &CompilationError{
Condition: e.Name,
Cause: err,
}
}
source := common.NewStringSource(e.Expression, e.Name)
ast, issues := env.CompileSource(source)
if issues != nil {
if err = issues.Err(); err != nil {
return &CompilationError{
Condition: e.Name,
Cause: err,
}
}
}
e.celProgramOpts = append(e.celProgramOpts, cel.EvalOptions(cel.OptPartialEval))
prg, err := env.Program(ast, e.celProgramOpts...)
if err != nil {
return &CompilationError{
Condition: e.Name,
Cause: fmt.Errorf("condition expression construction: %w", err),
}
}
if !reflect.DeepEqual(ast.OutputType(), cel.BoolType) {
return &CompilationError{
Condition: e.Name,
Cause: fmt.Errorf("expected a bool condition expression output, but got '%s'", ast.OutputType()),
}
}
e.celEnv = env
e.celProgram = prg
return nil
}
// CastContextToTypedParameters converts the provided context to typed condition
// parameters and returns an error if any additional context fields are provided
// that are not defined by the evaluable condition.
func (e *EvaluableCondition) CastContextToTypedParameters(contextMap map[string]*structpb.Value) (map[string]any, error) {
if len(contextMap) == 0 {
return nil, nil
}
parameterTypes := e.GetParameters()
if len(parameterTypes) == 0 {
return nil, &ParameterTypeError{
Condition: e.Name,
Cause: fmt.Errorf("no parameters defined for the condition"),
}
}
converted := make(map[string]any, len(contextMap))
for parameterKey, paramTypeRef := range parameterTypes {
contextValue, ok := contextMap[parameterKey]
if !ok {
continue
}
varType, err := types.DecodeParameterType(paramTypeRef)
if err != nil {
return nil, &ParameterTypeError{
Condition: e.Name,
Cause: fmt.Errorf("failed to decode condition parameter type '%s': %v", paramTypeRef.GetTypeName(), err),
}
}
convertedParam, err := varType.ConvertValue(contextValue.AsInterface())
if err != nil {
return nil, &ParameterTypeError{
Condition: e.Name,
Cause: fmt.Errorf("failed to convert context parameter '%s': %w", parameterKey, err),
}
}
converted[parameterKey] = convertedParam
}
return converted, nil
}
// Evaluate evaluates the provided CEL condition expression with a CEL environment
// constructed from the condition's parameter type definitions and using the context maps provided.
// If more than one source map of context is provided, and if the keys provided in those map
// context(s) are overlapping, then the overlapping key for the last most context wins.
// If there are parameters missing, ConditionMet will always be set as false.
func (e *EvaluableCondition) Evaluate(
ctx context.Context,
contextMaps ...map[string]*structpb.Value,
) (EvaluationResult, error) {
_, span := tracer.Start(ctx, "Evaluate")
defer span.End()
if err := e.Compile(); err != nil {
return emptyEvaluationResult, NewEvaluationError(e.Name, err)
}
contextFields := contextMaps[0]
if contextFields == nil {
contextFields = map[string]*structpb.Value{}
}
// merge context fields
clonedContextFields := maps.Clone(contextFields)
for _, fields := range contextMaps[1:] {
maps.Copy(clonedContextFields, fields)
}
typedParams, err := e.CastContextToTypedParameters(clonedContextFields)
if err != nil {
return emptyEvaluationResult, NewEvaluationError(e.Name, err)
}
activation, err := e.celEnv.PartialVars(typedParams)
if err != nil {
return emptyEvaluationResult, NewEvaluationError(e.Name, fmt.Errorf("failed to construct condition partial vars: %v", err))
}
var missingParameters []string
for key := range e.GetParameters() {
if _, ok := activation.ResolveName(key); ok {
continue
}
missingParameters = append(missingParameters, key)
}
out, details, err := e.celProgram.ContextEval(ctx, activation)
if err != nil {
return emptyEvaluationResult, NewEvaluationError(
e.Name,
fmt.Errorf("failed to evaluate condition expression: %v", err),
)
}
var evaluationCost uint64
if details != nil {
cost := details.ActualCost()
if cost != nil {
evaluationCost = *cost
}
}
if celtypes.IsUnknown(out) {
return EvaluationResult{
ConditionMet: false,
MissingParameters: missingParameters,
Cost: evaluationCost,
}, nil
}
conditionMetVal, err := out.ConvertToNative(reflect.TypeOf(false))
if err != nil {
return emptyEvaluationResult, NewEvaluationError(
e.Name,
fmt.Errorf("failed to convert condition output to bool: %v", err),
)
}
conditionMet, ok := conditionMetVal.(bool)
if !ok {
return emptyEvaluationResult, NewEvaluationError(
e.Name,
fmt.Errorf("expected CEL type conversion to return native Go bool"),
)
}
return EvaluationResult{
ConditionMet: conditionMet,
MissingParameters: missingParameters,
Cost: evaluationCost,
}, nil
}
// WithTrackEvaluationCost enables CEL evaluation cost on the EvaluableCondition and returns the
// mutated EvaluableCondition. The expectation is that this is called on the Uncompiled condition
// because it modifies the behavior of the CEL program that is constructed after Compile.
func (e *EvaluableCondition) WithTrackEvaluationCost() *EvaluableCondition {
e.celProgramOpts = append(e.celProgramOpts, cel.EvalOptions(cel.OptOptimize, cel.OptTrackCost))
return e
}
// WithMaxEvaluationCost enables CEL evaluation cost enforcement on the EvaluableCondition and
// returns the mutated EvaluableCondition. The expectation is that this is called on the Uncompiled
// condition because it modifies the behavior of the CEL program that is constructed after Compile.
func (e *EvaluableCondition) WithMaxEvaluationCost(cost uint64) *EvaluableCondition {
e.celProgramOpts = append(e.celProgramOpts, cel.CostLimit(cost))
return e
}
// WithInterruptCheckFrequency defines the upper limit on the number of iterations within a CEL comprehension to evaluate before CEL will interrupt evaluation and check for cancellation.
// Within a comprehension on the EvaluableCondition and returns the mutated EvaluableCondition.
// The expectation is that this is called on the Uncompiled condition because it modifies
// the behavior of the CEL program that is constructed after Compile.
func (e *EvaluableCondition) WithInterruptCheckFrequency(checkFrequency uint) *EvaluableCondition {
e.celProgramOpts = append(e.celProgramOpts, cel.InterruptCheckFrequency(checkFrequency))
return e
}
// NewUncompiled returns a new EvaluableCondition that has not
// validated and compiled its expression.
func NewUncompiled(condition *openfgav1.Condition) *EvaluableCondition {
return &EvaluableCondition{Condition: condition}
}
// NewCompiled returns a new EvaluableCondition with a validated and
// compiled expression.
func NewCompiled(condition *openfgav1.Condition) (*EvaluableCondition, error) {
compiled := NewUncompiled(condition)
if err := compiled.Compile(); err != nil {
return nil, err
}
return compiled, nil
}
package condition
import (
"fmt"
"github.com/natefinch/wrap"
)
var ErrEvaluationFailed = fmt.Errorf("failed to evaluate relationship condition")
type CompilationError struct {
Condition string
Cause error
}
func (e *CompilationError) Error() string {
return fmt.Sprintf("failed to compile expression on condition '%s' - %v", e.Condition, e.Cause)
}
func (e *CompilationError) Unwrap() error {
return e.Cause
}
type EvaluationError struct {
Condition string
Cause error
}
func NewEvaluationError(condition string, cause error) error {
return wrap.With(&EvaluationError{
Condition: condition,
Cause: cause,
}, ErrEvaluationFailed)
}
func (e *EvaluationError) Error() string {
if _, ok := e.Cause.(*ParameterTypeError); ok {
return e.Unwrap().Error()
}
return fmt.Sprintf("'%s' - %v", e.Condition, e.Cause)
}
func (e *EvaluationError) Unwrap() error {
return e.Cause
}
type ParameterTypeError struct {
Condition string
Cause error
}
func (e *ParameterTypeError) Error() string {
return fmt.Sprintf("parameter type error on condition '%s' - %v", e.Condition, e.Cause)
}
func (e *ParameterTypeError) Unwrap() error {
return e.Cause
}
package eval
import (
"context"
"fmt"
"strconv"
"time"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/types/known/structpb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/condition"
"github.com/openfga/openfga/internal/condition/metrics"
"github.com/openfga/openfga/pkg/telemetry"
"github.com/openfga/openfga/pkg/tuple"
)
var tracer = otel.Tracer("openfga/internal/condition/eval")
// EvaluateTupleCondition looks at the given tuple's condition and returns an evaluation result for the given context.
// If the tuple doesn't have a condition, it exits early and doesn't create a span.
// If the tuple's condition isn't found in the model it returns an EvaluationError.
func EvaluateTupleCondition(
ctx context.Context,
tupleKey *openfgav1.TupleKey,
evaluableCondition *condition.EvaluableCondition,
context *structpb.Struct,
) (bool, error) {
if tupleKey.GetCondition().GetName() == "" {
return true, nil
}
if evaluableCondition == nil || tupleKey.GetCondition().GetName() != evaluableCondition.GetName() {
err := condition.NewEvaluationError(tupleKey.GetCondition().GetName(), fmt.Errorf("condition was not found"))
return false, err
}
ctx, span := tracer.Start(ctx, "EvaluateTupleCondition", trace.WithAttributes(
attribute.String("tuple_key", tuple.TupleKeyWithConditionToString(tupleKey)),
attribute.String("condition_name", tupleKey.GetCondition().GetName())))
defer span.End()
start := time.Now()
// merge both contexts
contextFields := []map[string]*structpb.Value{
{},
}
if context != nil {
contextFields = []map[string]*structpb.Value{context.GetFields()}
}
tupleContext := tupleKey.GetCondition().GetContext()
if tupleContext != nil {
contextFields = append(contextFields, tupleContext.GetFields())
}
conditionResult, err := evaluableCondition.Evaluate(ctx, contextFields...)
if err != nil {
telemetry.TraceError(span, err)
return false, err
}
if len(conditionResult.MissingParameters) > 0 {
return false, condition.NewEvaluationError(
tupleKey.GetCondition().GetName(),
fmt.Errorf("tuple '%s' is missing context parameters '%v'",
tuple.TupleKeyToString(tupleKey),
conditionResult.MissingParameters),
)
}
metrics.Metrics.ObserveEvaluationDuration(time.Since(start))
metrics.Metrics.ObserveEvaluationCost(conditionResult.Cost)
span.SetAttributes(attribute.Bool("condition_met", conditionResult.ConditionMet),
attribute.String("condition_cost", strconv.FormatUint(conditionResult.Cost, 10)),
attribute.StringSlice("condition_missing_params", conditionResult.MissingParameters),
)
return conditionResult.ConditionMet, nil
}
// Package metrics provides various metric and telemetry definitions for OpenFGA Conditions.
package metrics
import (
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/openfga/openfga/internal/build"
"github.com/openfga/openfga/internal/utils"
"github.com/openfga/openfga/pkg/server/config"
)
// Metrics provides access to Condition metrics.
var Metrics *ConditionMetrics
func init() {
m := &ConditionMetrics{
compilationTime: promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: "condition_compilation_duration_ms",
Help: "A histogram measuring the compilation time (in milliseconds) of a Condition.",
Buckets: []float64{1, 5, 15, 50, 100, 250, 500, 1000},
}),
evaluationTime: promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: "condition_evaluation_duration_ms",
Help: "A histogram measuring the evaluation time (in milliseconds) of a Condition.",
Buckets: []float64{0.1, 0.25, 0.5, 1, 5, 15, 50, 100, 250, 500},
}),
evaluationCost: promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: "condition_evaluation_cost",
Help: "A histogram of the CEL evaluation cost of a Condition in a Relationship Tuple",
Buckets: utils.LinearBuckets(0, config.DefaultMaxConditionEvaluationCost, 10),
NativeHistogramBucketFactor: 1.1,
NativeHistogramMaxBucketNumber: config.DefaultMaxConditionEvaluationCost,
NativeHistogramMinResetDuration: time.Hour,
}),
}
Metrics = m
}
type ConditionMetrics struct {
compilationTime prometheus.Histogram
evaluationTime prometheus.Histogram
evaluationCost prometheus.Histogram
}
// ObserveCompilationDuration records the duration (in milliseconds) that Condition compilation took.
func (m *ConditionMetrics) ObserveCompilationDuration(elapsed time.Duration) {
m.compilationTime.Observe(float64(elapsed.Milliseconds()))
}
// ObserveEvaluationDuration records the duration (in milliseconds) that Condition evaluation took.
func (m *ConditionMetrics) ObserveEvaluationDuration(elapsed time.Duration) {
m.evaluationTime.Observe(float64(elapsed.Milliseconds()))
}
// ObserveEvaluationCost records the CEL evaluation cost the Condition required to resolve the expression.
func (m *ConditionMetrics) ObserveEvaluationCost(cost uint64) {
m.evaluationCost.Observe(float64(cost))
}
package types
import (
"fmt"
"math/big"
"reflect"
"time"
)
func primitiveTypeConverterFunc[T any](value any) (any, error) {
v, ok := value.(T)
if !ok {
return nil, fmt.Errorf("expected type value '%T', but found '%s'", *new(T), reflect.TypeOf(value))
}
return v, nil
}
func numericTypeConverterFunc[T int64 | uint64 | float64](value any) (any, error) {
v, ok := value.(T)
if ok {
return v, nil
}
floatValue, ok := value.(float64)
bigFloat := big.NewFloat(floatValue)
if !ok {
stringValue, ok := value.(string)
if !ok {
return nil, fmt.Errorf("expected type value '%T', but found '%s'", *new(T), reflect.TypeOf(value))
}
f, _, err := big.ParseFloat(stringValue, 10, 64, 0)
if err != nil {
return nil, fmt.Errorf("expected a %T value, but found invalid string value '%v'", *new(T), value)
}
bigFloat = f
}
n := *new(T)
switch any(n).(type) {
case int64:
if !bigFloat.IsInt() {
return nil, fmt.Errorf("expected an int value, but found numeric value '%s'", bigFloat.String())
}
numericValue, _ := bigFloat.Int64()
return numericValue, nil
case uint64:
if !bigFloat.IsInt() {
return nil, fmt.Errorf("expected a uint value, but found numeric value '%s'", bigFloat.String())
}
numericValue, _ := bigFloat.Int64()
if numericValue < 0 {
return nil, fmt.Errorf("expected a uint value, but found int64 value '%s'", bigFloat.String())
}
return uint64(numericValue), nil
case float64:
numericValue, a := bigFloat.Float64()
if a == big.Above || a == big.Below {
return nil, fmt.Errorf("number cannot be represented as a float64: %s", bigFloat.String())
}
return numericValue, nil
default:
return nil, fmt.Errorf("unsupported numeric type in numerical parameter type conversion: %T", n)
}
}
func anyTypeConverterFunc(value any) (any, error) {
return value, nil
}
func durationTypeConverterFunc(value any) (any, error) {
v, ok := value.(string)
if !ok {
return nil, fmt.Errorf("expected a duration string, but found: %T '%v'", value, value)
}
d, err := time.ParseDuration(v)
if err != nil {
return nil, fmt.Errorf("expected a valid duration string, but found: '%v'", value)
}
return d, nil
}
func timestampTypeConverterFunc(value any) (any, error) {
v, ok := value.(string)
if !ok {
return nil, fmt.Errorf("expected RFC 3339 formatted timestamp string, but found: %T '%v'", value, value)
}
d, err := time.Parse(time.RFC3339, v)
if err != nil {
return nil, fmt.Errorf("expected RFC 3339 formatted timestamp string, but found '%s'", v)
}
return d, nil
}
func ipaddressTypeConverterFunc(value any) (any, error) {
ipaddr, ok := value.(IPAddress)
if ok {
return ipaddr, nil
}
v, ok := value.(string)
if !ok {
return nil, fmt.Errorf("expected an ipaddress string, but found: %T '%v'", value, value)
}
d, err := ParseIPAddress(v)
if err != nil {
return nil, fmt.Errorf("expected a well-formed IP address, but found: '%s'", v)
}
return d, nil
}
package types
import (
"fmt"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
)
func DecodeParameterType(conditionParamType *openfgav1.ConditionParamTypeRef) (*ParameterType, error) {
paramTypedef, ok := paramTypeDefinitions[conditionParamType.GetTypeName()]
if !ok {
return nil, fmt.Errorf("unknown condition parameter type `%s`", conditionParamType.GetTypeName())
}
if len(conditionParamType.GetGenericTypes()) != int(paramTypedef.genericTypeCount) {
return nil, fmt.Errorf(
"condition parameter type `%s` requires %d generic types; found %d",
conditionParamType.GetTypeName(),
len(conditionParamType.GetGenericTypes()),
paramTypedef.genericTypeCount,
)
}
genericTypes := make([]ParameterType, 0, paramTypedef.genericTypeCount)
for _, encodedGenericType := range conditionParamType.GetGenericTypes() {
genericType, err := DecodeParameterType(encodedGenericType)
if err != nil {
return nil, err
}
genericTypes = append(genericTypes, *genericType)
}
return paramTypedef.toParameterType(genericTypes)
}
package types
import (
"fmt"
"github.com/google/cel-go/cel"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
)
func mapTypeConverterFunc(genericTypes []ParameterType) ParameterType {
return ParameterType{
name: openfgav1.ConditionParamTypeRef_TYPE_NAME_MAP,
celType: cel.MapType(cel.StringType, genericTypes[0].celType),
genericTypes: genericTypes,
typedParamConverter: func(value any) (any, error) {
v, ok := value.(map[string]any)
if !ok {
return nil, fmt.Errorf("map requires a map, found: %T", value)
}
converted := make(map[string]any, len(v))
for key, item := range v {
convertedItem, err := genericTypes[0].ConvertValue(item)
if err != nil {
return nil, fmt.Errorf("found an invalid value for key '%s': %w", key, err)
}
converted[key] = convertedItem
}
return converted, nil
},
}
}
func listTypeConverterFunc(genericTypes []ParameterType) ParameterType {
return ParameterType{
name: openfgav1.ConditionParamTypeRef_TYPE_NAME_LIST,
celType: cel.ListType(genericTypes[0].celType),
genericTypes: genericTypes,
typedParamConverter: func(value any) (any, error) {
v, ok := value.([]any)
if !ok {
return nil, fmt.Errorf("list requires a list, found: %T", value)
}
converted := make([]any, len(v))
for index, item := range v {
convertedItem, err := genericTypes[0].ConvertValue(item)
if err != nil {
return nil, fmt.Errorf("found an invalid list item at index `%d`: %w", index, err)
}
converted[index] = convertedItem
}
return converted, nil
},
}
}
package types
import (
"fmt"
"net/netip"
"reflect"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
)
var ipaddrLibraryDecls = map[string][]cel.FunctionOpt{
"ipaddress": {
cel.Overload("string_to_ipaddress", []*cel.Type{cel.StringType}, ipaddrCelType,
cel.UnaryBinding(stringToIPAddress))},
}
var ipaddrLib = &IPAddress{}
func IPAddressEnvOption() cel.EnvOption {
return cel.Lib(ipaddrLib)
}
func (ip IPAddress) CompileOptions() []cel.EnvOption {
options := []cel.EnvOption{}
for name, overloads := range ipaddrLibraryDecls {
options = append(options, cel.Function(name, overloads...))
}
return options
}
func (ip IPAddress) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
// IPAddressType defines a ParameterType that is used to represent IP addresses in CEL expressions.
var IPAddressType = registerCustomParamType(
openfgav1.ConditionParamTypeRef_TYPE_NAME_IPADDRESS,
cel.ObjectType("IPAddress"),
ipaddressTypeConverterFunc,
cel.Function("in_cidr",
cel.MemberOverload("ipaddr_in_cidr",
[]*cel.Type{cel.ObjectType("IPAddress"), cel.StringType},
cel.BoolType,
cel.BinaryBinding(ipaddressCELBinaryBinding),
),
),
)
// IPAddress represents a network IP address.
type IPAddress struct {
addr netip.Addr
}
// ParseIPAddress attempts to parse the provided ip string. If the provided string does
// not define a well-formed IP address, then an error is returned.
func ParseIPAddress(ip string) (IPAddress, error) {
addr, err := netip.ParseAddr(ip)
if err != nil {
return IPAddress{}, err
}
return IPAddress{addr}, nil
}
// ipaddrCelType defines a CEL type for the IPAddress and registers it as a receiver type.
var ipaddrCelType = cel.ObjectType("IPAddress", traits.ReceiverType)
// ConvertToNative implements the CEL ref.Val.ConvertToNative.
//
// See https://pkg.go.dev/github.com/google/cel-go/common/types/ref#Val
func (ip IPAddress) ConvertToNative(typeDesc reflect.Type) (any, error) {
if reflect.TypeOf(ip).AssignableTo(typeDesc) {
return ip, nil
}
switch typeDesc {
case reflect.TypeOf(""):
return ip.addr.String(), nil
default:
return nil, fmt.Errorf("failed to convert from type '%s' to native Go type 'IPAddress'", typeDesc)
}
}
// ConvertToType implements the CEL ref.Val.ConvertToType.
//
// See https://pkg.go.dev/github.com/google/cel-go/common/types/ref#Val
func (ip IPAddress) ConvertToType(typeValue ref.Type) ref.Val {
switch typeValue {
case types.StringType:
return types.String(ip.addr.String())
case types.TypeType:
return ipaddrCelType
default:
return types.NewErr("failed to convert from CEL type '%s' to '%s'", ipaddrCelType, typeValue)
}
}
// Equal implements the CEL ref.Val.Equal.
//
// See https://pkg.go.dev/github.com/google/cel-go/common/types/ref#Val
func (ip IPAddress) Equal(other ref.Val) ref.Val {
otherip, ok := other.(IPAddress)
if !ok {
return types.NoSuchOverloadErr()
}
return types.Bool(ip.addr.Compare(otherip.addr) == 0)
}
// Type implements the CEL ref.Val.Type.
//
// See https://pkg.go.dev/github.com/google/cel-go/common/types/ref#Val
func (ip IPAddress) Type() ref.Type {
return ipaddrCelType
}
// Value implements ref.Val.Value.
//
// See https://pkg.go.dev/github.com/google/cel-go/common/types/ref#Val
func (ip IPAddress) Value() any {
return ip
}
// ipaddressBinaryBinding implements a cel.BinaryBinding that is used as a receiver overload for
// comparing an ipaddress value against a network CIDR defined as a string. If the ipaddress is
// within the CIDR range this binding will return true, otherwise it will return false or an error.
//
// See https://pkg.go.dev/github.com/google/cel-go/cel#BinaryBinding
func ipaddressCELBinaryBinding(lhs, rhs ref.Val) ref.Val {
cidr, ok := rhs.Value().(string)
if !ok {
return types.NewErr("a CIDR string is required for comparison")
}
network, err := netip.ParsePrefix(cidr)
if err != nil {
return types.NewErr("'%s' is a malformed CIDR string", cidr)
}
ipaddr, ok := lhs.(IPAddress)
if !ok {
return types.NewErr("an IPAddress parameter value is required for comparison")
}
return types.Bool(network.Contains(ipaddr.addr))
}
func stringToIPAddress(arg ref.Val) ref.Val {
ipStr, ok := arg.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
ipaddr, err := ParseIPAddress(ipStr)
if err != nil {
return types.NewErr("%s", err.Error())
}
return ipaddr
}
package types
import (
"fmt"
"strings"
"github.com/google/cel-go/cel"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
)
var CustomParamTypes = map[openfgav1.ConditionParamTypeRef_TypeName][]cel.EnvOption{}
var paramTypeDefinitions = map[openfgav1.ConditionParamTypeRef_TypeName]paramTypeDefinition{}
// typedParamValueConverter defines a signature that implementations can provide to enforce type enforcements
// over any values provided.
type typedParamValueConverter func(value any) (any, error)
// paramTypeDefinition represents a parameter type definition included in a relationship condition.
//
// For example, the following condition defines two parameter type definitions (user and color), and
// the 'user' parameter has a type that is a map<any> (a map with a generic type of any) and the 'color'
// parameter has a type that is a string.
//
// condition favorite_color(user map<any>, color string) {
// user.favoriteColor == color
// }
type paramTypeDefinition struct {
// name is the name/keyword for the type (e.g. 'string', 'timestamp', 'duration', 'map', 'list', 'ipaddress')
name openfgav1.ConditionParamTypeRef_TypeName
genericTypeCount uint
toParameterType func(genericType []ParameterType) (*ParameterType, error)
}
var paramTypeString = map[openfgav1.ConditionParamTypeRef_TypeName]string{
openfgav1.ConditionParamTypeRef_TYPE_NAME_ANY: "any",
openfgav1.ConditionParamTypeRef_TYPE_NAME_BOOL: "bool",
openfgav1.ConditionParamTypeRef_TYPE_NAME_STRING: "string",
openfgav1.ConditionParamTypeRef_TYPE_NAME_INT: "int",
openfgav1.ConditionParamTypeRef_TYPE_NAME_UINT: "uint",
openfgav1.ConditionParamTypeRef_TYPE_NAME_DOUBLE: "double",
openfgav1.ConditionParamTypeRef_TYPE_NAME_DURATION: "duration",
openfgav1.ConditionParamTypeRef_TYPE_NAME_TIMESTAMP: "timestamp",
openfgav1.ConditionParamTypeRef_TYPE_NAME_MAP: "map",
openfgav1.ConditionParamTypeRef_TYPE_NAME_LIST: "list",
openfgav1.ConditionParamTypeRef_TYPE_NAME_IPADDRESS: "ipaddress",
}
func registerParamTypeWithGenerics(
paramTypeKeyword openfgav1.ConditionParamTypeRef_TypeName,
genericTypeCount uint,
toParameterType func(genericType []ParameterType) ParameterType,
) func(genericTypes ...ParameterType) (ParameterType, error) {
paramTypeDefinitions[paramTypeKeyword] = paramTypeDefinition{
name: paramTypeKeyword,
genericTypeCount: genericTypeCount,
toParameterType: func(genericTypes []ParameterType) (*ParameterType, error) {
if uint(len(genericTypes)) != genericTypeCount {
return nil, fmt.Errorf("type `%s` requires %d generic types; found %d", paramTypeKeyword, genericTypeCount, len(genericTypes))
}
built := toParameterType(genericTypes)
return &built, nil
},
}
return func(genericTypes ...ParameterType) (ParameterType, error) {
if uint(len(genericTypes)) != genericTypeCount {
return ParameterType{}, fmt.Errorf("invalid number of parameters given to type constructor. expected: %d, found: %d", genericTypeCount, len(genericTypes))
}
return toParameterType(genericTypes), nil
}
}
func registerParamType(
paramTypeKeyword openfgav1.ConditionParamTypeRef_TypeName,
celType *cel.Type,
typedParamConverter typedParamValueConverter,
) ParameterType {
paramType := ParameterType{
name: paramTypeKeyword,
celType: celType,
genericTypes: nil,
typedParamConverter: typedParamConverter,
}
paramTypeDefinitions[paramTypeKeyword] = paramTypeDefinition{
name: paramTypeKeyword,
genericTypeCount: 0,
toParameterType: func(genericTypes []ParameterType) (*ParameterType, error) {
return ¶mType, nil
},
}
return paramType
}
func registerCustomParamType(
paramTypeKeyword openfgav1.ConditionParamTypeRef_TypeName,
celType *cel.Type,
typeConverter typedParamValueConverter,
celOpts ...cel.EnvOption,
) ParameterType {
CustomParamTypes[paramTypeKeyword] = celOpts
return registerParamType(paramTypeKeyword, celType, typeConverter)
}
// ParameterType defines the canonical representation of parameter types supported in conditions.
type ParameterType struct {
name openfgav1.ConditionParamTypeRef_TypeName
celType *cel.Type
genericTypes []ParameterType
typedParamConverter typedParamValueConverter
}
func NewParameterType(
name openfgav1.ConditionParamTypeRef_TypeName,
celType *cel.Type,
generics []ParameterType,
typedParamConverter typedParamValueConverter,
) ParameterType {
return ParameterType{
name,
celType,
generics,
typedParamConverter,
}
}
// CelType returns the underlying Google CEL type for the variable type.
func (pt ParameterType) CelType() *cel.Type {
return pt.celType
}
func (pt ParameterType) String() string {
if len(pt.genericTypes) > 0 {
genericTypeStrings := make([]string, 0, len(pt.genericTypes))
for _, genericType := range pt.genericTypes {
genericTypeStrings = append(genericTypeStrings, genericType.String())
}
// e.g. map<int>
return fmt.Sprintf("%s<%s>", pt.name, strings.Join(genericTypeStrings, ", "))
}
str, ok := paramTypeString[pt.name]
if !ok {
return "unknown"
}
return str
}
func (pt ParameterType) ConvertValue(value any) (any, error) {
converted, err := pt.typedParamConverter(value)
if err != nil {
return nil, err
}
return converted, nil
}
package graph
type CheckResolverOrderedBuilder struct {
resolvers []CheckResolver
localCheckerOptions []LocalCheckerOption
shadowLocalCheckerOptions []LocalCheckerOption
shadowResolverEnabled bool
shadowResolverOptions []ShadowResolverOpt
cachedCheckResolverEnabled bool
cachedCheckResolverOptions []CachedCheckResolverOpt
dispatchThrottlingCheckResolverEnabled bool
dispatchThrottlingCheckResolverOptions []DispatchThrottlingCheckResolverOpt
}
type CheckResolverOrderedBuilderOpt func(checkResolver *CheckResolverOrderedBuilder)
// WithLocalCheckerOpts sets the opts to be used to build LocalChecker.
func WithLocalCheckerOpts(opts ...LocalCheckerOption) CheckResolverOrderedBuilderOpt {
return func(r *CheckResolverOrderedBuilder) {
r.localCheckerOptions = opts
}
}
func WithLocalShadowCheckerOpts(opts ...LocalCheckerOption) CheckResolverOrderedBuilderOpt {
return func(r *CheckResolverOrderedBuilder) {
r.shadowLocalCheckerOptions = opts
}
}
func WithShadowResolverEnabled(enabled bool) CheckResolverOrderedBuilderOpt {
return func(r *CheckResolverOrderedBuilder) {
r.shadowResolverEnabled = enabled
}
}
func WithShadowResolverOpts(opts ...ShadowResolverOpt) CheckResolverOrderedBuilderOpt {
return func(r *CheckResolverOrderedBuilder) {
r.shadowResolverOptions = opts
}
}
// WithCachedCheckResolverOpts sets the opts to be used to build CachedCheckResolver.
func WithCachedCheckResolverOpts(enabled bool, opts ...CachedCheckResolverOpt) CheckResolverOrderedBuilderOpt {
return func(r *CheckResolverOrderedBuilder) {
r.cachedCheckResolverEnabled = enabled
r.cachedCheckResolverOptions = opts
}
}
// WithDispatchThrottlingCheckResolverOpts sets the opts to be used to build DispatchThrottlingCheckResolver.
func WithDispatchThrottlingCheckResolverOpts(enabled bool, opts ...DispatchThrottlingCheckResolverOpt) CheckResolverOrderedBuilderOpt {
return func(r *CheckResolverOrderedBuilder) {
r.dispatchThrottlingCheckResolverEnabled = enabled
r.dispatchThrottlingCheckResolverOptions = opts
}
}
func NewOrderedCheckResolvers(opts ...CheckResolverOrderedBuilderOpt) *CheckResolverOrderedBuilder {
checkResolverBuilder := &CheckResolverOrderedBuilder{}
for _, opt := range opts {
opt(checkResolverBuilder)
}
return checkResolverBuilder
}
// Build constructs a CheckResolver that is composed of various CheckResolvers in the manner of a circular linked list.
// The resolvers should be added from least resource intensive to most resource intensive.
//
// [...Other resolvers depending on the opts order]
// LocalChecker ----------------------------^
//
// The returned CheckResolverCloser should be used to close all resolvers involved in the list.
func (c *CheckResolverOrderedBuilder) Build() (CheckResolver, CheckResolverCloser, error) {
c.resolvers = []CheckResolver{}
if c.cachedCheckResolverEnabled {
cachedCheckResolver, err := NewCachedCheckResolver(c.cachedCheckResolverOptions...)
if err != nil {
return nil, nil, err
}
c.resolvers = append(c.resolvers, cachedCheckResolver)
}
if c.dispatchThrottlingCheckResolverEnabled {
c.resolvers = append(c.resolvers, NewDispatchThrottlingCheckResolver(c.dispatchThrottlingCheckResolverOptions...))
}
if c.shadowResolverEnabled {
main := NewLocalChecker(c.localCheckerOptions...)
shadow := NewLocalChecker(c.shadowLocalCheckerOptions...)
c.resolvers = append(c.resolvers, NewShadowChecker(main, shadow, c.shadowResolverOptions...))
} else {
c.resolvers = append(c.resolvers, NewLocalChecker(c.localCheckerOptions...))
}
for i, resolver := range c.resolvers {
if i == len(c.resolvers)-1 {
resolver.SetDelegate(c.resolvers[0])
continue
}
resolver.SetDelegate(c.resolvers[i+1])
}
return c.resolvers[0], c.close, nil
}
// close will ensure all the CheckResolver constructed are closed.
func (c *CheckResolverOrderedBuilder) close() {
for _, resolver := range c.resolvers {
resolver.Close()
}
}
// LocalCheckResolver returns the local checker in the chain of CheckResolver.
func LocalCheckResolver(resolver CheckResolver) (*LocalChecker, bool) {
if resolver == nil {
return nil, false
}
localChecker, ok := resolver.(*LocalChecker)
if ok {
return localChecker, true
}
shadowChecker, ok := resolver.(*ShadowResolver)
if ok {
return LocalCheckResolver(shadowChecker.main)
}
delegate := resolver.GetDelegate()
if delegate != nil {
if delegate == resolver {
// this handles the case where the delegate is itself (to avoid the problem of infinite loop)
return nil, false
}
return LocalCheckResolver(delegate)
}
return nil, false
}
package graph
import (
"context"
"strconv"
"time"
"github.com/cespare/xxhash/v2"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/build"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/telemetry"
"github.com/openfga/openfga/pkg/tuple"
)
const (
defaultMaxCacheSize = 10000
defaultCacheTTL = 10 * time.Second
)
var (
checkCacheTotalCounter = promauto.NewCounter(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: "check_cache_total_count",
Help: "The total number of calls to ResolveCheck.",
})
checkCacheHitCounter = promauto.NewCounter(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: "check_cache_hit_count",
Help: "The total number of cache hits for ResolveCheck.",
})
checkCacheInvalidHit = promauto.NewCounter(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: "check_cache_invalid_hit_count",
Help: "The total number of cache hits for ResolveCheck that were discarded because they were invalidated.",
})
)
var _ storage.CacheItem = (*CheckResponseCacheEntry)(nil)
type CheckResponseCacheEntry struct {
LastModified time.Time
CheckResponse *ResolveCheckResponse
}
func (c *CheckResponseCacheEntry) CacheEntityType() string {
return "check_response"
}
// CachedCheckResolver attempts to resolve check sub-problems via prior computations before
// delegating the request to some underlying CheckResolver.
type CachedCheckResolver struct {
delegate CheckResolver
cache storage.InMemoryCache[any]
cacheTTL time.Duration
logger logger.Logger
// allocatedCache is used to denote whether the cache is allocated by this struct.
// If so, CachedCheckResolver is responsible for cleaning up.
allocatedCache bool
}
var _ CheckResolver = (*CachedCheckResolver)(nil)
// CachedCheckResolverOpt defines an option that can be used to change the behavior of cachedCheckResolver
// instance.
type CachedCheckResolverOpt func(*CachedCheckResolver)
// WithCacheTTL sets the TTL (as a duration) for any single Check cache key value.
func WithCacheTTL(ttl time.Duration) CachedCheckResolverOpt {
return func(ccr *CachedCheckResolver) {
ccr.cacheTTL = ttl
}
}
// WithExistingCache sets the cache to the specified cache.
// Note that the original cache will not be stopped as it may still be used by others. It is up to the caller
// to check whether the original cache should be stopped.
func WithExistingCache(cache storage.InMemoryCache[any]) CachedCheckResolverOpt {
return func(ccr *CachedCheckResolver) {
ccr.cache = cache
}
}
// WithLogger sets the logger for the cached check resolver.
func WithLogger(logger logger.Logger) CachedCheckResolverOpt {
return func(ccr *CachedCheckResolver) {
ccr.logger = logger
}
}
// NewCachedCheckResolver constructs a CheckResolver that delegates Check resolution to the provided delegate,
// but before delegating the query to the delegate a cache-key lookup is made to see if the Check sub-problem
// has already recently been computed. If the Check sub-problem is in the cache, then the response is returned
// immediately and no re-computation is necessary.
// NOTE: the ResolveCheck's resolution data will be set as the default values as we actually did no database lookup.
func NewCachedCheckResolver(opts ...CachedCheckResolverOpt) (*CachedCheckResolver, error) {
checker := &CachedCheckResolver{
cacheTTL: defaultCacheTTL,
logger: logger.NewNoopLogger(),
}
checker.delegate = checker
for _, opt := range opts {
opt(checker)
}
if checker.cache == nil {
checker.allocatedCache = true
cacheOptions := []storage.InMemoryLRUCacheOpt[any]{
storage.WithMaxCacheSize[any](defaultMaxCacheSize),
}
var err error
checker.cache, err = storage.NewInMemoryLRUCache[any](cacheOptions...)
if err != nil {
return nil, err
}
}
return checker, nil
}
// SetDelegate sets this CachedCheckResolver's dispatch delegate.
func (c *CachedCheckResolver) SetDelegate(delegate CheckResolver) {
c.delegate = delegate
}
// GetDelegate returns this CachedCheckResolver's dispatch delegate.
func (c *CachedCheckResolver) GetDelegate() CheckResolver {
return c.delegate
}
// Close will deallocate resource allocated by the CachedCheckResolver
// It will not deallocate cache if it has been passed in from WithExistingCache.
func (c *CachedCheckResolver) Close() {
if c.allocatedCache {
c.cache.Stop()
}
}
func (c *CachedCheckResolver) ResolveCheck(
ctx context.Context,
req *ResolveCheckRequest,
) (*ResolveCheckResponse, error) {
span := trace.SpanFromContext(ctx)
cacheKey := BuildCacheKey(*req)
tryCache := req.Consistency != openfgav1.ConsistencyPreference_HIGHER_CONSISTENCY
if tryCache {
checkCacheTotalCounter.Inc()
if cachedResp := c.cache.Get(cacheKey); cachedResp != nil {
res := cachedResp.(*CheckResponseCacheEntry)
isValid := res.LastModified.After(req.LastCacheInvalidationTime)
c.logger.Debug("CachedCheckResolver found cache key",
zap.String("store_id", req.GetStoreID()),
zap.String("authorization_model_id", req.GetAuthorizationModelID()),
zap.String("tuple_key", req.GetTupleKey().String()),
zap.Bool("isValid", isValid))
span.SetAttributes(attribute.Bool("cached", isValid))
if isValid {
checkCacheHitCounter.Inc()
// return a copy to avoid races across goroutines
return res.CheckResponse.clone(), nil
}
// we tried the cache and hit an invalid entry
checkCacheInvalidHit.Inc()
} else {
c.logger.Debug("CachedCheckResolver not found cache key",
zap.String("store_id", req.GetStoreID()),
zap.String("authorization_model_id", req.GetAuthorizationModelID()),
zap.String("tuple_key", req.GetTupleKey().String()))
}
}
// not in cache, or consistency options experimental flag is set, and consistency param set to HIGHER_CONSISTENCY
resp, err := c.delegate.ResolveCheck(ctx, req)
if err != nil {
telemetry.TraceError(span, err)
return nil, err
}
// when the response indicates cycle detected. The result is indeterminate because the
// parent of the cycle could have resolved to true. Thus, we don't save the result and let
// the parent handle it.
if resp.GetCycleDetected() {
span.SetAttributes(attribute.Bool("cycle_detected", true))
c.logger.Debug("CachedCheckResolver not saving to cache due to cycle",
zap.String("store_id", req.GetStoreID()),
zap.String("authorization_model_id", req.GetAuthorizationModelID()),
zap.String("tuple_key", req.GetTupleKey().String()))
return resp, nil
}
clonedResp := resp.clone()
c.cache.Set(cacheKey, &CheckResponseCacheEntry{LastModified: time.Now(), CheckResponse: clonedResp}, c.cacheTTL)
return resp, nil
}
func BuildCacheKey(req ResolveCheckRequest) string {
tup := tuple.From(req.GetTupleKey())
cacheKeyString := tup.String() + req.GetInvariantCacheKey()
hasher := xxhash.New()
// Digest.WriteString returns int and a nil error, ignoring
_, _ = hasher.WriteString(cacheKeyString)
return strconv.FormatUint(hasher.Sum64(), 10)
}
package graph
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/emirpasic/gods/sets/hashset"
"github.com/sourcegraph/conc/panics"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/checkutil"
"github.com/openfga/openfga/internal/concurrency"
openfgaErrors "github.com/openfga/openfga/internal/errors"
"github.com/openfga/openfga/internal/planner"
"github.com/openfga/openfga/internal/validation"
"github.com/openfga/openfga/pkg/logger"
serverconfig "github.com/openfga/openfga/pkg/server/config"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/telemetry"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
var tracer = otel.Tracer("internal/graph/check")
type setOperatorType int
var (
ErrUnknownSetOperator = fmt.Errorf("%w: unexpected set operator type encountered", openfgaErrors.ErrUnknown)
ErrPanic = errors.New("panic captured")
)
const (
unionSetOperator setOperatorType = iota
intersectionSetOperator
exclusionSetOperator
)
type checkOutcome struct {
resp *ResolveCheckResponse
err error
}
type LocalChecker struct {
delegate CheckResolver
concurrencyLimit int
upstreamTimeout time.Duration
planner planner.Manager
logger logger.Logger
optimizationsEnabled bool
maxResolutionDepth uint32
}
type LocalCheckerOption func(d *LocalChecker)
// WithResolveNodeBreadthLimit see server.WithResolveNodeBreadthLimit.
func WithResolveNodeBreadthLimit(limit uint32) LocalCheckerOption {
return func(d *LocalChecker) {
d.concurrencyLimit = int(limit)
}
}
func WithOptimizations(enabled bool) LocalCheckerOption {
return func(d *LocalChecker) {
d.optimizationsEnabled = enabled
}
}
func WithPlanner(p planner.Manager) LocalCheckerOption {
return func(d *LocalChecker) {
d.planner = p
}
}
func WithLocalCheckerLogger(logger logger.Logger) LocalCheckerOption {
return func(d *LocalChecker) {
d.logger = logger
}
}
func WithMaxResolutionDepth(depth uint32) LocalCheckerOption {
return func(d *LocalChecker) {
d.maxResolutionDepth = depth
}
}
func WithUpstreamTimeout(timeout time.Duration) LocalCheckerOption {
return func(d *LocalChecker) {
d.upstreamTimeout = timeout
}
}
// NewLocalChecker constructs a LocalChecker that can be used to evaluate a Check
// request locally.
//
// Developers wanting a LocalChecker with other optional layers (e.g caching and others)
// are encouraged to use [[NewOrderedCheckResolvers]] instead.
func NewLocalChecker(opts ...LocalCheckerOption) *LocalChecker {
checker := &LocalChecker{
concurrencyLimit: serverconfig.DefaultResolveNodeBreadthLimit,
maxResolutionDepth: serverconfig.DefaultResolveNodeLimit,
upstreamTimeout: serverconfig.DefaultRequestTimeout,
logger: logger.NewNoopLogger(),
planner: planner.NewNoopPlanner(),
}
// by default, a LocalChecker delegates/dispatches subproblems to itself (e.g. local dispatch) unless otherwise configured.
checker.delegate = checker
for _, opt := range opts {
opt(checker)
}
return checker
}
// SetDelegate sets this LocalChecker's dispatch delegate.
func (c *LocalChecker) SetDelegate(delegate CheckResolver) {
c.delegate = delegate
}
// GetDelegate sets this LocalChecker's dispatch delegate.
func (c *LocalChecker) GetDelegate() CheckResolver {
return c.delegate
}
// CheckHandlerFunc defines a function that evaluates a CheckResponse or returns an error
// otherwise.
type CheckHandlerFunc func(ctx context.Context) (*ResolveCheckResponse, error)
// CheckFuncReducer defines a function that combines or reduces one or more CheckHandlerFunc into
// a single CheckResponse with a maximum limit on the number of concurrent evaluations that can be
// in flight at any given time.
type CheckFuncReducer func(ctx context.Context, concurrencyLimit int, handlers ...CheckHandlerFunc) (*ResolveCheckResponse, error)
// runHandler safely executes a CheckHandlerFunc, recovers from any panics,
// and returns the result as a checkOutcome.
func runHandler(ctx context.Context, handler CheckHandlerFunc) checkOutcome {
var res *ResolveCheckResponse
var err error
recoveredErr := panics.Try(func() {
res, err = handler(ctx)
})
if recoveredErr != nil {
err = fmt.Errorf("%w: %s", ErrPanic, recoveredErr.AsError())
}
return checkOutcome{res, err}
}
// union implements a CheckFuncReducer that requires any of the provided CheckHandlerFunc to resolve
// to an allowed outcome. The first allowed outcome causes premature termination of the reducer.
func union(ctx context.Context, concurrencyLimit int, handlers ...CheckHandlerFunc) (resp *ResolveCheckResponse, err error) {
cancellableCtx, cancel := context.WithCancel(ctx)
defer cancel()
pool := concurrency.NewPool(cancellableCtx, concurrencyLimit)
out := make(chan checkOutcome, len(handlers))
for _, handler := range handlers {
h := handler
pool.Go(func(ctx context.Context) error {
concurrency.TrySendThroughChannel(cancellableCtx, runHandler(ctx, h), out)
return nil
})
}
go func() {
_ = pool.Wait()
close(out)
}()
var finalErr error
finalResult := &ResolveCheckResponse{Allowed: false}
for i := 0; i < len(handlers); i++ {
select {
case <-ctx.Done():
return nil, ctx.Err()
case outcome, ok := <-out:
if !ok {
break
}
if outcome.err != nil {
finalErr = outcome.err
continue // Continue to see if we find an 'Allowed: true'
}
if outcome.resp.GetResolutionMetadata().CycleDetected {
finalResult.ResolutionMetadata.CycleDetected = true
}
if outcome.resp.Allowed {
// Short-circuit success. defer cancel() will clean up workers.
return outcome.resp, nil
}
}
}
if finalErr != nil {
return nil, finalErr
}
return finalResult, nil
}
// intersection implements a CheckFuncReducer that requires all of the provided CheckHandlerFunc to resolve
// to an allowed outcome. The first falsey causes premature termination of the reducer. Errors are swallowed if there is a false outcome.
func intersection(ctx context.Context, concurrencyLimit int, handlers ...CheckHandlerFunc) (resp *ResolveCheckResponse, err error) {
if len(handlers) < 2 {
return nil, fmt.Errorf("%w, expected at least two rewrite operands for intersection operator, but got '%d'", openfgaErrors.ErrUnknown, len(handlers))
}
cancellableCtx, cancel := context.WithCancel(ctx)
defer cancel()
pool := concurrency.NewPool(cancellableCtx, concurrencyLimit)
out := make(chan checkOutcome, len(handlers))
for _, handler := range handlers {
h := handler // Capture loop variable for the goroutine
pool.Go(func(ctx context.Context) error {
concurrency.TrySendThroughChannel(cancellableCtx, runHandler(ctx, h), out)
return nil
})
}
go func() {
_ = pool.Wait()
close(out)
}()
var finalErr error
finalResult := &ResolveCheckResponse{
Allowed: true,
}
for i := 0; i < len(handlers); i++ {
select {
case <-ctx.Done():
return nil, ctx.Err()
case outcome, ok := <-out:
if !ok {
break
}
if outcome.err != nil {
// Store the first error we see, but don't exit yet.
// A definitive 'false' result from another handler can override this.
if finalErr == nil {
finalErr = outcome.err
}
continue
}
if outcome.resp.GetResolutionMetadata().CycleDetected || !outcome.resp.Allowed {
// Short-circuit failure. defer cancel() will clean up workers.
finalResult.Allowed = false
finalResult.ResolutionMetadata.CycleDetected = outcome.resp.GetResolutionMetadata().CycleDetected
return finalResult, nil
}
}
}
// If we've processed all handlers without a definitive 'false',
// then any error we encountered along the way is the final result.
if finalErr != nil {
return nil, finalErr
}
// If the loop completes without any "false" outcomes or errors, the result is "true".
return finalResult, nil
}
// exclusion implements a CheckFuncReducer that requires a 'base' CheckHandlerFunc to resolve to an allowed
// outcome and a 'sub' CheckHandlerFunc to resolve to a falsey outcome.
func exclusion(ctx context.Context, _ int, handlers ...CheckHandlerFunc) (*ResolveCheckResponse, error) {
if len(handlers) != 2 {
return nil, fmt.Errorf("%w, expected two rewrite operands for exclusion operator, but got '%d'", openfgaErrors.ErrUnknown, len(handlers))
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
baseChan := make(chan checkOutcome, 1)
subChan := make(chan checkOutcome, 1)
go func() {
concurrency.TrySendThroughChannel(ctx, runHandler(ctx, handlers[0]), baseChan)
close(baseChan)
}()
go func() {
concurrency.TrySendThroughChannel(ctx, runHandler(ctx, handlers[1]), subChan)
close(subChan)
}()
var baseErr, subErr error
// Loop until we have received one result from each of the two channels.
resultsReceived := 0
for resultsReceived < 2 {
select {
case <-ctx.Done():
return nil, ctx.Err()
case res, ok := <-baseChan:
if !ok {
baseChan = nil // Stop selecting this case.
continue
}
resultsReceived++
if res.err != nil {
baseErr = res.err
continue
}
// Short-circuit: If base is false, the whole expression is false.
if res.resp.GetCycleDetected() || !res.resp.GetAllowed() {
return &ResolveCheckResponse{Allowed: false, ResolutionMetadata: ResolveCheckResponseMetadata{CycleDetected: res.resp.GetCycleDetected()}}, nil
}
case res, ok := <-subChan:
if !ok {
subChan = nil // Stop selecting this case.
continue
}
resultsReceived++
if res.err != nil {
subErr = res.err
continue
}
// Short-circuit: If subtract is true, the whole expression is false.
if res.resp.GetCycleDetected() || res.resp.GetAllowed() {
return &ResolveCheckResponse{Allowed: false, ResolutionMetadata: ResolveCheckResponseMetadata{CycleDetected: res.resp.GetCycleDetected()}}, nil
}
}
}
// At this point, we are guaranteed to have both results (or to have already short-circuited).
if baseErr != nil {
return nil, baseErr
}
if subErr != nil {
return nil, subErr
}
// The only way to get here is if base was (Allowed: true) and subtract was (Allowed: false).
return &ResolveCheckResponse{Allowed: true}, nil
}
// Close is a noop.
func (c *LocalChecker) Close() {
}
// dispatch clones the parent request, modifies its metadata and tupleKey, and dispatches the new request
// to the CheckResolver this LocalChecker was constructed with.
func (c *LocalChecker) dispatch(_ context.Context, parentReq *ResolveCheckRequest, tk *openfgav1.TupleKey) CheckHandlerFunc {
return func(ctx context.Context) (*ResolveCheckResponse, error) {
parentReq.GetRequestMetadata().DispatchCounter.Add(1)
childRequest := parentReq.clone()
childRequest.TupleKey = tk
childRequest.GetRequestMetadata().Depth++
resp, err := c.delegate.ResolveCheck(ctx, childRequest)
if err != nil {
return nil, err
}
return resp, nil
}
}
var _ CheckResolver = (*LocalChecker)(nil)
// ResolveCheck implements [[CheckResolver.ResolveCheck]].
func (c *LocalChecker) ResolveCheck(
ctx context.Context,
req *ResolveCheckRequest,
) (*ResolveCheckResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
ctx, span := tracer.Start(ctx, "ResolveCheck", trace.WithAttributes(
attribute.String("store_id", req.GetStoreID()),
attribute.String("resolver_type", "LocalChecker"),
attribute.String("tuple_key", tuple.TupleKeyWithConditionToString(req.GetTupleKey())),
))
defer span.End()
if req.GetRequestMetadata().Depth == c.maxResolutionDepth {
return nil, ErrResolutionDepthExceeded
}
cycle := c.hasCycle(req)
if cycle {
span.SetAttributes(attribute.Bool("cycle_detected", true))
return &ResolveCheckResponse{
Allowed: false,
ResolutionMetadata: ResolveCheckResponseMetadata{
CycleDetected: true,
},
}, nil
}
tupleKey := req.GetTupleKey()
object := tupleKey.GetObject()
relation := tupleKey.GetRelation()
if tuple.IsSelfDefining(req.GetTupleKey()) {
return &ResolveCheckResponse{
Allowed: true,
}, nil
}
typesys, ok := typesystem.TypesystemFromContext(ctx)
if !ok {
return nil, fmt.Errorf("%w: typesystem missing in context", openfgaErrors.ErrUnknown)
}
_, ok = storage.RelationshipTupleReaderFromContext(ctx)
if !ok {
return nil, fmt.Errorf("%w: relationship tuple reader datastore missing in context", openfgaErrors.ErrUnknown)
}
objectType, _ := tuple.SplitObject(object)
rel, err := typesys.GetRelation(objectType, relation)
if err != nil {
return nil, fmt.Errorf("relation '%s' undefined for object type '%s'", relation, objectType)
}
hasPath, err := typesys.PathExists(tupleKey.GetUser(), relation, objectType)
if err != nil {
return nil, err
}
if !hasPath {
return &ResolveCheckResponse{
Allowed: false,
}, nil
}
resp, err := c.CheckRewrite(ctx, req, rel.GetRewrite())(ctx)
if err != nil {
telemetry.TraceError(span, err)
return nil, err
}
return resp, nil
}
// hasCycle returns true if a cycle has been found. It modifies the request object.
func (c *LocalChecker) hasCycle(req *ResolveCheckRequest) bool {
key := tuple.TupleKeyToString(req.GetTupleKey())
if req.VisitedPaths == nil {
req.VisitedPaths = map[string]struct{}{}
}
_, cycleDetected := req.VisitedPaths[key]
if cycleDetected {
return true
}
req.VisitedPaths[key] = struct{}{}
return false
}
func (c *LocalChecker) checkPublicAssignable(ctx context.Context, req *ResolveCheckRequest) CheckHandlerFunc {
typesys, _ := typesystem.TypesystemFromContext(ctx)
ds, _ := storage.RelationshipTupleReaderFromContext(ctx)
storeID := req.GetStoreID()
reqTupleKey := req.GetTupleKey()
userType := tuple.GetType(reqTupleKey.GetUser())
wildcardRelationReference := typesystem.WildcardRelationReference(userType)
return func(ctx context.Context) (*ResolveCheckResponse, error) {
ctx, span := tracer.Start(ctx, "checkPublicAssignable")
defer span.End()
response := &ResolveCheckResponse{
Allowed: false,
}
opts := storage.ReadUsersetTuplesOptions{
Consistency: storage.ConsistencyOptions{
Preference: req.GetConsistency(),
},
}
// We want to query via ReadUsersetTuples instead of ReadUserTuple tuples to take
// advantage of the storage wrapper cache
// (https://github.com/openfga/openfga/blob/af054d9693bd7ebd0420456b144c2fb6888aaf87/internal/graph/storagewrapper.go#L139).
// In the future, if storage wrapper cache is available for ReadUserTuple, we can switch it to ReadUserTuple.
iter, err := ds.ReadUsersetTuples(ctx, storeID, storage.ReadUsersetTuplesFilter{
Object: reqTupleKey.GetObject(),
Relation: reqTupleKey.GetRelation(),
AllowedUserTypeRestrictions: []*openfgav1.RelationReference{wildcardRelationReference},
}, opts)
if err != nil {
return nil, err
}
filteredIter := storage.NewConditionsFilteredTupleKeyIterator(
storage.NewFilteredTupleKeyIterator(
storage.NewTupleKeyIteratorFromTupleIterator(iter),
validation.FilterInvalidTuples(typesys),
),
checkutil.BuildTupleKeyConditionFilter(ctx, req.GetContext(), typesys),
)
defer filteredIter.Stop()
_, err = filteredIter.Next(ctx)
if err != nil {
if errors.Is(err, storage.ErrIteratorDone) {
return response, nil
}
return nil, err
}
// when we get to here, it means there is public wild card assigned
span.SetAttributes(attribute.Bool("allowed", true))
response.Allowed = true
return response, nil
}
}
func (c *LocalChecker) checkDirectUserTuple(ctx context.Context, req *ResolveCheckRequest) CheckHandlerFunc {
typesys, _ := typesystem.TypesystemFromContext(ctx)
reqTupleKey := req.GetTupleKey()
return func(ctx context.Context) (*ResolveCheckResponse, error) {
ctx, span := tracer.Start(ctx, "checkDirectUserTuple",
trace.WithAttributes(attribute.String("tuple_key", tuple.TupleKeyWithConditionToString(reqTupleKey))))
defer span.End()
response := &ResolveCheckResponse{
Allowed: false,
}
ds, _ := storage.RelationshipTupleReaderFromContext(ctx)
storeID := req.GetStoreID()
opts := storage.ReadUserTupleOptions{
Consistency: storage.ConsistencyOptions{
Preference: req.GetConsistency(),
},
}
t, err := ds.ReadUserTuple(ctx, storeID, reqTupleKey, opts)
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
return response, nil
}
return nil, err
}
// filter out invalid tuples yielded by the database query
tupleKey := t.GetKey()
err = validation.ValidateTupleForRead(typesys, tupleKey)
if err != nil {
return response, nil
}
tupleKeyConditionFilter := checkutil.BuildTupleKeyConditionFilter(ctx, req.Context, typesys)
conditionMet, err := tupleKeyConditionFilter(tupleKey)
if err != nil {
telemetry.TraceError(span, err)
return nil, err
}
if conditionMet {
span.SetAttributes(attribute.Bool("allowed", true))
response.Allowed = true
}
return response, nil
}
}
// helper function to return whether checkDirectUserTuple should run.
func shouldCheckDirectTuple(ctx context.Context, reqTupleKey *openfgav1.TupleKey) bool {
typesys, _ := typesystem.TypesystemFromContext(ctx)
objectType := tuple.GetType(reqTupleKey.GetObject())
relation := reqTupleKey.GetRelation()
isDirectlyRelated, _ := typesys.IsDirectlyRelated(
typesystem.DirectRelationReference(objectType, relation), // target
typesystem.DirectRelationReference(tuple.GetType(reqTupleKey.GetUser()), tuple.GetRelation(reqTupleKey.GetUser())), // source
)
return isDirectlyRelated
}
// helper function to return whether checkPublicAssignable should run.
func shouldCheckPublicAssignable(ctx context.Context, reqTupleKey *openfgav1.TupleKey) bool {
typesys, _ := typesystem.TypesystemFromContext(ctx)
objectType := tuple.GetType(reqTupleKey.GetObject())
relation := reqTupleKey.GetRelation()
// if the user tuple is userset, by definition it cannot be a wildcard
if tuple.IsObjectRelation(reqTupleKey.GetUser()) {
return false
}
isPubliclyAssignable, _ := typesys.IsPubliclyAssignable(
typesystem.DirectRelationReference(objectType, relation), // target
tuple.GetType(reqTupleKey.GetUser()),
)
return isPubliclyAssignable
}
func (c *LocalChecker) profiledCheckHandler(keyPlan planner.Selector, strategy *planner.PlanConfig, resolver CheckHandlerFunc) CheckHandlerFunc {
return func(ctx context.Context) (*ResolveCheckResponse, error) {
start := time.Now()
res, err := resolver(ctx)
if err != nil {
// penalize plans that timeout from the upstream context
if errors.Is(err, context.DeadlineExceeded) {
keyPlan.UpdateStats(strategy, c.upstreamTimeout)
}
return nil, err
}
keyPlan.UpdateStats(strategy, time.Since(start))
return res, nil
}
}
func (c *LocalChecker) checkDirectUsersetTuples(ctx context.Context, req *ResolveCheckRequest) CheckHandlerFunc {
typesys, _ := typesystem.TypesystemFromContext(ctx)
reqTupleKey := req.GetTupleKey()
return func(ctx context.Context) (*ResolveCheckResponse, error) {
ctx, span := tracer.Start(ctx, "checkDirectUsersetTuples", trace.WithAttributes(
attribute.String("userset", tuple.ToObjectRelationString(reqTupleKey.GetObject(), reqTupleKey.GetRelation())),
))
defer span.End()
objectType, relation := tuple.GetType(reqTupleKey.GetObject()), reqTupleKey.GetRelation()
userType := tuple.GetType(reqTupleKey.GetUser())
directlyRelatedUsersetTypes, _ := typesys.DirectlyRelatedUsersets(objectType, relation)
isUserset := tuple.IsObjectRelation(reqTupleKey.GetUser())
// if user in request is userset, we do not have additional strategies to apply
if isUserset {
iter, err := checkutil.IteratorReadUsersetTuples(ctx, req, directlyRelatedUsersetTypes)
if err != nil {
return nil, err
}
defer iter.Stop()
return c.defaultUserset(ctx, req, directlyRelatedUsersetTypes, iter)(ctx)
}
possibleStrategies := map[string]*planner.PlanConfig{
defaultResolver: defaultPlan,
}
var b strings.Builder
b.WriteString("userset|")
b.WriteString(req.GetAuthorizationModelID())
b.WriteString("|")
b.WriteString(objectType)
b.WriteString("|")
b.WriteString(relation)
b.WriteString("|")
b.WriteString(userType)
b.WriteString("|")
// if the type#relation is resolvable recursively, then it can only be resolved recursively
if typesys.UsersetUseRecursiveResolver(objectType, relation, userType) {
iter, err := checkutil.IteratorReadUsersetTuples(ctx, req, directlyRelatedUsersetTypes)
if err != nil {
return nil, err
}
defer iter.Stop()
b.WriteString("infinite")
key := b.String()
keyPlan := c.planner.GetPlanSelector(key)
possibleStrategies[defaultResolver] = defaultRecursivePlan
possibleStrategies[recursiveResolver] = recursivePlan
plan := keyPlan.Select(possibleStrategies)
resolver := c.defaultUserset
if plan.Name == recursiveResolver {
resolver = c.recursiveUserset
}
return c.profiledCheckHandler(keyPlan, plan, resolver(ctx, req, directlyRelatedUsersetTypes, iter))(ctx)
}
var resolvers []CheckHandlerFunc
var remainingUsersetTypes []*openfgav1.RelationReference
keyPlanPrefix := b.String()
possibleStrategies[weightTwoResolver] = weight2Plan
for _, userset := range directlyRelatedUsersetTypes {
if !typesys.UsersetUseWeight2Resolver(objectType, relation, userType, userset) {
remainingUsersetTypes = append(remainingUsersetTypes, userset)
continue
}
usersets := []*openfgav1.RelationReference{userset}
iter, err := checkutil.IteratorReadUsersetTuples(ctx, req, usersets)
if err != nil {
return nil, err
}
// NOTE: we collect defers given that the iterator won't be consumed until `union` resolves at the end.
defer iter.Stop()
var k strings.Builder
k.WriteString(keyPlanPrefix)
k.WriteString("userset|")
k.WriteString(userset.String())
key := k.String()
keyPlan := c.planner.GetPlanSelector(key)
strategy := keyPlan.Select(possibleStrategies)
resolver := c.defaultUserset
if strategy.Name == weightTwoResolver {
resolver = c.weight2Userset
}
resolvers = append(resolvers, c.profiledCheckHandler(keyPlan, strategy, resolver(ctx, req, usersets, iter)))
}
// for all usersets could not be resolved through weight2 resolver, resolve them all through the default resolver.
// they all resolved as a group rather than individually.
if len(remainingUsersetTypes) > 0 {
iter, err := checkutil.IteratorReadUsersetTuples(ctx, req, remainingUsersetTypes)
if err != nil {
return nil, err
}
defer iter.Stop()
resolvers = append(resolvers, c.defaultUserset(ctx, req, remainingUsersetTypes, iter))
}
return union(ctx, c.concurrencyLimit, resolvers...)
}
}
// checkDirect composes three CheckHandlerFunc which evaluate direct relationships with the provided
// 'object#relation'. The first handler looks up direct matches on the provided 'object#relation@user',
// the second handler looks up wildcard matches on the provided 'object#relation@user:*',
// while the third handler looks up relationships between the target 'object#relation' and any usersets
// related to it.
func (c *LocalChecker) checkDirect(parentctx context.Context, req *ResolveCheckRequest) CheckHandlerFunc {
return func(ctx context.Context) (*ResolveCheckResponse, error) {
ctx, span := tracer.Start(ctx, "checkDirect")
defer span.End()
typesys, _ := typesystem.TypesystemFromContext(parentctx) // note: use of 'parentctx' not 'ctx' - this is important
reqTupleKey := req.GetTupleKey()
objectType := tuple.GetType(reqTupleKey.GetObject())
relation := reqTupleKey.GetRelation()
// directlyRelatedUsersetTypes could be "group#member"
directlyRelatedUsersetTypes, _ := typesys.DirectlyRelatedUsersets(objectType, relation)
var checkFuncs []CheckHandlerFunc
if shouldCheckDirectTuple(ctx, req.GetTupleKey()) {
checkFuncs = []CheckHandlerFunc{c.checkDirectUserTuple(parentctx, req)}
}
if shouldCheckPublicAssignable(ctx, reqTupleKey) {
checkFuncs = append(checkFuncs, c.checkPublicAssignable(parentctx, req))
}
if len(directlyRelatedUsersetTypes) > 0 {
checkFuncs = append(checkFuncs, c.checkDirectUsersetTuples(parentctx, req))
}
resp, err := union(ctx, c.concurrencyLimit, checkFuncs...)
if err != nil {
telemetry.TraceError(span, err)
return nil, err
}
return resp, nil
}
}
// checkComputedUserset evaluates the Check request with the rewritten relation (e.g. the computed userset relation).
func (c *LocalChecker) checkComputedUserset(_ context.Context, req *ResolveCheckRequest, rewrite *openfgav1.Userset) CheckHandlerFunc {
rewrittenTupleKey := tuple.NewTupleKey(
req.GetTupleKey().GetObject(),
rewrite.GetComputedUserset().GetRelation(),
req.GetTupleKey().GetUser(),
)
childRequest := req.clone()
childRequest.TupleKey = rewrittenTupleKey
return func(ctx context.Context) (*ResolveCheckResponse, error) {
ctx, span := tracer.Start(ctx, "checkComputedUserset")
defer span.End()
// No dispatch here, as we don't want to increase resolution depth.
return c.ResolveCheck(ctx, childRequest)
}
}
// checkTTU looks up all tuples of the target tupleset relation on the provided object and for each one
// of them evaluates the computed userset of the TTU rewrite rule for them.
func (c *LocalChecker) checkTTU(parentctx context.Context, req *ResolveCheckRequest, rewrite *openfgav1.Userset) CheckHandlerFunc {
return func(ctx context.Context) (*ResolveCheckResponse, error) {
ctx, span := tracer.Start(ctx, "checkTTU")
defer span.End()
typesys, _ := typesystem.TypesystemFromContext(parentctx) // note: use of 'parentctx' not 'ctx' - this is important
ds, _ := storage.RelationshipTupleReaderFromContext(parentctx)
objectType, relation := tuple.GetType(req.GetTupleKey().GetObject()), req.GetTupleKey().GetRelation()
userType := tuple.GetType(req.GetTupleKey().GetUser())
ctx = typesystem.ContextWithTypesystem(ctx, typesys)
ctx = storage.ContextWithRelationshipTupleReader(ctx, ds)
tuplesetRelation := rewrite.GetTupleToUserset().GetTupleset().GetRelation()
computedRelation := rewrite.GetTupleToUserset().GetComputedUserset().GetRelation()
tk := req.GetTupleKey()
object := tk.GetObject()
span.SetAttributes(
attribute.String("tupleset_relation", tuple.ToObjectRelationString(tuple.GetType(object), tuplesetRelation)),
attribute.String("computed_relation", computedRelation),
)
opts := storage.ReadOptions{
Consistency: storage.ConsistencyOptions{
Preference: req.GetConsistency(),
},
}
storeID := req.GetStoreID()
iter, err := ds.Read(
ctx,
storeID,
storage.ReadFilter{Object: object, Relation: tuplesetRelation, User: ""},
opts,
)
if err != nil {
return nil, err
}
// filter out invalid tuples yielded by the database iterator
filteredIter := storage.NewConditionsFilteredTupleKeyIterator(
storage.NewFilteredTupleKeyIterator(
storage.NewTupleKeyIteratorFromTupleIterator(iter),
validation.FilterInvalidTuples(typesys),
),
checkutil.BuildTupleKeyConditionFilter(ctx, req.GetContext(), typesys),
)
defer filteredIter.Stop()
resolver := c.defaultTTU
possibleStrategies := map[string]*planner.PlanConfig{
defaultResolver: defaultPlan,
}
isUserset := tuple.IsObjectRelation(tk.GetUser())
if !isUserset {
if typesys.TTUUseWeight2Resolver(objectType, relation, userType, rewrite.GetTupleToUserset()) {
possibleStrategies[weightTwoResolver] = weight2Plan
resolver = c.weight2TTU
} else if typesys.TTUUseRecursiveResolver(objectType, relation, userType, rewrite.GetTupleToUserset()) {
possibleStrategies[defaultResolver] = defaultRecursivePlan
possibleStrategies[recursiveResolver] = recursivePlan
resolver = c.recursiveTTU
}
}
if len(possibleStrategies) == 1 {
// short circuit, no additional resolvers are available
return resolver(ctx, req, rewrite, filteredIter)(ctx)
}
var b strings.Builder
b.WriteString("ttu|")
b.WriteString(req.GetAuthorizationModelID())
b.WriteString("|")
b.WriteString(objectType)
b.WriteString("|")
b.WriteString(relation)
b.WriteString("|")
b.WriteString(userType)
b.WriteString("|")
b.WriteString(tuplesetRelation)
b.WriteString("|")
b.WriteString(computedRelation)
planKey := b.String()
keyPlan := c.planner.GetPlanSelector(planKey)
strategy := keyPlan.Select(possibleStrategies)
switch strategy.Name {
case defaultResolver:
resolver = c.defaultTTU
case weightTwoResolver:
resolver = c.weight2TTU
case recursiveResolver:
resolver = c.recursiveTTU
}
return c.profiledCheckHandler(keyPlan, strategy, resolver(ctx, req, rewrite, filteredIter))(ctx)
}
}
func (c *LocalChecker) checkSetOperation(
ctx context.Context,
req *ResolveCheckRequest,
setOpType setOperatorType,
reducer CheckFuncReducer,
children ...*openfgav1.Userset,
) CheckHandlerFunc {
var handlers []CheckHandlerFunc
var reducerKey string
switch setOpType {
case unionSetOperator, intersectionSetOperator, exclusionSetOperator:
if setOpType == unionSetOperator {
reducerKey = "union"
}
if setOpType == intersectionSetOperator {
reducerKey = "intersection"
}
if setOpType == exclusionSetOperator {
reducerKey = "exclusion"
}
for _, child := range children {
handlers = append(handlers, c.CheckRewrite(ctx, req, child))
}
default:
return func(ctx context.Context) (*ResolveCheckResponse, error) {
return nil, ErrUnknownSetOperator
}
}
return func(ctx context.Context) (*ResolveCheckResponse, error) {
var err error
var resp *ResolveCheckResponse
ctx, span := tracer.Start(ctx, reducerKey)
defer func() {
if err != nil {
telemetry.TraceError(span, err)
}
span.End()
}()
resp, err = reducer(ctx, c.concurrencyLimit, handlers...)
return resp, err
}
}
func (c *LocalChecker) CheckRewrite(
ctx context.Context,
req *ResolveCheckRequest,
rewrite *openfgav1.Userset,
) CheckHandlerFunc {
switch rw := rewrite.GetUserset().(type) {
case *openfgav1.Userset_This:
return c.checkDirect(ctx, req)
case *openfgav1.Userset_ComputedUserset:
return c.checkComputedUserset(ctx, req, rewrite)
case *openfgav1.Userset_TupleToUserset:
return c.checkTTU(ctx, req, rewrite)
case *openfgav1.Userset_Union:
return c.checkSetOperation(ctx, req, unionSetOperator, union, rw.Union.GetChild()...)
case *openfgav1.Userset_Intersection:
return c.checkSetOperation(ctx, req, intersectionSetOperator, intersection, rw.Intersection.GetChild()...)
case *openfgav1.Userset_Difference:
return c.checkSetOperation(ctx, req, exclusionSetOperator, exclusion, rw.Difference.GetBase(), rw.Difference.GetSubtract())
default:
return func(ctx context.Context) (*ResolveCheckResponse, error) {
return nil, ErrUnknownSetOperator
}
}
}
// TODO: make these subsequent functions generic and move outside this package.
type usersetMessage struct {
userset string
err error
}
// streamedLookupUsersetFromIterator returns a channel with all the usersets given by the input iterator.
// It closes the channel in the end.
func streamedLookupUsersetFromIterator(ctx context.Context, iter storage.TupleMapper) <-chan usersetMessage {
usersetMessageChan := make(chan usersetMessage, 100)
go func() {
defer func() {
if r := recover(); r != nil {
concurrency.TrySendThroughChannel(ctx, usersetMessage{err: fmt.Errorf("%w: %s", ErrPanic, r)}, usersetMessageChan)
}
close(usersetMessageChan)
}()
for {
res, err := iter.Next(ctx)
if err != nil {
if storage.IterIsDoneOrCancelled(err) {
return
}
concurrency.TrySendThroughChannel(ctx, usersetMessage{err: err}, usersetMessageChan)
return
}
concurrency.TrySendThroughChannel(ctx, usersetMessage{userset: res}, usersetMessageChan)
}
}()
return usersetMessageChan
}
// processUsersetMessage will add the userset in the primarySet.
// In addition, it returns whether the userset exists in secondarySet.
// This is used to find the intersection between userset from user and userset from object.
func processUsersetMessage(userset string,
primarySet *hashset.Set,
secondarySet *hashset.Set) bool {
primarySet.Add(userset)
return secondarySet.Contains(userset)
}
package graph
import (
"context"
"errors"
"fmt"
"time"
"github.com/sourcegraph/conc/panics"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/concurrency"
"github.com/openfga/openfga/internal/planner"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
const defaultResolver = "default"
var defaultPlan = &planner.PlanConfig{
Name: defaultResolver,
InitialGuess: 50 * time.Millisecond,
// Low Lambda: Represents zero confidence. It's a pure guess.
Lambda: 1,
// With α = 0.5 ≤ 1, it means maximum uncertainty about variance; with λ = 1, we also have weak confidence in the mean.
// These values will encourage strong exploration of other strategies. Having these values for the default strategy helps to enforce the usage of the "faster" strategies,
// helping out with the cold start when we don't have enough data.
Alpha: 0.5,
Beta: 0.5,
}
var defaultRecursivePlan = &planner.PlanConfig{
Name: defaultResolver,
InitialGuess: 300 * time.Millisecond, // Higher initial guess for recursive checks
// Low Lambda: Represents zero confidence. It's a pure guess.
Lambda: 1,
// With α = 0.5 ≤ 1, it means maximum uncertainty about variance; with λ = 1, we also have weak confidence in the mean.
// These values will encourage strong exploration of other strategies. Having these values for the default strategy helps to enforce the usage of the "faster" strategies,
// helping out with the cold start when we don't have enough data.
Alpha: 0.5,
Beta: 0.5,
}
type dispatchParams struct {
parentReq *ResolveCheckRequest
tk *openfgav1.TupleKey
}
type dispatchMsg struct {
err error
shortCircuit bool
dispatchParams *dispatchParams
}
// defaultUserset will check userset path.
// This is the slow path as it requires dispatch on all its children.
func (c *LocalChecker) defaultUserset(_ context.Context, req *ResolveCheckRequest, _ []*openfgav1.RelationReference, iter storage.TupleKeyIterator) CheckHandlerFunc {
return func(ctx context.Context) (*ResolveCheckResponse, error) {
ctx, span := tracer.Start(ctx, "defaultUserset")
defer span.End()
dispatchChan := make(chan dispatchMsg, c.concurrencyLimit)
cancellableCtx, cancelFunc := context.WithCancel(ctx)
pool := concurrency.NewPool(cancellableCtx, 1)
defer func() {
cancelFunc()
// We need to wait always to avoid a goroutine leak.
_ = pool.Wait()
}()
pool.Go(func(ctx context.Context) error {
c.produceUsersetDispatches(ctx, req, dispatchChan, iter)
return nil
})
return c.consumeDispatches(ctx, c.concurrencyLimit, dispatchChan)
}
}
func (c *LocalChecker) produceUsersetDispatches(ctx context.Context, req *ResolveCheckRequest, dispatches chan dispatchMsg, iter storage.TupleKeyIterator) {
defer close(dispatches)
reqTupleKey := req.GetTupleKey()
typesys, _ := typesystem.TypesystemFromContext(ctx)
for {
t, err := iter.Next(ctx)
if err != nil {
// cancelled doesn't need to flush nor send errors back to main routine
if storage.IterIsDoneOrCancelled(err) {
break
}
concurrency.TrySendThroughChannel(ctx, dispatchMsg{err: err}, dispatches)
break
}
usersetObject, usersetRelation := tuple.SplitObjectRelation(t.GetUser())
// if the user value is a typed wildcard and the type of the wildcard
// matches the target user objectType, then we're done searching
if tuple.IsTypedWildcard(usersetObject) && typesystem.IsSchemaVersionSupported(typesys.GetSchemaVersion()) {
wildcardType := tuple.GetType(usersetObject)
if tuple.GetType(reqTupleKey.GetUser()) == wildcardType {
concurrency.TrySendThroughChannel(ctx, dispatchMsg{shortCircuit: true}, dispatches)
break
}
}
if usersetRelation != "" {
tupleKey := tuple.NewTupleKey(usersetObject, usersetRelation, reqTupleKey.GetUser())
concurrency.TrySendThroughChannel(ctx, dispatchMsg{dispatchParams: &dispatchParams{parentReq: req, tk: tupleKey}}, dispatches)
}
}
}
// defaultTTU is the slow path for checkTTU where we cannot short-circuit TTU evaluation and
// resort to dispatch check on its children.
func (c *LocalChecker) defaultTTU(_ context.Context, req *ResolveCheckRequest, rewrite *openfgav1.Userset, iter storage.TupleKeyIterator) CheckHandlerFunc {
return func(ctx context.Context) (*ResolveCheckResponse, error) {
ctx, span := tracer.Start(ctx, "defaultTTU")
defer span.End()
computedRelation := rewrite.GetTupleToUserset().GetComputedUserset().GetRelation()
dispatchChan := make(chan dispatchMsg, c.concurrencyLimit)
cancellableCtx, cancelFunc := context.WithCancel(ctx)
// sending to channel in batches up to a pre-configured value to subsequently checkMembership for.
pool := concurrency.NewPool(cancellableCtx, 1)
defer func() {
cancelFunc()
// We need to wait always to avoid a goroutine leak.
_ = pool.Wait()
}()
pool.Go(func(ctx context.Context) error {
c.produceTTUDispatches(ctx, computedRelation, req, dispatchChan, iter)
return nil
})
return c.consumeDispatches(ctx, c.concurrencyLimit, dispatchChan)
}
}
func (c *LocalChecker) produceTTUDispatches(ctx context.Context, computedRelation string, req *ResolveCheckRequest, dispatches chan dispatchMsg, iter storage.TupleKeyIterator) {
defer close(dispatches)
reqTupleKey := req.GetTupleKey()
typesys, _ := typesystem.TypesystemFromContext(ctx)
for {
t, err := iter.Next(ctx)
if err != nil {
if storage.IterIsDoneOrCancelled(err) {
break
}
concurrency.TrySendThroughChannel(ctx, dispatchMsg{err: err}, dispatches)
break
}
userObj, _ := tuple.SplitObjectRelation(t.GetUser())
if _, err := typesys.GetRelation(tuple.GetType(userObj), computedRelation); err != nil {
if errors.Is(err, typesystem.ErrRelationUndefined) {
continue // skip computed relations on tupleset relationships if they are undefined
}
}
tupleKey := &openfgav1.TupleKey{
Object: userObj,
Relation: computedRelation,
User: reqTupleKey.GetUser(),
}
concurrency.TrySendThroughChannel(ctx, dispatchMsg{dispatchParams: &dispatchParams{parentReq: req, tk: tupleKey}}, dispatches)
}
}
func (c *LocalChecker) consumeDispatches(ctx context.Context, limit int, dispatchChan chan dispatchMsg) (*ResolveCheckResponse, error) {
cancellableCtx, cancel := context.WithCancel(ctx)
outcomeChannel := c.processDispatches(cancellableCtx, limit, dispatchChan)
var finalErr error
finalResult := &ResolveCheckResponse{
Allowed: false,
}
ConsumerLoop:
for {
select {
case <-ctx.Done():
break ConsumerLoop
case outcome, ok := <-outcomeChannel:
if !ok {
break ConsumerLoop
}
if outcome.err != nil {
finalErr = outcome.err
break // continue
}
if outcome.resp.GetResolutionMetadata().CycleDetected {
finalResult.ResolutionMetadata.CycleDetected = true
}
if outcome.resp.Allowed {
finalErr = nil
finalResult = outcome.resp
break ConsumerLoop
}
}
}
cancel() // prevent further processing of other checks
// context cancellation from upstream (e.g. client)
if ctx.Err() != nil {
finalErr = ctx.Err()
}
if finalErr != nil {
return nil, finalErr
}
return finalResult, nil
}
// processDispatches returns a channel where the outcomes of the dispatched checks are sent, and begins sending messages to this channel.
func (c *LocalChecker) processDispatches(ctx context.Context, limit int, dispatchChan chan dispatchMsg) <-chan checkOutcome {
outcomes := make(chan checkOutcome, limit)
dispatchPool := concurrency.NewPool(ctx, limit)
go func() {
defer func() {
// We need to wait always to avoid a goroutine leak.
_ = dispatchPool.Wait()
close(outcomes)
}()
for {
select {
case <-ctx.Done():
return
case msg, ok := <-dispatchChan:
if !ok {
return
}
if msg.err != nil {
concurrency.TrySendThroughChannel(ctx, checkOutcome{err: msg.err}, outcomes)
break // continue
}
if msg.shortCircuit {
resp := &ResolveCheckResponse{
Allowed: true,
}
concurrency.TrySendThroughChannel(ctx, checkOutcome{resp: resp}, outcomes)
return
}
if msg.dispatchParams != nil {
dispatchPool.Go(func(ctx context.Context) error {
recoveredError := panics.Try(func() {
resp, err := c.dispatch(ctx, msg.dispatchParams.parentReq, msg.dispatchParams.tk)(ctx)
concurrency.TrySendThroughChannel(ctx, checkOutcome{resp: resp, err: err}, outcomes)
})
if recoveredError != nil {
concurrency.TrySendThroughChannel(
ctx,
checkOutcome{err: fmt.Errorf("%w: %s", ErrPanic, recoveredError.AsError())},
outcomes,
)
}
return nil
})
}
}
}
}()
return outcomes
}
package graph
import (
"context"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"github.com/openfga/openfga/internal/throttler"
"github.com/openfga/openfga/internal/throttler/threshold"
"github.com/openfga/openfga/pkg/server/config"
)
// DispatchThrottlingCheckResolverConfig encapsulates configuration for dispatch throttling check resolver.
type DispatchThrottlingCheckResolverConfig struct {
DefaultThreshold uint32
MaxThreshold uint32
}
// DispatchThrottlingCheckResolver will prioritize requests with fewer dispatches over
// requests with more dispatches.
// Initially, request's dispatches will not be throttled and will be processed
// immediately. When the number of request dispatches is above the DefaultThreshold, the dispatches are placed
// in the throttling queue. One item form the throttling queue will be processed ticker.
// This allows a check / list objects request to be gradually throttled.
type DispatchThrottlingCheckResolver struct {
delegate CheckResolver
config *DispatchThrottlingCheckResolverConfig
throttler throttler.Throttler
}
var _ CheckResolver = (*DispatchThrottlingCheckResolver)(nil)
// DispatchThrottlingCheckResolverOpt defines an option that can be used to change the behavior of DispatchThrottlingCheckResolver
// instance.
type DispatchThrottlingCheckResolverOpt func(checkResolver *DispatchThrottlingCheckResolver)
// WithDispatchThrottlingCheckResolverConfig sets the config to be used for DispatchThrottlingCheckResolver.
func WithDispatchThrottlingCheckResolverConfig(config DispatchThrottlingCheckResolverConfig) DispatchThrottlingCheckResolverOpt {
return func(r *DispatchThrottlingCheckResolver) {
r.config = &config
}
}
// WithThrottler sets the throttler to be used for DispatchThrottlingCheckResolver.
func WithThrottler(throttler throttler.Throttler) DispatchThrottlingCheckResolverOpt {
return func(r *DispatchThrottlingCheckResolver) {
r.throttler = throttler
}
}
// WithConstantRateThrottler sets the constant rate throttler to be used for DispatchThrottlingCheckResolver.
func WithConstantRateThrottler(frequency time.Duration, metricLabel string) DispatchThrottlingCheckResolverOpt {
return func(r *DispatchThrottlingCheckResolver) {
r.throttler = throttler.NewConstantRateThrottler(frequency, metricLabel)
}
}
func NewDispatchThrottlingCheckResolver(opts ...DispatchThrottlingCheckResolverOpt) *DispatchThrottlingCheckResolver {
dispatchThrottlingCheckResolver := &DispatchThrottlingCheckResolver{
config: &DispatchThrottlingCheckResolverConfig{
DefaultThreshold: config.DefaultCheckDispatchThrottlingDefaultThreshold,
MaxThreshold: config.DefaultCheckDispatchThrottlingMaxThreshold,
},
throttler: throttler.NewNoopThrottler(),
}
dispatchThrottlingCheckResolver.delegate = dispatchThrottlingCheckResolver
for _, opt := range opts {
opt(dispatchThrottlingCheckResolver)
}
return dispatchThrottlingCheckResolver
}
func (r *DispatchThrottlingCheckResolver) SetDelegate(delegate CheckResolver) {
r.delegate = delegate
}
func (r *DispatchThrottlingCheckResolver) GetDelegate() CheckResolver {
return r.delegate
}
func (r *DispatchThrottlingCheckResolver) Close() {
r.throttler.Close()
}
func (r *DispatchThrottlingCheckResolver) ResolveCheck(ctx context.Context,
req *ResolveCheckRequest,
) (*ResolveCheckResponse, error) {
span := trace.SpanFromContext(ctx)
currentNumDispatch := req.GetRequestMetadata().DispatchCounter.Load()
shouldThrottle := threshold.ShouldThrottle(
ctx,
currentNumDispatch,
r.config.DefaultThreshold,
r.config.MaxThreshold,
)
span.SetAttributes(
attribute.Int("dispatch_count", int(currentNumDispatch)),
attribute.Bool("is_throttled", shouldThrottle))
if shouldThrottle {
req.GetRequestMetadata().WasThrottled.Store(true)
r.throttler.Throttle(ctx)
}
return r.delegate.ResolveCheck(ctx, req)
}
// Package graph contains code related to evaluation of authorization models through graph traversals.
package graph
import (
"context"
"errors"
"fmt"
"strings"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
type ctxKey string
const (
resolutionDepthCtxKey ctxKey = "resolution-depth"
)
var (
ErrResolutionDepthExceeded = errors.New("resolution depth exceeded")
)
type findEdgeOption int
const (
resolveAllEdges findEdgeOption = iota
resolveAnyEdge
)
// ContextWithResolutionDepth attaches the provided graph resolution depth to the parent context.
func ContextWithResolutionDepth(parent context.Context, depth uint32) context.Context {
return context.WithValue(parent, resolutionDepthCtxKey, depth)
}
// ResolutionDepthFromContext returns the current graph resolution depth from the provided context (if any).
func ResolutionDepthFromContext(ctx context.Context) (uint32, bool) {
depth, ok := ctx.Value(resolutionDepthCtxKey).(uint32)
return depth, ok
}
type RelationshipEdgeType int
const (
// DirectEdge defines a direct connection between a source object reference
// and some target user reference.
DirectEdge RelationshipEdgeType = iota
// TupleToUsersetEdge defines a connection between a source object reference
// and some target user reference that is co-dependent upon the lookup of a third object reference.
TupleToUsersetEdge
// ComputedUsersetEdge defines a direct connection between a source object reference
// and some target user reference. The difference with DirectEdge is that DirectEdge will involve
// a read of tuples and this one will not.
ComputedUsersetEdge
)
func (r RelationshipEdgeType) String() string {
switch r {
case DirectEdge:
return "direct"
case ComputedUsersetEdge:
return "computed_userset"
case TupleToUsersetEdge:
return "ttu"
default:
return "undefined"
}
}
type EdgeCondition int
// RelationshipEdge represents a possible relationship between some source object reference
// and a target user reference. The possibility is realized depending on the tuples and on the edge's type.
type RelationshipEdge struct {
Type RelationshipEdgeType
// The edge is directed towards this node, which can be like group:*, or group, or group:member
TargetReference *openfgav1.RelationReference
// If the type is TupleToUsersetEdge, this defines the TTU condition
TuplesetRelation string
TargetReferenceInvolvesIntersectionOrExclusion bool
}
func (r RelationshipEdge) String() string {
// TODO also print the condition
var val string
if r.TuplesetRelation != "" {
val = fmt.Sprintf("userset %s, type %s, tupleset %s", r.TargetReference.String(), r.Type.String(), r.TuplesetRelation)
} else {
val = fmt.Sprintf("userset %s, type %s", r.TargetReference.String(), r.Type.String())
}
return strings.ReplaceAll(val, " ", " ")
}
// RelationshipGraph represents a graph of relationships and the connectivity between
// object and relation references within the graph through direct or indirect relationships.
type RelationshipGraph struct {
typesystem *typesystem.TypeSystem
}
// New returns a RelationshipGraph from an authorization model. The RelationshipGraph should be used to introspect what kind of relationships between
// object types can exist. To visualize this graph, use https://github.com/jon-whit/openfga-graphviz-gen
func New(typesystem *typesystem.TypeSystem) *RelationshipGraph {
return &RelationshipGraph{
typesystem: typesystem,
}
}
// GetRelationshipEdges finds all paths from a source to a target and then returns all the edges at distance 0 or 1 of the source in those paths.
func (g *RelationshipGraph) GetRelationshipEdges(target *openfgav1.RelationReference, source *openfgav1.RelationReference) ([]*RelationshipEdge, error) {
return g.getRelationshipEdges(target, source, map[string]struct{}{}, resolveAllEdges)
}
// GetPrunedRelationshipEdges finds all paths from a source to a target and then returns all the edges at distance 0 or 1 of the source in those paths.
// If the edges from the source to the target pass through a relationship involving intersection or exclusion (directly or indirectly),
// then GetPrunedRelationshipEdges will just return the first-most edge involved in that rewrite.
//
// Consider the following model:
//
// type user
// type document
//
// relations
// define allowed: [user]
// define viewer: [user] and allowed
//
// The pruned relationship edges from the 'user' type to 'document#viewer' returns only the edge from 'user' to 'document#viewer' and with a 'RequiresFurtherEvalCondition'.
// This is because when evaluating relationships involving intersection or exclusion we choose to only evaluate one operand of the rewrite rule, and for each result found
// we call Check on the result to evaluate the sub-condition on the 'and allowed' bit.
func (g *RelationshipGraph) GetPrunedRelationshipEdges(target *openfgav1.RelationReference, source *openfgav1.RelationReference) ([]*RelationshipEdge, error) {
return g.getRelationshipEdges(target, source, map[string]struct{}{}, resolveAnyEdge)
}
func (g *RelationshipGraph) getRelationshipEdges(
target *openfgav1.RelationReference,
source *openfgav1.RelationReference,
visited map[string]struct{},
findEdgeOption findEdgeOption,
) ([]*RelationshipEdge, error) {
key := tuple.ToObjectRelationString(target.GetType(), target.GetRelation())
if _, ok := visited[key]; ok {
// We've already visited the target so no need to do so again.
return nil, nil
}
visited[key] = struct{}{}
relation, err := g.typesystem.GetRelation(target.GetType(), target.GetRelation())
if err != nil {
return nil, err
}
return g.getRelationshipEdgesWithTargetRewrite(
target,
source,
relation.GetRewrite(),
visited,
findEdgeOption,
)
}
// getRelationshipEdgesWithTargetRewrite does a BFS on the graph starting at `target` and trying to reach `source`.
func (g *RelationshipGraph) getRelationshipEdgesWithTargetRewrite(
target *openfgav1.RelationReference,
source *openfgav1.RelationReference,
targetRewrite *openfgav1.Userset,
visited map[string]struct{},
findEdgeOption findEdgeOption,
) ([]*RelationshipEdge, error) {
switch t := targetRewrite.GetUserset().(type) {
case *openfgav1.Userset_This: // e.g. define viewer:[user]
var res []*RelationshipEdge
directlyRelated, _ := g.typesystem.IsDirectlyRelated(target, source)
publiclyAssignable, _ := g.typesystem.IsPubliclyAssignable(target, source.GetType())
if directlyRelated || publiclyAssignable {
// if source=user, or define viewer:[user:*]
res = append(res, &RelationshipEdge{
Type: DirectEdge,
TargetReference: typesystem.DirectRelationReference(target.GetType(), target.GetRelation()),
TargetReferenceInvolvesIntersectionOrExclusion: false,
})
}
typeRestrictions, _ := g.typesystem.GetDirectlyRelatedUserTypes(target.GetType(), target.GetRelation())
for _, typeRestriction := range typeRestrictions {
if typeRestriction.GetRelation() != "" { // e.g. define viewer:[team#member]
// recursively sub-collect any edges for (team#member, source)
edges, err := g.getRelationshipEdges(typeRestriction, source, visited, findEdgeOption)
if err != nil {
return nil, err
}
res = append(res, edges...)
}
}
return res, nil
case *openfgav1.Userset_ComputedUserset: // e.g. target = define viewer: writer
var edges []*RelationshipEdge
// if source=document#writer
sourceRelMatchesRewritten := target.GetType() == source.GetType() && t.ComputedUserset.GetRelation() == source.GetRelation()
if sourceRelMatchesRewritten {
edges = append(edges, &RelationshipEdge{
Type: ComputedUsersetEdge,
TargetReference: typesystem.DirectRelationReference(target.GetType(), target.GetRelation()),
TargetReferenceInvolvesIntersectionOrExclusion: false,
})
}
collected, err := g.getRelationshipEdges(
typesystem.DirectRelationReference(target.GetType(), t.ComputedUserset.GetRelation()),
source,
visited,
findEdgeOption,
)
if err != nil {
return nil, err
}
edges = append(
edges,
collected...,
)
return edges, nil
case *openfgav1.Userset_TupleToUserset: // e.g. type document, define viewer: writer from parent
tupleset := t.TupleToUserset.GetTupleset().GetRelation() // parent
computedUserset := t.TupleToUserset.GetComputedUserset().GetRelation() // writer
var res []*RelationshipEdge
// e.g. type document, define parent:[user, group]
tuplesetTypeRestrictions, _ := g.typesystem.GetDirectlyRelatedUserTypes(target.GetType(), tupleset)
for _, typeRestriction := range tuplesetTypeRestrictions {
r, err := g.typesystem.GetRelation(typeRestriction.GetType(), computedUserset)
if err != nil {
if errors.Is(err, typesystem.ErrRelationUndefined) {
continue
}
return nil, err
}
if typeRestriction.GetType() == source.GetType() && computedUserset == source.GetRelation() {
involvesIntersection, err := g.typesystem.RelationInvolvesIntersection(typeRestriction.GetType(), r.GetName())
if err != nil {
return nil, err
}
involvesExclusion, err := g.typesystem.RelationInvolvesExclusion(typeRestriction.GetType(), r.GetName())
if err != nil {
return nil, err
}
res = append(res, &RelationshipEdge{
Type: TupleToUsersetEdge,
TargetReference: typesystem.DirectRelationReference(target.GetType(), target.GetRelation()),
TuplesetRelation: tupleset,
TargetReferenceInvolvesIntersectionOrExclusion: involvesIntersection || involvesExclusion,
})
}
subResults, err := g.getRelationshipEdges(
typesystem.DirectRelationReference(typeRestriction.GetType(), computedUserset),
source,
visited,
findEdgeOption,
)
if err != nil {
return nil, err
}
res = append(res, subResults...)
}
return res, nil
case *openfgav1.Userset_Union: // e.g. target = define viewer: self or writer
var res []*RelationshipEdge
for _, child := range t.Union.GetChild() {
// we recurse through each child rewrite
childResults, err := g.getRelationshipEdgesWithTargetRewrite(target, source, child, visited, findEdgeOption)
if err != nil {
return nil, err
}
res = append(res, childResults...)
}
return res, nil
case *openfgav1.Userset_Intersection:
if findEdgeOption == resolveAnyEdge {
child := t.Intersection.GetChild()[0]
childresults, err := g.getRelationshipEdgesWithTargetRewrite(target, source, child, visited, findEdgeOption)
if err != nil {
return nil, err
}
for _, childresult := range childresults {
childresult.TargetReferenceInvolvesIntersectionOrExclusion = true
}
return childresults, nil
}
var edges []*RelationshipEdge
for _, child := range t.Intersection.GetChild() {
res, err := g.getRelationshipEdgesWithTargetRewrite(target, source, child, visited, findEdgeOption)
if err != nil {
return nil, err
}
edges = append(edges, res...)
}
if len(edges) > 0 {
edges[0].TargetReferenceInvolvesIntersectionOrExclusion = true
}
return edges, nil
case *openfgav1.Userset_Difference:
if findEdgeOption == resolveAnyEdge {
// if we have 'a but not b', then we prune 'b' and only resolve 'a' with a
// condition that requires further evaluation. It's more likely the blacklist
// on 'but not b' is a larger set than the base set 'a', and so pruning the
// subtracted set is generally going to be a better choice.
child := t.Difference.GetBase()
childresults, err := g.getRelationshipEdgesWithTargetRewrite(target, source, child, visited, findEdgeOption)
if err != nil {
return nil, err
}
for _, childresult := range childresults {
childresult.TargetReferenceInvolvesIntersectionOrExclusion = true
}
return childresults, nil
}
var edges []*RelationshipEdge
baseRewrite := t.Difference.GetBase()
baseEdges, err := g.getRelationshipEdgesWithTargetRewrite(target, source, baseRewrite, visited, findEdgeOption)
if err != nil {
return nil, err
}
if len(baseEdges) > 0 {
baseEdges[0].TargetReferenceInvolvesIntersectionOrExclusion = true
}
edges = append(edges, baseEdges...)
subtractRewrite := t.Difference.GetSubtract()
subEdges, err := g.getRelationshipEdgesWithTargetRewrite(target, source, subtractRewrite, visited, findEdgeOption)
if err != nil {
return nil, err
}
edges = append(edges, subEdges...)
return edges, nil
default:
panic("unexpected userset rewrite encountered")
}
}
// Code generated by MockGen. DO NOT EDIT.
// Source: interface.go
//
// Generated by this command:
//
// mockgen -source interface.go -destination ./mock_check_resolver.go -package graph CheckResolver
//
// Package graph is a generated GoMock package.
package graph
import (
context "context"
reflect "reflect"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
gomock "go.uber.org/mock/gomock"
)
// MockCheckResolver is a mock of CheckResolver interface.
type MockCheckResolver struct {
ctrl *gomock.Controller
recorder *MockCheckResolverMockRecorder
isgomock struct{}
}
// MockCheckResolverMockRecorder is the mock recorder for MockCheckResolver.
type MockCheckResolverMockRecorder struct {
mock *MockCheckResolver
}
// NewMockCheckResolver creates a new mock instance.
func NewMockCheckResolver(ctrl *gomock.Controller) *MockCheckResolver {
mock := &MockCheckResolver{ctrl: ctrl}
mock.recorder = &MockCheckResolverMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockCheckResolver) EXPECT() *MockCheckResolverMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockCheckResolver) Close() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Close")
}
// Close indicates an expected call of Close.
func (mr *MockCheckResolverMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCheckResolver)(nil).Close))
}
// GetDelegate mocks base method.
func (m *MockCheckResolver) GetDelegate() CheckResolver {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDelegate")
ret0, _ := ret[0].(CheckResolver)
return ret0
}
// GetDelegate indicates an expected call of GetDelegate.
func (mr *MockCheckResolverMockRecorder) GetDelegate() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDelegate", reflect.TypeOf((*MockCheckResolver)(nil).GetDelegate))
}
// ResolveCheck mocks base method.
func (m *MockCheckResolver) ResolveCheck(ctx context.Context, req *ResolveCheckRequest) (*ResolveCheckResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ResolveCheck", ctx, req)
ret0, _ := ret[0].(*ResolveCheckResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ResolveCheck indicates an expected call of ResolveCheck.
func (mr *MockCheckResolverMockRecorder) ResolveCheck(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveCheck", reflect.TypeOf((*MockCheckResolver)(nil).ResolveCheck), ctx, req)
}
// SetDelegate mocks base method.
func (m *MockCheckResolver) SetDelegate(delegate CheckResolver) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetDelegate", delegate)
}
// SetDelegate indicates an expected call of SetDelegate.
func (mr *MockCheckResolverMockRecorder) SetDelegate(delegate any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDelegate", reflect.TypeOf((*MockCheckResolver)(nil).SetDelegate), delegate)
}
// MockCheckRewriteResolver is a mock of CheckRewriteResolver interface.
type MockCheckRewriteResolver struct {
ctrl *gomock.Controller
recorder *MockCheckRewriteResolverMockRecorder
isgomock struct{}
}
// MockCheckRewriteResolverMockRecorder is the mock recorder for MockCheckRewriteResolver.
type MockCheckRewriteResolverMockRecorder struct {
mock *MockCheckRewriteResolver
}
// NewMockCheckRewriteResolver creates a new mock instance.
func NewMockCheckRewriteResolver(ctrl *gomock.Controller) *MockCheckRewriteResolver {
mock := &MockCheckRewriteResolver{ctrl: ctrl}
mock.recorder = &MockCheckRewriteResolverMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockCheckRewriteResolver) EXPECT() *MockCheckRewriteResolverMockRecorder {
return m.recorder
}
// CheckRewrite mocks base method.
func (m *MockCheckRewriteResolver) CheckRewrite(ctx context.Context, req *ResolveCheckRequest, rewrite *openfgav1.Userset) CheckHandlerFunc {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CheckRewrite", ctx, req, rewrite)
ret0, _ := ret[0].(CheckHandlerFunc)
return ret0
}
// CheckRewrite indicates an expected call of CheckRewrite.
func (mr *MockCheckRewriteResolverMockRecorder) CheckRewrite(ctx, req, rewrite any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckRewrite", reflect.TypeOf((*MockCheckRewriteResolver)(nil).CheckRewrite), ctx, req, rewrite)
}
// Close mocks base method.
func (m *MockCheckRewriteResolver) Close() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Close")
}
// Close indicates an expected call of Close.
func (mr *MockCheckRewriteResolverMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCheckRewriteResolver)(nil).Close))
}
// GetDelegate mocks base method.
func (m *MockCheckRewriteResolver) GetDelegate() CheckResolver {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDelegate")
ret0, _ := ret[0].(CheckResolver)
return ret0
}
// GetDelegate indicates an expected call of GetDelegate.
func (mr *MockCheckRewriteResolverMockRecorder) GetDelegate() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDelegate", reflect.TypeOf((*MockCheckRewriteResolver)(nil).GetDelegate))
}
// ResolveCheck mocks base method.
func (m *MockCheckRewriteResolver) ResolveCheck(ctx context.Context, req *ResolveCheckRequest) (*ResolveCheckResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ResolveCheck", ctx, req)
ret0, _ := ret[0].(*ResolveCheckResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ResolveCheck indicates an expected call of ResolveCheck.
func (mr *MockCheckRewriteResolverMockRecorder) ResolveCheck(ctx, req any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveCheck", reflect.TypeOf((*MockCheckRewriteResolver)(nil).ResolveCheck), ctx, req)
}
// SetDelegate mocks base method.
func (m *MockCheckRewriteResolver) SetDelegate(delegate CheckResolver) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetDelegate", delegate)
}
// SetDelegate indicates an expected call of SetDelegate.
func (mr *MockCheckRewriteResolverMockRecorder) SetDelegate(delegate any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDelegate", reflect.TypeOf((*MockCheckRewriteResolver)(nil).SetDelegate), delegate)
}
package graph
import (
"context"
"errors"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/checkutil"
"github.com/openfga/openfga/internal/concurrency"
"github.com/openfga/openfga/internal/iterator"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
// objectProvider is an interface that abstracts the building of a channel that holds object IDs or usersets.
// It must close the channel when there are no more results.
type objectProvider interface {
End()
Begin(ctx context.Context, req *ResolveCheckRequest) (<-chan usersetMessage, error)
}
type recursiveTTUObjectProvider struct {
ts *typesystem.TypeSystem
tuplesetRelation string
computedRelation string
cancel context.CancelFunc
}
func newRecursiveTTUObjectProvider(ts *typesystem.TypeSystem, ttu *openfgav1.TupleToUserset) *recursiveTTUObjectProvider {
tuplesetRelation := ttu.GetTupleset().GetRelation()
computedRelation := ttu.GetComputedUserset().GetRelation()
return &recursiveTTUObjectProvider{ts: ts, tuplesetRelation: tuplesetRelation, computedRelation: computedRelation}
}
var _ objectProvider = (*recursiveTTUObjectProvider)(nil)
func (c *recursiveTTUObjectProvider) End() {
if c.cancel != nil {
c.cancel()
}
}
func (c *recursiveTTUObjectProvider) Begin(ctx context.Context, req *ResolveCheckRequest) (<-chan usersetMessage, error) {
objectType := tuple.GetType(req.GetTupleKey().GetObject())
possibleParents, err := c.ts.GetDirectlyRelatedUserTypes(objectType, c.tuplesetRelation)
if err != nil {
return nil, err
}
leftChans, err := produceLeftChannels(ctx, req, possibleParents, checkutil.BuildTTUV2RelationFunc(c.computedRelation))
if err != nil {
return nil, err
}
outChannel := make(chan usersetMessage, len(leftChans))
ctx, cancel := context.WithCancel(ctx)
c.cancel = cancel
go iteratorsToUserset(ctx, leftChans, outChannel)
return outChannel, nil
}
type recursiveUsersetObjectProvider struct {
ts *typesystem.TypeSystem
cancel context.CancelFunc
}
func newRecursiveUsersetObjectProvider(ts *typesystem.TypeSystem) *recursiveUsersetObjectProvider {
return &recursiveUsersetObjectProvider{ts: ts}
}
var _ objectProvider = (*recursiveUsersetObjectProvider)(nil)
func (c *recursiveUsersetObjectProvider) End() {
if c.cancel != nil {
c.cancel()
}
}
func (c *recursiveUsersetObjectProvider) Begin(ctx context.Context, req *ResolveCheckRequest) (<-chan usersetMessage, error) {
objectType := tuple.GetType(req.GetTupleKey().GetObject())
reference := []*openfgav1.RelationReference{{Type: objectType, RelationOrWildcard: &openfgav1.RelationReference_Relation{Relation: req.GetTupleKey().GetRelation()}}}
leftChans, err := produceLeftChannels(ctx, req, reference, checkutil.BuildUsersetV2RelationFunc())
if err != nil {
return nil, err
}
outChannel := make(chan usersetMessage, len(leftChans))
ctx, cancel := context.WithCancel(ctx)
c.cancel = cancel
go iteratorsToUserset(ctx, leftChans, outChannel)
return outChannel, nil
}
// TODO: This should be iteratorsToObjectID since ultimately, the mapper was already applied and its just and ObjectID.
func iteratorsToUserset(ctx context.Context, chans []<-chan *iterator.Msg, out chan usersetMessage) {
if len(chans) == 0 {
close(out)
return
}
pool := concurrency.NewPool(ctx, len(chans))
for _, c := range chans {
pool.Go(func(ctx context.Context) error {
open := true
defer func() {
if open {
iterator.Drain(c)
}
}()
for {
select {
case <-ctx.Done():
return ErrShortCircuit
case msg, ok := <-c:
if !ok {
open = false
return nil
}
if msg.Err != nil {
concurrency.TrySendThroughChannel(ctx, usersetMessage{err: msg.Err}, out)
return ErrShortCircuit
}
for {
t, err := msg.Iter.Next(ctx)
if err != nil {
msg.Iter.Stop()
if storage.IterIsDoneOrCancelled(err) {
if errors.Is(err, storage.ErrIteratorDone) {
break
}
return ErrShortCircuit
}
concurrency.TrySendThroughChannel(ctx, usersetMessage{err: err}, out)
break
}
concurrency.TrySendThroughChannel(ctx, usersetMessage{userset: t}, out)
}
}
}
})
}
go func() {
_ = pool.Wait()
close(out)
}()
}
package graph
import (
"context"
"errors"
"sync"
"time"
"github.com/emirpasic/gods/sets/hashset"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/checkutil"
"github.com/openfga/openfga/internal/concurrency"
"github.com/openfga/openfga/internal/planner"
"github.com/openfga/openfga/internal/validation"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
const recursiveResolver = "recursive"
// In general these values tell the query planner that the recursive strategy usually performs around 150 ms but occasionally spikes.
// However, even when it spikes we want to keep it using it or exploring it despite variance, rather than over-penalizing single slow runs.
var recursivePlan = &planner.PlanConfig{
Name: recursiveResolver,
InitialGuess: 150 * time.Millisecond,
// Medium Lambda: Represents medium confidence in the initial guess. It's like
// starting with the belief of having already seen 5 good runs.
Lambda: 5.0,
// UNCERTAINTY ABOUT CONSISTENCY: The gap between p50 and p99 is large.
// Low Alpha/Beta values create a wider belief curve, telling the planner
// to expect and not be overly surprised by performance variations.
// Low expected precision: 𝐸[𝜏]= 𝛼/𝛽 = 2.0/2.5 = 0.8.
// High expected variance: E[σ2]= β/(α−1) =2.5/1 = 2.5, this will allow for relative bursty / jiterry results.
// Wide tolerance for spread: 𝛼 = 2, this will allow for considerable uncertainty in how spike the latency can be.
// When β > α, we expect lower precision and higher variance
Alpha: 2.0,
Beta: 2.5,
}
type recursiveMapping struct {
kind storage.TupleMapperKind
tuplesetRelation string
allowedUserTypeRestrictions []*openfgav1.RelationReference
}
func (c *LocalChecker) recursiveUserset(_ context.Context, req *ResolveCheckRequest, _ []*openfgav1.RelationReference, rightIter storage.TupleKeyIterator) CheckHandlerFunc {
return func(ctx context.Context) (*ResolveCheckResponse, error) {
typesys, _ := typesystem.TypesystemFromContext(ctx)
directlyRelatedUsersetTypes, _ := typesys.DirectlyRelatedUsersets(tuple.GetType(req.GetTupleKey().GetObject()), req.GetTupleKey().GetRelation())
objectProvider := newRecursiveUsersetObjectProvider(typesys)
return c.recursiveFastPath(ctx, req, rightIter, &recursiveMapping{
kind: storage.UsersetKind,
allowedUserTypeRestrictions: directlyRelatedUsersetTypes,
}, objectProvider)
}
}
// recursiveTTU solves a union relation of the form "{operand1} OR ... {operandN} OR {recursive TTU}"
// rightIter gives the iterator for the recursive TTU.
func (c *LocalChecker) recursiveTTU(_ context.Context, req *ResolveCheckRequest, rewrite *openfgav1.Userset, rightIter storage.TupleKeyIterator) CheckHandlerFunc {
return func(ctx context.Context) (*ResolveCheckResponse, error) {
typesys, _ := typesystem.TypesystemFromContext(ctx)
ttu := rewrite.GetTupleToUserset()
objectProvider := newRecursiveTTUObjectProvider(typesys, ttu)
return c.recursiveFastPath(ctx, req, rightIter, &recursiveMapping{
kind: storage.TTUKind,
tuplesetRelation: ttu.GetTupleset().GetRelation(),
}, objectProvider)
}
}
func (c *LocalChecker) recursiveFastPath(ctx context.Context, req *ResolveCheckRequest, iter storage.TupleKeyIterator, mapping *recursiveMapping, objectProvider objectProvider) (*ResolveCheckResponse, error) {
ctx, span := tracer.Start(ctx, "recursiveFastPath")
defer span.End()
usersetFromUser := hashset.New()
usersetFromObject := hashset.New()
cancellableCtx, cancel := context.WithCancel(ctx)
defer cancel()
objectToUsersetIter := storage.WrapIterator(mapping.kind, iter)
defer objectToUsersetIter.Stop()
objectToUsersetMessageChan := streamedLookupUsersetFromIterator(cancellableCtx, objectToUsersetIter)
res := &ResolveCheckResponse{
Allowed: false,
}
// check to see if there are any recursive userset assigned. If not,
// we don't even need to check the terminal type side.
select {
case <-ctx.Done():
return nil, ctx.Err()
case objectToUsersetMessage, ok := <-objectToUsersetMessageChan:
if !ok {
return res, ctx.Err()
}
if objectToUsersetMessage.err != nil {
return nil, objectToUsersetMessage.err
}
usersetFromObject.Add(objectToUsersetMessage.userset)
}
userToUsersetMessageChan, err := objectProvider.Begin(cancellableCtx, req)
if err != nil {
return nil, err
}
defer objectProvider.End()
userToUsersetDone := false
objectToUsersetDone := false
// NOTE: This loop initializes the terminal type and the first level of depth as this is a breadth first traversal.
// To maintain simplicity the terminal type will be fully loaded, but it could arguably be loaded async.
for !userToUsersetDone || !objectToUsersetDone {
select {
case <-ctx.Done():
return nil, ctx.Err()
case userToUsersetMessage, ok := <-userToUsersetMessageChan:
if !ok {
userToUsersetDone = true
if usersetFromUser.Size() == 0 {
return res, ctx.Err()
}
break
}
if userToUsersetMessage.err != nil {
return nil, userToUsersetMessage.err
}
if processUsersetMessage(userToUsersetMessage.userset, usersetFromUser, usersetFromObject) {
res.Allowed = true
return res, nil
}
case objectToUsersetMessage, ok := <-objectToUsersetMessageChan:
if !ok {
// usersetFromObject must not be empty because we would have caught it earlier.
objectToUsersetDone = true
break
}
if objectToUsersetMessage.err != nil {
return nil, objectToUsersetMessage.err
}
if processUsersetMessage(objectToUsersetMessage.userset, usersetFromObject, usersetFromUser) {
res.Allowed = true
return res, nil
}
}
}
newReq := req.clone()
return c.recursiveMatchUserUserset(ctx, newReq, mapping, usersetFromObject, usersetFromUser)
}
func buildRecursiveMapper(ctx context.Context, req *ResolveCheckRequest, mapping *recursiveMapping) (storage.TupleMapper, error) {
var iter storage.TupleIterator
var err error
typesys, _ := typesystem.TypesystemFromContext(ctx)
ds, _ := storage.RelationshipTupleReaderFromContext(ctx)
consistencyOpts := storage.ConsistencyOptions{
Preference: req.GetConsistency(),
}
switch mapping.kind {
case storage.UsersetKind:
objectType := req.GetTupleKey().GetObject()
relation := req.GetTupleKey().GetRelation()
iter, err = ds.ReadUsersetTuples(ctx, req.GetStoreID(), storage.ReadUsersetTuplesFilter{
Object: objectType,
Relation: relation,
AllowedUserTypeRestrictions: mapping.allowedUserTypeRestrictions,
}, storage.ReadUsersetTuplesOptions{Consistency: consistencyOpts})
case storage.TTUKind:
objectType := req.GetTupleKey().GetObject()
iter, err = ds.Read(ctx, req.GetStoreID(),
storage.ReadFilter{Object: objectType, Relation: mapping.tuplesetRelation, User: ""},
storage.ReadOptions{Consistency: consistencyOpts})
default:
return nil, errors.New("unsupported mapper kind")
}
if err != nil {
return nil, err
}
filteredIter := storage.NewConditionsFilteredTupleKeyIterator(
storage.NewFilteredTupleKeyIterator(
storage.NewTupleKeyIteratorFromTupleIterator(iter),
validation.FilterInvalidTuples(typesys),
),
checkutil.BuildTupleKeyConditionFilter(ctx, req.GetContext(), typesys),
)
return storage.WrapIterator(mapping.kind, filteredIter), nil
}
func (c *LocalChecker) recursiveMatchUserUserset(ctx context.Context, req *ResolveCheckRequest, mapping *recursiveMapping, currentLevelFromObject *hashset.Set, usersetFromUser *hashset.Set) (*ResolveCheckResponse, error) {
ctx, span := tracer.Start(ctx, "recursiveMatchUserUserset", trace.WithAttributes(
attribute.Int("first_level_size", currentLevelFromObject.Size()),
attribute.Int("terminal_type_size", usersetFromUser.Size()),
))
defer span.End()
checkOutcomeChan := make(chan checkOutcome, c.concurrencyLimit)
cancellableCtx, cancel := context.WithCancel(ctx)
wg := sync.WaitGroup{}
defer func() {
cancel()
// We need to wait always to avoid a goroutine leak.
wg.Wait()
}()
wg.Add(1)
go func() {
c.breadthFirstRecursiveMatch(cancellableCtx, req, mapping, &sync.Map{}, currentLevelFromObject, usersetFromUser, checkOutcomeChan)
wg.Done()
}()
var finalErr error
finalResult := &ResolveCheckResponse{
Allowed: false,
}
ConsumerLoop:
for {
select {
case <-ctx.Done():
break ConsumerLoop
case outcome, ok := <-checkOutcomeChan:
if !ok {
break ConsumerLoop
}
if outcome.err != nil {
finalErr = outcome.err
break // continue
}
if outcome.resp.Allowed {
finalErr = nil
finalResult = outcome.resp
break ConsumerLoop
}
}
}
// context cancellation from upstream (e.g. client)
if ctx.Err() != nil {
finalErr = ctx.Err()
}
if finalErr != nil {
return nil, finalErr
}
return finalResult, nil
}
// Note that visited does not necessary means that there are cycles. For the following model,
// type user
// type group
//
// relations
// define member: [user, group#member]
//
// We have something like
// group:1#member@group:2#member
// group:1#member@group:3#member
// group:2#member@group:a#member
// group:3#member@group:a#member
// Note that both group:2#member and group:3#member has group:a#member. However, they are not cycles.
func (c *LocalChecker) breadthFirstRecursiveMatch(ctx context.Context, req *ResolveCheckRequest, mapping *recursiveMapping, visitedUserset *sync.Map, currentUsersetLevel *hashset.Set, usersetFromUser *hashset.Set, checkOutcomeChan chan checkOutcome) {
req.GetRequestMetadata().Depth++
if req.GetRequestMetadata().Depth == c.maxResolutionDepth {
concurrency.TrySendThroughChannel(ctx, checkOutcome{err: ErrResolutionDepthExceeded}, checkOutcomeChan)
close(checkOutcomeChan)
return
}
if currentUsersetLevel.Size() == 0 || ctx.Err() != nil {
// nothing else to search for or upstream cancellation
close(checkOutcomeChan)
return
}
relation := req.GetTupleKey().GetRelation()
user := req.GetTupleKey().GetUser()
pool := concurrency.NewPool(ctx, c.concurrencyLimit)
mu := &sync.Mutex{}
nextUsersetLevel := hashset.New()
for _, usersetInterface := range currentUsersetLevel.Values() {
userset := usersetInterface.(string)
_, visited := visitedUserset.LoadOrStore(userset, struct{}{})
if visited {
continue
}
newReq := req.clone()
newReq.TupleKey = tuple.NewTupleKey(userset, relation, user)
mapper, err := buildRecursiveMapper(ctx, newReq, mapping)
if err != nil {
concurrency.TrySendThroughChannel(ctx, checkOutcome{err: err}, checkOutcomeChan)
continue
}
// if the pool is short-circuited, the iterator should be stopped
defer mapper.Stop()
pool.Go(func(ctx context.Context) error {
objectToUsersetMessageChan := streamedLookupUsersetFromIterator(ctx, mapper)
for usersetMsg := range objectToUsersetMessageChan {
if usersetMsg.err != nil {
concurrency.TrySendThroughChannel(ctx, checkOutcome{err: usersetMsg.err}, checkOutcomeChan)
return nil
}
userset := usersetMsg.userset
if usersetFromUser.Contains(userset) {
concurrency.TrySendThroughChannel(ctx, checkOutcome{resp: &ResolveCheckResponse{
Allowed: true,
}}, checkOutcomeChan)
return ErrShortCircuit // cancel will be propagated to the remaining goroutines
}
mu.Lock()
nextUsersetLevel.Add(userset)
mu.Unlock()
}
return nil
})
}
// wait for all checks to wrap up
// if a match was found, clean up
if err := pool.Wait(); errors.Is(err, ErrShortCircuit) {
close(checkOutcomeChan)
return
}
c.breadthFirstRecursiveMatch(ctx, req, mapping, visitedUserset, nextUsersetLevel, usersetFromUser, checkOutcomeChan)
}
package graph
import (
"errors"
"strings"
"sync/atomic"
"time"
"golang.org/x/exp/maps"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/storage"
)
type ResolveCheckRequest struct {
StoreID string
AuthorizationModelID string // TODO replace with typesystem
TupleKey *openfgav1.TupleKey
ContextualTuples []*openfgav1.TupleKey
Context *structpb.Struct
RequestMetadata *ResolveCheckRequestMetadata
VisitedPaths map[string]struct{}
Consistency openfgav1.ConsistencyPreference
LastCacheInvalidationTime time.Time
// Invariant parts of a check request are those that don't change in sub-problems
// AuthorizationModelID, StoreID, Context, and ContextualTuples.
// the invariantCacheKey is computed once per request, and passed to sub-problems via copy in .clone()
invariantCacheKey string
}
type ResolveCheckRequestMetadata struct {
// Thinking of a Check as a tree of evaluations,
// Depth is the current level in the tree in the current path that we are exploring.
// When we jump one level, we increment it by 1. If it hits maxResolutionDepth (resolveNodeLimit), we throw ErrResolutionDepthExceeded.
Depth uint32
// DispatchCounter is the address to a shared counter that keeps track of how many calls to ResolveCheck we had to do
// to solve the root/parent problem.
// The contents of this counter will be written by concurrent goroutines.
// After the root problem has been solved, this value can be read.
DispatchCounter *atomic.Uint32
// WasThrottled indicates whether the request was throttled
WasThrottled *atomic.Bool
}
type ResolveCheckRequestParams struct {
StoreID string
TupleKey *openfgav1.TupleKey
ContextualTuples []*openfgav1.TupleKey
Context *structpb.Struct
Consistency openfgav1.ConsistencyPreference
LastCacheInvalidationTime time.Time
AuthorizationModelID string
}
func NewCheckRequestMetadata() *ResolveCheckRequestMetadata {
return &ResolveCheckRequestMetadata{
DispatchCounter: new(atomic.Uint32),
WasThrottled: new(atomic.Bool),
}
}
func NewResolveCheckRequest(
params ResolveCheckRequestParams,
) (*ResolveCheckRequest, error) {
if params.AuthorizationModelID == "" {
return nil, errors.New("missing authorization_model_id")
}
if params.StoreID == "" {
return nil, errors.New("missing store_id")
}
r := &ResolveCheckRequest{
StoreID: params.StoreID,
AuthorizationModelID: params.AuthorizationModelID,
TupleKey: params.TupleKey,
ContextualTuples: params.ContextualTuples,
Context: params.Context,
VisitedPaths: make(map[string]struct{}),
RequestMetadata: NewCheckRequestMetadata(),
Consistency: params.Consistency,
// avoid having to read from cache consistently by propagating it
LastCacheInvalidationTime: params.LastCacheInvalidationTime,
}
keyBuilder := &strings.Builder{}
err := storage.WriteInvariantCheckCacheKey(keyBuilder, &storage.CheckCacheKeyParams{
StoreID: params.StoreID,
AuthorizationModelID: params.AuthorizationModelID,
ContextualTuples: params.ContextualTuples,
Context: params.Context,
})
if err != nil {
return nil, err
}
r.invariantCacheKey = keyBuilder.String()
return r, nil
}
func (r *ResolveCheckRequest) clone() *ResolveCheckRequest {
var requestMetadata *ResolveCheckRequestMetadata
origRequestMetadata := r.GetRequestMetadata()
if origRequestMetadata != nil {
requestMetadata = &ResolveCheckRequestMetadata{
DispatchCounter: origRequestMetadata.DispatchCounter,
Depth: origRequestMetadata.Depth,
WasThrottled: origRequestMetadata.WasThrottled,
}
}
var tupleKey *openfgav1.TupleKey
if origTupleKey := r.GetTupleKey(); origTupleKey != nil {
tupleKey = proto.Clone(origTupleKey).(*openfgav1.TupleKey)
}
return &ResolveCheckRequest{
StoreID: r.GetStoreID(),
AuthorizationModelID: r.GetAuthorizationModelID(),
TupleKey: tupleKey,
ContextualTuples: r.GetContextualTuples(),
Context: r.GetContext(),
RequestMetadata: requestMetadata,
VisitedPaths: maps.Clone(r.GetVisitedPaths()),
Consistency: r.GetConsistency(),
LastCacheInvalidationTime: r.GetLastCacheInvalidationTime(),
invariantCacheKey: r.GetInvariantCacheKey(),
}
}
func (r *ResolveCheckRequest) GetStoreID() string {
if r == nil {
return ""
}
return r.StoreID
}
func (r *ResolveCheckRequest) GetAuthorizationModelID() string {
if r == nil {
return ""
}
return r.AuthorizationModelID
}
func (r *ResolveCheckRequest) GetTupleKey() *openfgav1.TupleKey {
if r == nil {
return nil
}
return r.TupleKey
}
func (r *ResolveCheckRequest) GetContextualTuples() []*openfgav1.TupleKey {
if r == nil {
return nil
}
return r.ContextualTuples
}
func (r *ResolveCheckRequest) GetRequestMetadata() *ResolveCheckRequestMetadata {
if r == nil {
return nil
}
return r.RequestMetadata
}
func (r *ResolveCheckRequest) GetContext() *structpb.Struct {
if r == nil {
return nil
}
return r.Context
}
func (r *ResolveCheckRequest) GetConsistency() openfgav1.ConsistencyPreference {
if r == nil {
return openfgav1.ConsistencyPreference_UNSPECIFIED
}
return r.Consistency
}
func (r *ResolveCheckRequest) GetVisitedPaths() map[string]struct{} {
if r == nil {
return map[string]struct{}{}
}
return r.VisitedPaths
}
func (r *ResolveCheckRequest) GetLastCacheInvalidationTime() time.Time {
if r == nil {
return time.Time{}
}
return r.LastCacheInvalidationTime
}
func (r *ResolveCheckRequest) GetInvariantCacheKey() string {
if r == nil {
return ""
}
return r.invariantCacheKey
}
package graph
import "time"
type ResolveCheckResponseMetadata struct {
// Number of Read operations accumulated after this request completes.
DatastoreQueryCount uint32
// Number of items read from the database after this request completes.
DatastoreItemCount uint64
// Indicates if the ResolveCheck subproblem that was evaluated involved
// a cycle in the evaluation.
CycleDetected bool
// The total time it took to resolve the check request.
Duration time.Duration
}
// clone clones the provided ResolveCheckResponse.
func (r *ResolveCheckResponse) clone() *ResolveCheckResponse {
return &ResolveCheckResponse{
Allowed: r.GetAllowed(),
ResolutionMetadata: r.GetResolutionMetadata(),
}
}
type ResolveCheckResponse struct {
Allowed bool
ResolutionMetadata ResolveCheckResponseMetadata
}
func (r *ResolveCheckResponse) GetCycleDetected() bool {
if r == nil {
return false
}
return r.GetResolutionMetadata().CycleDetected
}
func (r *ResolveCheckResponse) GetAllowed() bool {
if r == nil {
return false
}
return r.Allowed
}
func (r *ResolveCheckResponse) GetResolutionMetadata() ResolveCheckResponseMetadata {
if r == nil {
return ResolveCheckResponseMetadata{}
}
return r.ResolutionMetadata
}
package graph
import (
"context"
"sync"
"time"
"go.uber.org/zap"
"github.com/openfga/openfga/pkg/logger"
)
const Hundred = 100
type ShadowResolverOpt func(*ShadowResolver)
func ShadowResolverWithName(name string) ShadowResolverOpt {
return func(shadowResolver *ShadowResolver) {
shadowResolver.name = name
}
}
func ShadowResolverWithTimeout(timeout time.Duration) ShadowResolverOpt {
return func(shadowResolver *ShadowResolver) {
shadowResolver.shadowTimeout = timeout
}
}
func ShadowResolverWithLogger(logger logger.Logger) ShadowResolverOpt {
return func(shadowResolver *ShadowResolver) {
shadowResolver.logger = logger
}
}
type ShadowResolver struct {
name string
main CheckResolver
shadow CheckResolver
shadowTimeout time.Duration
logger logger.Logger
// only used for testing signals
wg *sync.WaitGroup
}
var _ CheckResolver = (*ShadowResolver)(nil)
func (s ShadowResolver) ResolveCheck(ctx context.Context, req *ResolveCheckRequest) (*ResolveCheckResponse, error) {
ctxClone := context.WithoutCancel(ctx) // needs typesystem and datastore etc
mainStart := time.Now()
res, err := s.main.ResolveCheck(ctx, req)
mainDuration := time.Since(mainStart)
if err != nil {
return nil, err
}
resClone := res.clone()
reqClone := req.clone()
reqClone.VisitedPaths = nil // reset completely for evaluation
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer func() {
if r := recover(); r != nil {
s.logger.ErrorWithContext(ctx, "panic recovered",
zap.String("resolver", s.name),
zap.Any("error", err),
zap.String("request", reqClone.GetTupleKey().String()),
zap.String("store_id", reqClone.GetStoreID()),
zap.String("model_id", reqClone.GetAuthorizationModelID()),
zap.String("function", "ShadowResolver.ResolveCheck"),
)
}
}()
ctx, cancel := context.WithTimeout(ctxClone, s.shadowTimeout)
defer cancel()
shadowStart := time.Now()
shadowRes, err := s.shadow.ResolveCheck(ctx, reqClone)
shadowDuration := time.Since(shadowStart)
if err != nil {
s.logger.WarnWithContext(ctx, "shadow check errored",
zap.String("resolver", s.name),
zap.Error(err),
zap.String("request", reqClone.GetTupleKey().String()),
zap.String("store_id", reqClone.GetStoreID()),
zap.String("model_id", reqClone.GetAuthorizationModelID()),
)
return
}
if shadowRes.GetAllowed() != resClone.GetAllowed() {
s.logger.InfoWithContext(ctx, "shadow check difference",
zap.String("resolver", s.name),
zap.String("request", reqClone.GetTupleKey().String()),
zap.String("store_id", reqClone.GetStoreID()),
zap.String("model_id", reqClone.GetAuthorizationModelID()),
zap.Bool("main", resClone.GetAllowed()),
zap.Bool("main_cycle", resClone.GetCycleDetected()),
zap.Int64("main_latency", mainDuration.Milliseconds()),
zap.Bool("shadow", shadowRes.GetAllowed()),
zap.Bool("shadow_cycle", shadowRes.GetCycleDetected()),
zap.Int64("shadow_latency", shadowDuration.Milliseconds()),
)
} else {
s.logger.InfoWithContext(ctx, "shadow check match",
zap.Int64("main_latency", mainDuration.Milliseconds()),
zap.Int64("shadow_latency", shadowDuration.Milliseconds()),
)
}
}()
return res, nil
}
func (s ShadowResolver) Close() {
s.main.Close()
s.shadow.Close()
}
func (s ShadowResolver) SetDelegate(delegate CheckResolver) {
s.main.SetDelegate(delegate)
// shadow should result in noop regardless of outcome
}
func (s ShadowResolver) GetDelegate() CheckResolver {
return s.main.GetDelegate()
}
func NewShadowChecker(main CheckResolver, shadow CheckResolver, opts ...ShadowResolverOpt) *ShadowResolver {
r := &ShadowResolver{name: "check", main: main, shadow: shadow, wg: &sync.WaitGroup{}}
for _, opt := range opts {
opt(r)
}
return r
}
package graph
import (
"context"
"errors"
"fmt"
"time"
"github.com/emirpasic/gods/sets/hashset"
"github.com/sourcegraph/conc/panics"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/checkutil"
"github.com/openfga/openfga/internal/concurrency"
"github.com/openfga/openfga/internal/iterator"
"github.com/openfga/openfga/internal/planner"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
const IteratorMinBatchThreshold = 100
const BaseIndex = 0
const DifferenceIndex = 1
const weightTwoResolver = "weight2"
// This strategy is configured to show that it has proven fast and consistent.
var weight2Plan = &planner.PlanConfig{
Name: weightTwoResolver,
InitialGuess: 20 * time.Millisecond,
// High Lambda: Represents strong confidence in the initial guess. It's like
// starting with the belief of having already seen 10 good runs.
Lambda: 10.0,
// High Alpha, Low Beta: Creates a very NARROW belief about variance.
// This tells the planner: "I am very confident that the performance is
// consistently close to 10ms". A single slow run will be a huge surprise
// and will dramatically shift this belief.
// High expected precision: 𝐸[𝜏]= 𝛼/𝛽 = 20/2 = 10
// Low expected variance: E[σ2]= β/(α−1) =2/9 = 0.105, narrow jitter
// A slow sample will look like an outlier and move the posterior noticeably but overall this prior exploits.
Alpha: 20,
Beta: 2,
}
var ErrShortCircuit = errors.New("short circuit")
type fastPathSetHandler func(context.Context, *iterator.Streams, chan<- *iterator.Msg)
func (c *LocalChecker) weight2Userset(_ context.Context, req *ResolveCheckRequest, usersets []*openfgav1.RelationReference, iter storage.TupleKeyIterator) CheckHandlerFunc {
return func(ctx context.Context) (*ResolveCheckResponse, error) {
cancellableCtx, cancel := context.WithCancel(ctx)
defer cancel()
leftChans, err := produceLeftChannels(cancellableCtx, req, usersets, checkutil.BuildUsersetV2RelationFunc())
if err != nil {
return nil, err
}
if len(leftChans) == 0 {
return &ResolveCheckResponse{
Allowed: false,
}, nil
}
return c.weight2(ctx, leftChans, storage.WrapIterator(storage.UsersetKind, iter))
}
}
func (c *LocalChecker) weight2TTU(ctx context.Context, req *ResolveCheckRequest, rewrite *openfgav1.Userset, iter storage.TupleKeyIterator) CheckHandlerFunc {
return func(ctx context.Context) (*ResolveCheckResponse, error) {
typesys, _ := typesystem.TypesystemFromContext(ctx)
objectType := tuple.GetType(req.GetTupleKey().GetObject())
tuplesetRelation := rewrite.GetTupleToUserset().GetTupleset().GetRelation()
computedRelation := rewrite.GetTupleToUserset().GetComputedUserset().GetRelation()
possibleParents, err := typesys.GetDirectlyRelatedUserTypes(objectType, tuplesetRelation)
if err != nil {
return nil, err
}
cancellableCtx, cancel := context.WithCancel(ctx)
defer cancel()
leftChans, err := produceLeftChannels(cancellableCtx, req, possibleParents, checkutil.BuildTTUV2RelationFunc(computedRelation))
if err != nil {
return nil, err
}
if len(leftChans) == 0 {
return &ResolveCheckResponse{
Allowed: false,
}, nil
}
return c.weight2(ctx, leftChans, storage.WrapIterator(storage.TTUKind, iter))
}
}
// weight2 attempts to find the intersection across 2 producers (channels) of ObjectIDs.
// In the case of a TTU:
// Right channel is the result set of the Read of ObjectID/Relation that yields the User's ObjectID.
// Left channel is the result set of ReadStartingWithUser of User/Relation that yields Object's ObjectID.
// From the perspective of the model, the left hand side of a TTU is the computed relationship being expanded.
func (c *LocalChecker) weight2(ctx context.Context, leftChans []<-chan *iterator.Msg, iter storage.TupleMapper) (*ResolveCheckResponse, error) {
ctx, span := tracer.Start(ctx, "weight2")
defer span.End()
cancellableCtx, cancel := context.WithCancel(ctx)
leftChan := iterator.FanInIteratorChannels(cancellableCtx, leftChans)
rightChan := streamedLookupUsersetFromIterator(cancellableCtx, iter)
rightOpen := true
leftOpen := true
defer func() {
cancel()
iter.Stop()
if !leftOpen {
return
}
iterator.Drain(leftChan)
}()
res := &ResolveCheckResponse{
Allowed: false,
}
rightSet := hashset.New()
leftSet := hashset.New()
select {
case <-ctx.Done():
return nil, ctx.Err()
case r, ok := <-rightChan:
if !ok {
return res, ctx.Err()
}
if r.err != nil {
return nil, r.err
}
rightSet.Add(r.userset)
}
var lastErr error
ConsumerLoop:
for leftOpen || rightOpen {
select {
case <-ctx.Done():
lastErr = ctx.Err()
break ConsumerLoop
case msg, ok := <-leftChan:
if !ok {
leftOpen = false
if leftSet.Size() == 0 {
if ctx.Err() != nil {
lastErr = ctx.Err()
}
break ConsumerLoop
}
break
}
if msg.Err != nil {
lastErr = msg.Err
break ConsumerLoop
}
for {
t, err := msg.Iter.Next(ctx)
if err != nil {
msg.Iter.Stop()
if storage.IterIsDoneOrCancelled(err) {
break
}
lastErr = err
continue
}
if processUsersetMessage(t, leftSet, rightSet) {
msg.Iter.Stop()
res.Allowed = true
lastErr = nil
break ConsumerLoop
}
}
case msg, ok := <-rightChan:
if !ok {
rightOpen = false
break
}
if msg.err != nil {
lastErr = msg.err
continue
}
if processUsersetMessage(msg.userset, rightSet, leftSet) {
res.Allowed = true
lastErr = nil
break ConsumerLoop
}
}
}
return res, lastErr
}
func produceLeftChannels(
ctx context.Context,
req *ResolveCheckRequest,
relationReferences []*openfgav1.RelationReference,
relationFunc checkutil.V2RelationFunc,
) ([]<-chan *iterator.Msg, error) {
typesys, _ := typesystem.TypesystemFromContext(ctx)
leftChans := make([]<-chan *iterator.Msg, 0, len(relationReferences))
for _, parentType := range relationReferences {
relation := relationFunc(parentType)
rel, err := typesys.GetRelation(parentType.GetType(), relation)
if err != nil {
continue
}
r := req.clone()
r.TupleKey = &openfgav1.TupleKey{
Object: tuple.BuildObject(parentType.GetType(), "ignore"),
// depending on relationFunc, it will return the parentType's relation (userset) or computedRelation (TTU)
Relation: relation,
User: r.GetTupleKey().GetUser(),
}
leftChan, err := fastPathRewrite(ctx, r, rel.GetRewrite())
if err != nil {
// if the resolver already started it needs to be drained
if len(leftChans) > 0 {
iterator.Drain(iterator.FanInIteratorChannels(ctx, leftChans))
}
return nil, err
}
leftChans = append(leftChans, leftChan)
}
return leftChans, nil
}
func fastPathNoop(_ context.Context, _ *ResolveCheckRequest) (chan *iterator.Msg, error) {
iterChan := make(chan *iterator.Msg)
close(iterChan)
return iterChan, nil
}
// fastPathDirect assumes that req.Object + req.Relation is a directly assignable relation, e.g. define viewer: [user, user:*].
// It returns a channel with one element, and then closes the channel.
// The element is an iterator over all objects that are directly related to the user or the wildcard (if applicable).
func fastPathDirect(ctx context.Context, req *ResolveCheckRequest) (chan *iterator.Msg, error) {
typesys, _ := typesystem.TypesystemFromContext(ctx)
ds, _ := storage.RelationshipTupleReaderFromContext(ctx)
tk := req.GetTupleKey()
objRel := tuple.ToObjectRelationString(tuple.GetType(tk.GetObject()), tk.GetRelation())
i, err := checkutil.IteratorReadStartingFromUser(ctx, typesys, ds, req, objRel, nil, true)
if err != nil {
return nil, err
}
iterChan := make(chan *iterator.Msg, 1)
iter := storage.WrapIterator(storage.ObjectIDKind, i)
if !concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Iter: iter}, iterChan) {
iter.Stop() // will not be received to be cleaned up
}
close(iterChan)
return iterChan, nil
}
func fastPathComputed(ctx context.Context, req *ResolveCheckRequest, rewrite *openfgav1.Userset) (chan *iterator.Msg, error) {
typesys, _ := typesystem.TypesystemFromContext(ctx)
computedRelation := rewrite.GetComputedUserset().GetRelation()
childRequest := req.clone()
childRequest.TupleKey.Relation = computedRelation
objectType := tuple.GetType(childRequest.GetTupleKey().GetObject())
rel, err := typesys.GetRelation(objectType, computedRelation)
if err != nil {
return nil, err
}
return fastPathRewrite(ctx, childRequest, rel.GetRewrite())
}
// add the nextItemInSliceStreams to specified batch. If batch is full, try to send batch to outChan and clear slice.
// If nextItemInSliceStreams has error, will also send message to specified outChan.
func addNextItemInSliceStreamsToBatch(ctx context.Context, streamSlices []*iterator.Stream, streamsToProcess []int, batch []string, outChan chan<- *iterator.Msg) ([]string, error) {
item, err := iterator.NextItemInSliceStreams(ctx, streamSlices, streamsToProcess)
if err != nil {
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Err: err}, outChan)
return nil, err
}
if item != "" {
batch = append(batch, item)
}
if len(batch) > IteratorMinBatchThreshold {
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Iter: storage.NewStaticIterator[string](batch)}, outChan)
batch = make([]string, 0)
}
return batch, nil
}
func fastPathUnion(ctx context.Context, streams *iterator.Streams, outChan chan<- *iterator.Msg) {
batch := make([]string, 0)
defer func() {
// flush
if len(batch) > 0 {
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Iter: storage.NewStaticIterator[string](batch)}, outChan)
}
close(outChan)
streams.Stop()
}()
/*
collect iterators from all channels, until all drained
start performing union algorithm across the heads, if an iterator is empty, poll once again the source
ask to see if the channel has a new iterator, otherwise consider it done
*/
for streams.GetActiveStreamsCount() > 0 {
if ctx.Err() != nil {
return
}
iterStreams, err := streams.CleanDone(ctx)
if err != nil {
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Err: err}, outChan)
return
}
allIters := true
minObject := ""
itersWithEqualObject := make([]int, 0)
for idx, stream := range iterStreams {
v, err := stream.Head(ctx)
if err != nil {
if storage.IterIsDoneOrCancelled(err) {
allIters = false
// we need to ensure we have all iterators at all times
break
}
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Err: err}, outChan)
return
}
// initialize
if idx == 0 {
minObject = v
}
if minObject == v {
itersWithEqualObject = append(itersWithEqualObject, idx)
} else if minObject > v {
minObject = v
itersWithEqualObject = []int{idx}
}
}
if !allIters {
// we need to ensure we have all iterators at all times
continue
}
// all iterators with the same value move forward
batch, err = addNextItemInSliceStreamsToBatch(ctx, iterStreams, itersWithEqualObject, batch, outChan)
if err != nil {
// We are relying on the fact that we have called .Head(ctx) earlier
// and no one else should have called the iterator (especially since it is
// protected by mutex). Therefore, it is impossible for the iterator to return
// Done here. Hence, any error received here should be considered as legitimate
// errors.
return
}
}
}
func fastPathIntersection(ctx context.Context, streams *iterator.Streams, outChan chan<- *iterator.Msg) {
batch := make([]string, 0)
defer func() {
// flush
if len(batch) > 0 {
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Iter: storage.NewStaticIterator[string](batch)}, outChan)
}
close(outChan)
streams.Stop()
}()
/*
collect iterators from all channels, once none are nil
start performing intersection algorithm across the heads, if an iterator is drained
ask to see if the channel has a new iterator, otherwise consider it done
exit if one of the channels closes as there is no more possible intersection of all
*/
childrenTotal := streams.GetActiveStreamsCount()
for streams.GetActiveStreamsCount() == childrenTotal {
if ctx.Err() != nil {
return
}
iterStreams, err := streams.CleanDone(ctx)
if err != nil {
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Err: err}, outChan)
return
}
if len(iterStreams) != childrenTotal {
// short circuit
return
}
maxObject := ""
itersWithEqualObject := make([]int, 0)
allIters := true
for idx, stream := range iterStreams {
v, err := stream.Head(ctx)
if err != nil {
if storage.IterIsDoneOrCancelled(err) {
allIters = false
// we need to ensure we have all iterators at all times
break
}
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Err: err}, outChan)
return
}
if idx == 0 {
maxObject = v
}
if maxObject == v {
itersWithEqualObject = append(itersWithEqualObject, idx)
} else if maxObject < v {
maxObject = v
itersWithEqualObject = []int{idx}
}
}
if !allIters {
// we need to ensure we have all iterators at all times
continue
}
// all children have the same value
if len(itersWithEqualObject) == childrenTotal {
// all iterators have the same value thus flush entry and move iterators
batch, err = addNextItemInSliceStreamsToBatch(ctx, iterStreams, itersWithEqualObject, batch, outChan)
if err != nil {
// We are relying on the fact that we have called .Head(ctx) earlier
// and no one else should have called the iterator (especially since it is
// protected by mutex). Therefore, it is impossible for the iterator to return
// Done here. Hence, any error received here should be considered as legitimate
// errors.
return
}
continue
}
// move all iterators to less than the MAX to be >= than MAX
for _, stream := range iterStreams {
err = stream.SkipToTargetObject(ctx, maxObject)
if err != nil {
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Err: err}, outChan)
return
}
}
}
}
func fastPathDifference(ctx context.Context, streams *iterator.Streams, outChan chan<- *iterator.Msg) {
batch := make([]string, 0)
defer func() {
// flush
if len(batch) > 0 {
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Iter: storage.NewStaticIterator[string](batch)}, outChan)
}
close(outChan)
streams.Stop()
}()
// both base and difference are still remaining
for streams.GetActiveStreamsCount() == 2 {
if ctx.Err() != nil {
return
}
iterStreams, err := streams.CleanDone(ctx)
if err != nil {
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Err: err}, outChan)
return
}
if len(iterStreams) != 2 {
// short circuit
break
}
allIters := true
base := ""
diff := ""
for idx, stream := range iterStreams {
v, err := stream.Head(ctx)
if err != nil {
if storage.IterIsDoneOrCancelled(err) {
allIters = false
// we need to ensure we have all iterators at all times
break
}
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Err: err}, outChan)
return
}
if idx == BaseIndex {
base = v
}
if idx == DifferenceIndex {
diff = v
}
}
if !allIters {
// we need to ensure we have all iterators at all times
continue
}
// move both iterator heads
if base == diff {
_, err = iterator.NextItemInSliceStreams(ctx, iterStreams, []int{BaseIndex, DifferenceIndex})
if err != nil {
// We are relying on the fact that we have called .Head(ctx) earlier
// and no one else should have called the iterator (especially since it is
// protected by mutex). Therefore, it is impossible for the iterator to return
// Done here. Hence, any error received here should be considered as legitimate
// errors.
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Err: err}, outChan)
return
}
continue
}
if diff > base {
batch, err = addNextItemInSliceStreamsToBatch(ctx, iterStreams, []int{BaseIndex}, batch, outChan)
if err != nil {
// We are relying on the fact that we have called .Head(ctx) earlier
// and no one else should have called the iterator (especially since it is
// protected by mutex). Therefore, it is impossible for the iterator to return
// Done here. Hence, any error received here should be considered as legitimate
// errors.
return
}
continue
}
// diff < base, then move the diff to catch up with base
err = iterStreams[DifferenceIndex].SkipToTargetObject(ctx, base)
if err != nil {
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Err: err}, outChan)
return
}
}
iterStreams, err := streams.CleanDone(ctx)
if err != nil {
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Err: err}, outChan)
return
}
// drain the base
if len(iterStreams) == 1 && iterStreams[BaseIndex].Idx() == BaseIndex {
for len(iterStreams) == 1 {
stream := iterStreams[BaseIndex]
items, err := stream.Drain(ctx)
if err != nil {
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Err: err}, outChan)
return
}
batch = append(batch, items...)
if len(batch) > IteratorMinBatchThreshold {
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Iter: storage.NewStaticIterator[string](batch)}, outChan)
batch = make([]string, 0)
}
iterStreams, err = streams.CleanDone(ctx)
if err != nil {
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Err: err}, outChan)
return
}
}
}
}
// fastPathOperationSetup returns a channel with a number of elements that is >= the number of children.
// Each element is an iterator.
// The caller must wait until the channel is closed.
func fastPathOperationSetup(ctx context.Context, req *ResolveCheckRequest, resolver fastPathSetHandler, children ...*openfgav1.Userset) (chan *iterator.Msg, error) {
iterStreams := make([]*iterator.Stream, 0, len(children))
for idx, child := range children {
producerChan, err := fastPathRewrite(ctx, req, child)
if err != nil {
return nil, err
}
iterStreams = append(iterStreams, iterator.NewStream(idx, producerChan))
}
outChan := make(chan *iterator.Msg, len(children))
go func() {
recoveredError := panics.Try(func() {
resolver(ctx, iterator.NewStreams(iterStreams), outChan)
})
if recoveredError != nil {
concurrency.TrySendThroughChannel(ctx, &iterator.Msg{Err: fmt.Errorf("%w: %s", ErrPanic, recoveredError.AsError())}, outChan)
}
}()
return outChan, nil
}
// fastPathRewrite returns a channel that will contain an unknown but finite number of elements.
// The channel is closed at the end.
func fastPathRewrite(
ctx context.Context,
req *ResolveCheckRequest,
rewrite *openfgav1.Userset,
) (chan *iterator.Msg, error) {
switch rw := rewrite.GetUserset().(type) {
case *openfgav1.Userset_This:
return fastPathDirect(ctx, req)
case *openfgav1.Userset_ComputedUserset:
return fastPathComputed(ctx, req, rewrite)
case *openfgav1.Userset_Union:
return fastPathOperationSetup(ctx, req, fastPathUnion, rw.Union.GetChild()...)
case *openfgav1.Userset_Intersection:
return fastPathOperationSetup(ctx, req, fastPathIntersection, rw.Intersection.GetChild()...)
case *openfgav1.Userset_Difference:
return fastPathOperationSetup(ctx, req, fastPathDifference, rw.Difference.GetBase(), rw.Difference.GetSubtract())
case *openfgav1.Userset_TupleToUserset:
return fastPathNoop(ctx, req)
default:
return nil, ErrUnknownSetOperator
}
}
package iterator
import (
"context"
"sync"
"github.com/openfga/openfga/internal/concurrency"
)
func Drain(ch <-chan *Msg) *sync.WaitGroup {
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
for msg := range ch {
if msg.Iter != nil {
msg.Iter.Stop()
}
}
wg.Done()
}()
return wg
}
func FanInIteratorChannels(ctx context.Context, chans []<-chan *Msg) <-chan *Msg {
limit := len(chans)
out := make(chan *Msg, limit)
if limit == 0 {
close(out)
return out
}
pool := concurrency.NewPool(ctx, limit)
for _, c := range chans {
pool.Go(func(ctx context.Context) error {
for v := range c {
if !concurrency.TrySendThroughChannel(ctx, v, out) {
if v.Iter != nil {
v.Iter.Stop()
}
}
}
return nil
})
}
go func() {
// NOTE: the consumer of this channel will block waiting for it to close
_ = pool.Wait()
close(out)
}()
return out
}
package iterator
import (
"context"
"fmt"
"slices"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/tuple"
)
type Msg struct {
Iter storage.Iterator[string]
Err error
}
// Stream aggregates multiple iterators that are sent to a source channel into one iterator.
type Stream struct {
idx int
buffer storage.Iterator[string]
sourceIsClosed bool // sourceIsClosed is set when the buffer is Done and all the source are exhausted
source chan *Msg
}
func NewStream(idx int, source chan *Msg) *Stream {
return &Stream{
idx: idx,
source: source,
}
}
func (s *Stream) Idx() int {
return s.idx
}
// Head returns the first item in the buffer. If the Head is sourceIsClosed or
// cancelled, it will stop the buffer and set the buffer to nil.
func (s *Stream) Head(ctx context.Context) (string, error) {
if s.buffer == nil {
return "", storage.ErrIteratorDone
}
t, err := s.buffer.Head(ctx)
if err != nil {
if storage.IterIsDoneOrCancelled(err) {
s.buffer.Stop()
s.buffer = nil
}
return "", err
}
return t, nil
}
func (s *Stream) Next(ctx context.Context) (string, error) {
if s.buffer == nil {
return "", storage.ErrIteratorDone
}
t, err := s.buffer.Next(ctx)
if err != nil {
if storage.IterIsDoneOrCancelled(err) {
s.buffer.Stop()
s.buffer = nil
}
return "", err
}
return t, nil
}
func (s *Stream) Stop() {
if s.buffer != nil {
s.buffer.Stop()
}
for msg := range s.source {
if msg.Iter != nil {
msg.Iter.Stop()
}
}
}
// SkipToTargetObject moves the buffer until the buffer's head object is >= target object.
// If the buffer is drained and no more items, it will set to stop and buffer will be nil.
func (s *Stream) SkipToTargetObject(ctx context.Context, target string) error {
if !tuple.IsValidObject(target) {
return fmt.Errorf("invalid target object: %s", target)
}
t, err := s.Head(ctx)
if err != nil {
if storage.IterIsDoneOrCancelled(err) {
return nil
}
return err
}
tmpKey := t
for tmpKey < target {
_, _ = s.Next(ctx)
t, err = s.Head(ctx)
if err != nil {
if storage.IterIsDoneOrCancelled(err) {
break
}
return err
}
tmpKey = t
}
return nil
}
// Drain all item in the stream's buffer and return these items.
func (s *Stream) Drain(ctx context.Context) ([]string, error) {
var batch []string
for {
t, err := s.Next(ctx)
if err != nil {
if storage.IterIsDoneOrCancelled(err) {
break
}
return nil, err
}
batch = append(batch, t)
}
return batch, nil
}
func (s *Stream) fetchSource(ctx context.Context) error {
if s.buffer != nil || s.sourceIsClosed {
// no need to poll further
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case i, ok := <-s.source:
if !ok {
s.sourceIsClosed = true
break
}
if i.Err != nil {
return i.Err
}
s.buffer = i.Iter
}
return nil
}
func (s *Stream) isDone() bool {
return s.sourceIsClosed && s.buffer == nil
}
// NextItemInSliceStreams will advance all streamSlices specified in streamToProcess and return the item advanced.
// Assumption is that the stream slices first item is identical, and we want to advance all these streams.
func NextItemInSliceStreams(ctx context.Context, streamSlices []*Stream, streamToProcess []int) (string, error) {
var item string
var err error
for _, iterIdx := range streamToProcess {
item, err = streamSlices[iterIdx].Next(ctx)
if err != nil {
return "", err
}
}
return item, nil
}
type Streams struct {
streams []*Stream
}
func NewStreams(streams []*Stream) *Streams {
return &Streams{
streams: streams,
}
}
// GetActiveStreamsCount will return the active streams from the last time CleanDone was called.
func (s *Streams) GetActiveStreamsCount() int {
return len(s.streams)
}
// Stop will Drain all streams completely to avoid leaving dangling resources
// NOTE: caller should consider running this in a goroutine to not block.
func (s *Streams) Stop() {
for _, stream := range s.streams {
stream.Stop()
}
}
// CleanDone will clean up the sourceIsClosed iterator streams and return a list of the remaining active streams.
// To be considered active your source channel must still be open.
func (s *Streams) CleanDone(ctx context.Context) ([]*Stream, error) {
for _, stream := range s.streams {
err := stream.fetchSource(ctx)
if err != nil {
return nil, err
}
}
// clean up all empty entries that are both sourceIsClosed and drained
s.streams = slices.DeleteFunc(s.streams, func(entry *Stream) bool {
return entry.isDone()
})
return s.streams, nil
}
package authn
import (
"context"
grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
"github.com/openfga/openfga/internal/authn"
"github.com/openfga/openfga/pkg/authclaims"
)
func AuthFunc(authenticator authn.Authenticator) grpcauth.AuthFunc {
return func(ctx context.Context) (context.Context, error) {
claims, err := authenticator.Authenticate(ctx)
if err != nil {
return nil, err
}
return authclaims.ContextWithAuthClaims(ctx, claims), nil
}
}
package planner
import (
"math/rand"
"sync"
"sync/atomic"
"time"
)
// keyPlan manages the statistics for a single key and makes decisions about its resolvers.
// This struct is now entirely lock-free, using a sync.Map to manage its stats.
type keyPlan struct {
stats sync.Map // Stores map[string]*ThompsonStats
planner *Planner
// lastAccessed stores the UnixNano timestamp of the last access.
// Using atomic guarantees thread-safe updates without a mutex.
lastAccessed atomic.Int64
}
var _ Selector = (*keyPlan)(nil)
// touch updates the lastAccessed timestamp to the current time.
func (kp *keyPlan) touch() {
kp.lastAccessed.Store(time.Now().UnixNano())
}
// getOrCreateStats atomically retrieves or creates the ThompsonStats for a given resolver name.
// This avoids the allocation overhead of calling LoadOrStore directly on the hot path.
func (kp *keyPlan) getOrCreateStats(plan *PlanConfig) *ThompsonStats {
// Fast path: Try a simple load first. This avoids the allocation in the common case.
val, ok := kp.stats.Load(plan.Name)
if ok {
return val.(*ThompsonStats)
}
// Slow path: The stats don't exist. Create a new one.
newTS := NewThompsonStats(plan.InitialGuess, plan.Lambda, plan.Alpha, plan.Beta)
// Use LoadOrStore to handle the race where another goroutine might have created it
// in the time between our Load and now. The newTs object is only stored if
// no entry existed.
actual, _ := kp.stats.LoadOrStore(plan.Name, newTS)
return actual.(*ThompsonStats)
}
// Select implements the Thompson Sampling decision rule.
func (kp *keyPlan) Select(resolvers map[string]*PlanConfig) *PlanConfig {
kp.touch() // Mark this key as recently used.
rng := kp.planner.rngPool.Get().(*rand.Rand)
defer kp.planner.rngPool.Put(rng)
bestResolver := ""
var minSampledTime float64 = -1
for k, plan := range resolvers {
// Use the optimized helper method to get stats without unnecessary allocations.
ts := kp.getOrCreateStats(plan)
sampledTime := ts.Sample(rng)
if bestResolver == "" || sampledTime < minSampledTime {
minSampledTime = sampledTime
bestResolver = k
}
}
return resolvers[bestResolver]
}
// UpdateStats performs the Bayesian update for the given resolver's statistics.
func (kp *keyPlan) UpdateStats(plan *PlanConfig, duration time.Duration) {
kp.touch() // Mark this key as recently used.
// Use the optimized helper method to avoid allocations.
ts := kp.getOrCreateStats(plan)
ts.Update(duration)
}
package planner
import (
"math/rand"
"sync"
"time"
)
// Planner is the top-level entry point for creating and managing plans for different keys.
// It is safe for concurrent use and includes a background routine to evict old keys.
type Planner struct {
keys sync.Map // Stores *keyPlan, ensuring fine-grained concurrency per key.
evictionThreshold time.Duration
// Use a pool of RNGs to reduce allocation overhead and initialization cost on the hot path.
rngPool sync.Pool
wg sync.WaitGroup
stopCleanup chan struct{}
}
var _ Manager = (*Planner)(nil)
// Config holds configuration for the planner.
type Config struct {
EvictionThreshold time.Duration // How long a key can be unused before being evicted. (e.g., 30 * time.Minute)
CleanupInterval time.Duration // How often the planner checks for stale keys. (e.g., 5 * time.Minute)
}
// New creates a new Planner with the specified configuration and starts its cleanup routine.
func New(config *Config) *Planner {
p := &Planner{
evictionThreshold: config.EvictionThreshold,
stopCleanup: make(chan struct{}),
wg: sync.WaitGroup{},
}
p.rngPool.New = func() interface{} {
// Each new RNG is seeded to ensure different sequences.
return rand.New(rand.NewSource(time.Now().UnixNano()))
}
if config.EvictionThreshold > 0 && config.CleanupInterval > 0 {
p.startCleanupRoutine(config.CleanupInterval)
}
return p
}
func NewNoopPlanner() *Planner {
p := &Planner{
evictionThreshold: 0,
stopCleanup: make(chan struct{}),
wg: sync.WaitGroup{},
}
p.rngPool.New = func() interface{} {
// Each new RNG is seeded to ensure different sequences.
return rand.New(rand.NewSource(time.Now().UnixNano()))
}
return p
}
// GetPlanSelector retrieves the plan for a specific key, creating it if it doesn't exist.
func (p *Planner) GetPlanSelector(key string) Selector {
upsertPlan := &keyPlan{planner: p}
upsertPlan.touch()
kp, loaded := p.keys.LoadOrStore(key, upsertPlan)
plan := kp.(*keyPlan)
if loaded {
plan.touch() // Mark as accessed.
}
return plan
}
// startCleanupRoutine runs a background goroutine that periodically evicts stale keys.
func (p *Planner) startCleanupRoutine(interval time.Duration) {
ticker := time.NewTicker(interval)
p.wg.Add(1)
go func() {
for {
select {
case <-ticker.C:
p.evictStaleKeys()
case <-p.stopCleanup:
ticker.Stop()
p.wg.Done()
return
}
}
}()
}
// evictStaleKeys iterates over all keys and removes any that haven't been accessed
// within the evictionThreshold.
func (p *Planner) evictStaleKeys() {
evictionThresholdNano := p.evictionThreshold.Nanoseconds()
nowNano := time.Now().UnixNano()
// NOTE: Consider also bounding the total number of keys stored.
p.keys.Range(func(key, value interface{}) bool {
kp := value.(*keyPlan)
lastAccessed := kp.lastAccessed.Load()
if (nowNano - lastAccessed) > evictionThresholdNano {
p.keys.Delete(key)
}
return true // continue iteration
})
}
// Stop gracefully terminates the background cleanup goroutine.
func (p *Planner) Stop() {
close(p.stopCleanup)
p.wg.Wait()
}
package planner
import (
"math"
"math/rand"
"sync/atomic"
"time"
"unsafe"
"gonum.org/v1/gonum/stat/distuv"
)
// ThompsonStats holds the parameters for the Normal-gamma distribution,
// which models our belief about the performance (execution time) of a strategy.
type ThompsonStats struct {
params unsafe.Pointer // *samplingParams - atomic access
}
type samplingParams struct {
mu float64
lambda float64
alpha float64
beta float64
}
// Sample draws a random execution time from the learned distribution.
// This is the core of Thompson Sampling: we sample from our belief and act greedily on that sample.
func (ts *ThompsonStats) Sample(r *rand.Rand) float64 {
// Load parameters atomically for best performance
params := (*samplingParams)(atomic.LoadPointer(&ts.params))
// Fast path gamma sampling using acceptance-rejection
tau := ts.fastGammaSample(r, params.alpha, params.beta)
// Fast normal sampling
variance := 1.0 / (params.lambda * tau)
if variance <= 0 {
return params.mu
}
// Use standard normal * sqrt(variance) + mean for better performance
stdNormal := r.NormFloat64()
mean := params.mu + stdNormal*math.Sqrt(variance)
return mean
}
// fastGammaSample implements the highly efficient Marsaglia and Tsang acceptance-rejection
// method for generating gamma-distributed random variables for alpha >= 1.
// This avoids the overhead of the more general gonum library for our specific high-performance use case.
// See: G. Marsaglia and W. W. Tsang, "A simple method for generating gamma variables,"
// ACM Trans. Math. Softw. 26, 3 (Sept. 2000), 363-372.
func (ts *ThompsonStats) fastGammaSample(r *rand.Rand, alpha, beta float64) float64 {
// For alpha >= 1, use acceptance-rejection method (faster than gonum)
if alpha >= 1.0 {
d := alpha - 1.0/3.0
c := 1.0 / math.Sqrt(9.0*d)
for {
x := r.NormFloat64()
v := 1.0 + c*x
if v <= 0 {
continue
}
v = v * v * v
u := r.Float64()
if u < 1.0-0.0331*(x*x)*(x*x) {
return d * v / beta
}
if math.Log(u) < 0.5*x*x+d*(1.0-v+math.Log(v)) {
return d * v / beta
}
}
}
// Fallback to gonum for alpha < 1
return distuv.Gamma{Alpha: alpha, Beta: beta, Src: r}.Rand()
}
// Update performs a Bayesian update on the distribution's parameters
// using the new data point (the observed execution duration). It is the responsibility of the caller
// to enforce synchronization if multiple goroutines may call Update concurrently.
func (ts *ThompsonStats) Update(duration time.Duration) {
x := float64(duration.Nanoseconds()) / 1e6 // Convert to milliseconds with higher precision
for {
// 1. Atomically load the current parameters
oldPtr := atomic.LoadPointer(&ts.params)
currentParams := (*samplingParams)(oldPtr)
// 2. Calculate the new parameters based on the old ones
newLambda := currentParams.lambda + 1
newMu := (currentParams.lambda*currentParams.mu + x) / newLambda
newAlpha := currentParams.alpha + 0.5
diff := x - currentParams.mu
newBeta := currentParams.beta + (currentParams.lambda*diff*diff)/(2*newLambda)
newParams := &samplingParams{
mu: newMu,
lambda: newLambda,
alpha: newAlpha,
beta: newBeta,
}
// 3. Try to atomically swap the old pointer with the new one.
// If another goroutine changed the pointer in the meantime, this will fail,
// and we will loop again to retry the whole operation.
if atomic.CompareAndSwapPointer(&ts.params, oldPtr, unsafe.Pointer(newParams)) {
return
}
}
}
// NewThompsonStats creates a new stats object with a diffuse prior,
// representing our initial uncertainty about a strategy's performance.
func NewThompsonStats(initialGuess time.Duration, lambda, alpha, beta float64) *ThompsonStats {
initialMs := float64(initialGuess.Nanoseconds()) / 1e6
ts := &ThompsonStats{}
// Create the initial immutable parameter snapshot.
params := &samplingParams{
mu: initialMs,
lambda: lambda,
alpha: alpha,
beta: beta,
}
atomic.StorePointer(&ts.params, unsafe.Pointer(params))
return ts
}
package seq
import "iter"
// Sequence is a function that turns its input into an `iter.Seq[T]` that
// yields values in the order that they were provided to the function.
func Sequence[T any](items ...T) iter.Seq[T] {
return func(yield func(T) bool) {
for _, item := range items {
if !yield(item) {
return
}
}
}
}
// Flatten is a function that merges a set of provided `iter.Seq[T]`
// values into a single `iter.Seq[T]` value. The values of each input are
// yielded in the order yielded by each `iter.Seq[T]`, in the order provided
// to the function.
func Flatten[T any](seqs ...iter.Seq[T]) iter.Seq[T] {
return func(yield func(T) bool) {
for _, seq := range seqs {
for item := range seq {
if !yield(item) {
return
}
}
}
}
}
// Transform is a function that maps the values yielded by the input `seq`
// to values produced by the input function `fn`, and returns an `iter.Seq`
// that yields those new values.
func Transform[T any, U any](seq iter.Seq[T], fn func(T) U) iter.Seq[U] {
return func(yield func(U) bool) {
for item := range seq {
if !yield(fn(item)) {
return
}
}
}
}
// Filter is a function the yields only values for which the predicate
// returns `true`.
func Filter[T any](seq iter.Seq[T], fn func(T) bool) iter.Seq[T] {
return func(yield func(T) bool) {
for item := range seq {
if fn(item) {
if !yield(item) {
return
}
}
}
}
}
package shared
import (
"context"
"sync"
"golang.org/x/sync/singleflight"
"github.com/openfga/openfga/internal/cachecontroller"
"github.com/openfga/openfga/pkg/logger"
serverconfig "github.com/openfga/openfga/pkg/server/config"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/storagewrappers/sharediterator"
)
// SharedDatastoreResourcesOpt defines an option that can be used to change the behavior of SharedDatastoreResources
// instance.
type SharedDatastoreResourcesOpt func(*SharedDatastoreResources)
// WithLogger sets the logger for CachedDatastore.
func WithLogger(logger logger.Logger) SharedDatastoreResourcesOpt {
return func(scr *SharedDatastoreResources) {
scr.Logger = logger
}
}
// WithCacheController allows overriding the default cacheController created in NewSharedDatastoreResources().
func WithCacheController(cacheController cachecontroller.CacheController) SharedDatastoreResourcesOpt {
return func(scr *SharedDatastoreResources) {
scr.CacheController = cacheController
}
}
// WithShadowCacheController allows overriding the default shadow cacheController created in NewSharedDatastoreResources().
func WithShadowCacheController(cacheController cachecontroller.CacheController) SharedDatastoreResourcesOpt {
return func(scr *SharedDatastoreResources) {
scr.ShadowCacheController = cacheController
}
}
// SharedDatastoreResources contains resources that can be shared across Check requests.
type SharedDatastoreResources struct {
SingleflightGroup *singleflight.Group
WaitGroup *sync.WaitGroup
ServerCtx context.Context
CheckCache storage.InMemoryCache[any]
CacheController cachecontroller.CacheController
ShadowCheckCache storage.InMemoryCache[any]
ShadowCacheController cachecontroller.CacheController
Logger logger.Logger
SharedIteratorStorage *sharediterator.Storage
}
func NewSharedDatastoreResources(
sharedCtx context.Context,
sharedSf *singleflight.Group,
ds storage.OpenFGADatastore,
settings serverconfig.CacheSettings,
opts ...SharedDatastoreResourcesOpt,
) (*SharedDatastoreResources, error) {
s := &SharedDatastoreResources{
WaitGroup: &sync.WaitGroup{},
SingleflightGroup: sharedSf,
ServerCtx: sharedCtx,
CacheController: cachecontroller.NewNoopCacheController(),
Logger: logger.NewNoopLogger(),
SharedIteratorStorage: sharediterator.NewSharedIteratorDatastoreStorage(
sharediterator.WithSharedIteratorDatastoreStorageLimit(
int(settings.SharedIteratorLimit))),
}
if settings.ShouldCreateNewCache() {
var err error
s.CheckCache, err = storage.NewInMemoryLRUCache([]storage.InMemoryLRUCacheOpt[any]{
storage.WithMaxCacheSize[any](int64(settings.CheckCacheLimit)),
}...)
if err != nil {
return nil, err
}
}
if settings.ShouldCreateCacheController() {
s.CacheController = cachecontroller.NewCacheController(ds, s.CheckCache, settings.CacheControllerTTL, settings.CheckIteratorCacheTTL, cachecontroller.WithLogger(s.Logger))
}
// The default behavior is to use the same cache instance for both the
// check cache and the shadow cache. However, if the user opts in to use a
// separate cache instance for the shadow cache, we need to create new
// instances.
s.ShadowCheckCache = s.CheckCache
s.ShadowCacheController = s.CacheController
if settings.ShouldCreateShadowNewCache() {
var err error
s.ShadowCheckCache, err = storage.NewInMemoryLRUCache([]storage.InMemoryLRUCacheOpt[any]{
storage.WithMaxCacheSize[any](int64(settings.CheckCacheLimit)),
}...)
if err != nil {
return nil, err
}
}
if settings.ShouldCreateShadowCacheController() {
s.ShadowCacheController = cachecontroller.NewCacheController(ds, s.ShadowCheckCache, settings.CacheControllerTTL, settings.CheckIteratorCacheTTL, cachecontroller.WithLogger(s.Logger))
}
for _, opt := range opts {
opt(s)
}
return s, nil
}
func (s *SharedDatastoreResources) Close() {
// wait for any goroutines still in flight before
// closing the cache instance to avoid data races
s.WaitGroup.Wait()
if s.CheckCache != nil {
s.CheckCache.Stop()
}
if s.ShadowCheckCache != nil && s.CheckCache != s.ShadowCheckCache {
s.ShadowCheckCache.Stop()
}
}
package stack
import (
"fmt"
"strings"
)
// Stack is an implementation of a stack based on a linked list.
//
// *Important*: Each push() or pop() operation creates and returns a pointer to a new stack entirely to
// ensure thread safety.
type node[T any] struct {
value T
next *node[T]
}
type Stack[T any] *node[T]
func Push[T any](stack Stack[T], value T) Stack[T] {
return Stack[T](&node[T]{value: value, next: (*node[T])(stack)})
}
func Pop[T any](stack Stack[T]) (T, Stack[T]) {
return stack.value, Stack[T](stack.next)
}
func Peek[T any](stack Stack[T]) T {
return stack.value
}
func Len[T any](stack Stack[T]) int {
var ctr int
s := stack
for s != nil {
ctr++
s = s.next
}
return ctr
}
func String[T any](stack Stack[T]) string {
var val T
var sb strings.Builder
for stack != nil {
val, stack = Pop(stack)
sb.WriteString(fmt.Sprintf("%v", val))
}
return sb.String()
}
package threshold
import (
"context"
"github.com/openfga/openfga/internal/throttler"
"github.com/openfga/openfga/pkg/dispatch"
)
type Config struct {
Enabled bool
Throttler throttler.Throttler
Threshold uint32
MaxThreshold uint32
}
func ShouldThrottle(ctx context.Context, currentCount uint32, defaultThreshold uint32, maxThreshold uint32) bool {
threshold := defaultThreshold
if maxThreshold == 0 {
maxThreshold = defaultThreshold
}
thresholdInCtx := dispatch.ThrottlingThresholdFromContext(ctx)
if thresholdInCtx > 0 {
threshold = min(thresholdInCtx, maxThreshold)
}
return currentCount > threshold
}
//go:generate mockgen -source throttler.go -destination ../mocks/mock_throttler.go -package mocks
package throttler
import (
"context"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/openfga/openfga/internal/build"
"github.com/openfga/openfga/pkg/telemetry"
)
var (
throttlingDelayMsHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: "throttling_delay_ms",
Help: "Time spent waiting for dispatch throttling resolver",
Buckets: []float64{1, 3, 5, 10, 25, 50, 100, 1000, 5000}, // Milliseconds. Upper bound is config.UpstreamTimeout.
NativeHistogramBucketFactor: 1.1,
NativeHistogramMaxBucketNumber: 100,
NativeHistogramMinResetDuration: time.Hour,
}, []string{"grpc_service", "grpc_method", "throttler_name"})
)
type Throttler interface {
Close()
Throttle(context.Context)
}
type noopThrottler struct{}
var _ Throttler = (*noopThrottler)(nil)
func (r *noopThrottler) Throttle(ctx context.Context) {
}
func (r *noopThrottler) Close() {
}
func NewNoopThrottler() Throttler { return &noopThrottler{} }
// constantRateThrottler implements a throttling mechanism that can be used to control the rate of recursive resource consumption.
// Throttling will release the goroutines from the throttlingQueue based on the configured ticker.
type constantRateThrottler struct {
name string
ticker *time.Ticker
throttlingQueue chan struct{}
done chan struct{}
}
// NewConstantRateThrottler constructs a constantRateThrottler which can be used to control the rate of recursive resource consumption.
func NewConstantRateThrottler(frequency time.Duration, metricLabel string) Throttler {
return newConstantRateThrottler(frequency, metricLabel)
}
// Returns a constantRateThrottler instead of Throttler for testing purpose to be used internally.
func newConstantRateThrottler(frequency time.Duration, throttlerName string) *constantRateThrottler {
constantRateThrottler := &constantRateThrottler{
name: throttlerName,
ticker: time.NewTicker(frequency),
throttlingQueue: make(chan struct{}),
done: make(chan struct{}),
}
go constantRateThrottler.runTicker()
return constantRateThrottler
}
func (r *constantRateThrottler) nonBlockingSend(signalChan chan struct{}) {
select {
case signalChan <- struct{}{}:
// message sent
default:
// message dropped
}
}
func (r *constantRateThrottler) runTicker() {
for {
select {
case <-r.done:
return
case <-r.ticker.C:
r.nonBlockingSend(r.throttlingQueue)
}
}
}
func (r *constantRateThrottler) Close() {
r.done <- struct{}{}
r.ticker.Stop()
close(r.done)
close(r.throttlingQueue)
}
// Throttle provides a synchronous blocking mechanism that will block if the currentNumDispatch exceeds the configured dispatch threshold.
// It will block until a value is produced on the underlying throttling queue channel,
// which is produced by periodically sending a value on the channel based on the configured ticker frequency.
func (r *constantRateThrottler) Throttle(ctx context.Context) {
start := time.Now()
select {
case <-ctx.Done():
case <-r.throttlingQueue:
}
end := time.Now()
timeWaiting := end.Sub(start).Milliseconds()
rpcInfo := telemetry.RPCInfoFromContext(ctx)
throttlingDelayMsHistogram.WithLabelValues(
rpcInfo.Service,
rpcInfo.Method,
r.name,
).Observe(float64(timeWaiting))
}
// Package apimethod provides a type for the API grpc method names.
package apimethod
type APIMethod string
// String converts the APIMethod to its string representation.
func (a APIMethod) String() string {
return string(a)
}
// API methods.
const (
ReadAuthorizationModel APIMethod = "ReadAuthorizationModel"
ReadAuthorizationModels APIMethod = "ReadAuthorizationModels"
Read APIMethod = "Read"
Write APIMethod = "Write"
ListObjects APIMethod = "ListObjects"
StreamedListObjects APIMethod = "StreamedListObjects"
Check APIMethod = "Check"
BatchCheck APIMethod = "BatchCheck"
ListUsers APIMethod = "ListUsers"
WriteAssertions APIMethod = "WriteAssertions"
ReadAssertions APIMethod = "ReadAssertions"
WriteAuthorizationModel APIMethod = "WriteAuthorizationModel"
ListStores APIMethod = "ListStores"
CreateStore APIMethod = "CreateStore"
GetStore APIMethod = "GetStore"
DeleteStore APIMethod = "DeleteStore"
Expand APIMethod = "Expand"
ReadChanges APIMethod = "ReadChanges"
)
package utils
import (
"sort"
"strconv"
)
// Bucketize will put the value of a metric into the correct bucket, and return the label for it.
// It is expected that the buckets are already sorted in increasing order and non-empty.
func Bucketize(value uint, buckets []uint) string {
idx := sort.Search(len(buckets), func(i int) bool {
return value <= buckets[i]
})
if idx == len(buckets) {
return "+Inf"
}
return strconv.Itoa(int(buckets[idx]))
}
// LinearBuckets returns an evenly distributed range of buckets in the closed interval
// [min...max]. The min and max count toward the bucket count since they are included
// in the range.
func LinearBuckets(minValue, maxValue float64, count int) []float64 {
var buckets []float64
width := (maxValue - minValue) / float64(count-1)
for i := minValue; i <= maxValue; i += width {
buckets = append(buckets, i)
}
return buckets
}
package utils
import "iter"
// Filter functions similar to other language list filter functions.
// It accepts a generic slice and a predicate function to apply to each element,
// returning an iterator sequence containing only the elements for which the predicate returned true.
//
// Example filtering for even numbers:
//
// iter := Filter([]int{1, 2, 3, 4, 5}, func(n int) bool { return n%2 == 0})
// // To collect results: slices.Collect(iter) returns []int{2, 4}
func Filter[T any](s []T, predicate func(T) bool) iter.Seq[T] {
return func(yield func(T) bool) {
for _, item := range s {
if predicate(item) {
if !yield(item) {
// Stop if yield returns false (no more items)
return
}
}
}
}
}
package utils
import (
"math/rand"
"time"
)
func JitterDuration(baseDuration, maxJitter time.Duration) time.Duration {
if maxJitter <= 0 {
return baseDuration
}
jitter := time.Duration(rand.Int63n(int64(maxJitter)))
return baseDuration + jitter
}
package utils
// Reduce accepts a generic slice, an initializer value, and a function.
// It iterates over the slice applying the supplied function to the current accumulated value and each element in the slice,
// reducing the slice to a single value.
//
// Example reducing to a sum:
//
// Reduce([]int{1, 2, 3}, 0, func(accumulator int, currentValue int) int {
// return accumulator + currentValue
// })
//
// returns 6.
func Reduce[S ~[]E, E any, A any](s S, initializer A, f func(A, E) A) A {
i := initializer
for _, item := range s {
i = f(i, item)
}
return i
}
package validation
import (
"errors"
"fmt"
"reflect"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
// ValidateUserObjectRelation returns nil if the tuple is well-formed and valid according to the provided model.
func ValidateUserObjectRelation(typesys *typesystem.TypeSystem, tk *openfgav1.TupleKey) error {
if err := ValidateUser(typesys, tk.GetUser()); err != nil {
return err
}
if err := ValidateObject(typesys, tk); err != nil {
return err
}
if err := ValidateRelation(typesys, tk); err != nil {
return err
}
return nil
}
// ValidateTupleForWrite returns nil if a tuple is well formed and valid according to the provided model.
// It is a superset of ValidateUserObjectRelation and ValidateTupleForRead;
// ONLY meant to be used in Write and contextual tuples (since these mimic being written in the datastore).
func ValidateTupleForWrite(typesys *typesystem.TypeSystem, tk *openfgav1.TupleKey) error {
if err := ValidateUserObjectRelation(typesys, tk); err != nil {
return &tuple.InvalidTupleError{Cause: err, TupleKey: tk}
}
// now we assume our tuple is well-formed, it's time to check
// the tuple against other model and type-restriction constraints
return ValidateTupleForRead(typesys, tk)
}
// ValidateTupleForRead returns nil if a tuple is valid according to the provided model.
// It also validates TTU relations and type restrictions.
func ValidateTupleForRead(typesys *typesystem.TypeSystem, tk *openfgav1.TupleKey) error {
if err := validateTuplesetRestrictions(typesys, tk); err != nil {
return &tuple.InvalidTupleError{Cause: err, TupleKey: tk}
}
objectType := tuple.GetType(tk.GetObject())
relation := tk.GetRelation()
hasTypeInfo, err := typesys.HasTypeInfo(objectType, relation)
if err != nil {
return err
}
if hasTypeInfo {
err := validateTypeRestrictions(typesys, tk)
if err != nil {
return &tuple.InvalidTupleError{Cause: err, TupleKey: tk}
}
if err := validateCondition(typesys, tk); err != nil {
return err
}
}
return nil
}
// validateTuplesetRestrictions validates the provided TupleKey against tupleset restrictions.
//
// Given a rewrite definition such as 'viewer from parent', the 'parent' relation is known as the
// tupleset. This method ensures the following are *not* possible:
//
// 1. `document:1#parent@folder:1#parent` (cannot evaluate/assign a userset value to a tupleset relation)
// 2. `document:1#parent@*` (cannot evaluate/assign untyped wildcard to a tupleset relation (1.0 models))
// 3. `document:1#parent@folder:*` (cannot evaluate/assign typed wildcard to a tupleset relation (1.1. Models)).
func validateTuplesetRestrictions(typesys *typesystem.TypeSystem, tk *openfgav1.TupleKey) error {
objectType := tuple.GetType(tk.GetObject())
relation := tk.GetRelation()
isTupleset, err := typesys.IsTuplesetRelation(objectType, relation)
if err != nil {
return err
}
if !isTupleset {
return nil
}
rel, err := typesys.GetRelation(objectType, relation)
if err != nil {
return err
}
rewrite := rel.GetRewrite().GetUserset()
// tupleset relation involving a rewrite
if rewrite != nil && reflect.TypeOf(rewrite) != reflect.TypeOf(&openfgav1.Userset_This{}) {
return fmt.Errorf("unexpected rewrite encountered with tupleset relation '%s#%s'", objectType, relation)
}
user := tk.GetUser()
// tupleset relation involving a wildcard (covers the '*' and 'type:*' cases)
// should precede IsValidObject due to old model (1.0) support were wildcards didn't have type
if tuple.IsWildcard(user) {
return fmt.Errorf("unexpected wildcard relationship with tupleset relation '%s#%s'", objectType, relation)
}
// tupleset relation involving a userset (e.g. object#relation) or a user_id (e.g. not a valid object)
if !tuple.IsValidObject(user) {
return fmt.Errorf("unexpected user '%s' with tupleset relation '%s#%s'", user, objectType, relation)
}
return nil
}
// validateTypeRestrictions makes sure the type restrictions are enforced.
// 1. If the tuple is of the form doc:budget#reader@person:bob, then 'doc#reader' must allow type 'person'.
// 2. If the tuple is of the form doc:budget#reader@group:abc#member, then 'doc#reader' must allow 'group#member'.
// 3. If the tuple is of the form doc:budget#reader@person:*, we allow it only if 'doc#reader' allows the typed wildcard 'person:*'.
func validateTypeRestrictions(typesys *typesystem.TypeSystem, tk *openfgav1.TupleKey) error {
objectType := tuple.GetType(tk.GetObject()) // e.g. "doc"
userType, _ := tuple.SplitObject(tk.GetUser()) // e.g. (person, bob) or (group, abc#member) or ("", person:*)
_, userRel := tuple.SplitObjectRelation(tk.GetUser()) // e.g. (person:bob, "") or (group:abc, member) or (person:*, "")
typeDefinitionForObject, ok := typesys.GetTypeDefinition(objectType)
if !ok {
return fmt.Errorf("type '%s' does not exist in the authorization model", objectType)
}
relationsForObject := typeDefinitionForObject.GetMetadata().GetRelations()
relationInformation := relationsForObject[tk.GetRelation()]
user := tk.GetUser()
if tuple.IsObjectRelation(user) {
// case 2 documented above
for _, typeInformation := range relationInformation.GetDirectlyRelatedUserTypes() {
if typeInformation.GetType() == userType && typeInformation.GetRelation() == userRel {
return nil
}
}
return fmt.Errorf("'%s#%s' is not an allowed type restriction for '%s#%s'", userType, userRel, objectType, tk.GetRelation())
}
if tuple.IsTypedWildcard(user) {
// case 3 documented above
for _, typeInformation := range relationInformation.GetDirectlyRelatedUserTypes() {
if typeInformation.GetType() == userType && typeInformation.GetWildcard() != nil {
return nil
}
}
return fmt.Errorf("the typed wildcard '%s' is not an allowed type restriction for '%s#%s'", user, objectType, tk.GetRelation())
}
// the user must be an object (case 1), so check directly against the objectType
for _, typeInformation := range relationInformation.GetDirectlyRelatedUserTypes() {
if typeInformation.GetType() == userType && typeInformation.GetWildcard() == nil && typeInformation.GetRelation() == "" {
return nil
}
}
return fmt.Errorf("type '%s' is not an allowed type restriction for '%s#%s'", userType, objectType, tk.GetRelation())
}
// validateCondition returns an error if the condition of the tuple is required but not present,
// or if the tuple provides a condition but it is invalid according to the model.
func validateCondition(typesys *typesystem.TypeSystem, tk *openfgav1.TupleKey) error {
objectType := tuple.GetType(tk.GetObject())
userType := tuple.GetType(tk.GetUser())
userRelation := tuple.GetRelation(tk.GetUser())
typeRestrictions, err := typesys.GetDirectlyRelatedUserTypes(objectType, tk.GetRelation())
if err != nil {
return err
}
if tk.GetCondition() == nil {
for _, directlyRelatedType := range typeRestrictions {
if directlyRelatedType.GetCondition() != "" {
continue
}
if directlyRelatedType.GetType() != userType {
continue
}
if directlyRelatedType.GetRelationOrWildcard() != nil {
if directlyRelatedType.GetRelation() != "" && directlyRelatedType.GetRelation() != userRelation {
continue
}
if directlyRelatedType.GetWildcard() != nil && !tuple.IsTypedWildcard(tk.GetUser()) {
continue
}
} else if tuple.IsTypedWildcard(tk.GetUser()) {
// This is a wildcard tuple but the directlyRelatedType tuple is not for wildcard.
continue
}
return nil
}
return &tuple.InvalidConditionalTupleError{
Cause: fmt.Errorf("condition is missing"), TupleKey: tk,
}
}
condition, ok := typesys.GetConditions()[tk.GetCondition().GetName()]
if !ok {
return &tuple.InvalidConditionalTupleError{
Cause: fmt.Errorf("undefined condition"), TupleKey: tk,
}
}
validCondition := false
for _, directlyRelatedType := range typeRestrictions {
if directlyRelatedType.GetType() == userType && directlyRelatedType.GetCondition() == tk.GetCondition().GetName() {
validCondition = true
break
}
}
if !validCondition {
return &tuple.InvalidConditionalTupleError{
Cause: fmt.Errorf("invalid condition for type restriction"), TupleKey: tk,
}
}
contextStruct := tk.GetCondition().GetContext()
contextFieldMap := contextStruct.GetFields()
typedParams, err := condition.CastContextToTypedParameters(contextFieldMap)
if err != nil {
return &tuple.InvalidConditionalTupleError{
Cause: err, TupleKey: tk,
}
}
for key := range contextFieldMap {
_, ok := typedParams[key]
if !ok {
return &tuple.InvalidConditionalTupleError{
Cause: fmt.Errorf("found invalid context parameter: %s", key),
TupleKey: tk,
}
}
}
return nil
}
// FilterInvalidTuples filters out tuples that aren't valid according to the provided model.
func FilterInvalidTuples(typesys *typesystem.TypeSystem) storage.TupleKeyFilterFunc {
return func(tupleKey *openfgav1.TupleKey) bool {
err := ValidateTupleForRead(typesys, tupleKey)
return err == nil
}
}
// ValidateObject validates the provided object string 'type:id' against the provided
// model. An object is considered valid if it validates against one of the type
// definitions included in the provided model.
func ValidateObject(typesys *typesystem.TypeSystem, tk *openfgav1.TupleKey) error {
object := tk.GetObject()
if !tuple.IsValidObject(object) {
return fmt.Errorf("invalid 'object' field format")
}
objectType, id := tuple.SplitObject(object)
if id == tuple.Wildcard {
return fmt.Errorf("the 'object' field cannot reference a typed wildcard")
}
_, ok := typesys.GetTypeDefinition(objectType)
if !ok {
return &tuple.TypeNotFoundError{TypeName: objectType}
}
return nil
}
// ValidateRelation validates the relation on the provided objectType against the given model.
// A relation is valid if it is defined as a relation for the type definition of the given
// objectType.
func ValidateRelation(typesys *typesystem.TypeSystem, tk *openfgav1.TupleKey) error {
object := tk.GetObject()
relation := tk.GetRelation()
// TODO: determine if we can avoid this since just checking for existence in the typesystem is enough
if !tuple.IsValidRelation(relation) {
return fmt.Errorf("the 'relation' field is malformed")
}
objectType := tuple.GetType(object)
_, err := typesys.GetRelation(objectType, relation)
if err != nil {
if errors.Is(err, typesystem.ErrObjectTypeUndefined) {
return &tuple.TypeNotFoundError{TypeName: objectType}
}
if errors.Is(err, typesystem.ErrRelationUndefined) {
return &tuple.RelationNotFoundError{Relation: relation, TypeName: objectType}
}
return err
}
return nil
}
// ValidateUser validates the 'user' string provided by validating that it meets
// the model constraints. For 1.0 and 1.1 models if the user field is a userset
// value, then the objectType and relation must be defined. For 1.1 models the
// user field must either be a userset or an object, and if it's an object we
// verify the objectType is defined in the model.
func ValidateUser(typesys *typesystem.TypeSystem, user string) error {
if !tuple.IsValidUser(user) {
return fmt.Errorf("the 'user' field is malformed")
}
isValidObject := tuple.IsValidObject(user)
isValidUserset := tuple.IsObjectRelation(user)
userObject, userRelation := tuple.SplitObjectRelation(user)
userObjectType := tuple.GetType(userObject)
schemaVersion := typesys.GetSchemaVersion()
if typesystem.IsSchemaVersionSupported(schemaVersion) {
if !isValidObject && !isValidUserset {
return fmt.Errorf("the 'user' field must be an object (e.g. document:1) or an 'object#relation' or a typed wildcard (e.g. group:*)")
}
_, ok := typesys.GetTypeDefinition(userObjectType)
if !ok {
return &tuple.TypeNotFoundError{TypeName: userObjectType}
}
}
// for 1.0 and 1.1 models if the 'user' field is a userset then we validate the 'object#relation'
// by making sure the user objectType and relation are defined in the model.
if isValidUserset {
_, err := typesys.GetRelation(userObjectType, userRelation)
if err != nil {
if errors.Is(err, typesystem.ErrObjectTypeUndefined) {
return &tuple.TypeNotFoundError{TypeName: userObjectType}
}
if errors.Is(err, typesystem.ErrRelationUndefined) {
return &tuple.RelationNotFoundError{Relation: userRelation, TypeName: userObjectType}
}
}
}
return nil
}
package authclaims
import (
"context"
)
type ctxKey string
// authClaimsContextKey is the key to store the auth claims in the context.
const authClaimsContextKey = ctxKey("auth-claims")
// skipAuthz is the key to store whether to skip authz check in the context.
const skipAuthz = ctxKey("skip-authz-key")
// AuthClaims contains claims that are included in OIDC standard claims. https://openid.net/specs/openid-connect-core-1_0.html#IDToken
type AuthClaims struct {
Subject string
Scopes map[string]bool
ClientID string
}
// ContextWithAuthClaims creates a copy of the parent context with the provided AuthClaims.
func ContextWithAuthClaims(parent context.Context, claims *AuthClaims) context.Context {
return context.WithValue(parent, authClaimsContextKey, claims)
}
// AuthClaimsFromContext extracts the AuthClaims from the provided ctx (if any).
func AuthClaimsFromContext(ctx context.Context) (*AuthClaims, bool) {
claims, ok := ctx.Value(authClaimsContextKey).(*AuthClaims)
if !ok {
return nil, false
}
return claims, true
}
// ContextWithSkipAuthzCheck creates a copy of the parent context and attaches whether to skip authz check to.
func ContextWithSkipAuthzCheck(parent context.Context, skipAuthzCheck bool) context.Context {
return context.WithValue(parent, skipAuthz, skipAuthzCheck)
}
// SkipAuthzCheckFromContext returns whether the authorize check can be skipped.
func SkipAuthzCheckFromContext(ctx context.Context) bool {
isSkipped, ok := ctx.Value(skipAuthz).(bool)
return isSkipped && ok
}
package dispatch
import "context"
type dispatchThrottlingThresholdType uint32
const (
dispatchThrottlingThreshold dispatchThrottlingThresholdType = iota
)
// ContextWithThrottlingThreshold will save the dispatch throttling threshold in context.
// This can be used to set per request dispatch throttling when OpenFGA is used as library in another Go project.
func ContextWithThrottlingThreshold(ctx context.Context, threshold uint32) context.Context {
return context.WithValue(ctx, dispatchThrottlingThreshold, threshold)
}
// ThrottlingThresholdFromContext returns the dispatch throttling threshold saved in context
// Return 0 if not found.
func ThrottlingThresholdFromContext(ctx context.Context) uint32 {
thresholdInContext := ctx.Value(dispatchThrottlingThreshold)
if thresholdInContext != nil {
thresholdInInt, ok := thresholdInContext.(uint32)
if ok {
return thresholdInInt
}
}
return 0
}
package encoder
import (
"encoding/base64"
)
// Ensure Base64Encoder implements the Encoder interface.
var _ Encoder = (*Base64Encoder)(nil)
// Base64Encoder adheres to the Encoder interface, utilizing
// the encoding/base64 encoding strategy for base64 encoding.
type Base64Encoder struct{}
// NewBase64Encoder creates a new instance of the Encoder interface that employs
// the base64 encoding strategy provided by the encoding/base64 package.
func NewBase64Encoder() *Base64Encoder {
return &Base64Encoder{}
}
// Decode performs base64 URL decoding on the input string using the encoding/base64 package.
func (e *Base64Encoder) Decode(s string) ([]byte, error) {
return base64.URLEncoding.DecodeString(s)
}
// Encode performs base64 URL encoding on the input byte slice using the encoding/base64 package.
func (e *Base64Encoder) Encode(data []byte) (string, error) {
return base64.URLEncoding.EncodeToString(data), nil
}
//go:generate mockgen -source encoder.go -destination ../../internal/mocks/mock_encoder.go -package mocks OpenFGADatastore
package encoder
// Ensure NoopEncoder implements the Encoder interface.
var _ Encoder = (*NoopEncoder)(nil)
// Encoder is an interface that defines methods for decoding and encoding data.
type Encoder interface {
Decode(string) ([]byte, error)
Encode([]byte) (string, error)
}
// NoopEncoder is an implementation of the Encoder interface
// that performs no actual encoding or decoding.
type NoopEncoder struct{}
// Decode returns the input string as a byte slice.
func (e NoopEncoder) Decode(s string) ([]byte, error) {
return []byte(s), nil
}
// Encode returns the input byte slice as a string.
func (e NoopEncoder) Encode(data []byte) (string, error) {
return string(data), nil
}
package encoder
import (
"github.com/openfga/openfga/pkg/encrypter"
)
// Ensure TokenEncoder implements the Encoder interface.
var _ Encoder = (*TokenEncoder)(nil)
// TokenEncoder combines an encrypter and an encoder to provide
// functionality for encoding and decoding tokens.
type TokenEncoder struct {
encrypter encrypter.Encrypter
encoder Encoder
}
// NewTokenEncoder constructs a TokenEncoder with the provided encrypter and encoder.
func NewTokenEncoder(encrypter encrypter.Encrypter, encoder Encoder) *TokenEncoder {
return &TokenEncoder{
encrypter: encrypter,
encoder: encoder,
}
}
// Decode first decodes the input string using its internal decoder,
// and subsequently decrypts the resulting data using its encrypter.
func (e *TokenEncoder) Decode(s string) ([]byte, error) {
decoded, err := e.encoder.Decode(s)
if err != nil {
return nil, err
}
return e.encrypter.Decrypt(decoded)
}
// Encode first encrypts the provided data using its internal encrypter,
// and then encodes the resulting encrypted data using its encoder.
func (e *TokenEncoder) Encode(data []byte) (string, error) {
encrypted, err := e.encrypter.Encrypt(data)
if err != nil {
return "", err
}
return e.encoder.Encode(encrypted)
}
//go:generate mockgen -source token_serializer.go -destination ../../internal/mocks/mock_token_serializer.go -package mocks OpenFGADatastore
package encoder
import (
"errors"
"fmt"
"strings"
"github.com/openfga/openfga/pkg/storage"
)
type ContinuationTokenSerializer interface {
// Serialize serializes the continuation token into a format readable by ReadChanges
Serialize(ulid string, objType string) (token []byte, err error)
// Deserialize deserializes the continuation token into a format readable by ReadChanges
Deserialize(token string) (ulid string, objType string, err error)
}
// StringContinuationTokenSerializer is a ContinuationTokenSerializer that serializes the continuation token as a string.
type StringContinuationTokenSerializer struct{}
// NewStringContinuationTokenSerializer returns a new instance of StringContinuationTokenSerializer.
// Serializes the continuation token into a string, as ulid & type concatenated by a pipe.
func NewStringContinuationTokenSerializer() ContinuationTokenSerializer {
return &StringContinuationTokenSerializer{}
}
// Serialize serializes the continuation token into a string, as ulid & type concatenated by a pipe.
func (ts *StringContinuationTokenSerializer) Serialize(ulid string, objType string) ([]byte, error) {
if ulid == "" {
return nil, errors.New("empty ulid provided for continuation token")
}
return []byte(fmt.Sprintf("%s|%s", ulid, objType)), nil
}
// Deserialize deserializes the continuation token from a string, as ulid & type concatenated by a pipe.
func (ts *StringContinuationTokenSerializer) Deserialize(continuationToken string) (ulid string, objType string, err error) {
if !strings.Contains(continuationToken, "|") {
return "", "", storage.ErrInvalidContinuationToken
}
tokenParts := strings.Split(continuationToken, "|")
return tokenParts[0], tokenParts[1], nil
}
package encrypter
// Ensure NoopEncrypter implements the Encrypter interface.
var _ Encrypter = (*NoopEncrypter)(nil)
// Encrypter is an interface that defines methods for encrypting and decrypting data.
type Encrypter interface {
Decrypt([]byte) ([]byte, error)
Encrypt([]byte) ([]byte, error)
}
// NoopEncrypter is an implementation of the Encrypter interface
// that performs no actual encryption or decryption.
type NoopEncrypter struct{}
// NewNoopEncrypter creates a new instance of NoopEncrypter.
func NewNoopEncrypter() *NoopEncrypter {
return &NoopEncrypter{}
}
// Decrypt returns the input byte slice as is.
func (e *NoopEncrypter) Decrypt(data []byte) ([]byte, error) {
return data, nil
}
// Encrypt returns the input byte slice as is.
func (e *NoopEncrypter) Encrypt(data []byte) ([]byte, error) {
return data, nil
}
package encrypter
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"errors"
"io"
)
// Ensure GCMEncrypter implements the Encrypter interface.
var _ Encrypter = (*GCMEncrypter)(nil)
// GCMEncrypter is an implementation of the Encrypter interface
// that uses the AES-GCM encryption algorithm.
type GCMEncrypter struct {
cipherMode cipher.AEAD
}
// NewGCMEncrypter creates a new instance of GCMEncrypter with the provided key.
// It initializes the AES-GCM cipher mode for encryption and decryption.
func NewGCMEncrypter(key string) (*GCMEncrypter, error) {
c, err := aes.NewCipher(create32ByteKey(key))
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(c)
if err != nil {
return nil, err
}
return &GCMEncrypter{cipherMode: gcm}, nil
}
// Decrypt decrypts an AES-GCM encrypted byte array.
func (e *GCMEncrypter) Decrypt(data []byte) ([]byte, error) {
if len(data) == 0 {
return data, nil
}
nonceSize := e.cipherMode.NonceSize()
if len(data) < nonceSize {
return nil, errors.New("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
return e.cipherMode.Open(nil, nonce, ciphertext, nil)
}
// Encrypt encrypts the given byte array using the AES-GCM block cipher.
func (e *GCMEncrypter) Encrypt(data []byte) ([]byte, error) {
if len(data) == 0 {
return data, nil
}
nonce := make([]byte, e.cipherMode.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
return e.cipherMode.Seal(nonce, nonce, data, nil), nil
}
// create32ByteKey creates a 32-byte key by taking the
// hex representation of the SHA-256 hash of a string.
func create32ByteKey(s string) []byte {
sum := sha256.Sum256([]byte(s))
return sum[:]
}
package featureflags
type Client interface {
Boolean(flagName string, storeID string) bool
}
type defaultClient struct {
flags map[string]any
}
// NewDefaultClient creates a default feature flag client which takes in a static list of enabled feature flag names
// and stores them as keys in a map.
func NewDefaultClient(flags []string) Client {
enabledFlags := make(map[string]any, len(flags))
for _, flag := range flags {
enabledFlags[flag] = struct{}{}
}
return &defaultClient{
flags: enabledFlags,
}
}
func (c *defaultClient) Boolean(flagName string, storeID string) bool {
_, ok := c.flags[flagName]
return ok
}
type hardcodedBooleanClient struct {
result bool // this client will always return this result
}
// NewHardcodedBooleanClient creates a hardcodedBooleanClient which always returns the value of `result` it's given.
// The hardcodedBooleanClient is used in testing and in shadow code paths where we want to force enable/disable a feature.
func NewHardcodedBooleanClient(result bool) Client {
return &hardcodedBooleanClient{result: result}
}
func (h *hardcodedBooleanClient) Boolean(flagName string, _ string) bool {
return h.result
}
func NewNoopFeatureFlagClient() Client {
return NewHardcodedBooleanClient(false)
}
package gateway
import (
"context"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/openfga/openfga/pkg/logger"
)
// Transport is the interface to work with the transport layer.
type Transport interface {
// SetHeader sets a response header with a key and a value.
// It should not be called after a response has been sent.
SetHeader(ctx context.Context, key, value string)
}
// NoopTransport defines a no-op transport.
type NoopTransport struct {
}
var _ Transport = (*NoopTransport)(nil)
func NewNoopTransport() *NoopTransport {
return &NoopTransport{}
}
func (n *NoopTransport) SetHeader(_ context.Context, key, value string) {
}
// RPCTransport defines a transport for gRPC.
type RPCTransport struct {
logger logger.Logger
}
var _ Transport = (*RPCTransport)(nil)
// NewRPCTransport returns a transport for gRPC.
func NewRPCTransport(l logger.Logger) *RPCTransport {
return &RPCTransport{logger: l}
}
// SetHeader tries to set a header. If an error occurred, it logs an error.
func (g *RPCTransport) SetHeader(ctx context.Context, key, value string) {
if err := grpc.SetHeader(ctx, metadata.Pairs(key, value)); err != nil {
g.logger.ErrorWithContext(
ctx,
"failed to set grpc header",
zap.Error(err),
zap.String("header", key),
)
}
}
package logger
import (
"context"
"fmt"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"github.com/openfga/openfga/internal/build"
)
type Logger interface {
// These are ops that call directly to the actual zap implementation
Debug(string, ...zap.Field)
Info(string, ...zap.Field)
Warn(string, ...zap.Field)
Error(string, ...zap.Field)
Panic(string, ...zap.Field)
Fatal(string, ...zap.Field)
With(...zap.Field) Logger
// These are the equivalent logger function but with context provided
DebugWithContext(context.Context, string, ...zap.Field)
InfoWithContext(context.Context, string, ...zap.Field)
WarnWithContext(context.Context, string, ...zap.Field)
ErrorWithContext(context.Context, string, ...zap.Field)
PanicWithContext(context.Context, string, ...zap.Field)
FatalWithContext(context.Context, string, ...zap.Field)
}
// NewNoopLogger provides a noop logger.
func NewNoopLogger() *ZapLogger {
return &ZapLogger{
zap.NewNop(),
}
}
// ZapLogger is an implementation of Logger that uses the uber/zap logger underneath.
// It provides additional methods such as ones that logs based on context.
type ZapLogger struct {
*zap.Logger
}
var _ Logger = (*ZapLogger)(nil)
// With creates a child logger and adds structured context to it. Fields added
// to the child don't affect the parent, and vice versa. Any fields that
// require evaluation (such as Objects) are evaluated upon invocation of With.
func (l *ZapLogger) With(fields ...zap.Field) Logger {
return &ZapLogger{l.Logger.With(fields...)}
}
func (l *ZapLogger) Debug(msg string, fields ...zap.Field) {
l.Logger.Debug(msg, fields...)
}
func (l *ZapLogger) Info(msg string, fields ...zap.Field) {
l.Logger.Info(msg, fields...)
}
func (l *ZapLogger) Warn(msg string, fields ...zap.Field) {
l.Logger.Warn(msg, fields...)
}
func (l *ZapLogger) Error(msg string, fields ...zap.Field) {
l.Logger.Error(msg, fields...)
}
func (l *ZapLogger) Panic(msg string, fields ...zap.Field) {
l.Logger.Panic(msg, fields...)
}
func (l *ZapLogger) Fatal(msg string, fields ...zap.Field) {
l.Logger.Fatal(msg, fields...)
}
func (l *ZapLogger) DebugWithContext(ctx context.Context, msg string, fields ...zap.Field) {
l.Logger.Debug(msg, fields...)
}
func (l *ZapLogger) InfoWithContext(ctx context.Context, msg string, fields ...zap.Field) {
l.Logger.Info(msg, fields...)
}
func (l *ZapLogger) WarnWithContext(ctx context.Context, msg string, fields ...zap.Field) {
l.Logger.Warn(msg, fields...)
}
func (l *ZapLogger) ErrorWithContext(ctx context.Context, msg string, fields ...zap.Field) {
fields = append(fields, ctxzap.TagsToFields(ctx)...)
l.Logger.Error(msg, fields...)
}
func (l *ZapLogger) PanicWithContext(ctx context.Context, msg string, fields ...zap.Field) {
l.Logger.Panic(msg, fields...)
}
func (l *ZapLogger) FatalWithContext(ctx context.Context, msg string, fields ...zap.Field) {
l.Logger.Fatal(msg, fields...)
}
// OptionsLogger Implements options for logger.
type OptionsLogger struct {
format string
level string
timestampFormat string
outputPaths []string
}
type OptionLogger func(ol *OptionsLogger)
func WithFormat(format string) OptionLogger {
return func(ol *OptionsLogger) {
ol.format = format
}
}
func WithLevel(level string) OptionLogger {
return func(ol *OptionsLogger) {
ol.level = level
}
}
func WithTimestampFormat(timestampFormat string) OptionLogger {
return func(ol *OptionsLogger) {
ol.timestampFormat = timestampFormat
}
}
// WithOutputPaths sets a list of URLs or file paths to write logging output to.
//
// URLs with the "file" scheme must use absolute paths on the local filesystem.
// No user, password, port, fragments, or query parameters are allowed, and the
// hostname must be empty or "localhost".
//
// Since it's common to write logs to the local filesystem, URLs without a scheme
// (e.g., "/var/log/foo.log") are treated as local file paths. Without a scheme,
// the special paths "stdout" and "stderr" are interpreted as os.Stdout and os.Stderr.
// When specified without a scheme, relative file paths also work.
//
// Defaults to "stdout".
func WithOutputPaths(paths ...string) OptionLogger {
return func(ol *OptionsLogger) {
ol.outputPaths = paths
}
}
func NewLogger(options ...OptionLogger) (*ZapLogger, error) {
logOptions := &OptionsLogger{
level: "info",
format: "text",
timestampFormat: "ISO8601",
outputPaths: []string{"stdout"},
}
for _, opt := range options {
opt(logOptions)
}
if logOptions.level == "none" {
return NewNoopLogger(), nil
}
level, err := zap.ParseAtomicLevel(logOptions.level)
if err != nil {
return nil, fmt.Errorf("unknown log level: %s, error: %w", logOptions.level, err)
}
cfg := zap.NewProductionConfig()
cfg.Level = level
cfg.OutputPaths = logOptions.outputPaths
cfg.EncoderConfig.TimeKey = "timestamp"
cfg.EncoderConfig.CallerKey = "" // remove the "caller" field
cfg.DisableStacktrace = true
if logOptions.format == "text" {
cfg.Encoding = "console"
cfg.DisableCaller = true
cfg.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
cfg.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
} else { // Json
cfg.EncoderConfig.EncodeTime = zapcore.EpochTimeEncoder // default in json for backward compatibility
if logOptions.timestampFormat == "ISO8601" {
cfg.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
}
}
log, err := cfg.Build()
if err != nil {
return nil, err
}
if logOptions.format == "json" {
log = log.With(zap.String("build.version", build.Version), zap.String("build.commit", build.Commit))
}
return &ZapLogger{log}, nil
}
func MustNewLogger(logFormat, logLevel, logTimestampFormat string) *ZapLogger {
logger, err := NewLogger(
WithFormat(logFormat),
WithLevel(logLevel),
WithTimestampFormat(logTimestampFormat))
if err != nil {
panic(err)
}
return logger
}
package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/textproto"
"strconv"
"strings"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"google.golang.org/grpc/grpclog"
"google.golang.org/protobuf/proto"
"github.com/openfga/openfga/pkg/server/errors"
)
// XHttpCode is used to set the header for the response HTTP code.
const XHttpCode = "x-http-code"
// HTTPResponseModifier is a helper function designed to modify the status code in the context of HTTP responses.
func HTTPResponseModifier(ctx context.Context, w http.ResponseWriter, p proto.Message) error {
md, ok := runtime.ServerMetadataFromContext(ctx)
if !ok {
return nil
}
// Set http status code.
if vals := md.HeaderMD.Get(XHttpCode); len(vals) > 0 {
code, err := strconv.Atoi(vals[0])
if err != nil {
return err
}
// Delete the headers to not expose any grpc-metadata in http response.
delete(md.HeaderMD, XHttpCode)
delete(w.Header(), "Grpc-Metadata-X-Http-Code")
w.WriteHeader(code)
}
return nil
}
func requestAcceptsTrailers(req *http.Request) bool {
te := req.Header.Get("TE")
return strings.Contains(strings.ToLower(te), "trailers")
}
func handleForwardResponseTrailerHeader(w http.ResponseWriter, md runtime.ServerMetadata) {
for k := range md.TrailerMD {
tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", runtime.MetadataTrailerPrefix, k))
w.Header().Add("Trailer", tKey)
}
}
func handleForwardResponseTrailer(w http.ResponseWriter, md runtime.ServerMetadata) {
for k, vs := range md.TrailerMD {
tKey := fmt.Sprintf("%s%s", runtime.MetadataTrailerPrefix, k)
for _, v := range vs {
w.Header().Add(tKey, v)
}
}
}
// CustomHTTPErrorHandler handles custom error objects in the context of HTTP requests.
// It is similar to [runtime.DefaultHTTPErrorHandler] but accepts an [*errors.EncodedError] object.
func CustomHTTPErrorHandler(ctx context.Context, w http.ResponseWriter, r *http.Request, err *errors.EncodedError) {
// Convert as error object.
pb := err.ActualError
w.Header().Del("Trailer")
w.Header().Del("Transfer-Encoding")
w.Header().Set("Content-Type", "application/json")
buf := bytes.NewBuffer([]byte{})
jsonEncoder := json.NewEncoder(buf)
jsonEncoder.SetEscapeHTML(false)
if err := jsonEncoder.Encode(pb); err != nil {
grpclog.Errorf("failed to json encode the protobuf error '%v'", pb)
}
md, ok := runtime.ServerMetadataFromContext(ctx)
if !ok {
grpclog.Infof("Failed to extract ServerMetadata from context")
}
for k, val := range md.HeaderMD {
for _, individualVal := range val {
if k != "content-type" {
w.Header().Set(k, individualVal)
}
}
}
// RFC 7230 https://tools.ietf.org/html/rfc7230#section-4.1.2
// Unless the request includes a TE header field indicating "trailers"
// is acceptable, as described in Section 4.3, a server SHOULD NOT
// generate trailer fields that it believes are necessary for the user
// agent to receive.
doForwardTrailers := requestAcceptsTrailers(r)
if doForwardTrailers {
handleForwardResponseTrailerHeader(w, md)
w.Header().Set("Transfer-Encoding", "chunked")
}
st := err.HTTPStatusCode
w.WriteHeader(st)
if _, err := w.Write(buf.Bytes()); err != nil { // nosemgrep: no-direct-write-to-responsewriter
grpclog.Infof("Failed to write response: %v", err)
}
if doForwardTrailers {
handleForwardResponseTrailer(w, md)
}
}
package logging
import (
"context"
"encoding/json"
"errors"
"strconv"
"time"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/openfga/openfga/pkg/logger"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
)
const (
grpcServiceKey = "grpc_service"
grpcMethodKey = "grpc_method"
grpcTypeKey = "grpc_type"
grpcCodeKey = "grpc_code"
requestIDKey = "request_id"
traceIDKey = "trace_id"
rawRequestKey = "raw_request"
rawResponseKey = "raw_response"
internalErrorKey = "internal_error"
grpcReqCompleteKey = "grpc_req_complete"
userAgentKey = "user_agent"
queryDurationKey = "query_duration_ms"
gatewayUserAgentHeader string = "grpcgateway-user-agent"
userAgentHeader string = "user-agent"
healthCheckService string = "grpc.health.v1.Health"
)
// NewLoggingInterceptor creates a new logging interceptor for gRPC unary server requests.
func NewLoggingInterceptor(logger logger.Logger) grpc.UnaryServerInterceptor {
return interceptors.UnaryServerInterceptor(reportable(logger))
}
// NewStreamingLoggingInterceptor creates a new streaming logging interceptor for gRPC stream server requests.
func NewStreamingLoggingInterceptor(logger logger.Logger) grpc.StreamServerInterceptor {
return interceptors.StreamServerInterceptor(reportable(logger))
}
type reporter struct {
ctx context.Context
logger logger.Logger
fields []zap.Field
protomarshaler protojson.MarshalOptions
serviceName string
}
// PostCall is invoked after all PostMsgSend operations.
func (r *reporter) PostCall(err error, rpcDuration time.Duration) {
rpcDurationMs := strconv.FormatInt(rpcDuration.Milliseconds(), 10)
r.fields = append(r.fields, zap.String(queryDurationKey, rpcDurationMs))
r.fields = append(r.fields, ctxzap.TagsToFields(r.ctx)...)
code := serverErrors.ConvertToEncodedErrorCode(status.Convert(err))
r.fields = append(r.fields, zap.Int32(grpcCodeKey, code))
if err != nil {
var internalError serverErrors.InternalError
if errors.As(err, &internalError) {
r.fields = append(r.fields, zap.String(internalErrorKey, internalError.Unwrap().Error()))
r.logger.Error(err.Error(), r.fields...)
} else {
r.fields = append(r.fields, zap.Error(err))
r.logger.Info(grpcReqCompleteKey, r.fields...)
}
return
}
if r.serviceName == healthCheckService {
r.logger.Debug(grpcReqCompleteKey, r.fields...)
} else {
r.logger.Info(grpcReqCompleteKey, r.fields...)
}
}
// PostMsgSend is invoked once after a unary response or multiple times in
// streaming requests after each message has been sent.
func (r *reporter) PostMsgSend(msg interface{}, err error, _ time.Duration) {
if err != nil {
// This is the actual error that customers see.
intCode := serverErrors.ConvertToEncodedErrorCode(status.Convert(err))
encodedError := serverErrors.NewEncodedError(intCode, err.Error())
protomsg := encodedError.ActualError
if resp, err := json.Marshal(protomsg); err == nil {
r.fields = append(r.fields, zap.Any(rawResponseKey, json.RawMessage(resp)))
}
return
}
protomsg, ok := msg.(protoreflect.ProtoMessage)
if ok {
if resp, err := r.protomarshaler.Marshal(protomsg); err == nil {
r.fields = append(r.fields, zap.Any(rawResponseKey, json.RawMessage(resp)))
}
}
}
// PostMsgReceive is invoked after receiving a message in streaming requests.
func (r *reporter) PostMsgReceive(msg interface{}, _ error, _ time.Duration) {
protomsg, ok := msg.(protoreflect.ProtoMessage)
if ok {
if req, err := r.protomarshaler.Marshal(protomsg); err == nil {
r.fields = append(r.fields, zap.Any(rawRequestKey, json.RawMessage(req)))
}
}
}
// userAgentFromContext retrieves the user agent field from the provided context.
// If the user agent field is not present in the context, the function returns an empty string and false.
func userAgentFromContext(ctx context.Context) (string, bool) {
if headers, ok := metadata.FromIncomingContext(ctx); ok {
if header := headers.Get(gatewayUserAgentHeader); len(header) > 0 {
return header[0], true
}
if header := headers.Get(userAgentHeader); len(header) > 0 {
return header[0], true
}
}
return "", false
}
func reportable(l logger.Logger) interceptors.CommonReportableFunc {
return func(ctx context.Context, c interceptors.CallMeta) (interceptors.Reporter, context.Context) {
fields := []zap.Field{
zap.String(grpcServiceKey, c.Service),
zap.String(grpcMethodKey, c.Method),
zap.String(grpcTypeKey, string(c.Typ)),
}
spanCtx := trace.SpanContextFromContext(ctx)
if spanCtx.HasTraceID() {
fields = append(fields, zap.String(traceIDKey, spanCtx.TraceID().String()))
}
if userAgent, ok := userAgentFromContext(ctx); ok {
fields = append(fields, zap.String(userAgentKey, userAgent))
}
return &reporter{
ctx: ctx,
logger: l,
fields: fields,
protomarshaler: protojson.MarshalOptions{EmitUnpopulated: true},
serviceName: c.Service,
}, ctx
}
}
package recovery
import (
"context"
"encoding/json"
"fmt"
"net/http"
"runtime/debug"
grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/recovery"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/server/errors"
)
// HTTPPanicRecoveryHandler recover from panic for http services.
func HTTPPanicRecoveryHandler(next http.Handler, logger logger.Logger) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
logger.Error("HTTPPanicRecoveryHandler has recovered a panic",
zap.Error(fmt.Errorf("%v", err)),
zap.ByteString("stacktrace", debug.Stack()),
)
w.Header().Set("content-type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
responseBody, err := json.Marshal(map[string]string{
"code": openfgav1.InternalErrorCode_internal_error.String(),
"message": errors.InternalServerErrorMsg,
})
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
_, err = w.Write(responseBody)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}
}()
next.ServeHTTP(w, r)
})
}
// PanicRecoveryHandler recovers from panics for unary/stream services.
func PanicRecoveryHandler(logger logger.Logger) grpc_recovery.RecoveryHandlerFuncContext {
return func(ctx context.Context, p any) error {
logger.Error("PanicRecoveryHandler has recovered a panic",
zap.Error(fmt.Errorf("%v", p)),
zap.ByteString("stacktrace", debug.Stack()),
)
return status.Errorf(codes.Internal, errors.InternalServerErrorMsg)
}
}
package requestid
import (
"context"
"github.com/google/uuid"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
const (
requestIDKey = "request_id"
requestIDTraceKey = "request_id"
// RequestIDHeader defines the HTTP header that is set in each HTTP response
// for a given request. The value of the header is unique per request.
RequestIDHeader = "X-Request-Id"
)
// InitRequestID returns the ID to be used to identify the request.
// If tracing is enabled, returns trace ID, e.g. "1e20da43269fe07e3d2ac018c0aad2d1".
// Otherwise returns a new UUID, e.g. "38fee7ac-4bfe-4cf6-baa2-8b5ec296b485".
func InitRequestID(ctx context.Context) string {
spanCtx := trace.SpanContextFromContext(ctx)
if spanCtx.TraceID().IsValid() {
return spanCtx.TraceID().String()
}
id, _ := uuid.NewRandom()
return id.String()
}
// NewUnaryInterceptor creates a grpc.UnaryServerInterceptor which must
// come after the trace interceptor and before the logging interceptor.
func NewUnaryInterceptor() grpc.UnaryServerInterceptor {
return interceptors.UnaryServerInterceptor(reportable())
}
// NewStreamingInterceptor creates a grpc.StreamServerInterceptor which must
// come after the trace interceptor and before the logging interceptor.
func NewStreamingInterceptor() grpc.StreamServerInterceptor {
return interceptors.StreamServerInterceptor(reportable())
}
func reportable() interceptors.CommonReportableFunc {
return func(ctx context.Context, c interceptors.CallMeta) (interceptors.Reporter, context.Context) {
requestID := InitRequestID(ctx)
grpc_ctxtags.Extract(ctx).Set(requestIDKey, requestID) // CtxTags used by other middlewares
_ = grpc.SetHeader(ctx, metadata.Pairs(RequestIDHeader, requestID))
trace.SpanFromContext(ctx).SetAttributes(attribute.String(requestIDTraceKey, requestID))
return interceptors.NoopReporter{}, ctx
}
}
package storeid
import (
"context"
"time"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
type ctxKey string
const (
storeIDCtxKey ctxKey = "store-id-context-key"
storeIDKey string = "store_id"
// StoreIDHeader represents the HTTP header name used to
// specify the OpenFGA store identifier in API requests.
StoreIDHeader string = "Openfga-Store-Id"
)
type storeidHandle struct {
storeid string
}
// StoreIDFromContext retrieves the store ID stored in the provided context.
func StoreIDFromContext(ctx context.Context) (string, bool) {
if c := ctx.Value(storeIDCtxKey); c != nil {
handle := c.(*storeidHandle)
return handle.storeid, true
}
return "", false
}
func contextWithHandle(ctx context.Context) context.Context {
return context.WithValue(ctx, storeIDCtxKey, &storeidHandle{})
}
// SetStoreIDInContext sets the store ID in the provided context based on information from the request.
func SetStoreIDInContext(ctx context.Context, req interface{}) {
handle := ctx.Value(storeIDCtxKey)
if handle == nil {
return
}
if r, ok := req.(hasGetStoreID); ok {
handle.(*storeidHandle).storeid = r.GetStoreId()
}
}
type hasGetStoreID interface {
GetStoreId() string
}
// NewUnaryInterceptor creates a grpc.UnaryServerInterceptor which injects
// store_id metadata into the RPC context if an RPC message is received with
// a GetStoreId method.
func NewUnaryInterceptor() grpc.UnaryServerInterceptor {
return interceptors.UnaryServerInterceptor(reportable())
}
// NewStreamingInterceptor creates a grpc.StreamServerInterceptor which injects
// store_id metadata into the RPC context if an RPC message is received with a
// GetStoreId method.
func NewStreamingInterceptor() grpc.StreamServerInterceptor {
return interceptors.StreamServerInterceptor(reportable())
}
type reporter struct {
ctx context.Context
}
// PostCall is a placeholder for handling actions after a gRPC call.
func (r *reporter) PostCall(error, time.Duration) {}
// PostMsgSend is a placeholder for handling actions after sending a message in streaming requests.
func (r *reporter) PostMsgSend(interface{}, error, time.Duration) {}
// PostMsgReceive is invoked after receiving a message in streaming requests.
func (r *reporter) PostMsgReceive(msg interface{}, _ error, _ time.Duration) {
if m, ok := msg.(hasGetStoreID); ok {
storeID := m.GetStoreId()
SetStoreIDInContext(r.ctx, msg)
trace.SpanFromContext(r.ctx).SetAttributes(attribute.String(storeIDKey, storeID))
grpc_ctxtags.Extract(r.ctx).Set(storeIDKey, storeID)
_ = grpc.SetHeader(r.ctx, metadata.Pairs(StoreIDHeader, storeID))
}
}
func reportable() interceptors.CommonReportableFunc {
return func(ctx context.Context, c interceptors.CallMeta) (interceptors.Reporter, context.Context) {
ctx = contextWithHandle(ctx)
r := reporter{ctx}
return &r, r.ctx
}
}
package middleware
import (
"context"
"time"
grpcvalidator "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator"
"google.golang.org/grpc"
"github.com/openfga/openfga/pkg/logger"
)
// TimeoutInterceptor sets the timeout in each request.
type TimeoutInterceptor struct {
timeout time.Duration
logger logger.Logger
}
// NewTimeoutInterceptor returns new TimeoutInterceptor that timeouts request if it
// exceeds the timeout value.
func NewTimeoutInterceptor(timeout time.Duration, logger logger.Logger) *TimeoutInterceptor {
return &TimeoutInterceptor{
timeout: timeout,
logger: logger,
}
}
// NewUnaryTimeoutInterceptor returns an interceptor that will timeout according to the configured timeout.
// We need to use this middleware instead of relying on runtime.DefaultContextTimeout to allow us
// to return proper error code.
func (h *TimeoutInterceptor) NewUnaryTimeoutInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
ctx, cancel := context.WithTimeout(ctx, h.timeout)
defer cancel()
return handler(ctx, req)
}
}
// NewStreamTimeoutInterceptor returns an interceptor that will timeout according to the configured timeout.
// We need to use this middleware instead of relying on runtime.DefaultContextTimeout to allow us
// to return proper error code.
func (h *TimeoutInterceptor) NewStreamTimeoutInterceptor() grpc.StreamServerInterceptor {
validator := grpcvalidator.StreamServerInterceptor()
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return validator(srv, stream, info, func(srv interface{}, ss grpc.ServerStream) error {
ctx, cancel := context.WithTimeout(stream.Context(), h.timeout)
defer cancel()
return handler(srv, &recvWrapper{
ctx: ctx,
ServerStream: ss,
})
})
}
}
type recvWrapper struct {
ctx context.Context
grpc.ServerStream
}
// Context returns the context associated with the recvWrapper.
func (r *recvWrapper) Context() context.Context {
return r.ctx
}
package validator
import (
"context"
grpcvalidator "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator"
"google.golang.org/grpc"
)
type ctxKey string
var (
requestIsValidatedCtxKey = ctxKey("request-validated")
)
func contextWithRequestIsValidated(ctx context.Context) context.Context {
return context.WithValue(ctx, requestIsValidatedCtxKey, true)
}
// RequestIsValidatedFromContext returns true if the provided context object has the flag
// indicating that the request has been validated and if its value is set to true.
func RequestIsValidatedFromContext(ctx context.Context) bool {
validated, ok := ctx.Value(requestIsValidatedCtxKey).(bool)
return validated && ok
}
// UnaryServerInterceptor returns a new unary server interceptor that runs request validations
// and injects a bool in the context indicating that validation has been run.
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
validator := grpcvalidator.UnaryServerInterceptor()
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return validator(ctx, req, info, func(ctx context.Context, req interface{}) (interface{}, error) {
return handler(contextWithRequestIsValidated(ctx), req)
})
}
}
// StreamServerInterceptor returns a new streaming server interceptor that runs request validations
// and injects a bool in the context indicating that validation has been run.
func StreamServerInterceptor() grpc.StreamServerInterceptor {
validator := grpcvalidator.StreamServerInterceptor()
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return validator(srv, stream, info, func(srv interface{}, ss grpc.ServerStream) error {
return handler(srv, &recvWrapper{
ctx: contextWithRequestIsValidated(stream.Context()),
ServerStream: ss,
})
})
}
}
type recvWrapper struct {
ctx context.Context
grpc.ServerStream
}
// Context returns the context associated with the recvWrapper.
func (r *recvWrapper) Context() context.Context {
return r.ctx
}
package server
import (
"context"
"net/http"
"strconv"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/utils/apimethod"
httpmiddleware "github.com/openfga/openfga/pkg/middleware/http"
"github.com/openfga/openfga/pkg/middleware/validator"
"github.com/openfga/openfga/pkg/server/commands"
"github.com/openfga/openfga/pkg/telemetry"
)
func (s *Server) WriteAssertions(ctx context.Context, req *openfgav1.WriteAssertionsRequest) (*openfgav1.WriteAssertionsResponse, error) {
ctx, span := tracer.Start(ctx, apimethod.WriteAssertions.String(), trace.WithAttributes(
attribute.String("store_id", req.GetStoreId()),
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: apimethod.WriteAssertions.String(),
})
err := s.checkAuthz(ctx, req.GetStoreId(), apimethod.WriteAssertions)
if err != nil {
return nil, err
}
storeID := req.GetStoreId()
typesys, err := s.resolveTypesystem(ctx, storeID, req.GetAuthorizationModelId())
if err != nil {
return nil, err
}
c := commands.NewWriteAssertionsCommand(s.datastore, commands.WithWriteAssertCmdLogger(s.logger))
res, err := c.Execute(ctx, &openfgav1.WriteAssertionsRequest{
StoreId: storeID,
AuthorizationModelId: typesys.GetAuthorizationModelID(), // the resolved model id
Assertions: req.GetAssertions(),
})
if err != nil {
return nil, err
}
s.transport.SetHeader(ctx, httpmiddleware.XHttpCode, strconv.Itoa(http.StatusNoContent))
return res, nil
}
func (s *Server) ReadAssertions(ctx context.Context, req *openfgav1.ReadAssertionsRequest) (*openfgav1.ReadAssertionsResponse, error) {
ctx, span := tracer.Start(ctx, apimethod.ReadAssertions.String(), trace.WithAttributes(
attribute.String("store_id", req.GetStoreId()),
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: apimethod.ReadAssertions.String(),
})
err := s.checkAuthz(ctx, req.GetStoreId(), apimethod.ReadAssertions)
if err != nil {
return nil, err
}
typesys, err := s.resolveTypesystem(ctx, req.GetStoreId(), req.GetAuthorizationModelId())
if err != nil {
return nil, err
}
q := commands.NewReadAssertionsQuery(s.datastore, commands.WithReadAssertionsQueryLogger(s.logger))
return q.Execute(ctx, req.GetStoreId(), typesys.GetAuthorizationModelID())
}
package server
import (
"context"
"net/http"
"strconv"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/utils/apimethod"
httpmiddleware "github.com/openfga/openfga/pkg/middleware/http"
"github.com/openfga/openfga/pkg/middleware/validator"
"github.com/openfga/openfga/pkg/server/commands"
"github.com/openfga/openfga/pkg/telemetry"
)
func (s *Server) ReadAuthorizationModel(ctx context.Context, req *openfgav1.ReadAuthorizationModelRequest) (*openfgav1.ReadAuthorizationModelResponse, error) {
ctx, span := tracer.Start(ctx, apimethod.ReadAuthorizationModel.String(), trace.WithAttributes(
attribute.String("store_id", req.GetStoreId()),
attribute.KeyValue{Key: authorizationModelIDKey, Value: attribute.StringValue(req.GetId())},
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: apimethod.ReadAuthorizationModel.String(),
})
err := s.checkAuthz(ctx, req.GetStoreId(), apimethod.ReadAuthorizationModel)
if err != nil {
return nil, err
}
q := commands.NewReadAuthorizationModelQuery(s.datastore, commands.WithReadAuthModelQueryLogger(s.logger))
return q.Execute(ctx, req)
}
func (s *Server) WriteAuthorizationModel(ctx context.Context, req *openfgav1.WriteAuthorizationModelRequest) (*openfgav1.WriteAuthorizationModelResponse, error) {
ctx, span := tracer.Start(ctx, apimethod.WriteAuthorizationModel.String(), trace.WithAttributes(
attribute.String("store_id", req.GetStoreId()),
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: apimethod.WriteAuthorizationModel.String(),
})
err := s.checkAuthz(ctx, req.GetStoreId(), apimethod.WriteAuthorizationModel)
if err != nil {
return nil, err
}
c := commands.NewWriteAuthorizationModelCommand(s.datastore,
commands.WithWriteAuthModelLogger(s.logger),
commands.WithWriteAuthModelMaxSizeInBytes(s.maxAuthorizationModelSizeInBytes),
)
res, err := c.Execute(ctx, req)
if err != nil {
return nil, err
}
s.transport.SetHeader(ctx, httpmiddleware.XHttpCode, strconv.Itoa(http.StatusCreated))
return res, nil
}
func (s *Server) ReadAuthorizationModels(ctx context.Context, req *openfgav1.ReadAuthorizationModelsRequest) (*openfgav1.ReadAuthorizationModelsResponse, error) {
ctx, span := tracer.Start(ctx, apimethod.ReadAuthorizationModels.String(), trace.WithAttributes(
attribute.String("store_id", req.GetStoreId()),
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: apimethod.ReadAuthorizationModels.String(),
})
err := s.checkAuthz(ctx, req.GetStoreId(), apimethod.ReadAuthorizationModels)
if err != nil {
return nil, err
}
c := commands.NewReadAuthorizationModelsQuery(s.datastore,
commands.WithReadAuthModelsQueryLogger(s.logger),
commands.WithReadAuthModelsQueryEncoder(s.encoder),
)
return c.Execute(ctx, req)
}
package server
import (
"context"
"errors"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/condition"
"github.com/openfga/openfga/internal/graph"
"github.com/openfga/openfga/internal/utils/apimethod"
"github.com/openfga/openfga/pkg/middleware/validator"
"github.com/openfga/openfga/pkg/server/commands"
"github.com/openfga/openfga/pkg/server/config"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/telemetry"
)
func (s *Server) BatchCheck(ctx context.Context, req *openfgav1.BatchCheckRequest) (*openfgav1.BatchCheckResponse, error) {
ctx, span := tracer.Start(ctx, apimethod.BatchCheck.String(), trace.WithAttributes(
attribute.KeyValue{Key: "store_id", Value: attribute.StringValue(req.GetStoreId())},
attribute.KeyValue{Key: "batch_size", Value: attribute.IntValue(len(req.GetChecks()))},
attribute.KeyValue{Key: "consistency", Value: attribute.StringValue(req.GetConsistency().String())},
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: apimethod.BatchCheck.String(),
})
storeID := req.GetStoreId()
err := s.checkAuthz(ctx, storeID, apimethod.BatchCheck)
if err != nil {
return nil, err
}
typesys, err := s.resolveTypesystem(ctx, storeID, req.GetAuthorizationModelId())
if err != nil {
return nil, err
}
builder := s.getCheckResolverBuilder(req.GetStoreId())
checkResolver, checkResolverCloser, err := builder.Build()
if err != nil {
return nil, err
}
defer checkResolverCloser()
cmd := commands.NewBatchCheckCommand(
s.datastore,
checkResolver,
typesys,
commands.WithBatchCheckCacheOptions(s.sharedDatastoreResources, s.cacheSettings),
commands.WithBatchCheckCommandLogger(s.logger),
commands.WithBatchCheckMaxChecksPerBatch(s.maxChecksPerBatchCheck),
commands.WithBatchCheckMaxConcurrentChecks(s.maxConcurrentChecksPerBatch),
commands.WithBatchCheckDatastoreThrottler(
s.featureFlagClient.Boolean(config.ExperimentalDatastoreThrottling, storeID),
s.checkDatastoreThrottleThreshold,
s.checkDatastoreThrottleDuration,
),
)
result, metadata, err := cmd.Execute(ctx, &commands.BatchCheckCommandParams{
AuthorizationModelID: typesys.GetAuthorizationModelID(),
Checks: req.GetChecks(),
Consistency: req.GetConsistency(),
StoreID: storeID,
})
if err != nil {
telemetry.TraceError(span, err)
var batchValidationError *commands.BatchCheckValidationError
if errors.As(err, &batchValidationError) {
return nil, serverErrors.ValidationError(err)
}
return nil, err
}
methodName := "batchcheck"
dispatchCount := float64(metadata.DispatchCount)
grpc_ctxtags.Extract(ctx).Set(dispatchCountHistogramName, dispatchCount)
span.SetAttributes(attribute.Float64(dispatchCountHistogramName, dispatchCount))
dispatchCountHistogram.WithLabelValues(
s.serviceName,
methodName,
).Observe(dispatchCount)
var throttled bool
if metadata.ThrottleCount > 0 {
throttled = true
throttledRequestCounter.WithLabelValues(s.serviceName, methodName).Add(float64(metadata.ThrottleCount))
}
grpc_ctxtags.Extract(ctx).Set("request.throttled", throttled)
queryCount := float64(metadata.DatastoreQueryCount)
span.SetAttributes(attribute.Float64(datastoreQueryCountHistogramName, queryCount))
datastoreQueryCountHistogram.WithLabelValues(
s.serviceName,
methodName,
).Observe(queryCount)
datastoreItemCount := float64(metadata.DatastoreItemCount)
span.SetAttributes(attribute.Float64(datastoreItemCountHistogramName, datastoreItemCount))
datastoreItemCountHistogram.WithLabelValues(
s.serviceName,
methodName,
).Observe(datastoreItemCount)
duplicateChecks := "duplicate_checks"
span.SetAttributes(attribute.Int(duplicateChecks, metadata.DuplicateCheckCount))
grpc_ctxtags.Extract(ctx).Set(duplicateChecks, metadata.DuplicateCheckCount)
var batchResult = map[string]*openfgav1.BatchCheckSingleResult{}
for correlationID, outcome := range result {
batchResult[string(correlationID)] = transformCheckResultToProto(outcome)
s.emitCheckDurationMetric(outcome.CheckResponse.GetResolutionMetadata(), methodName)
}
grpc_ctxtags.Extract(ctx).Set(datastoreQueryCountHistogramName, metadata.DatastoreQueryCount)
grpc_ctxtags.Extract(ctx).Set(datastoreItemCountHistogramName, metadata.DatastoreItemCount)
return &openfgav1.BatchCheckResponse{Result: batchResult}, nil
}
// transformCheckResultToProto transforms the internal BatchCheckOutcome into the external-facing
// BatchCheckSingleResult struct for transmission back via the api.
func transformCheckResultToProto(outcome *commands.BatchCheckOutcome) *openfgav1.BatchCheckSingleResult {
singleResult := &openfgav1.BatchCheckSingleResult{}
if outcome.Err != nil {
singleResult.CheckResult = &openfgav1.BatchCheckSingleResult_Error{
Error: transformCheckCommandErrorToBatchCheckError(outcome.Err),
}
} else {
singleResult.CheckResult = &openfgav1.BatchCheckSingleResult_Allowed{
Allowed: outcome.CheckResponse.Allowed,
}
}
return singleResult
}
func transformCheckCommandErrorToBatchCheckError(cmdErr error) *openfgav1.CheckError {
var invalidRelationError *commands.InvalidRelationError
var invalidTupleError *commands.InvalidTupleError
var throttledError *commands.ThrottledError
err := &openfgav1.CheckError{Message: cmdErr.Error()}
// switch to map the possible errors to their specific GRPC codes in the proto definition
switch {
case errors.As(cmdErr, &invalidRelationError):
err.Code = &openfgav1.CheckError_InputError{InputError: openfgav1.ErrorCode_validation_error}
case errors.As(cmdErr, &invalidTupleError):
err.Code = &openfgav1.CheckError_InputError{InputError: openfgav1.ErrorCode_invalid_tuple}
case errors.Is(cmdErr, graph.ErrResolutionDepthExceeded):
err.Code = &openfgav1.CheckError_InputError{InputError: openfgav1.ErrorCode_authorization_model_resolution_too_complex}
case errors.Is(cmdErr, condition.ErrEvaluationFailed):
err.Code = &openfgav1.CheckError_InputError{InputError: openfgav1.ErrorCode_validation_error}
case errors.As(cmdErr, &throttledError):
err.Code = &openfgav1.CheckError_InputError{InputError: openfgav1.ErrorCode_validation_error}
case errors.Is(cmdErr, context.DeadlineExceeded):
err.Code = &openfgav1.CheckError_InternalError{InternalError: openfgav1.InternalErrorCode_deadline_exceeded}
default:
err.Code = &openfgav1.CheckError_InternalError{InternalError: openfgav1.InternalErrorCode_internal_error}
}
return err
}
package server
import (
"context"
"errors"
"strconv"
"time"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/graph"
"github.com/openfga/openfga/internal/utils"
"github.com/openfga/openfga/internal/utils/apimethod"
"github.com/openfga/openfga/pkg/middleware/validator"
"github.com/openfga/openfga/pkg/server/commands"
serverconfig "github.com/openfga/openfga/pkg/server/config"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/telemetry"
)
func (s *Server) Check(ctx context.Context, req *openfgav1.CheckRequest) (*openfgav1.CheckResponse, error) {
const methodName = "check"
builder := s.getCheckResolverBuilder(req.GetStoreId())
checkResolver, checkResolverCloser, err := builder.Build()
if err != nil {
return nil, err
}
defer checkResolverCloser()
startTime := time.Now()
tk := req.GetTupleKey()
ctx, span := tracer.Start(ctx, apimethod.Check.String(), trace.WithAttributes(
attribute.KeyValue{Key: "store_id", Value: attribute.StringValue(req.GetStoreId())},
attribute.KeyValue{Key: "object", Value: attribute.StringValue(tk.GetObject())},
attribute.KeyValue{Key: "relation", Value: attribute.StringValue(tk.GetRelation())},
attribute.KeyValue{Key: "user", Value: attribute.StringValue(tk.GetUser())},
attribute.KeyValue{Key: "consistency", Value: attribute.StringValue(req.GetConsistency().String())},
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: apimethod.Check.String(),
})
err = s.checkAuthz(ctx, req.GetStoreId(), apimethod.Check)
if err != nil {
return nil, err
}
storeID := req.GetStoreId()
typesys, err := s.resolveTypesystem(ctx, storeID, req.GetAuthorizationModelId())
if err != nil {
return nil, err
}
checkQuery := commands.NewCheckCommand(
s.datastore,
checkResolver,
typesys,
commands.WithCheckCommandLogger(s.logger),
commands.WithCheckCommandMaxConcurrentReads(s.maxConcurrentReadsForCheck),
commands.WithCheckCommandCache(s.sharedDatastoreResources, s.cacheSettings),
commands.WithCheckDatastoreThrottler(
s.featureFlagClient.Boolean(serverconfig.ExperimentalDatastoreThrottling, storeID),
s.checkDatastoreThrottleThreshold,
s.checkDatastoreThrottleDuration,
),
)
resp, checkRequestMetadata, err := checkQuery.Execute(ctx, &commands.CheckCommandParams{
StoreID: storeID,
TupleKey: req.GetTupleKey(),
ContextualTuples: req.GetContextualTuples(),
Context: req.GetContext(),
Consistency: req.GetConsistency(),
})
endTime := time.Since(startTime).Milliseconds()
var (
wasRequestThrottled bool
rawDispatchCount uint32
)
if checkRequestMetadata != nil {
wasRequestThrottled = checkRequestMetadata.WasThrottled.Load()
rawDispatchCount = checkRequestMetadata.DispatchCounter.Load()
dispatchCount := float64(rawDispatchCount)
grpc_ctxtags.Extract(ctx).Set(dispatchCountHistogramName, dispatchCount)
span.SetAttributes(attribute.Float64(dispatchCountHistogramName, dispatchCount))
dispatchCountHistogram.WithLabelValues(
s.serviceName,
methodName,
).Observe(dispatchCount)
}
if resp != nil {
queryCount := float64(resp.GetResolutionMetadata().DatastoreQueryCount)
grpc_ctxtags.Extract(ctx).Set(datastoreQueryCountHistogramName, queryCount)
span.SetAttributes(attribute.Float64(datastoreQueryCountHistogramName, queryCount))
datastoreQueryCountHistogram.WithLabelValues(
s.serviceName,
methodName,
).Observe(queryCount)
datastoreItemCount := float64(resp.GetResolutionMetadata().DatastoreItemCount)
grpc_ctxtags.Extract(ctx).Set(datastoreItemCountHistogramName, datastoreItemCount)
span.SetAttributes(attribute.Float64(datastoreItemCountHistogramName, datastoreItemCount))
datastoreItemCountHistogram.WithLabelValues(
s.serviceName,
methodName,
).Observe(datastoreItemCount)
requestDurationHistogram.WithLabelValues(
s.serviceName,
methodName,
utils.Bucketize(uint(queryCount), s.requestDurationByQueryHistogramBuckets),
utils.Bucketize(uint(rawDispatchCount), s.requestDurationByDispatchCountHistogramBuckets),
req.GetConsistency().String(),
).Observe(float64(endTime))
if s.authorizer.AccessControlStoreID() == req.GetStoreId() {
accessControlStoreCheckDurationHistogram.WithLabelValues(
utils.Bucketize(uint(queryCount), s.requestDurationByQueryHistogramBuckets),
utils.Bucketize(uint(rawDispatchCount), s.requestDurationByDispatchCountHistogramBuckets),
req.GetConsistency().String(),
).Observe(float64(endTime))
}
if wasRequestThrottled {
throttledRequestCounter.WithLabelValues(s.serviceName, methodName).Inc()
}
grpc_ctxtags.Extract(ctx).Set("request.throttled", wasRequestThrottled)
}
if err != nil {
telemetry.TraceError(span, err)
finalErr := commands.CheckCommandErrorToServerError(err)
if errors.Is(finalErr, serverErrors.ErrThrottledTimeout) {
throttledRequestCounter.WithLabelValues(s.serviceName, methodName).Inc()
}
// should we define all metrics in one place that is accessible from everywhere (including LocalChecker!)
// and add a wrapper helper that automatically injects the service name tag?
return nil, finalErr
}
checkResultCounter.With(prometheus.Labels{allowedLabel: strconv.FormatBool(resp.GetAllowed())}).Inc()
span.SetAttributes(
attribute.Bool("cycle_detected", resp.GetCycleDetected()),
attribute.Bool("allowed", resp.GetAllowed()))
res := &openfgav1.CheckResponse{
Allowed: resp.Allowed,
}
return res, nil
}
func (s *Server) getCheckResolverBuilder(storeID string) *graph.CheckResolverOrderedBuilder {
checkCacheOptions, checkDispatchThrottlingOptions := s.getCheckResolverOptions()
return graph.NewOrderedCheckResolvers([]graph.CheckResolverOrderedBuilderOpt{
graph.WithLocalCheckerOpts([]graph.LocalCheckerOption{
graph.WithResolveNodeBreadthLimit(s.resolveNodeBreadthLimit),
graph.WithOptimizations(s.featureFlagClient.Boolean(serverconfig.ExperimentalCheckOptimizations, storeID)),
graph.WithMaxResolutionDepth(s.resolveNodeLimit),
graph.WithPlanner(s.planner),
graph.WithUpstreamTimeout(s.requestTimeout),
graph.WithLocalCheckerLogger(s.logger),
}...),
graph.WithLocalShadowCheckerOpts([]graph.LocalCheckerOption{
graph.WithResolveNodeBreadthLimit(s.resolveNodeBreadthLimit),
graph.WithOptimizations(true), // shadow checker always uses optimizations
graph.WithMaxResolutionDepth(s.resolveNodeLimit),
graph.WithPlanner(s.planner),
}...),
graph.WithShadowResolverEnabled(s.featureFlagClient.Boolean(serverconfig.ExperimentalShadowCheck, storeID)),
graph.WithShadowResolverOpts([]graph.ShadowResolverOpt{
graph.ShadowResolverWithLogger(s.logger),
graph.ShadowResolverWithTimeout(s.shadowCheckResolverTimeout),
}...),
graph.WithCachedCheckResolverOpts(s.cacheSettings.ShouldCacheCheckQueries(), checkCacheOptions...),
graph.WithDispatchThrottlingCheckResolverOpts(s.checkDispatchThrottlingEnabled, checkDispatchThrottlingOptions...),
}...)
}
package commands
import (
"context"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/cespare/xxhash/v2"
"go.uber.org/zap"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/cachecontroller"
"github.com/openfga/openfga/internal/concurrency"
"github.com/openfga/openfga/internal/graph"
"github.com/openfga/openfga/internal/shared"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/server/config"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/typesystem"
)
type BatchCheckQuery struct {
sharedCheckResources *shared.SharedDatastoreResources
cacheSettings config.CacheSettings
checkResolver graph.CheckResolver
datastore storage.RelationshipTupleReader
logger logger.Logger
maxChecksAllowed uint32
maxConcurrentChecks uint32
typesys *typesystem.TypeSystem
datastoreThrottlingEnabled bool
datastoreThrottleThreshold int
datastoreThrottleDuration time.Duration
}
type BatchCheckCommandParams struct {
AuthorizationModelID string
Checks []*openfgav1.BatchCheckItem
Consistency openfgav1.ConsistencyPreference
StoreID string
}
type BatchCheckOutcome struct {
CheckResponse *graph.ResolveCheckResponse
Err error
}
type BatchCheckMetadata struct {
ThrottleCount uint32
DispatchCount uint32
DatastoreQueryCount uint32
DatastoreItemCount uint64
DuplicateCheckCount int
}
type BatchCheckValidationError struct {
Message string
}
func (e BatchCheckValidationError) Error() string {
return e.Message
}
type CorrelationID string
type CacheKey string
type checkAndCorrelationIDs struct {
Check *openfgav1.BatchCheckItem
CorrelationIDs []CorrelationID
}
type BatchCheckQueryOption func(*BatchCheckQuery)
func WithBatchCheckCacheOptions(sharedCheckResources *shared.SharedDatastoreResources, cacheSettings config.CacheSettings) BatchCheckQueryOption {
return func(c *BatchCheckQuery) {
c.sharedCheckResources = sharedCheckResources
c.cacheSettings = cacheSettings
}
}
func WithBatchCheckCommandLogger(l logger.Logger) BatchCheckQueryOption {
return func(bq *BatchCheckQuery) {
bq.logger = l
}
}
func WithBatchCheckMaxConcurrentChecks(maxConcurrentChecks uint32) BatchCheckQueryOption {
return func(bq *BatchCheckQuery) {
bq.maxConcurrentChecks = maxConcurrentChecks
}
}
func WithBatchCheckMaxChecksPerBatch(maxChecks uint32) BatchCheckQueryOption {
return func(bq *BatchCheckQuery) {
bq.maxChecksAllowed = maxChecks
}
}
func WithBatchCheckDatastoreThrottler(enabled bool, threshold int, duration time.Duration) BatchCheckQueryOption {
return func(bq *BatchCheckQuery) {
bq.datastoreThrottlingEnabled = enabled
bq.datastoreThrottleThreshold = threshold
bq.datastoreThrottleDuration = duration
}
}
func NewBatchCheckCommand(datastore storage.RelationshipTupleReader, checkResolver graph.CheckResolver, typesys *typesystem.TypeSystem, opts ...BatchCheckQueryOption) *BatchCheckQuery {
cmd := &BatchCheckQuery{
logger: logger.NewNoopLogger(),
datastore: datastore,
checkResolver: checkResolver,
typesys: typesys,
maxChecksAllowed: config.DefaultMaxChecksPerBatchCheck,
maxConcurrentChecks: config.DefaultMaxConcurrentChecksPerBatchCheck,
cacheSettings: config.NewDefaultCacheSettings(),
sharedCheckResources: &shared.SharedDatastoreResources{
CacheController: cachecontroller.NewNoopCacheController(),
},
}
for _, opt := range opts {
opt(cmd)
}
return cmd
}
func (bq *BatchCheckQuery) Execute(ctx context.Context, params *BatchCheckCommandParams) (map[CorrelationID]*BatchCheckOutcome, *BatchCheckMetadata, error) {
if len(params.Checks) > int(bq.maxChecksAllowed) {
return nil, nil, &BatchCheckValidationError{
Message: "batchCheck received " + strconv.Itoa(len(params.Checks)) + " checks, the maximum allowed is " + strconv.Itoa(int(bq.maxChecksAllowed)),
}
}
if len(params.Checks) == 0 {
return nil, nil, &BatchCheckValidationError{
Message: "batch check requires at least one check to evaluate, no checks were received",
}
}
if err := validateCorrelationIDs(params.Checks); err != nil {
return nil, nil, err
}
// Before processing the batch, deduplicate the checks based on their unique cache key
// After all routines have finished, we will map each individual check response to all associated CorrelationIDs
cacheKeyMap := make(map[CacheKey]*checkAndCorrelationIDs)
for _, check := range params.Checks {
key, err := generateCacheKeyFromCheck(check, params.StoreID, bq.typesys.GetAuthorizationModelID())
if err != nil {
bq.logger.Error("batch check cache key computation failed with error", zap.Error(err))
return nil, nil, err
}
if item, ok := cacheKeyMap[key]; ok {
item.CorrelationIDs = append(item.CorrelationIDs, CorrelationID(check.GetCorrelationId()))
} else {
cacheKeyMap[key] = &checkAndCorrelationIDs{
Check: check,
CorrelationIDs: []CorrelationID{CorrelationID(check.GetCorrelationId())},
}
}
}
var resultMap = new(sync.Map)
var totalQueryCount atomic.Uint32
var totalDispatchCount atomic.Uint32
var totalThrottleCount atomic.Uint32
var totalItemCount atomic.Uint64
pool := concurrency.NewPool(ctx, int(bq.maxConcurrentChecks))
for key, item := range cacheKeyMap {
check := item.Check
pool.Go(func(ctx context.Context) error {
select {
case <-ctx.Done():
resultMap.Store(key, &BatchCheckOutcome{
Err: ctx.Err(),
})
return nil
default:
}
checkQuery := NewCheckCommand(
bq.datastore,
bq.checkResolver,
bq.typesys,
WithCheckCommandLogger(bq.logger),
WithCheckCommandCache(bq.sharedCheckResources, bq.cacheSettings),
WithCheckDatastoreThrottler(
bq.datastoreThrottlingEnabled,
bq.datastoreThrottleThreshold,
bq.datastoreThrottleDuration,
),
)
checkParams := &CheckCommandParams{
StoreID: params.StoreID,
TupleKey: check.GetTupleKey(),
ContextualTuples: check.GetContextualTuples(),
Context: check.GetContext(),
Consistency: params.Consistency,
}
response, metadata, err := checkQuery.Execute(ctx, checkParams)
resultMap.Store(key, &BatchCheckOutcome{
CheckResponse: response,
Err: err,
})
if metadata != nil {
if metadata.WasThrottled.Load() {
totalThrottleCount.Add(1)
}
totalDispatchCount.Add(metadata.DispatchCounter.Load())
}
totalQueryCount.Add(response.GetResolutionMetadata().DatastoreQueryCount)
totalItemCount.Add(response.GetResolutionMetadata().DatastoreItemCount)
return nil
})
}
_ = pool.Wait()
results := map[CorrelationID]*BatchCheckOutcome{}
// Each cacheKey can have > 1 associated CorrelationID
for cacheKey, checkItem := range cacheKeyMap {
res, _ := resultMap.Load(cacheKey)
outcome := res.(*BatchCheckOutcome)
for _, id := range checkItem.CorrelationIDs {
// map all associated CorrelationIDs to this outcome
results[id] = outcome
}
}
return results, &BatchCheckMetadata{
ThrottleCount: totalThrottleCount.Load(),
DatastoreQueryCount: totalQueryCount.Load(),
DatastoreItemCount: totalItemCount.Load(),
DispatchCount: totalDispatchCount.Load(),
DuplicateCheckCount: len(params.Checks) - len(cacheKeyMap),
}, nil
}
func validateCorrelationIDs(checks []*openfgav1.BatchCheckItem) error {
seen := map[string]struct{}{}
for _, check := range checks {
if check.GetCorrelationId() == "" {
return &BatchCheckValidationError{
Message: "received empty correlation id for tuple: " + check.GetTupleKey().String(),
}
}
_, ok := seen[check.GetCorrelationId()]
if ok {
return &BatchCheckValidationError{
Message: "received duplicate correlation id: " + check.GetCorrelationId(),
}
}
seen[check.GetCorrelationId()] = struct{}{}
}
return nil
}
func generateCacheKeyFromCheck(check *openfgav1.BatchCheckItem, storeID string, authModelID string) (CacheKey, error) {
tupleKey := check.GetTupleKey()
cacheKeyParams := &storage.CheckCacheKeyParams{
StoreID: storeID,
AuthorizationModelID: authModelID,
TupleKey: &openfgav1.TupleKey{
User: tupleKey.GetUser(),
Relation: tupleKey.GetRelation(),
Object: tupleKey.GetObject(),
},
ContextualTuples: check.GetContextualTuples().GetTupleKeys(),
Context: check.GetContext(),
}
hasher := xxhash.New()
err := storage.WriteCheckCacheKey(hasher, cacheKeyParams)
if err != nil {
return "", err
}
keyStr := strconv.FormatUint(hasher.Sum64(), 10)
return CacheKey(keyStr), nil
}
package commands
import (
"context"
"errors"
"math"
"time"
"google.golang.org/protobuf/types/known/structpb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/cachecontroller"
"github.com/openfga/openfga/internal/graph"
"github.com/openfga/openfga/internal/shared"
"github.com/openfga/openfga/internal/utils/apimethod"
"github.com/openfga/openfga/internal/validation"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/server/config"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/storagewrappers"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
const (
defaultMaxConcurrentReadsForCheck = math.MaxUint32
)
type CheckQuery struct {
logger logger.Logger
checkResolver graph.CheckResolver
typesys *typesystem.TypeSystem
datastore storage.RelationshipTupleReader
sharedCheckResources *shared.SharedDatastoreResources
cacheSettings config.CacheSettings
maxConcurrentReads uint32
shouldCacheIterators bool
datastoreThrottlingEnabled bool
datastoreThrottleThreshold int
datastoreThrottleDuration time.Duration
}
type CheckCommandParams struct {
StoreID string
TupleKey *openfgav1.CheckRequestTupleKey
ContextualTuples *openfgav1.ContextualTupleKeys
Context *structpb.Struct
Consistency openfgav1.ConsistencyPreference
}
type CheckQueryOption func(*CheckQuery)
func WithCheckCommandMaxConcurrentReads(m uint32) CheckQueryOption {
return func(c *CheckQuery) {
c.maxConcurrentReads = m
}
}
func WithCheckCommandLogger(l logger.Logger) CheckQueryOption {
return func(c *CheckQuery) {
c.logger = l
}
}
func WithCheckCommandCache(sharedCheckResources *shared.SharedDatastoreResources, cacheSettings config.CacheSettings) CheckQueryOption {
return func(c *CheckQuery) {
c.sharedCheckResources = sharedCheckResources
c.cacheSettings = cacheSettings
}
}
func WithCheckDatastoreThrottler(enabled bool, threshold int, duration time.Duration) CheckQueryOption {
return func(c *CheckQuery) {
c.datastoreThrottlingEnabled = enabled
c.datastoreThrottleDuration = duration
c.datastoreThrottleThreshold = threshold
}
}
// TODO accept CheckCommandParams so we can build the datastore object right away.
func NewCheckCommand(datastore storage.RelationshipTupleReader, checkResolver graph.CheckResolver, typesys *typesystem.TypeSystem, opts ...CheckQueryOption) *CheckQuery {
cmd := &CheckQuery{
logger: logger.NewNoopLogger(),
datastore: datastore,
checkResolver: checkResolver,
typesys: typesys,
maxConcurrentReads: defaultMaxConcurrentReadsForCheck,
shouldCacheIterators: false,
cacheSettings: config.NewDefaultCacheSettings(),
sharedCheckResources: &shared.SharedDatastoreResources{
CacheController: cachecontroller.NewNoopCacheController(),
},
}
for _, opt := range opts {
opt(cmd)
}
return cmd
}
func (c *CheckQuery) Execute(ctx context.Context, params *CheckCommandParams) (*graph.ResolveCheckResponse, *graph.ResolveCheckRequestMetadata, error) {
err := validateCheckRequest(c.typesys, params.TupleKey, params.ContextualTuples)
if err != nil {
return nil, nil, err
}
cacheInvalidationTime := time.Time{}
if params.Consistency != openfgav1.ConsistencyPreference_HIGHER_CONSISTENCY {
cacheInvalidationTime = c.sharedCheckResources.CacheController.DetermineInvalidationTime(ctx, params.StoreID)
}
resolveCheckRequest, err := graph.NewResolveCheckRequest(
graph.ResolveCheckRequestParams{
StoreID: params.StoreID,
TupleKey: tuple.ConvertCheckRequestTupleKeyToTupleKey(params.TupleKey),
Context: params.Context,
ContextualTuples: params.ContextualTuples.GetTupleKeys(),
Consistency: params.Consistency,
LastCacheInvalidationTime: cacheInvalidationTime,
AuthorizationModelID: c.typesys.GetAuthorizationModelID(),
},
)
if err != nil {
return nil, nil, err
}
datastoreWithTupleCache := storagewrappers.NewRequestStorageWrapperWithCache(
c.datastore,
params.ContextualTuples.GetTupleKeys(),
&storagewrappers.Operation{
Method: apimethod.Check,
Concurrency: c.maxConcurrentReads,
ThrottleThreshold: c.datastoreThrottleThreshold,
ThrottleDuration: c.datastoreThrottleDuration,
},
storagewrappers.DataResourceConfiguration{
Resources: c.sharedCheckResources,
CacheSettings: c.cacheSettings,
UseShadowCache: false,
},
)
ctx = typesystem.ContextWithTypesystem(ctx, c.typesys)
ctx = storage.ContextWithRelationshipTupleReader(ctx, datastoreWithTupleCache)
startTime := time.Now()
resp, err := c.checkResolver.ResolveCheck(ctx, resolveCheckRequest)
endTime := time.Since(startTime)
// ResolveCheck might fail half way throughout (e.g. due to a timeout) and return a nil response.
// Partial resolution metadata is still useful for obsevability.
// From here on, we can assume that request metadata and response are not nil even if
// there is an error present.
if resp == nil {
resp = &graph.ResolveCheckResponse{
Allowed: false,
ResolutionMetadata: graph.ResolveCheckResponseMetadata{},
}
}
resp.ResolutionMetadata.Duration = endTime
dsMeta := datastoreWithTupleCache.GetMetadata()
resp.ResolutionMetadata.DatastoreQueryCount = dsMeta.DatastoreQueryCount
resp.ResolutionMetadata.DatastoreItemCount = dsMeta.DatastoreItemCount
// Until dispatch throttling is deprecated, merge the results of both
resolveCheckRequest.GetRequestMetadata().WasThrottled.CompareAndSwap(false, dsMeta.WasThrottled)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) && resolveCheckRequest.GetRequestMetadata().WasThrottled.Load() {
return resp, resolveCheckRequest.GetRequestMetadata(), &ThrottledError{Cause: err}
}
return resp, resolveCheckRequest.GetRequestMetadata(), err
}
return resp, resolveCheckRequest.GetRequestMetadata(), nil
}
func validateCheckRequest(typesys *typesystem.TypeSystem, tupleKey *openfgav1.CheckRequestTupleKey, contextualTuples *openfgav1.ContextualTupleKeys) error {
// The input tuple Key should be validated loosely.
if err := validation.ValidateUserObjectRelation(typesys, tuple.ConvertCheckRequestTupleKeyToTupleKey(tupleKey)); err != nil {
return &InvalidRelationError{Cause: err}
}
// But contextual tuples need to be validated more strictly, the same as an input to a Write Tuple request.
for _, ctxTuple := range contextualTuples.GetTupleKeys() {
if err := validation.ValidateTupleForWrite(typesys, ctxTuple); err != nil {
return &InvalidTupleError{Cause: err}
}
}
return nil
}
package commands
import (
"context"
"github.com/oklog/ulid/v2"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/logger"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
)
type CreateStoreCommand struct {
storesBackend storage.StoresBackend
logger logger.Logger
}
type CreateStoreCmdOption func(*CreateStoreCommand)
func WithCreateStoreCmdLogger(l logger.Logger) CreateStoreCmdOption {
return func(c *CreateStoreCommand) {
c.logger = l
}
}
func NewCreateStoreCommand(
storesBackend storage.StoresBackend,
opts ...CreateStoreCmdOption,
) *CreateStoreCommand {
cmd := &CreateStoreCommand{
storesBackend: storesBackend,
logger: logger.NewNoopLogger(),
}
for _, opt := range opts {
opt(cmd)
}
return cmd
}
func (s *CreateStoreCommand) Execute(ctx context.Context, req *openfgav1.CreateStoreRequest) (*openfgav1.CreateStoreResponse, error) {
store, err := s.storesBackend.CreateStore(ctx, &openfgav1.Store{
Id: ulid.Make().String(),
Name: req.GetName(),
// TODO why not pass CreatedAt and UpdatedAt as derived from the ulid?
})
if err != nil {
return nil, serverErrors.HandleError("", err)
}
return &openfgav1.CreateStoreResponse{
Id: store.GetId(),
Name: store.GetName(),
CreatedAt: store.GetCreatedAt(),
UpdatedAt: store.GetUpdatedAt(),
}, nil
}
package commands
import (
"context"
"errors"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/logger"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
)
type DeleteStoreCommand struct {
storesBackend storage.StoresBackend
logger logger.Logger
}
type DeleteStoreCmdOption func(*DeleteStoreCommand)
func WithDeleteStoreCmdLogger(l logger.Logger) DeleteStoreCmdOption {
return func(c *DeleteStoreCommand) {
c.logger = l
}
}
func NewDeleteStoreCommand(
storesBackend storage.StoresBackend,
opts ...DeleteStoreCmdOption,
) *DeleteStoreCommand {
cmd := &DeleteStoreCommand{
storesBackend: storesBackend,
logger: logger.NewNoopLogger(),
}
for _, opt := range opts {
opt(cmd)
}
return cmd
}
func (s *DeleteStoreCommand) Execute(ctx context.Context, req *openfgav1.DeleteStoreRequest) (*openfgav1.DeleteStoreResponse, error) {
store, err := s.storesBackend.GetStore(ctx, req.GetStoreId())
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
return &openfgav1.DeleteStoreResponse{}, nil
}
return nil, serverErrors.HandleError("", err)
}
if err := s.storesBackend.DeleteStore(ctx, store.GetId()); err != nil {
return nil, serverErrors.HandleError("Error deleting store", err)
}
return &openfgav1.DeleteStoreResponse{}, nil
}
package commands
import (
"context"
"errors"
"github.com/openfga/openfga/internal/condition"
"github.com/openfga/openfga/internal/graph"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/tuple"
)
type InvalidTupleError struct {
Cause error
}
func (e *InvalidTupleError) Unwrap() error {
return e.Cause
}
func (e *InvalidTupleError) Error() string {
return e.Unwrap().Error()
}
type InvalidRelationError struct {
Cause error
}
func (e *InvalidRelationError) Unwrap() error {
return e.Cause
}
func (e *InvalidRelationError) Error() string {
return e.Unwrap().Error()
}
type ThrottledError struct {
Cause error
}
func (e *ThrottledError) Unwrap() error {
return e.Cause
}
func (e *ThrottledError) Error() string {
return e.Unwrap().Error()
}
// CheckCommandErrorToServerError converts internal errors thrown during the
// check_command into consumer-facing errors to be sent over the wire.
func CheckCommandErrorToServerError(err error) error {
var invalidRelationError *InvalidRelationError
if errors.As(err, &invalidRelationError) {
return serverErrors.ValidationError(err)
}
var invalidTupleError *InvalidTupleError
if errors.As(err, &invalidTupleError) {
tupleError := tuple.InvalidTupleError{Cause: err}
return serverErrors.HandleTupleValidateError(&tupleError)
}
if errors.Is(err, graph.ErrResolutionDepthExceeded) {
return serverErrors.ErrAuthorizationModelResolutionTooComplex
}
if errors.Is(err, condition.ErrEvaluationFailed) {
return serverErrors.ValidationError(err)
}
var throttledError *ThrottledError
if errors.As(err, &throttledError) {
return serverErrors.ErrThrottledTimeout
}
if errors.Is(err, context.DeadlineExceeded) {
return serverErrors.ErrRequestDeadlineExceeded
}
return serverErrors.HandleError("", err)
}
package commands
import (
"context"
"errors"
"fmt"
"slices"
"golang.org/x/sync/errgroup"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
openfgaErrors "github.com/openfga/openfga/internal/errors"
"github.com/openfga/openfga/internal/validation"
"github.com/openfga/openfga/pkg/logger"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/storagewrappers"
tupleUtils "github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
// ExpandQuery resolves a target TupleKey into a UsersetTree by expanding type definitions.
type ExpandQuery struct {
logger logger.Logger
datastore storage.RelationshipTupleReader
}
type ExpandQueryOption func(*ExpandQuery)
func WithExpandQueryLogger(l logger.Logger) ExpandQueryOption {
return func(eq *ExpandQuery) {
eq.logger = l
}
}
// NewExpandQuery creates a new ExpandQuery using the supplied backends for retrieving data.
func NewExpandQuery(datastore storage.OpenFGADatastore, opts ...ExpandQueryOption) *ExpandQuery {
eq := &ExpandQuery{
datastore: datastore,
logger: logger.NewNoopLogger(),
}
for _, opt := range opts {
opt(eq)
}
return eq
}
func (q *ExpandQuery) Execute(ctx context.Context, req *openfgav1.ExpandRequest) (*openfgav1.ExpandResponse, error) {
store := req.GetStoreId()
tupleKey := req.GetTupleKey()
object := tupleKey.GetObject()
relation := tupleKey.GetRelation()
if object == "" || relation == "" {
return nil, serverErrors.ErrInvalidExpandInput
}
tk := tupleUtils.NewTupleKey(object, relation, "")
typesys, ok := typesystem.TypesystemFromContext(ctx)
if !ok {
return nil, fmt.Errorf("%w: typesystem missing in context", openfgaErrors.ErrUnknown)
}
for _, ctxTuple := range req.GetContextualTuples().GetTupleKeys() {
if err := validation.ValidateTupleForWrite(typesys, ctxTuple); err != nil {
return nil, serverErrors.HandleTupleValidateError(err)
}
}
err := validation.ValidateObject(typesys, tk)
if err != nil {
return nil, serverErrors.ValidationError(err)
}
err = validation.ValidateRelation(typesys, tk)
if err != nil {
return nil, serverErrors.ValidationError(err)
}
q.datastore = storagewrappers.NewCombinedTupleReader(
q.datastore,
req.GetContextualTuples().GetTupleKeys(),
)
objectType := tupleUtils.GetType(object)
rel, err := typesys.GetRelation(objectType, relation)
if err != nil {
if errors.Is(err, typesystem.ErrObjectTypeUndefined) {
return nil, serverErrors.TypeNotFound(objectType)
}
if errors.Is(err, typesystem.ErrRelationUndefined) {
return nil, serverErrors.RelationNotFound(relation, objectType, tk)
}
return nil, serverErrors.HandleError("", err)
}
userset := rel.GetRewrite()
root, err := q.resolveUserset(ctx, store, userset, tk, typesys, req.GetConsistency())
if err != nil {
return nil, err
}
return &openfgav1.ExpandResponse{
Tree: &openfgav1.UsersetTree{
Root: root,
},
}, nil
}
func (q *ExpandQuery) resolveUserset(
ctx context.Context,
store string,
userset *openfgav1.Userset,
tk *openfgav1.TupleKey,
typesys *typesystem.TypeSystem,
consistency openfgav1.ConsistencyPreference,
) (*openfgav1.UsersetTree_Node, error) {
ctx, span := tracer.Start(ctx, "resolveUserset")
defer span.End()
switch us := userset.GetUserset().(type) {
case nil, *openfgav1.Userset_This:
return q.resolveThis(ctx, store, tk, typesys, consistency)
case *openfgav1.Userset_ComputedUserset:
return q.resolveComputedUserset(ctx, us.ComputedUserset, tk)
case *openfgav1.Userset_TupleToUserset:
return q.resolveTupleToUserset(ctx, store, us.TupleToUserset, tk, typesys, consistency)
case *openfgav1.Userset_Union:
return q.resolveUnionUserset(ctx, store, us.Union, tk, typesys, consistency)
case *openfgav1.Userset_Difference:
return q.resolveDifferenceUserset(ctx, store, us.Difference, tk, typesys, consistency)
case *openfgav1.Userset_Intersection:
return q.resolveIntersectionUserset(ctx, store, us.Intersection, tk, typesys, consistency)
default:
return nil, serverErrors.ErrUnsupportedUserSet
}
}
// resolveThis resolves a DirectUserset into a leaf node containing a distinct set of users with that relation.
func (q *ExpandQuery) resolveThis(ctx context.Context, store string, tk *openfgav1.TupleKey, typesys *typesystem.TypeSystem, consistency openfgav1.ConsistencyPreference) (*openfgav1.UsersetTree_Node, error) {
ctx, span := tracer.Start(ctx, "resolveThis")
defer span.End()
opts := storage.ReadOptions{
Consistency: storage.ConsistencyOptions{
Preference: consistency,
},
}
filter := storage.ReadFilter{
Object: tk.GetObject(),
Relation: tk.GetRelation(),
User: tk.GetUser(),
}
tupleIter, err := q.datastore.Read(ctx, store, filter, opts)
if err != nil {
return nil, serverErrors.HandleError("", err)
}
filteredIter := storage.NewFilteredTupleKeyIterator(
storage.NewTupleKeyIteratorFromTupleIterator(tupleIter),
validation.FilterInvalidTuples(typesys),
)
defer filteredIter.Stop()
distinctUsers := make(map[string]bool)
for {
tk, err := filteredIter.Next(ctx)
if err != nil {
if err == storage.ErrIteratorDone {
break
}
return nil, serverErrors.HandleError("", err)
}
distinctUsers[tk.GetUser()] = true
}
users := make([]string, 0, len(distinctUsers))
for u := range distinctUsers {
users = append(users, u)
}
// to make output array deterministic
slices.Sort(users)
return &openfgav1.UsersetTree_Node{
Name: toObjectRelation(tk),
Value: &openfgav1.UsersetTree_Node_Leaf{
Leaf: &openfgav1.UsersetTree_Leaf{
Value: &openfgav1.UsersetTree_Leaf_Users{
Users: &openfgav1.UsersetTree_Users{
Users: users,
},
},
},
},
}, nil
}
// resolveComputedUserset builds a leaf node containing the result of resolving a ComputedUserset rewrite.
func (q *ExpandQuery) resolveComputedUserset(ctx context.Context, userset *openfgav1.ObjectRelation, tk *openfgav1.TupleKey) (*openfgav1.UsersetTree_Node, error) {
_, span := tracer.Start(ctx, "resolveComputedUserset")
defer span.End()
computed := &openfgav1.TupleKey{
Object: userset.GetObject(),
Relation: userset.GetRelation(),
}
if len(computed.GetObject()) == 0 {
computed.Object = tk.GetObject()
}
if len(computed.GetRelation()) == 0 {
computed.Relation = tk.GetRelation()
}
return &openfgav1.UsersetTree_Node{
Name: toObjectRelation(tk),
Value: &openfgav1.UsersetTree_Node_Leaf{
Leaf: &openfgav1.UsersetTree_Leaf{
Value: &openfgav1.UsersetTree_Leaf_Computed{
Computed: &openfgav1.UsersetTree_Computed{
Userset: toObjectRelation(computed),
},
},
},
},
}, nil
}
// resolveTupleToUserset creates a new leaf node containing the result of expanding a TupleToUserset rewrite.
func (q *ExpandQuery) resolveTupleToUserset(
ctx context.Context,
store string,
userset *openfgav1.TupleToUserset,
tk *openfgav1.TupleKey,
typesys *typesystem.TypeSystem,
consistency openfgav1.ConsistencyPreference,
) (*openfgav1.UsersetTree_Node, error) {
ctx, span := tracer.Start(ctx, "resolveTupleToUserset")
defer span.End()
targetObject := tk.GetObject()
tupleset := userset.GetTupleset().GetRelation()
objectType := tupleUtils.GetType(targetObject)
_, err := typesys.GetRelation(objectType, tupleset)
if err != nil {
if errors.Is(err, typesystem.ErrObjectTypeUndefined) {
return nil, serverErrors.TypeNotFound(objectType)
}
if errors.Is(err, typesystem.ErrRelationUndefined) {
return nil, serverErrors.RelationNotFound(tupleset, objectType, tupleUtils.NewTupleKey(tk.GetObject(), tupleset, tk.GetUser()))
}
}
tsKey := &openfgav1.TupleKey{
Object: targetObject,
Relation: tupleset,
}
if tsKey.GetRelation() == "" {
tsKey.Relation = tk.GetRelation()
}
opts := storage.ReadOptions{
Consistency: storage.ConsistencyOptions{
Preference: consistency,
},
}
filter := storage.ReadFilter{
Object: tsKey.GetObject(),
Relation: tsKey.GetRelation(),
User: tsKey.GetUser(),
}
tupleIter, err := q.datastore.Read(ctx, store, filter, opts)
if err != nil {
return nil, serverErrors.HandleError("", err)
}
filteredIter := storage.NewFilteredTupleKeyIterator(
storage.NewTupleKeyIteratorFromTupleIterator(tupleIter),
validation.FilterInvalidTuples(typesys),
)
defer filteredIter.Stop()
var computed []*openfgav1.UsersetTree_Computed
seen := make(map[string]bool)
for {
tk, err := filteredIter.Next(ctx)
if err != nil {
if err == storage.ErrIteratorDone {
break
}
return nil, serverErrors.HandleError("", err)
}
user := tk.GetUser()
tObject, tRelation := tupleUtils.SplitObjectRelation(user)
// We only proceed in the case that tRelation == userset.GetComputedUserset().GetRelation().
// tRelation may be empty, and in this case, we set it to userset.GetComputedUserset().GetRelation().
if tRelation == "" {
tRelation = userset.GetComputedUserset().GetRelation()
}
cs := &openfgav1.TupleKey{
Object: tObject,
Relation: tRelation,
}
computedRelation := toObjectRelation(cs)
if !seen[computedRelation] {
computed = append(computed, &openfgav1.UsersetTree_Computed{Userset: computedRelation})
seen[computedRelation] = true
}
}
return &openfgav1.UsersetTree_Node{
Name: toObjectRelation(tk),
Value: &openfgav1.UsersetTree_Node_Leaf{
Leaf: &openfgav1.UsersetTree_Leaf{
Value: &openfgav1.UsersetTree_Leaf_TupleToUserset{
TupleToUserset: &openfgav1.UsersetTree_TupleToUserset{
Tupleset: toObjectRelation(tsKey),
Computed: computed,
},
},
},
},
}, nil
}
// resolveUnionUserset creates an intermediate Usertree node containing the union of its children.
func (q *ExpandQuery) resolveUnionUserset(
ctx context.Context,
store string,
usersets *openfgav1.Usersets,
tk *openfgav1.TupleKey,
typesys *typesystem.TypeSystem,
consistency openfgav1.ConsistencyPreference,
) (*openfgav1.UsersetTree_Node, error) {
ctx, span := tracer.Start(ctx, "resolveUnionUserset")
defer span.End()
nodes, err := q.resolveUsersets(ctx, store, usersets.GetChild(), tk, typesys, consistency)
if err != nil {
return nil, err
}
return &openfgav1.UsersetTree_Node{
Name: toObjectRelation(tk),
Value: &openfgav1.UsersetTree_Node_Union{
Union: &openfgav1.UsersetTree_Nodes{
Nodes: nodes,
},
},
}, nil
}
// resolveIntersectionUserset create an intermediate Usertree node containing the intersection of its children.
func (q *ExpandQuery) resolveIntersectionUserset(
ctx context.Context,
store string,
usersets *openfgav1.Usersets,
tk *openfgav1.TupleKey,
typesys *typesystem.TypeSystem,
consistency openfgav1.ConsistencyPreference,
) (*openfgav1.UsersetTree_Node, error) {
ctx, span := tracer.Start(ctx, "resolveIntersectionUserset")
defer span.End()
nodes, err := q.resolveUsersets(ctx, store, usersets.GetChild(), tk, typesys, consistency)
if err != nil {
return nil, err
}
return &openfgav1.UsersetTree_Node{
Name: toObjectRelation(tk),
Value: &openfgav1.UsersetTree_Node_Intersection{
Intersection: &openfgav1.UsersetTree_Nodes{
Nodes: nodes,
},
},
}, nil
}
// resolveDifferenceUserset creates and intermediate Usertree node containing the difference of its children.
func (q *ExpandQuery) resolveDifferenceUserset(
ctx context.Context,
store string,
userset *openfgav1.Difference,
tk *openfgav1.TupleKey,
typesys *typesystem.TypeSystem,
consistency openfgav1.ConsistencyPreference,
) (*openfgav1.UsersetTree_Node, error) {
ctx, span := tracer.Start(ctx, "resolveDifferenceUserset")
defer span.End()
nodes, err := q.resolveUsersets(ctx, store, []*openfgav1.Userset{userset.GetBase(), userset.GetSubtract()}, tk, typesys, consistency)
if err != nil {
return nil, err
}
base := nodes[0]
subtract := nodes[1]
return &openfgav1.UsersetTree_Node{
Name: toObjectRelation(tk),
Value: &openfgav1.UsersetTree_Node_Difference{
Difference: &openfgav1.UsersetTree_Difference{
Base: base,
Subtract: subtract,
},
},
}, nil
}
// resolveUsersets creates Usertree nodes for multiple Usersets.
func (q *ExpandQuery) resolveUsersets(
ctx context.Context,
store string,
usersets []*openfgav1.Userset,
tk *openfgav1.TupleKey,
typesys *typesystem.TypeSystem,
consistency openfgav1.ConsistencyPreference,
) ([]*openfgav1.UsersetTree_Node, error) {
ctx, span := tracer.Start(ctx, "resolveUsersets")
defer span.End()
out := make([]*openfgav1.UsersetTree_Node, len(usersets))
grp, ctx := errgroup.WithContext(ctx)
for i, us := range usersets {
// https://golang.org/doc/faq#closures_and_goroutines
grp.Go(func() error {
node, err := q.resolveUserset(ctx, store, us, tk, typesys, consistency)
if err != nil {
return err
}
out[i] = node
return nil
})
}
if err := grp.Wait(); err != nil {
return nil, err
}
return out, nil
}
func toObjectRelation(tk *openfgav1.TupleKey) string {
return tupleUtils.ToObjectRelationString(tk.GetObject(), tk.GetRelation())
}
package commands
import (
"context"
"errors"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/logger"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
)
type GetStoreQuery struct {
logger logger.Logger
storesBackend storage.StoresBackend
}
type GetStoreQueryOption func(*GetStoreQuery)
func WithGetStoreQueryLogger(l logger.Logger) GetStoreQueryOption {
return func(q *GetStoreQuery) {
q.logger = l
}
}
func NewGetStoreQuery(storesBackend storage.StoresBackend, opts ...GetStoreQueryOption) *GetStoreQuery {
q := &GetStoreQuery{
storesBackend: storesBackend,
logger: logger.NewNoopLogger(),
}
for _, opt := range opts {
opt(q)
}
return q
}
func (q *GetStoreQuery) Execute(ctx context.Context, req *openfgav1.GetStoreRequest) (*openfgav1.GetStoreResponse, error) {
storeID := req.GetStoreId()
store, err := q.storesBackend.GetStore(ctx, storeID)
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
return nil, serverErrors.ErrStoreIDNotFound
}
return nil, serverErrors.HandleError("", err)
}
return &openfgav1.GetStoreResponse{
Id: store.GetId(),
Name: store.GetName(),
CreatedAt: store.GetCreatedAt(),
UpdatedAt: store.GetUpdatedAt(),
}, nil
}
package commands
import (
"context"
"errors"
"fmt"
"math"
"strings"
"sync/atomic"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"google.golang.org/protobuf/types/known/structpb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/build"
"github.com/openfga/openfga/internal/cachecontroller"
"github.com/openfga/openfga/internal/concurrency"
"github.com/openfga/openfga/internal/condition"
openfgaErrors "github.com/openfga/openfga/internal/errors"
"github.com/openfga/openfga/internal/graph"
"github.com/openfga/openfga/internal/shared"
"github.com/openfga/openfga/internal/throttler"
"github.com/openfga/openfga/internal/throttler/threshold"
"github.com/openfga/openfga/internal/utils/apimethod"
"github.com/openfga/openfga/internal/validation"
"github.com/openfga/openfga/pkg/featureflags"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/server/commands/reverseexpand"
serverconfig "github.com/openfga/openfga/pkg/server/config"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/storagewrappers"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
const streamedBufferSize = 100
var (
furtherEvalRequiredCounter = promauto.NewCounter(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: "list_objects_further_eval_required_count",
Help: "Number of objects in a ListObjects call that needed to issue a Check call to determine a final result",
})
noFurtherEvalRequiredCounter = promauto.NewCounter(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: "list_objects_no_further_eval_required_count",
Help: "Number of objects in a ListObjects call that needed to issue a Check call to determine a final result",
})
)
type ListObjectsQuery struct {
datastore storage.RelationshipTupleReader
ff featureflags.Client
logger logger.Logger
listObjectsDeadline time.Duration
listObjectsMaxResults uint32
resolveNodeLimit uint32
resolveNodeBreadthLimit uint32
maxConcurrentReads uint32
dispatchThrottlerConfig threshold.Config
datastoreThrottlingEnabled bool
datastoreThrottleThreshold int
datastoreThrottleDuration time.Duration
checkResolver graph.CheckResolver
cacheSettings serverconfig.CacheSettings
sharedDatastoreResources *shared.SharedDatastoreResources
optimizationsEnabled bool // Indicates if experimental optimizations are enabled for ListObjectsResolver
useShadowCache bool // Indicates that the shadow cache should be used instead of the main cache
pipelineEnabled bool // Indicates whether to run with the pipeline optimized code
}
type ListObjectsResolver interface {
// Execute the ListObjectsQuery, returning a list of object IDs up to a maximum of q.listObjectsMaxResults
// or until q.listObjectsDeadline is hit, whichever happens first.
Execute(ctx context.Context, req *openfgav1.ListObjectsRequest) (*ListObjectsResponse, error)
// ExecuteStreamed executes the ListObjectsQuery, returning a stream of object IDs.
// It ignores the value of q.listObjectsMaxResults and returns all available results
// until q.listObjectsDeadline is hit.
ExecuteStreamed(ctx context.Context, req *openfgav1.StreamedListObjectsRequest, srv openfgav1.OpenFGAService_StreamedListObjectsServer) (*ListObjectsResolutionMetadata, error)
}
type ListObjectsResolutionMetadata struct {
// The total number of database reads from reverse_expand and Check (if any) to complete the ListObjects request
DatastoreQueryCount atomic.Uint32
// The total number of items read from the database during a ListObjects request.
DatastoreItemCount atomic.Uint64
// The total number of dispatches aggregated from reverse_expand and check resolutions (if any) to complete the ListObjects request
DispatchCounter atomic.Uint32
// WasThrottled indicates whether the request was throttled
WasThrottled atomic.Bool
// WasWeightedGraphUsed indicates whether the weighted graph was used as the algorithm for the ListObjects request.
WasWeightedGraphUsed atomic.Bool
// CheckCounter is the total number of check requests made during the ListObjects execution for the optimized path
CheckCounter atomic.Uint32
}
type ListObjectsResponse struct {
Objects []string
ResolutionMetadata ListObjectsResolutionMetadata
}
type ListObjectsQueryOption func(d *ListObjectsQuery)
func WithListObjectsDeadline(deadline time.Duration) ListObjectsQueryOption {
return func(d *ListObjectsQuery) {
d.listObjectsDeadline = deadline
}
}
func WithDispatchThrottlerConfig(config threshold.Config) ListObjectsQueryOption {
return func(d *ListObjectsQuery) {
d.dispatchThrottlerConfig = config
}
}
func WithListObjectsMaxResults(maxResults uint32) ListObjectsQueryOption {
return func(d *ListObjectsQuery) {
d.listObjectsMaxResults = maxResults
}
}
// WithResolveNodeLimit see server.WithResolveNodeLimit.
func WithResolveNodeLimit(limit uint32) ListObjectsQueryOption {
return func(d *ListObjectsQuery) {
d.resolveNodeLimit = limit
}
}
// WithResolveNodeBreadthLimit see server.WithResolveNodeBreadthLimit.
func WithResolveNodeBreadthLimit(limit uint32) ListObjectsQueryOption {
return func(d *ListObjectsQuery) {
d.resolveNodeBreadthLimit = limit
}
}
func WithLogger(l logger.Logger) ListObjectsQueryOption {
return func(d *ListObjectsQuery) {
d.logger = l
}
}
// WithMaxConcurrentReads see server.WithMaxConcurrentReadsForListObjects.
func WithMaxConcurrentReads(limit uint32) ListObjectsQueryOption {
return func(d *ListObjectsQuery) {
d.maxConcurrentReads = limit
}
}
func WithListObjectsCache(sharedDatastoreResources *shared.SharedDatastoreResources, cacheSettings serverconfig.CacheSettings) ListObjectsQueryOption {
return func(d *ListObjectsQuery) {
d.cacheSettings = cacheSettings
d.sharedDatastoreResources = sharedDatastoreResources
}
}
func WithListObjectsDatastoreThrottler(enabled bool, threshold int, duration time.Duration) ListObjectsQueryOption {
return func(d *ListObjectsQuery) {
d.datastoreThrottlingEnabled = enabled
d.datastoreThrottleThreshold = threshold
d.datastoreThrottleDuration = duration
}
}
func WithFeatureFlagClient(client featureflags.Client) ListObjectsQueryOption {
return func(d *ListObjectsQuery) {
if client != nil {
d.ff = client
return
}
d.ff = featureflags.NewNoopFeatureFlagClient()
}
}
func WithListObjectsUseShadowCache(useShadowCache bool) ListObjectsQueryOption {
return func(d *ListObjectsQuery) {
d.useShadowCache = useShadowCache
}
}
func WithListObjectsPipelineEnabled(value bool) ListObjectsQueryOption {
return func(d *ListObjectsQuery) {
d.pipelineEnabled = value
}
}
func NewListObjectsQuery(
ds storage.RelationshipTupleReader,
checkResolver graph.CheckResolver,
storeID string,
opts ...ListObjectsQueryOption,
) (*ListObjectsQuery, error) {
if ds == nil {
return nil, fmt.Errorf("the provided datastore parameter 'ds' must be non-nil")
}
if checkResolver == nil {
return nil, fmt.Errorf("the provided CheckResolver parameter 'checkResolver' must be non-nil")
}
query := &ListObjectsQuery{
datastore: ds,
logger: logger.NewNoopLogger(),
listObjectsDeadline: serverconfig.DefaultListObjectsDeadline,
listObjectsMaxResults: serverconfig.DefaultListObjectsMaxResults,
resolveNodeLimit: serverconfig.DefaultResolveNodeLimit,
resolveNodeBreadthLimit: serverconfig.DefaultResolveNodeBreadthLimit,
maxConcurrentReads: serverconfig.DefaultMaxConcurrentReadsForListObjects,
dispatchThrottlerConfig: threshold.Config{
Throttler: throttler.NewNoopThrottler(),
Enabled: serverconfig.DefaultListObjectsDispatchThrottlingEnabled,
Threshold: serverconfig.DefaultListObjectsDispatchThrottlingDefaultThreshold,
MaxThreshold: serverconfig.DefaultListObjectsDispatchThrottlingMaxThreshold,
},
checkResolver: checkResolver,
cacheSettings: serverconfig.NewDefaultCacheSettings(),
sharedDatastoreResources: &shared.SharedDatastoreResources{
CacheController: cachecontroller.NewNoopCacheController(),
},
optimizationsEnabled: false,
useShadowCache: false,
ff: featureflags.NewNoopFeatureFlagClient(),
}
for _, opt := range opts {
opt(query)
}
if query.ff.Boolean(serverconfig.ExperimentalListObjectsOptimizations, storeID) {
query.optimizationsEnabled = true
}
return query, nil
}
type ListObjectsResult struct {
ObjectID string
Err error
}
// listObjectsRequest captures the RPC request definition interface for the ListObjects API.
// The unary and streaming RPC definitions implement this interface, and so it can be used
// interchangeably for a canonical representation between the two.
type listObjectsRequest interface {
GetStoreId() string
GetAuthorizationModelId() string
GetType() string
GetRelation() string
GetUser() string
GetContextualTuples() *openfgav1.ContextualTupleKeys
GetContext() *structpb.Struct
GetConsistency() openfgav1.ConsistencyPreference
}
// evaluate fires of evaluation of the ListObjects query by delegating to
// [[reverseexpand.ReverseExpand#Execute]] and resolving the results yielded
// from it. If any results yielded by reverse expansion require further eval,
// then these results get dispatched to Check to resolve the residual outcome.
//
// The resultsChan is **always** closed by evaluate when it is done with its work,
// which is either when all results have been yielded, the deadline has been met,
// or some other terminal error case has occurred.
func (q *ListObjectsQuery) evaluate(
ctx context.Context,
req listObjectsRequest,
resultsChan chan<- ListObjectsResult,
maxResults uint32,
resolutionMetadata *ListObjectsResolutionMetadata,
) error {
targetObjectType := req.GetType()
targetRelation := req.GetRelation()
typesys, ok := typesystem.TypesystemFromContext(ctx)
if !ok {
return fmt.Errorf("%w: typesystem missing in context", openfgaErrors.ErrUnknown)
}
handler := func() {
userObj, userRel := tuple.SplitObjectRelation(req.GetUser())
userObjType, userObjID := tuple.SplitObject(userObj)
var sourceUserRef reverseexpand.IsUserRef
sourceUserRef = &reverseexpand.UserRefObject{
Object: &openfgav1.Object{
Type: userObjType,
Id: userObjID,
},
}
if tuple.IsTypedWildcard(userObj) {
sourceUserRef = &reverseexpand.UserRefTypedWildcard{Type: tuple.GetType(userObj)}
}
if userRel != "" {
sourceUserRef = &reverseexpand.UserRefObjectRelation{
ObjectRelation: &openfgav1.ObjectRelation{
Object: userObj,
Relation: userRel,
},
}
}
var bufferSize uint32
cappedMaxResults := uint32(math.Min(float64(maxResults), 1000)) // cap max results at 1000
bufferSize = uint32(math.Max(float64(cappedMaxResults/10), 10)) // 10% of max results, but make it at least 10
reverseExpandResultsChan := make(chan *reverseexpand.ReverseExpandResult, bufferSize)
objectsFound := atomic.Uint32{}
ds := storagewrappers.NewRequestStorageWrapperWithCache(
q.datastore,
req.GetContextualTuples().GetTupleKeys(),
&storagewrappers.Operation{
Method: apimethod.ListObjects,
Concurrency: q.maxConcurrentReads,
ThrottleThreshold: q.datastoreThrottleThreshold,
ThrottleDuration: q.datastoreThrottleDuration,
},
storagewrappers.DataResourceConfiguration{
Resources: q.sharedDatastoreResources,
CacheSettings: q.cacheSettings,
UseShadowCache: q.useShadowCache,
},
)
reverseExpandQuery := reverseexpand.NewReverseExpandQuery(
ds,
typesys,
reverseexpand.WithResolveNodeLimit(q.resolveNodeLimit),
reverseexpand.WithDispatchThrottlerConfig(q.dispatchThrottlerConfig),
reverseexpand.WithResolveNodeBreadthLimit(q.resolveNodeBreadthLimit),
reverseexpand.WithLogger(q.logger),
reverseexpand.WithCheckResolver(q.checkResolver),
reverseexpand.WithListObjectOptimizationsEnabled(q.optimizationsEnabled),
)
reverseExpandDoneWithError := make(chan struct{}, 1)
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
pool := concurrency.NewPool(cancelCtx, int(1+q.resolveNodeBreadthLimit))
pool.Go(func(ctx context.Context) error {
reverseExpandResolutionMetadata := reverseexpand.NewResolutionMetadata()
err := reverseExpandQuery.Execute(ctx, &reverseexpand.ReverseExpandRequest{
StoreID: req.GetStoreId(),
ObjectType: targetObjectType,
Relation: targetRelation,
User: sourceUserRef,
ContextualTuples: req.GetContextualTuples().GetTupleKeys(),
Context: req.GetContext(),
Consistency: req.GetConsistency(),
}, reverseExpandResultsChan, reverseExpandResolutionMetadata)
if err != nil {
reverseExpandDoneWithError <- struct{}{}
return err
}
resolutionMetadata.DispatchCounter.Add(reverseExpandResolutionMetadata.DispatchCounter.Load())
if !resolutionMetadata.WasThrottled.Load() && reverseExpandResolutionMetadata.WasThrottled.Load() {
resolutionMetadata.WasThrottled.Store(true)
}
resolutionMetadata.CheckCounter.Add(reverseExpandResolutionMetadata.CheckCounter.Load())
resolutionMetadata.WasWeightedGraphUsed.Store(reverseExpandResolutionMetadata.WasWeightedGraphUsed.Load())
return nil
})
ConsumerReadLoop:
for {
select {
case <-reverseExpandDoneWithError:
cancel() // cancel any inflight work if e.g. model too complex
break ConsumerReadLoop
case <-ctx.Done():
cancel() // cancel any inflight work if e.g. deadline exceeded
break ConsumerReadLoop
case res, channelOpen := <-reverseExpandResultsChan:
if !channelOpen {
// don't cancel here. Reverse Expand has finished finding candidate object IDs
// but since we haven't collected "maxResults",
// we need to wait until all the inflight Checks finish in the hopes that
// we collect a few more object IDs.
// if we send a cancellation now, we might miss those.
break ConsumerReadLoop
}
if (maxResults != 0) && objectsFound.Load() >= maxResults {
cancel() // cancel any inflight work if we already found enough results
break ConsumerReadLoop
}
if res.ResultStatus == reverseexpand.NoFurtherEvalStatus {
noFurtherEvalRequiredCounter.Inc()
trySendObject(ctx, res.Object, &objectsFound, maxResults, resultsChan)
continue
}
furtherEvalRequiredCounter.Inc()
pool.Go(func(ctx context.Context) error {
resp, checkRequestMetadata, err := NewCheckCommand(q.datastore, q.checkResolver, typesys,
WithCheckCommandLogger(q.logger),
WithCheckCommandMaxConcurrentReads(q.maxConcurrentReads),
WithCheckDatastoreThrottler(
q.datastoreThrottlingEnabled,
q.datastoreThrottleThreshold,
q.datastoreThrottleDuration,
),
).
Execute(ctx, &CheckCommandParams{
StoreID: req.GetStoreId(),
TupleKey: tuple.NewCheckRequestTupleKey(res.Object, req.GetRelation(), req.GetUser()),
ContextualTuples: req.GetContextualTuples(),
Context: req.GetContext(),
Consistency: req.GetConsistency(),
})
if err != nil {
return err
}
resolutionMetadata.DatastoreQueryCount.Add(resp.GetResolutionMetadata().DatastoreQueryCount)
resolutionMetadata.DatastoreItemCount.Add(resp.GetResolutionMetadata().DatastoreItemCount)
resolutionMetadata.DispatchCounter.Add(checkRequestMetadata.DispatchCounter.Load())
if !resolutionMetadata.WasThrottled.Load() && checkRequestMetadata.WasThrottled.Load() {
resolutionMetadata.WasThrottled.Store(true)
}
if resp.Allowed {
trySendObject(ctx, res.Object, &objectsFound, maxResults, resultsChan)
}
return nil
})
}
}
err := pool.Wait()
if err != nil {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
resultsChan <- ListObjectsResult{Err: err}
}
// TODO set header to indicate "deadline exceeded"
}
close(resultsChan)
dsMeta := ds.GetMetadata()
resolutionMetadata.DatastoreQueryCount.Add(dsMeta.DatastoreQueryCount)
resolutionMetadata.DatastoreItemCount.Add(dsMeta.DatastoreItemCount)
resolutionMetadata.WasThrottled.CompareAndSwap(false, dsMeta.WasThrottled)
}
go handler()
return nil
}
func trySendObject(ctx context.Context, object string, objectsFound *atomic.Uint32, maxResults uint32, resultsChan chan<- ListObjectsResult) {
if maxResults != 0 {
if objectsFound.Add(1) > maxResults {
return
}
}
concurrency.TrySendThroughChannel(ctx, ListObjectsResult{ObjectID: object}, resultsChan)
}
// Execute the ListObjectsQuery, returning a list of object IDs up to a maximum of q.listObjectsMaxResults
// or until q.listObjectsDeadline is hit, whichever happens first.
func (q *ListObjectsQuery) Execute(
ctx context.Context,
req *openfgav1.ListObjectsRequest,
) (*ListObjectsResponse, error) {
maxResults := q.listObjectsMaxResults
timeoutCtx := ctx
if q.listObjectsDeadline != 0 {
var cancel context.CancelFunc
timeoutCtx, cancel = context.WithTimeout(ctx, q.listObjectsDeadline)
defer cancel()
}
targetObjectType := req.GetType()
targetRelation := req.GetRelation()
typesys, ok := typesystem.TypesystemFromContext(ctx)
if !ok {
return nil, fmt.Errorf("%w: typesystem missing in context", openfgaErrors.ErrUnknown)
}
if !typesystem.IsSchemaVersionSupported(typesys.GetSchemaVersion()) {
return nil, serverErrors.ValidationError(typesystem.ErrInvalidSchemaVersion)
}
for _, ctxTuple := range req.GetContextualTuples().GetTupleKeys() {
if err := validation.ValidateTupleForWrite(typesys, ctxTuple); err != nil {
return nil, serverErrors.HandleTupleValidateError(err)
}
}
_, err := typesys.GetRelation(targetObjectType, targetRelation)
if err != nil {
if errors.Is(err, typesystem.ErrObjectTypeUndefined) {
return nil, serverErrors.TypeNotFound(targetObjectType)
}
if errors.Is(err, typesystem.ErrRelationUndefined) {
return nil, serverErrors.RelationNotFound(targetRelation, targetObjectType, nil)
}
return nil, serverErrors.HandleError("", err)
}
if err := validation.ValidateUser(typesys, req.GetUser()); err != nil {
return nil, serverErrors.ValidationError(fmt.Errorf("invalid 'user' value: %s", err))
}
if req.GetConsistency() != openfgav1.ConsistencyPreference_HIGHER_CONSISTENCY {
if q.cacheSettings.ShouldCacheListObjectsIterators() {
// Kick off background job to check if cache records are stale, invalidating where needed
q.sharedDatastoreResources.CacheController.InvalidateIfNeeded(ctx, req.GetStoreId())
}
if q.cacheSettings.ShouldShadowCacheListObjectsIterators() {
q.sharedDatastoreResources.ShadowCacheController.InvalidateIfNeeded(ctx, req.GetStoreId())
}
}
wgraph := typesys.GetWeightedGraph()
if wgraph != nil && q.pipelineEnabled {
ds := storagewrappers.NewRequestStorageWrapperWithCache(
q.datastore,
req.GetContextualTuples().GetTupleKeys(),
&storagewrappers.Operation{
Method: apimethod.ListObjects,
Concurrency: q.maxConcurrentReads,
ThrottlingEnabled: q.datastoreThrottlingEnabled,
ThrottleThreshold: q.datastoreThrottleThreshold,
ThrottleDuration: q.datastoreThrottleDuration,
},
storagewrappers.DataResourceConfiguration{
Resources: q.sharedDatastoreResources,
CacheSettings: q.cacheSettings,
UseShadowCache: q.useShadowCache,
},
)
backend := &reverseexpand.Backend{
Datastore: ds,
StoreID: req.GetStoreId(),
TypeSystem: typesys,
Context: req.GetContext(),
Graph: wgraph,
Preference: req.GetConsistency(),
}
pipeline := reverseexpand.NewPipeline(backend)
var source reverseexpand.Source
var target reverseexpand.Target
if source, ok = pipeline.Source(targetObjectType, targetRelation); !ok {
return nil, serverErrors.ValidationError(fmt.Errorf("object: %s relation: %s not in graph", targetObjectType, targetRelation))
}
userParts := strings.Split(req.GetUser(), "#")
objectParts := strings.Split(userParts[0], ":")
objectType := objectParts[0]
objectID := objectParts[1]
if len(userParts) > 1 {
objectType += "#" + userParts[1]
}
if target, ok = pipeline.Target(objectType, objectID); !ok {
return nil, serverErrors.ValidationError(fmt.Errorf("user: %s relation: %s not in graph", objectType, objectID))
}
seq := pipeline.Build(ctx, source, target)
var res ListObjectsResponse
for obj := range seq {
if timeoutCtx.Err() != nil {
break
}
if obj.Err != nil {
return nil, serverErrors.HandleError("", obj.Err)
}
res.Objects = append(res.Objects, obj.Value)
// Check if we've reached the max results limit
if maxResults > 0 && uint32(len(res.Objects)) >= maxResults {
break
}
}
dsMeta := ds.GetMetadata()
res.ResolutionMetadata.DatastoreQueryCount.Add(dsMeta.DatastoreQueryCount)
res.ResolutionMetadata.DatastoreItemCount.Add(dsMeta.DatastoreItemCount)
return &res, nil
}
// --------- OLD STUFF -----------
resultsChan := make(chan ListObjectsResult, 1)
if maxResults > 0 {
resultsChan = make(chan ListObjectsResult, maxResults)
}
var listObjectsResponse ListObjectsResponse
err = q.evaluate(timeoutCtx, req, resultsChan, maxResults, &listObjectsResponse.ResolutionMetadata)
if err != nil {
return nil, err
}
listObjectsResponse.Objects = make([]string, 0, maxResults)
var errs error
for result := range resultsChan {
if result.Err != nil {
if errors.Is(result.Err, graph.ErrResolutionDepthExceeded) {
return nil, serverErrors.ErrAuthorizationModelResolutionTooComplex
}
if errors.Is(result.Err, condition.ErrEvaluationFailed) {
errs = errors.Join(errs, result.Err)
continue
}
return nil, serverErrors.HandleError("", result.Err)
}
listObjectsResponse.Objects = append(listObjectsResponse.Objects, result.ObjectID)
}
if len(listObjectsResponse.Objects) < int(maxResults) && errs != nil {
return nil, errs
}
return &listObjectsResponse, nil
}
// ExecuteStreamed executes the ListObjectsQuery, returning a stream of object IDs.
// It ignores the value of q.listObjectsMaxResults and returns all available results
// until q.listObjectsDeadline is hit.
func (q *ListObjectsQuery) ExecuteStreamed(ctx context.Context, req *openfgav1.StreamedListObjectsRequest, srv openfgav1.OpenFGAService_StreamedListObjectsServer) (*ListObjectsResolutionMetadata, error) {
maxResults := uint32(math.MaxUint32)
timeoutCtx := ctx
if q.listObjectsDeadline != 0 {
var cancel context.CancelFunc
timeoutCtx, cancel = context.WithTimeout(ctx, q.listObjectsDeadline)
defer cancel()
}
var resolutionMetadata ListObjectsResolutionMetadata
targetObjectType := req.GetType()
targetRelation := req.GetRelation()
typesys, ok := typesystem.TypesystemFromContext(ctx)
if !ok {
return nil, fmt.Errorf("%w: typesystem missing in context", openfgaErrors.ErrUnknown)
}
if !typesystem.IsSchemaVersionSupported(typesys.GetSchemaVersion()) {
return nil, serverErrors.ValidationError(typesystem.ErrInvalidSchemaVersion)
}
for _, ctxTuple := range req.GetContextualTuples().GetTupleKeys() {
if err := validation.ValidateTupleForWrite(typesys, ctxTuple); err != nil {
return nil, serverErrors.HandleTupleValidateError(err)
}
}
_, err := typesys.GetRelation(targetObjectType, targetRelation)
if err != nil {
if errors.Is(err, typesystem.ErrObjectTypeUndefined) {
return nil, serverErrors.TypeNotFound(targetObjectType)
}
if errors.Is(err, typesystem.ErrRelationUndefined) {
return nil, serverErrors.RelationNotFound(targetRelation, targetObjectType, nil)
}
return nil, serverErrors.HandleError("", err)
}
if err := validation.ValidateUser(typesys, req.GetUser()); err != nil {
return nil, serverErrors.ValidationError(fmt.Errorf("invalid 'user' value: %s", err))
}
wgraph := typesys.GetWeightedGraph()
if wgraph != nil && q.pipelineEnabled {
ds := storagewrappers.NewRequestStorageWrapperWithCache(
q.datastore,
req.GetContextualTuples().GetTupleKeys(),
&storagewrappers.Operation{
Method: apimethod.ListObjects,
Concurrency: q.maxConcurrentReads,
ThrottlingEnabled: q.datastoreThrottlingEnabled,
ThrottleThreshold: q.datastoreThrottleThreshold,
ThrottleDuration: q.datastoreThrottleDuration,
},
storagewrappers.DataResourceConfiguration{
Resources: q.sharedDatastoreResources,
CacheSettings: q.cacheSettings,
UseShadowCache: q.useShadowCache,
},
)
backend := &reverseexpand.Backend{
Datastore: ds,
StoreID: req.GetStoreId(),
TypeSystem: typesys,
Context: req.GetContext(),
Graph: wgraph,
Preference: req.GetConsistency(),
}
pipeline := reverseexpand.NewPipeline(backend)
var source reverseexpand.Source
var target reverseexpand.Target
if source, ok = pipeline.Source(targetObjectType, targetRelation); !ok {
return nil, serverErrors.ValidationError(fmt.Errorf("object: %s relation: %s not in graph", targetObjectType, targetRelation))
}
userParts := strings.Split(req.GetUser(), "#")
objectParts := strings.Split(userParts[0], ":")
objectType := objectParts[0]
objectID := objectParts[1]
if len(userParts) > 1 {
objectType += "#" + userParts[1]
}
if target, ok = pipeline.Target(objectType, objectID); !ok {
return nil, serverErrors.ValidationError(fmt.Errorf("user: %s relation: %s not in graph", objectType, objectID))
}
seq := pipeline.Build(ctx, source, target)
var listObjectsCount uint32 = 0
for obj := range seq {
if timeoutCtx.Err() != nil {
break
}
if obj.Err != nil {
if errors.Is(obj.Err, condition.ErrEvaluationFailed) {
return nil, serverErrors.ValidationError(obj.Err)
}
return nil, serverErrors.HandleError("", obj.Err)
}
if err := srv.Send(&openfgav1.StreamedListObjectsResponse{
Object: obj.Value,
}); err != nil {
return nil, serverErrors.HandleError("", err)
}
listObjectsCount++
// Check if we've reached the max results limit
if maxResults > 0 && listObjectsCount >= maxResults {
break
}
}
dsMeta := ds.GetMetadata()
resolutionMetadata.DatastoreQueryCount.Add(dsMeta.DatastoreQueryCount)
resolutionMetadata.DatastoreItemCount.Add(dsMeta.DatastoreItemCount)
return &resolutionMetadata, nil
}
// make a buffered channel so that writer goroutines aren't blocked when attempting to send a result
resultsChan := make(chan ListObjectsResult, streamedBufferSize)
err = q.evaluate(timeoutCtx, req, resultsChan, maxResults, &resolutionMetadata)
if err != nil {
return nil, err
}
for result := range resultsChan {
if result.Err != nil {
if errors.Is(result.Err, graph.ErrResolutionDepthExceeded) {
return nil, serverErrors.ErrAuthorizationModelResolutionTooComplex
}
if errors.Is(result.Err, condition.ErrEvaluationFailed) {
return nil, serverErrors.ValidationError(result.Err)
}
return nil, serverErrors.HandleError("", result.Err)
}
if err := srv.Send(&openfgav1.StreamedListObjectsResponse{
Object: result.ObjectID,
}); err != nil {
return nil, serverErrors.HandleError("", err)
}
}
return &resolutionMetadata, nil
}
package commands
import (
"context"
"errors"
"maps"
"slices"
"sync"
"time"
"go.uber.org/zap"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/graph"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/typesystem"
)
const ListObjectsShadowExecute = "ShadowedListObjectsQuery.Execute"
type shadowedListObjectsQuery struct {
main ListObjectsResolver
shadow ListObjectsResolver
shadowTimeout time.Duration // A time.Duration specifying the maximum amount of time to wait for the shadow list_objects query to complete. If the shadow query exceeds this shadowTimeout, it will be cancelled, and its result will be ignored, but the shadowTimeout event will be logged.
maxDeltaItems int // The maximum number of items to log in the delta between the main and shadow results. This prevents excessive logging in case of large differences.
logger logger.Logger
// only used for testing signals
wg *sync.WaitGroup
}
type ShadowListObjectsQueryOption func(d *ShadowListObjectsQueryConfig)
// WithShadowListObjectsQueryEnabled sets whether the shadow list_objects query should use optimizations.
func WithShadowListObjectsQueryEnabled(enabled bool) ShadowListObjectsQueryOption {
return func(c *ShadowListObjectsQueryConfig) {
c.shadowEnabled = enabled
}
}
// WithShadowListObjectsQueryTimeout sets the shadowTimeout for the shadow list_objects query.
func WithShadowListObjectsQueryTimeout(timeout time.Duration) ShadowListObjectsQueryOption {
return func(c *ShadowListObjectsQueryConfig) {
c.shadowTimeout = timeout
}
}
func WithShadowListObjectsQueryLogger(logger logger.Logger) ShadowListObjectsQueryOption {
return func(c *ShadowListObjectsQueryConfig) {
c.logger = logger
}
}
func WithShadowListObjectsQueryMaxDeltaItems(maxDeltaItems int) ShadowListObjectsQueryOption {
return func(c *ShadowListObjectsQueryConfig) {
c.maxDeltaItems = maxDeltaItems
}
}
type ShadowListObjectsQueryConfig struct {
shadowEnabled bool // A boolean flag to globally enable or disable the shadow mode for list_objects queries. When false, the shadow query will not be executed.
shadowTimeout time.Duration // A time.Duration specifying the maximum amount of time to wait for the shadow list_objects query to complete. If the shadow query exceeds this shadowTimeout, it will be cancelled, and its result will be ignored, but the shadowTimeout event will be logged.
maxDeltaItems int // The maximum number of items to log in the delta between the main and shadow results. This prevents excessive logging in case of large differences.
logger logger.Logger
}
func NewShadowListObjectsQueryConfig(opts ...ShadowListObjectsQueryOption) *ShadowListObjectsQueryConfig {
result := &ShadowListObjectsQueryConfig{
shadowEnabled: false, // Disabled by default
shadowTimeout: 1 * time.Second, // Default shadowTimeout for shadow queries
logger: logger.NewNoopLogger(), // Default to a noop logger
maxDeltaItems: 100, // Default max delta items to log
}
for _, opt := range opts {
opt(result)
}
return result
}
// NewListObjectsQueryWithShadowConfig creates a new ListObjectsResolver that can run in shadow mode based on the provided ShadowListObjectsQueryConfig.
func NewListObjectsQueryWithShadowConfig(
ds storage.RelationshipTupleReader,
checkResolver graph.CheckResolver,
shadowConfig *ShadowListObjectsQueryConfig,
storeID string,
opts ...ListObjectsQueryOption,
) (ListObjectsResolver, error) {
if shadowConfig != nil && shadowConfig.shadowEnabled {
return newShadowedListObjectsQuery(ds, checkResolver, shadowConfig, storeID, opts...)
}
return NewListObjectsQuery(ds, checkResolver, storeID, opts...)
}
// newShadowedListObjectsQuery creates a new ListObjectsResolver that runs two queries in parallel: one with the pipeline enabled and one without.
func newShadowedListObjectsQuery(
ds storage.RelationshipTupleReader,
checkResolver graph.CheckResolver,
shadowConfig *ShadowListObjectsQueryConfig,
storeID string,
opts ...ListObjectsQueryOption,
) (ListObjectsResolver, error) {
if shadowConfig == nil {
return nil, errors.New("shadowConfig must be set")
}
standard, err := NewListObjectsQuery(ds, checkResolver, storeID,
// force disable pipeline
slices.Concat(opts, []ListObjectsQueryOption{WithListObjectsPipelineEnabled(false)})...,
)
if err != nil {
return nil, err
}
optimized, err := NewListObjectsQuery(ds, checkResolver, storeID,
// enable pipeline
slices.Concat(opts, []ListObjectsQueryOption{WithListObjectsPipelineEnabled(true), WithListObjectsUseShadowCache(true)})...,
)
if err != nil {
return nil, err
}
result := &shadowedListObjectsQuery{
main: standard,
shadow: optimized,
shadowTimeout: shadowConfig.shadowTimeout,
logger: shadowConfig.logger,
maxDeltaItems: shadowConfig.maxDeltaItems,
wg: &sync.WaitGroup{}, // only used for testing signals
}
return result, nil
}
func (q *shadowedListObjectsQuery) Execute(
ctx context.Context,
req *openfgav1.ListObjectsRequest,
) (*ListObjectsResponse, error) {
cloneCtx := context.WithoutCancel(ctx) // needs typesystem and datastore etc
startTime := time.Now()
res, err := q.main.Execute(ctx, req)
if err != nil {
return nil, err
}
latency := time.Since(startTime)
// If shadow mode is not shadowEnabled, just execute the main query
if q.checkShadowModePreconditions(cloneCtx, req) {
q.wg.Add(1) // only used for testing signals
go func() {
startTime = time.Now()
defer func() {
defer q.wg.Done() // only used for testing signals
if r := recover(); r != nil {
q.logger.ErrorWithContext(cloneCtx, "panic recovered",
loShadowLogFields(req,
zap.Duration("main_latency", latency),
zap.Duration("shadow_latency", time.Since(startTime)),
zap.Int("main_result_count", len(res.Objects)),
zap.Any("error", r),
)...,
)
}
}()
q.executeShadowModeAndCompareResults(cloneCtx, req, res, latency)
}()
}
return res, err
}
func (q *shadowedListObjectsQuery) ExecuteStreamed(ctx context.Context, req *openfgav1.StreamedListObjectsRequest, srv openfgav1.OpenFGAService_StreamedListObjectsServer) (*ListObjectsResolutionMetadata, error) {
return q.main.ExecuteStreamed(ctx, req, srv)
}
// executeShadowMode executes the main and shadow functions in parallel, returning the result of the main function if shadow mode is not shadowEnabled or if the shadow function fails.
// It compares the results of the main and shadow functions, logging any differences.
// If the shadow function takes longer than shadowTimeout, it will be cancelled, and its result will be ignored, but the shadowTimeout event will be logged.
// This function is designed to be run in a separate goroutine to avoid blocking the main execution flow.
func (q *shadowedListObjectsQuery) executeShadowModeAndCompareResults(parentCtx context.Context, req *openfgav1.ListObjectsRequest, mainResult *ListObjectsResponse, latency time.Duration) {
parentCtx, span := tracer.Start(parentCtx, "shadow")
defer span.End()
shadowCtx, shadowCancel := context.WithTimeout(parentCtx, q.shadowTimeout)
defer shadowCancel()
startTime := time.Now()
shadowRes, errShadow := q.shadow.Execute(shadowCtx, req)
shadowLatency := time.Since(startTime)
var mainQueryCount uint32
var mainItemCount uint64
var mainResultObjects []string
if mainResult != nil {
mainQueryCount = mainResult.ResolutionMetadata.DatastoreQueryCount.Load()
mainItemCount = mainResult.ResolutionMetadata.DatastoreItemCount.Load()
mainResultObjects = mainResult.Objects
}
if errShadow != nil {
q.logger.WarnWithContext(parentCtx, "shadowed list objects error",
loShadowLogFields(req,
zap.Duration("main_latency", latency),
zap.Duration("shadow_latency", shadowLatency),
zap.Int("main_result_count", len(mainResultObjects)),
zap.Any("error", errShadow),
)...,
)
return
}
var resultShadowed []string
var shadowQueryCount uint32
var shadowItemCount uint64
if shadowRes != nil {
resultShadowed = shadowRes.Objects
shadowQueryCount = shadowRes.ResolutionMetadata.DatastoreQueryCount.Load()
shadowItemCount = shadowRes.ResolutionMetadata.DatastoreItemCount.Load()
}
mapResultMain := keyMapFromSlice(mainResultObjects)
mapResultShadow := keyMapFromSlice(resultShadowed)
fields := []zap.Field{
zap.Duration("main_latency", latency),
zap.Duration("shadow_latency", shadowLatency),
zap.Int("main_result_count", len(mainResultObjects)),
zap.Int("shadow_result_count", len(resultShadowed)),
zap.Uint32("main_datastore_query_count", mainQueryCount),
zap.Uint32("shadow_datastore_query_count", shadowQueryCount),
zap.Uint64("main_datastore_item_count", mainItemCount),
zap.Uint64("shadow_datastore_item_count", shadowItemCount),
}
// compare sorted string arrays - sufficient for equality check
if !maps.Equal(mapResultMain, mapResultShadow) {
delta := calculateDelta(mapResultMain, mapResultShadow)
totalDelta := len(delta)
// Limit the delta to maxDeltaItems
if totalDelta > q.maxDeltaItems {
delta = delta[:q.maxDeltaItems]
}
fields = append(
fields,
zap.Bool("is_match", false),
zap.Int("total_delta", totalDelta),
zap.Any("delta", delta),
)
// log the differences if the shadow query failed or if the results are not equal
q.logger.WarnWithContext(parentCtx, "shadowed list objects result difference",
loShadowLogFields(req, fields...)...,
)
} else {
fields = append(
fields,
zap.Bool("is_match", true),
)
q.logger.InfoWithContext(parentCtx, "shadowed list objects result matches",
loShadowLogFields(req, fields...)...,
)
}
}
// checkShadowModePreconditions checks if the shadow mode preconditions are met:
// - If the weighted graph does not exist, skip the shadow query.
func (q *shadowedListObjectsQuery) checkShadowModePreconditions(ctx context.Context, req *openfgav1.ListObjectsRequest) bool {
typesys, ok := typesystem.TypesystemFromContext(ctx)
if !ok {
return false
}
if typesys.GetWeightedGraph() == nil {
q.logger.InfoWithContext(ctx, "shadowed list objects query skipped due to missing weighted graph",
loShadowLogFields(req)...,
)
return false
}
return true
}
func loShadowLogFields(req *openfgav1.ListObjectsRequest, fields ...zap.Field) []zap.Field {
return append([]zap.Field{
zap.String("func", ListObjectsShadowExecute),
zap.Any("request", req),
zap.String("store_id", req.GetStoreId()),
zap.String("model_id", req.GetAuthorizationModelId()),
}, fields...)
}
// keyMapFromSlice creates a map from a slice of strings, where each string is a key in the map.
func keyMapFromSlice(slice []string) map[string]struct{} {
result := make(map[string]struct{}, len(slice))
for _, item := range slice {
result[item] = struct{}{}
}
return result
}
// calculateDelta calculates the delta between two maps of string keys.
func calculateDelta(mapResultMain map[string]struct{}, mapResultShadow map[string]struct{}) []string {
delta := make([]string, 0, len(mapResultMain)+len(mapResultShadow))
// Find objects in shadow but not in main
for key := range mapResultMain {
if _, exists := mapResultShadow[key]; !exists {
delta = append(delta, "-"+key) // object in main but not in shadow
}
}
for key := range mapResultShadow {
if _, exists := mapResultMain[key]; !exists {
delta = append(delta, "+"+key) // object in shadow but not in main
}
}
// Sort the delta for consistent result
slices.Sort(delta)
return delta
}
package commands
import (
"context"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/encoder"
"github.com/openfga/openfga/pkg/logger"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
)
type ListStoresQuery struct {
storesBackend storage.StoresBackend
logger logger.Logger
encoder encoder.Encoder
}
type ListStoresQueryOption func(*ListStoresQuery)
func WithListStoresQueryLogger(l logger.Logger) ListStoresQueryOption {
return func(q *ListStoresQuery) {
q.logger = l
}
}
func WithListStoresQueryEncoder(e encoder.Encoder) ListStoresQueryOption {
return func(q *ListStoresQuery) {
q.encoder = e
}
}
func NewListStoresQuery(storesBackend storage.StoresBackend, opts ...ListStoresQueryOption) *ListStoresQuery {
q := &ListStoresQuery{
storesBackend: storesBackend,
logger: logger.NewNoopLogger(),
encoder: encoder.NewBase64Encoder(),
}
for _, opt := range opts {
opt(q)
}
return q
}
func (q *ListStoresQuery) Execute(ctx context.Context, req *openfgav1.ListStoresRequest, storeIDs []string) (*openfgav1.ListStoresResponse, error) {
decodedContToken, err := q.encoder.Decode(req.GetContinuationToken())
if err != nil {
return nil, serverErrors.ErrInvalidContinuationToken
}
opts := storage.ListStoresOptions{
IDs: storeIDs,
Name: req.GetName(),
Pagination: storage.NewPaginationOptions(req.GetPageSize().GetValue(), string(decodedContToken)),
}
stores, continuationToken, err := q.storesBackend.ListStores(ctx, opts)
if err != nil {
return nil, serverErrors.HandleError("", err)
}
encodedToken, err := q.encoder.Encode([]byte(continuationToken))
if err != nil {
return nil, serverErrors.HandleError("", err)
}
resp := &openfgav1.ListStoresResponse{
Stores: stores,
ContinuationToken: encodedToken,
}
return resp, nil
}
package listusers
import (
"maps"
"sync/atomic"
"google.golang.org/protobuf/types/known/structpb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
)
type listUsersRequest interface {
GetStoreId() string
GetAuthorizationModelId() string
GetObject() *openfgav1.Object
GetRelation() string
GetUserFilters() []*openfgav1.UserTypeFilter
GetContextualTuples() []*openfgav1.TupleKey
GetContext() *structpb.Struct
GetConsistency() openfgav1.ConsistencyPreference
}
type internalListUsersRequest struct {
*openfgav1.ListUsersRequest
// visitedUsersetsMap keeps track of the "path" we've made so far.
// It prevents stack overflows by preventing visiting the same userset twice.
visitedUsersetsMap map[string]struct{}
// depth is the current depths of the traversal expressed as a positive, incrementing integer.
// When expansion of list users recursively traverses one level, we increment by one. If this
// counter hits the limit, we throw ErrResolutionDepthExceeded. This protects against a potentially deep
// or endless cycle of recursion.
depth uint32
dispatchCount *atomic.Uint32
}
var _ listUsersRequest = (*internalListUsersRequest)(nil)
// nolint // it should be GetStoreID, but we want to satisfy the interface listUsersRequest
func (r *internalListUsersRequest) GetStoreId() string {
if r == nil {
return ""
}
return r.StoreId
}
// nolint // it should be GetAuthorizationModelID, but we want to satisfy the interface listUsersRequest
func (r *internalListUsersRequest) GetAuthorizationModelId() string {
if r == nil {
return ""
}
return r.AuthorizationModelId
}
func (r *internalListUsersRequest) GetObject() *openfgav1.Object {
if r == nil {
return nil
}
return r.Object
}
func (r *internalListUsersRequest) GetRelation() string {
if r == nil {
return ""
}
return r.Relation
}
func (r *internalListUsersRequest) GetUserFilters() []*openfgav1.UserTypeFilter {
if r == nil {
return nil
}
return r.UserFilters
}
func (r *internalListUsersRequest) GetContextualTuples() []*openfgav1.TupleKey {
if r == nil {
return nil
}
return r.ContextualTuples
}
func (r *internalListUsersRequest) GetDispatchCount() uint32 {
if r == nil {
return uint32(0)
}
return r.dispatchCount.Load()
}
func (r *internalListUsersRequest) GetContext() *structpb.Struct {
if r == nil {
return nil
}
return r.Context
}
type listUsersResponse struct {
Users []*openfgav1.User
Metadata listUsersResponseMetadata
}
type listUsersResponseMetadata struct {
DatastoreQueryCount uint32
DatastoreItemCount uint64
// The number of times we are recursively expanding to find users.
// Atomic is used to be consistent with the Check and ListObjects.
DispatchCounter *atomic.Uint32
// WasThrottled indicates whether the request was throttled
WasThrottled *atomic.Bool
}
func (r *listUsersResponse) GetUsers() []*openfgav1.User {
if r == nil {
return []*openfgav1.User{}
}
return r.Users
}
func (r *listUsersResponse) GetMetadata() listUsersResponseMetadata {
if r == nil {
return listUsersResponseMetadata{}
}
return r.Metadata
}
func fromListUsersRequest(o listUsersRequest, dispatchCount *atomic.Uint32) *internalListUsersRequest {
if dispatchCount == nil {
dispatchCount = new(atomic.Uint32)
}
return &internalListUsersRequest{
ListUsersRequest: &openfgav1.ListUsersRequest{
StoreId: o.GetStoreId(),
AuthorizationModelId: o.GetAuthorizationModelId(),
Object: o.GetObject(),
Relation: o.GetRelation(),
UserFilters: o.GetUserFilters(),
ContextualTuples: o.GetContextualTuples(),
Context: o.GetContext(),
Consistency: o.GetConsistency(),
},
visitedUsersetsMap: make(map[string]struct{}),
depth: 0,
dispatchCount: dispatchCount,
}
}
// clone creates a copy of the request. Note that some fields are not deep-cloned.
func (r *internalListUsersRequest) clone() *internalListUsersRequest {
v := fromListUsersRequest(r, r.dispatchCount)
v.visitedUsersetsMap = maps.Clone(r.visitedUsersetsMap)
v.depth = r.depth
return v
}
package listusers
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/sourcegraph/conc/panics"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/concurrency"
"github.com/openfga/openfga/internal/condition"
"github.com/openfga/openfga/internal/condition/eval"
openfgaErrors "github.com/openfga/openfga/internal/errors"
"github.com/openfga/openfga/internal/graph"
"github.com/openfga/openfga/internal/throttler/threshold"
"github.com/openfga/openfga/internal/utils/apimethod"
"github.com/openfga/openfga/internal/validation"
"github.com/openfga/openfga/pkg/logger"
serverconfig "github.com/openfga/openfga/pkg/server/config"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/storagewrappers"
"github.com/openfga/openfga/pkg/telemetry"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
var (
tracer = otel.Tracer("openfga/pkg/server/commands/list_users")
ErrPanic = errors.New("panic captured")
)
type listUsersQuery struct {
logger logger.Logger
datastore *storagewrappers.RequestStorageWrapper
resolveNodeBreadthLimit uint32
resolveNodeLimit uint32
maxResults uint32
maxConcurrentReads uint32
deadline time.Duration
dispatchThrottlerConfig threshold.Config
wasThrottled *atomic.Bool
expandDirectDispatch expandDirectDispatchHandler
datastoreThrottleThreshold int
datastoreThrottleDuration time.Duration
}
type expandResponse struct {
hasCycle bool
err error
}
// userRelationshipStatus represents the status of a relationship that a given user/subject has with respect to a specific relation.
//
// A user/subject either does or does not have a relationship, which represents that
// they either explicitly do have a relationship or explicitly do not.
type userRelationshipStatus int
type expandDirectDispatchHandler func(ctx context.Context, listUsersQuery *listUsersQuery, req *internalListUsersRequest, userObjectType string, userObjectID string, userRelation string, resp expandResponse, foundUsersChan chan<- foundUser, hasCycle *atomic.Bool) expandResponse
const (
HasRelationship userRelationshipStatus = iota
NoRelationship
)
type foundUser struct {
user *openfgav1.User
excludedUsers []*openfgav1.User
// relationshipStatus indicates whether the user explicitly does or does not have
// a specific relationship with respect to the relation being expanded. It almost
// exclusively applies to behavior stemming from exclusion rewrite rules.
//
// As users/subjects are being expanded we propagate the relationship status with
// respect to the relation being evaluated so that we can handle subjects which
// have been explicitly excluded from a relationship and where that relation is
// contained under the subtracted branch of another exclusion. This allows us to
// buble up the subject from the subtracted branch of the exclusion.
relationshipStatus userRelationshipStatus
}
type ListUsersQueryOption func(l *listUsersQuery)
func WithListUsersQueryLogger(l logger.Logger) ListUsersQueryOption {
return func(d *listUsersQuery) {
d.logger = l
}
}
// WithListUsersMaxResults see server.WithListUsersMaxResults.
func WithListUsersMaxResults(maxResults uint32) ListUsersQueryOption {
return func(d *listUsersQuery) {
d.maxResults = maxResults
}
}
// WithListUsersDeadline see server.WithListUsersDeadline.
func WithListUsersDeadline(t time.Duration) ListUsersQueryOption {
return func(d *listUsersQuery) {
d.deadline = t
}
}
// WithResolveNodeLimit see server.WithResolveNodeLimit.
func WithResolveNodeLimit(limit uint32) ListUsersQueryOption {
return func(d *listUsersQuery) {
d.resolveNodeLimit = limit
}
}
// WithResolveNodeBreadthLimit see server.WithResolveNodeBreadthLimit.
func WithResolveNodeBreadthLimit(limit uint32) ListUsersQueryOption {
return func(d *listUsersQuery) {
d.resolveNodeBreadthLimit = limit
}
}
// WithListUsersMaxConcurrentReads see server.WithMaxConcurrentReadsForListUsers.
func WithListUsersMaxConcurrentReads(limit uint32) ListUsersQueryOption {
return func(d *listUsersQuery) {
d.maxConcurrentReads = limit
}
}
func WithListUsersDatastoreThrottler(threshold int, duration time.Duration) ListUsersQueryOption {
return func(d *listUsersQuery) {
d.datastoreThrottleThreshold = threshold
d.datastoreThrottleDuration = duration
}
}
func (l *listUsersQuery) throttle(ctx context.Context, currentNumDispatch uint32) {
span := trace.SpanFromContext(ctx)
shouldThrottle := threshold.ShouldThrottle(
ctx,
currentNumDispatch,
l.dispatchThrottlerConfig.Threshold,
l.dispatchThrottlerConfig.MaxThreshold,
)
span.SetAttributes(
attribute.Int("dispatch_count", int(currentNumDispatch)),
attribute.Bool("is_throttled", shouldThrottle))
if shouldThrottle {
l.wasThrottled.Store(true)
l.dispatchThrottlerConfig.Throttler.Throttle(ctx)
}
}
func WithDispatchThrottlerConfig(config threshold.Config) ListUsersQueryOption {
return func(d *listUsersQuery) {
d.dispatchThrottlerConfig = config
}
}
// TODO accept ListUsersRequest instead of contextualTuples.
func NewListUsersQuery(ds storage.RelationshipTupleReader, contextualTuples []*openfgav1.TupleKey, opts ...ListUsersQueryOption) *listUsersQuery {
l := &listUsersQuery{
logger: logger.NewNoopLogger(),
resolveNodeBreadthLimit: serverconfig.DefaultResolveNodeBreadthLimit,
resolveNodeLimit: serverconfig.DefaultResolveNodeLimit,
deadline: serverconfig.DefaultListUsersDeadline,
maxResults: serverconfig.DefaultListUsersMaxResults,
maxConcurrentReads: serverconfig.DefaultMaxConcurrentReadsForListUsers,
wasThrottled: new(atomic.Bool),
expandDirectDispatch: expandDirectDispatch,
}
for _, opt := range opts {
opt(l)
}
l.datastore = storagewrappers.NewRequestStorageWrapper(ds, contextualTuples, &storagewrappers.Operation{
Method: apimethod.ListUsers,
Concurrency: l.maxConcurrentReads,
})
return l
}
// ListUsers assumes that the typesystem is in the context and that the request is valid.
func (l *listUsersQuery) ListUsers(
ctx context.Context,
req *openfgav1.ListUsersRequest,
) (*listUsersResponse, error) {
ctx, span := tracer.Start(ctx, "ListUsers", trace.WithAttributes(
attribute.String("store_id", req.GetStoreId()),
))
defer span.End()
cancellableCtx, cancelCtx := context.WithCancel(ctx)
if l.deadline != 0 {
cancellableCtx, cancelCtx = context.WithTimeout(cancellableCtx, l.deadline)
defer cancelCtx()
}
defer cancelCtx()
typesys, ok := typesystem.TypesystemFromContext(cancellableCtx)
if !ok {
return nil, fmt.Errorf("%w: typesystem missing in context", openfgaErrors.ErrUnknown)
}
userFilter := req.GetUserFilters()[0]
userset := tuple.ToObjectRelationString(tuple.ObjectKey(req.GetObject()), req.GetRelation())
if !tuple.UsersetMatchTypeAndRelation(userset, userFilter.GetRelation(), userFilter.GetType()) {
hasPossibleEdges, err := doesHavePossibleEdges(typesys, req)
if err != nil {
return nil, err
}
if !hasPossibleEdges {
span.SetAttributes(attribute.Bool("no_possible_edges", true))
return &listUsersResponse{
Users: []*openfgav1.User{},
Metadata: listUsersResponseMetadata{
DispatchCounter: new(atomic.Uint32),
WasThrottled: new(atomic.Bool),
},
}, nil
}
}
dispatchCount := atomic.Uint32{}
foundUsersCh := l.buildResultsChannel()
expandErrCh := make(chan error, 1)
foundUsersUnique := make(map[tuple.UserString]foundUser, 1000)
doneWithFoundUsersCh := make(chan struct{}, 1)
go func() {
for foundUser := range foundUsersCh {
foundUsersUnique[tuple.UserProtoToString(foundUser.user)] = foundUser
if l.maxResults > 0 {
if uint32(len(foundUsersUnique)) >= l.maxResults {
span.SetAttributes(attribute.Bool("max_results_found", true))
break
}
}
}
doneWithFoundUsersCh <- struct{}{}
}()
go func() {
internalRequest := fromListUsersRequest(req, &dispatchCount)
resp := l.expand(cancellableCtx, internalRequest, foundUsersCh)
if resp.err != nil {
expandErrCh <- resp.err
}
close(foundUsersCh)
}()
deadlineExceeded := false
select {
case <-doneWithFoundUsersCh:
break
case <-cancellableCtx.Done():
deadlineExceeded = true
// to avoid a race on the 'foundUsersUnique' map below, wait for the range over the channel to close
<-doneWithFoundUsersCh
break
}
select {
case err := <-expandErrCh:
if deadlineExceeded || errors.Is(err, context.DeadlineExceeded) {
// We skip the error because we want to send at least partial results to the user (but we should probably set response headers)
break
}
telemetry.TraceError(span, err)
return nil, err
default:
break
}
cancelCtx()
foundUsers := make([]*openfgav1.User, 0, len(foundUsersUnique))
for foundUserKey, foundUser := range foundUsersUnique {
if foundUser.relationshipStatus == NoRelationship {
continue
}
foundUsers = append(foundUsers, tuple.StringToUserProto(foundUserKey))
}
span.SetAttributes(attribute.Int("result_count", len(foundUsers)))
dsMeta := l.datastore.GetMetadata()
l.wasThrottled.CompareAndSwap(false, dsMeta.WasThrottled)
return &listUsersResponse{
Users: foundUsers,
Metadata: listUsersResponseMetadata{
DatastoreQueryCount: dsMeta.DatastoreQueryCount,
DatastoreItemCount: dsMeta.DatastoreItemCount,
DispatchCounter: &dispatchCount,
WasThrottled: l.wasThrottled,
},
}, nil
}
func doesHavePossibleEdges(typesys *typesystem.TypeSystem, req *openfgav1.ListUsersRequest) (bool, error) {
g := graph.New(typesys)
userFilters := req.GetUserFilters()
source := typesystem.DirectRelationReference(userFilters[0].GetType(), userFilters[0].GetRelation())
target := typesystem.DirectRelationReference(req.GetObject().GetType(), req.GetRelation())
edges, err := g.GetPrunedRelationshipEdges(target, source)
if err != nil {
return false, err
}
return len(edges) > 0, err
}
func (l *listUsersQuery) dispatch(
ctx context.Context,
req *internalListUsersRequest,
foundUsersChan chan<- foundUser,
) expandResponse {
newcount := req.dispatchCount.Add(1)
if l.dispatchThrottlerConfig.Enabled {
l.throttle(ctx, newcount)
}
return l.expand(ctx, req, foundUsersChan)
}
func (l *listUsersQuery) expand(
ctx context.Context,
req *internalListUsersRequest,
foundUsersChan chan<- foundUser,
) expandResponse {
ctx, span := tracer.Start(ctx, "expand")
defer span.End()
span.SetAttributes(attribute.Int("depth", int(req.depth)))
if req.depth >= l.resolveNodeLimit {
return expandResponse{
err: graph.ErrResolutionDepthExceeded,
}
}
req.depth++
if enteredCycle(req) {
span.SetAttributes(attribute.Bool("cycle_detected", true))
return expandResponse{
hasCycle: true,
}
}
reqObjectType := req.GetObject().GetType()
reqObjectID := req.GetObject().GetId()
reqRelation := req.GetRelation()
for _, userFilter := range req.GetUserFilters() {
if reqObjectType == userFilter.GetType() && reqRelation == userFilter.GetRelation() {
concurrency.TrySendThroughChannel(ctx, foundUser{
user: &openfgav1.User{
User: &openfgav1.User_Userset{
Userset: &openfgav1.UsersetUser{
Type: reqObjectType,
Id: reqObjectID,
Relation: reqRelation,
},
},
},
}, foundUsersChan)
}
}
typesys, _ := typesystem.TypesystemFromContext(ctx)
targetObjectType := req.GetObject().GetType()
targetRelation := req.GetRelation()
relation, err := typesys.GetRelation(targetObjectType, targetRelation)
if err != nil {
var relationUndefinedError *typesystem.RelationUndefinedError
if errors.As(err, &relationUndefinedError) {
return expandResponse{}
}
return expandResponse{
err: err,
}
}
relationRewrite := relation.GetRewrite()
resp := l.expandRewrite(ctx, req, relationRewrite, foundUsersChan)
if resp.err != nil {
telemetry.TraceError(span, resp.err)
}
return resp
}
func (l *listUsersQuery) expandRewrite(
ctx context.Context,
req *internalListUsersRequest,
rewrite *openfgav1.Userset,
foundUsersChan chan<- foundUser,
) expandResponse {
ctx, span := tracer.Start(ctx, "expandRewrite")
defer span.End()
var resp expandResponse
switch rewrite := rewrite.GetUserset().(type) {
case *openfgav1.Userset_This:
resp = l.expandDirect(ctx, req, foundUsersChan)
case *openfgav1.Userset_ComputedUserset:
rewrittenReq := req.clone()
rewrittenReq.Relation = rewrite.ComputedUserset.GetRelation()
resp = l.dispatch(ctx, rewrittenReq, foundUsersChan)
case *openfgav1.Userset_TupleToUserset:
resp = l.expandTTU(ctx, req, rewrite, foundUsersChan)
case *openfgav1.Userset_Intersection:
resp = l.expandIntersection(ctx, req, rewrite, foundUsersChan)
case *openfgav1.Userset_Difference:
resp = l.expandExclusion(ctx, req, rewrite, foundUsersChan)
case *openfgav1.Userset_Union:
resp = l.expandUnion(ctx, req, rewrite, foundUsersChan)
default:
panic("unexpected userset rewrite encountered")
}
if resp.err != nil {
telemetry.TraceError(span, resp.err)
}
return resp
}
func (l *listUsersQuery) expandDirect(
ctx context.Context,
req *internalListUsersRequest,
foundUsersChan chan<- foundUser,
) expandResponse {
ctx, span := tracer.Start(ctx, "expandDirect")
defer span.End()
typesys, _ := typesystem.TypesystemFromContext(ctx)
opts := storage.ReadOptions{
Consistency: storage.ConsistencyOptions{
Preference: req.GetConsistency(),
},
}
iter, err := l.datastore.Read(ctx, req.GetStoreId(), storage.ReadFilter{
Object: tuple.ObjectKey(req.GetObject()),
Relation: req.GetRelation(),
}, opts)
if err != nil {
telemetry.TraceError(span, err)
return expandResponse{
err: err,
}
}
defer iter.Stop()
filteredIter := storage.NewFilteredTupleKeyIterator(
storage.NewTupleKeyIteratorFromTupleIterator(iter),
validation.FilterInvalidTuples(typesys),
)
defer filteredIter.Stop()
pool := concurrency.NewPool(ctx, int(l.resolveNodeBreadthLimit))
var errs error
var hasCycle atomic.Bool
LoopOnIterator:
for {
tupleKey, err := filteredIter.Next(ctx)
if err != nil {
if !errors.Is(err, storage.ErrIteratorDone) {
errs = errors.Join(errs, err)
}
break LoopOnIterator
}
cond, _ := typesys.GetCondition(tupleKey.GetCondition().GetName())
condMet, err := eval.EvaluateTupleCondition(ctx, tupleKey, cond, req.Context)
if err != nil {
errs = errors.Join(errs, err)
if !errors.Is(err, condition.ErrEvaluationFailed) {
break LoopOnIterator
}
telemetry.TraceError(span, err)
}
if !condMet {
continue
}
tupleKeyUser := tupleKey.GetUser()
userObject, userRelation := tuple.SplitObjectRelation(tupleKeyUser)
userObjectType, userObjectID := tuple.SplitObject(userObject)
if userRelation == "" {
for _, f := range req.GetUserFilters() {
if f.GetType() == userObjectType {
user := tuple.StringToUserProto(tuple.BuildObject(userObjectType, userObjectID))
concurrency.TrySendThroughChannel(ctx, foundUser{
user: user,
}, foundUsersChan)
}
}
continue
}
pool.Go(func(ctx context.Context) error {
var resp expandResponse
recoveredError := panics.Try(func() {
resp = l.expandDirectDispatch(ctx, l, req, userObjectType, userObjectID, userRelation, resp, foundUsersChan, &hasCycle)
})
if recoveredError != nil {
resp = panicExpanseResponse(recoveredError)
}
return resp.err
})
}
errs = errors.Join(errs, pool.Wait())
if errs != nil {
telemetry.TraceError(span, errs)
}
return expandResponse{
err: errs,
hasCycle: hasCycle.Load(),
}
}
func expandDirectDispatch(ctx context.Context, l *listUsersQuery, req *internalListUsersRequest, userObjectType string, userObjectID string, userRelation string, resp expandResponse, foundUsersChan chan<- foundUser, hasCycle *atomic.Bool) expandResponse {
rewrittenReq := req.clone()
rewrittenReq.Object = &openfgav1.Object{Type: userObjectType, Id: userObjectID}
rewrittenReq.Relation = userRelation
resp = l.dispatch(ctx, rewrittenReq, foundUsersChan)
if resp.hasCycle {
hasCycle.Store(true)
}
return resp
}
func (l *listUsersQuery) expandIntersection(
ctx context.Context,
req *internalListUsersRequest,
rewrite *openfgav1.Userset_Intersection,
foundUsersChan chan<- foundUser,
) expandResponse {
ctx, span := tracer.Start(ctx, "expandIntersection")
defer span.End()
pool := concurrency.NewPool(ctx, int(l.resolveNodeBreadthLimit))
childOperands := rewrite.Intersection.GetChild()
intersectionFoundUsersChans := make([]chan foundUser, len(childOperands))
for i, rewrite := range childOperands {
intersectionFoundUsersChans[i] = make(chan foundUser, 1)
pool.Go(func(ctx context.Context) error {
resp := l.expandRewrite(ctx, req, rewrite, intersectionFoundUsersChans[i])
return resp.err
})
}
errChan := make(chan error, 1)
go func() {
err := pool.Wait()
for i := range intersectionFoundUsersChans {
close(intersectionFoundUsersChans[i])
}
errChan <- err
close(errChan)
}()
var mu sync.Mutex
var wg sync.WaitGroup
wg.Add(len(childOperands))
wildcardCount := atomic.Uint32{}
wildcardKey := tuple.TypedPublicWildcard(req.GetUserFilters()[0].GetType())
foundUsersCountMap := make(map[string]uint32, 0)
excludedUsersMap := make(map[string]struct{}, 0)
for _, foundUsersChan := range intersectionFoundUsersChans {
go func(foundUsersChan chan foundUser) {
defer wg.Done()
foundUsersMap := make(map[string]uint32, 0)
for foundUser := range foundUsersChan {
key := tuple.UserProtoToString(foundUser.user)
for _, excludedUser := range foundUser.excludedUsers {
key := tuple.UserProtoToString(excludedUser)
mu.Lock()
excludedUsersMap[key] = struct{}{}
mu.Unlock()
}
if foundUser.relationshipStatus == NoRelationship {
continue
}
foundUsersMap[key]++
}
_, wildcardExists := foundUsersMap[wildcardKey]
if wildcardExists {
wildcardCount.Add(1)
}
for userKey := range foundUsersMap {
mu.Lock()
// Increment the count for a user but decrement if a wildcard
// also exists to prevent double counting. This ensures accurate
// tracking for intersection criteria, avoiding inflated counts
// when both a user and a wildcard are present.
foundUsersCountMap[userKey]++
if wildcardExists {
foundUsersCountMap[userKey]--
}
mu.Unlock()
}
}(foundUsersChan)
}
wg.Wait()
excludedUsers := []*openfgav1.User{}
for key := range excludedUsersMap {
excludedUsers = append(excludedUsers, tuple.StringToUserProto(key))
}
for key, count := range foundUsersCountMap {
// Compare the number of times the specific user was returned for
// all intersection operands plus the number of wildcards.
// If this summed value equals the number of operands, the user satisfies
// the intersection expression and can be sent on `foundUsersChan`
if (count + wildcardCount.Load()) == uint32(len(childOperands)) {
fu := foundUser{
user: tuple.StringToUserProto(key),
excludedUsers: excludedUsers,
}
concurrency.TrySendThroughChannel(ctx, fu, foundUsersChan)
}
}
return expandResponse{
err: <-errChan,
}
}
func (l *listUsersQuery) expandUnion(
ctx context.Context,
req *internalListUsersRequest,
rewrite *openfgav1.Userset_Union,
foundUsersChan chan<- foundUser,
) expandResponse {
ctx, span := tracer.Start(ctx, "expandUnion")
defer span.End()
pool := concurrency.NewPool(ctx, int(l.resolveNodeBreadthLimit))
childOperands := rewrite.Union.GetChild()
unionFoundUsersChans := make([]chan foundUser, len(childOperands))
for i, rewrite := range childOperands {
unionFoundUsersChans[i] = make(chan foundUser, 1)
pool.Go(func(ctx context.Context) error {
resp := l.expandRewrite(ctx, req, rewrite, unionFoundUsersChans[i])
return resp.err
})
}
errChan := make(chan error, 1)
go func() {
err := pool.Wait()
for i := range unionFoundUsersChans {
close(unionFoundUsersChans[i])
}
errChan <- err
close(errChan)
}()
var mu sync.Mutex
var wg sync.WaitGroup
wg.Add(len(childOperands))
foundUsersMap := make(map[string]struct{}, 0)
excludedUsersCountMap := make(map[string]uint32, 0)
for _, foundUsersChan := range unionFoundUsersChans {
go func(foundUsersChan chan foundUser) {
defer wg.Done()
for foundUser := range foundUsersChan {
key := tuple.UserProtoToString(foundUser.user)
for _, excludedUser := range foundUser.excludedUsers {
key := tuple.UserProtoToString(excludedUser)
mu.Lock()
excludedUsersCountMap[key]++
mu.Unlock()
}
if foundUser.relationshipStatus == NoRelationship {
continue
}
mu.Lock()
foundUsersMap[key] = struct{}{}
mu.Unlock()
}
}(foundUsersChan)
}
wg.Wait()
excludedUsers := []*openfgav1.User{}
for key, count := range excludedUsersCountMap {
if count == uint32(len(childOperands)) {
excludedUsers = append(excludedUsers, tuple.StringToUserProto(key))
}
}
for key := range foundUsersMap {
fu := foundUser{
user: tuple.StringToUserProto(key),
excludedUsers: excludedUsers,
}
concurrency.TrySendThroughChannel(ctx, fu, foundUsersChan)
}
return expandResponse{
err: <-errChan,
}
}
func (l *listUsersQuery) expandExclusion(
ctx context.Context,
req *internalListUsersRequest,
rewrite *openfgav1.Userset_Difference,
foundUsersChan chan<- foundUser,
) expandResponse {
ctx, span := tracer.Start(ctx, "expandExclusion")
defer span.End()
baseFoundUsersCh := make(chan foundUser, 1)
subtractFoundUsersCh := make(chan foundUser, 1)
var baseError error
go func() {
resp := l.expandRewrite(ctx, req, rewrite.Difference.GetBase(), baseFoundUsersCh)
baseError = resp.err
close(baseFoundUsersCh)
}()
var subtractError error
var subtractHasCycle bool
go func() {
resp := l.expandRewrite(ctx, req, rewrite.Difference.GetSubtract(), subtractFoundUsersCh)
subtractError = resp.err
subtractHasCycle = resp.hasCycle
close(subtractFoundUsersCh)
}()
baseFoundUsersMap := make(map[string]foundUser, 0)
for fu := range baseFoundUsersCh {
key := tuple.UserProtoToString(fu.user)
baseFoundUsersMap[key] = fu
}
subtractFoundUsersMap := make(map[string]foundUser, len(baseFoundUsersMap))
for fu := range subtractFoundUsersCh {
key := tuple.UserProtoToString(fu.user)
subtractFoundUsersMap[key] = fu
}
if subtractHasCycle {
// Because exclusion contains the only bespoke treatment of
// cycle, everywhere else we consider it a falsey outcome.
// Once we make a determination within the exclusion handler, we're
// able to properly handle the case and do not need to propagate
// the existence of a cycle to an upstream handler.
return expandResponse{
err: nil,
}
}
wildcardKey := tuple.TypedPublicWildcard(req.GetUserFilters()[0].GetType())
_, baseWildcardExists := baseFoundUsersMap[wildcardKey]
_, subtractWildcardExists := subtractFoundUsersMap[wildcardKey]
for userKey, fu := range baseFoundUsersMap {
subtractedUser, userIsSubtracted := subtractFoundUsersMap[userKey]
_, wildcardSubtracted := subtractFoundUsersMap[wildcardKey]
switch {
case baseWildcardExists:
if !userIsSubtracted && !wildcardSubtracted {
concurrency.TrySendThroughChannel(ctx, foundUser{
user: tuple.StringToUserProto(userKey),
}, foundUsersChan)
}
for subtractedUserKey, subtractedFu := range subtractFoundUsersMap {
if tuple.IsTypedWildcard(subtractedUserKey) {
if !userIsSubtracted {
concurrency.TrySendThroughChannel(ctx, foundUser{
user: tuple.StringToUserProto(userKey),
relationshipStatus: NoRelationship,
}, foundUsersChan)
}
continue
}
if subtractedFu.relationshipStatus == NoRelationship {
concurrency.TrySendThroughChannel(ctx, foundUser{
user: tuple.StringToUserProto(subtractedUserKey),
relationshipStatus: HasRelationship,
}, foundUsersChan)
}
// a found user under the subtracted branch causes the subtracted user to have a negated relationship with respect
// to the base relation and is excluded since a wildcard is contained under the base branch.
if subtractedFu.relationshipStatus == HasRelationship {
concurrency.TrySendThroughChannel(ctx, foundUser{
user: tuple.StringToUserProto(subtractedUserKey),
relationshipStatus: NoRelationship,
excludedUsers: []*openfgav1.User{
tuple.StringToUserProto(subtractedUserKey),
},
}, foundUsersChan)
}
}
case subtractWildcardExists, userIsSubtracted:
if subtractedUser.relationshipStatus == HasRelationship {
concurrency.TrySendThroughChannel(ctx, foundUser{
user: tuple.StringToUserProto(userKey),
relationshipStatus: NoRelationship,
}, foundUsersChan)
}
if subtractedUser.relationshipStatus == NoRelationship {
concurrency.TrySendThroughChannel(ctx, foundUser{
user: tuple.StringToUserProto(userKey),
relationshipStatus: HasRelationship,
}, foundUsersChan)
}
default:
concurrency.TrySendThroughChannel(ctx, foundUser{
user: tuple.StringToUserProto(userKey),
relationshipStatus: fu.relationshipStatus,
}, foundUsersChan)
}
}
errs := errors.Join(baseError, subtractError)
if errs != nil {
telemetry.TraceError(span, errs)
}
return expandResponse{
err: errs,
}
}
func (l *listUsersQuery) expandTTU(
ctx context.Context,
req *internalListUsersRequest,
rewrite *openfgav1.Userset_TupleToUserset,
foundUsersChan chan<- foundUser,
) expandResponse {
ctx, span := tracer.Start(ctx, "expandTTU")
defer span.End()
tuplesetRelation := rewrite.TupleToUserset.GetTupleset().GetRelation()
computedRelation := rewrite.TupleToUserset.GetComputedUserset().GetRelation()
typesys, _ := typesystem.TypesystemFromContext(ctx)
opts := storage.ReadOptions{
Consistency: storage.ConsistencyOptions{
Preference: req.GetConsistency(),
},
}
iter, err := l.datastore.Read(ctx, req.GetStoreId(), storage.ReadFilter{
Object: tuple.ObjectKey(req.GetObject()),
Relation: tuplesetRelation,
}, opts)
if err != nil {
telemetry.TraceError(span, err)
return expandResponse{
err: err,
}
}
defer iter.Stop()
filteredIter := storage.NewFilteredTupleKeyIterator(
storage.NewTupleKeyIteratorFromTupleIterator(iter),
validation.FilterInvalidTuples(typesys),
)
defer filteredIter.Stop()
pool := concurrency.NewPool(ctx, int(l.resolveNodeBreadthLimit))
var errs error
LoopOnIterator:
for {
tupleKey, err := filteredIter.Next(ctx)
if err != nil {
if !errors.Is(err, storage.ErrIteratorDone) {
errs = errors.Join(errs, err)
}
break LoopOnIterator
}
cond, _ := typesys.GetCondition(tupleKey.GetCondition().GetName())
condMet, err := eval.EvaluateTupleCondition(ctx, tupleKey, cond, req.Context)
if err != nil {
errs = errors.Join(errs, err)
if !errors.Is(err, condition.ErrEvaluationFailed) {
break LoopOnIterator
}
telemetry.TraceError(span, err)
}
if !condMet {
continue
}
userObject := tupleKey.GetUser()
userObjectType, userObjectID := tuple.SplitObject(userObject)
pool.Go(func(ctx context.Context) error {
rewrittenReq := req.clone()
rewrittenReq.Object = &openfgav1.Object{Type: userObjectType, Id: userObjectID}
rewrittenReq.Relation = computedRelation
resp := l.dispatch(ctx, rewrittenReq, foundUsersChan)
return resp.err
})
}
errs = errors.Join(pool.Wait(), errs)
if errs != nil {
telemetry.TraceError(span, errs)
}
return expandResponse{
err: errs,
}
}
func enteredCycle(req *internalListUsersRequest) bool {
key := fmt.Sprintf("%s#%s", tuple.ObjectKey(req.GetObject()), req.Relation)
if _, loaded := req.visitedUsersetsMap[key]; loaded {
return true
}
req.visitedUsersetsMap[key] = struct{}{}
return false
}
func (l *listUsersQuery) buildResultsChannel() chan foundUser {
foundUsersCh := make(chan foundUser, serverconfig.DefaultListUsersMaxResults)
maxResults := l.maxResults
if maxResults > 0 {
foundUsersCh = make(chan foundUser, maxResults)
}
return foundUsersCh
}
func panicError(recovered *panics.Recovered) error {
return fmt.Errorf("%w: %s", ErrPanic, recovered.AsError())
}
func panicExpanseResponse(recovered *panics.Recovered) expandResponse {
return expandResponse{
hasCycle: false,
err: panicError(recovered),
}
}
package listusers
import (
"context"
"errors"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/validation"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/typesystem"
)
func ValidateListUsersRequest(ctx context.Context, req *openfgav1.ListUsersRequest, typesys *typesystem.TypeSystem) error {
_, span := tracer.Start(ctx, "validateListUsersRequest")
defer span.End()
if err := validateContextualTuples(req, typesys); err != nil {
return err
}
if err := validateUsersFilters(req, typesys); err != nil {
return err
}
return validateTargetRelation(req, typesys)
}
func validateContextualTuples(request *openfgav1.ListUsersRequest, typeSystem *typesystem.TypeSystem) error {
for _, contextualTuple := range request.GetContextualTuples() {
if err := validation.ValidateTupleForWrite(typeSystem, contextualTuple); err != nil {
return serverErrors.HandleTupleValidateError(err)
}
}
return nil
}
func validateUsersFilters(request *openfgav1.ListUsersRequest, typeSystem *typesystem.TypeSystem) error {
for _, userFilter := range request.GetUserFilters() {
if err := validateUserFilter(typeSystem, userFilter); err != nil {
return err
}
}
return nil
}
func validateUserFilter(typeSystem *typesystem.TypeSystem, usersFilter *openfgav1.UserTypeFilter) error {
filterObjectType := usersFilter.GetType()
if _, typeExists := typeSystem.GetTypeDefinition(filterObjectType); !typeExists {
return serverErrors.TypeNotFound(filterObjectType)
}
return validateUserFilterRelation(typeSystem, usersFilter, filterObjectType)
}
func validateUserFilterRelation(typeSystem *typesystem.TypeSystem, usersFilter *openfgav1.UserTypeFilter, filterObjectType string) error {
filterObjectRelation := usersFilter.GetRelation()
if filterObjectRelation == "" {
return nil
}
_, err := typeSystem.GetRelation(filterObjectType, filterObjectRelation)
if err == nil {
return nil
}
if errors.Is(err, typesystem.ErrRelationUndefined) {
return serverErrors.RelationNotFound(filterObjectRelation, filterObjectType, nil)
}
return serverErrors.HandleError("", err)
}
func validateTargetRelation(request *openfgav1.ListUsersRequest, typeSystem *typesystem.TypeSystem) error {
objectType := request.GetObject().GetType()
targetRelation := request.GetRelation()
_, err := typeSystem.GetRelation(objectType, targetRelation)
if err == nil {
return nil
}
if errors.Is(err, typesystem.ErrObjectTypeUndefined) {
return serverErrors.TypeNotFound(objectType)
}
if errors.Is(err, typesystem.ErrRelationUndefined) {
return serverErrors.RelationNotFound(targetRelation, objectType, nil)
}
return serverErrors.HandleError("", err)
}
package commands
import (
"context"
"fmt"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/encoder"
"github.com/openfga/openfga/pkg/logger"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
tupleUtils "github.com/openfga/openfga/pkg/tuple"
)
// A ReadQuery can be used to read one or many tuplesets
// Each tupleset specifies keys of a set of relation tuples.
// The set can include a single tuple key, or all tuples with
// a given object ID or userset in a type, optionally
// constrained by a relation name.
type ReadQuery struct {
datastore storage.OpenFGADatastore
logger logger.Logger
encoder encoder.Encoder
tokenSerializer encoder.ContinuationTokenSerializer
}
type ReadQueryOption func(*ReadQuery)
func WithReadQueryLogger(l logger.Logger) ReadQueryOption {
return func(rq *ReadQuery) {
rq.logger = l
}
}
func WithReadQueryEncoder(e encoder.Encoder) ReadQueryOption {
return func(rq *ReadQuery) {
rq.encoder = e
}
}
func WithReadQueryTokenSerializer(serializer encoder.ContinuationTokenSerializer) ReadQueryOption {
return func(rq *ReadQuery) {
rq.tokenSerializer = serializer
}
}
// NewReadQuery creates a ReadQuery using the provided OpenFGA datastore implementation.
func NewReadQuery(datastore storage.OpenFGADatastore, opts ...ReadQueryOption) *ReadQuery {
rq := &ReadQuery{
datastore: datastore,
logger: logger.NewNoopLogger(),
encoder: encoder.NewBase64Encoder(),
tokenSerializer: encoder.NewStringContinuationTokenSerializer(),
}
for _, opt := range opts {
opt(rq)
}
return rq
}
// Execute the ReadQuery, returning paginated `openfga.Tuple`(s) that match the tuple. Return all tuples if the tuple is
// nil or empty.
func (q *ReadQuery) Execute(ctx context.Context, req *openfgav1.ReadRequest) (*openfgav1.ReadResponse, error) {
store := req.GetStoreId()
tk := req.GetTupleKey()
// Restrict our reads due to some compatibility issues in one of our storage implementations.
if tk != nil {
objectType, objectID := tupleUtils.SplitObject(tk.GetObject())
if objectType == "" || (objectID == "" && tk.GetUser() == "") {
return nil, serverErrors.ValidationError(
fmt.Errorf("the 'tuple_key' field was provided but the object type field is required and both the object id and user cannot be empty"),
)
}
}
decodedContToken, err := q.encoder.Decode(req.GetContinuationToken())
if err != nil {
return nil, serverErrors.ErrInvalidContinuationToken
}
if len(decodedContToken) > 0 {
from, _, err := q.tokenSerializer.Deserialize(string(decodedContToken))
if err != nil {
return nil, serverErrors.ErrInvalidContinuationToken
}
decodedContToken = []byte(from)
}
opts := storage.ReadPageOptions{
Pagination: storage.NewPaginationOptions(req.GetPageSize().GetValue(), string(decodedContToken)),
Consistency: storage.ConsistencyOptions{Preference: req.GetConsistency()},
}
filter := storage.ReadFilter{}
if tk != nil {
filter = storage.ReadFilter{
Object: tk.GetObject(),
Relation: tk.GetRelation(),
User: tk.GetUser(),
}
}
tuples, contUlid, err := q.datastore.ReadPage(ctx, store, filter, opts)
if err != nil {
return nil, serverErrors.HandleError("", err)
}
if len(contUlid) == 0 {
return &openfgav1.ReadResponse{
Tuples: tuples,
ContinuationToken: "",
}, nil
}
contToken, err := q.tokenSerializer.Serialize(contUlid, "")
if err != nil {
return nil, serverErrors.HandleError("", err)
}
encodedContToken, err := q.encoder.Encode(contToken)
if err != nil {
return nil, serverErrors.HandleError("", err)
}
return &openfgav1.ReadResponse{
Tuples: tuples,
ContinuationToken: encodedContToken,
}, nil
}
package commands
import (
"context"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/logger"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
)
type ReadAssertionsQuery struct {
backend storage.AssertionsBackend
logger logger.Logger
}
type ReadAssertionsQueryOption func(*ReadAssertionsQuery)
func WithReadAssertionsQueryLogger(l logger.Logger) ReadAssertionsQueryOption {
return func(rq *ReadAssertionsQuery) {
rq.logger = l
}
}
func NewReadAssertionsQuery(backend storage.AssertionsBackend, opts ...ReadAssertionsQueryOption) *ReadAssertionsQuery {
rq := &ReadAssertionsQuery{
backend: backend,
logger: logger.NewNoopLogger(),
}
for _, opt := range opts {
opt(rq)
}
return rq
}
func (q *ReadAssertionsQuery) Execute(ctx context.Context, store, authorizationModelID string) (*openfgav1.ReadAssertionsResponse, error) {
assertions, err := q.backend.ReadAssertions(ctx, store, authorizationModelID)
if err != nil {
return nil, serverErrors.HandleError("", err)
}
return &openfgav1.ReadAssertionsResponse{
AuthorizationModelId: authorizationModelID,
Assertions: assertions,
}, nil
}
package commands
import (
"context"
"errors"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/logger"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
)
// ReadAuthorizationModelQuery retrieves a single type definition from a storage backend.
type ReadAuthorizationModelQuery struct {
backend storage.AuthorizationModelReadBackend
logger logger.Logger
}
type ReadAuthModelQueryOption func(*ReadAuthorizationModelQuery)
func WithReadAuthModelQueryLogger(l logger.Logger) ReadAuthModelQueryOption {
return func(m *ReadAuthorizationModelQuery) {
m.logger = l
}
}
func NewReadAuthorizationModelQuery(backend storage.AuthorizationModelReadBackend, opts ...ReadAuthModelQueryOption) *ReadAuthorizationModelQuery {
m := &ReadAuthorizationModelQuery{
backend: backend,
logger: logger.NewNoopLogger(),
}
for _, opt := range opts {
opt(m)
}
return m
}
func (q *ReadAuthorizationModelQuery) Execute(ctx context.Context, req *openfgav1.ReadAuthorizationModelRequest) (*openfgav1.ReadAuthorizationModelResponse, error) {
modelID := req.GetId()
azm, err := q.backend.ReadAuthorizationModel(ctx, req.GetStoreId(), modelID)
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
return nil, serverErrors.AuthorizationModelNotFound(modelID)
}
return nil, serverErrors.HandleError("", err)
}
return &openfgav1.ReadAuthorizationModelResponse{
AuthorizationModel: azm,
}, nil
}
package commands
import (
"context"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/encoder"
"github.com/openfga/openfga/pkg/logger"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
)
type ReadAuthorizationModelsQuery struct {
backend storage.AuthorizationModelReadBackend
logger logger.Logger
encoder encoder.Encoder
}
type ReadAuthModelsQueryOption func(*ReadAuthorizationModelsQuery)
func WithReadAuthModelsQueryLogger(l logger.Logger) ReadAuthModelsQueryOption {
return func(rm *ReadAuthorizationModelsQuery) {
rm.logger = l
}
}
func WithReadAuthModelsQueryEncoder(e encoder.Encoder) ReadAuthModelsQueryOption {
return func(rm *ReadAuthorizationModelsQuery) {
rm.encoder = e
}
}
func NewReadAuthorizationModelsQuery(backend storage.AuthorizationModelReadBackend, opts ...ReadAuthModelsQueryOption) *ReadAuthorizationModelsQuery {
rm := &ReadAuthorizationModelsQuery{
backend: backend,
logger: logger.NewNoopLogger(),
encoder: encoder.NewBase64Encoder(),
}
for _, opt := range opts {
opt(rm)
}
return rm
}
func (q *ReadAuthorizationModelsQuery) Execute(ctx context.Context, req *openfgav1.ReadAuthorizationModelsRequest) (*openfgav1.ReadAuthorizationModelsResponse, error) {
decodedContToken, err := q.encoder.Decode(req.GetContinuationToken())
if err != nil {
return nil, serverErrors.ErrInvalidContinuationToken
}
opts := storage.ReadAuthorizationModelsOptions{
Pagination: storage.NewPaginationOptions(req.GetPageSize().GetValue(), string(decodedContToken)),
}
models, contToken, err := q.backend.ReadAuthorizationModels(ctx, req.GetStoreId(), opts)
if err != nil {
return nil, serverErrors.HandleError("", err)
}
encodedContToken, err := q.encoder.Encode([]byte(contToken))
if err != nil {
return nil, serverErrors.HandleError("", err)
}
resp := &openfgav1.ReadAuthorizationModelsResponse{
AuthorizationModels: models,
ContinuationToken: encodedContToken,
}
return resp, nil
}
package commands
import (
"context"
"errors"
"time"
"github.com/oklog/ulid/v2"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/encoder"
"github.com/openfga/openfga/pkg/logger"
serverconfig "github.com/openfga/openfga/pkg/server/config"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
)
type ReadChangesQuery struct {
backend storage.ChangelogBackend
logger logger.Logger
encoder encoder.Encoder
tokenSerializer encoder.ContinuationTokenSerializer
horizonOffset time.Duration
}
type ReadChangesQueryOption func(*ReadChangesQuery)
func WithReadChangesQueryLogger(l logger.Logger) ReadChangesQueryOption {
return func(rq *ReadChangesQuery) {
rq.logger = l
}
}
func WithReadChangesQueryEncoder(e encoder.Encoder) ReadChangesQueryOption {
return func(rq *ReadChangesQuery) {
rq.encoder = e
}
}
// WithReadChangeQueryHorizonOffset specifies duration in minutes.
func WithReadChangeQueryHorizonOffset(horizonOffset int) ReadChangesQueryOption {
return func(rq *ReadChangesQuery) {
rq.horizonOffset = time.Duration(horizonOffset) * time.Minute
}
}
// WithContinuationTokenSerializer specifies the token serializer to be used.
func WithContinuationTokenSerializer(tokenSerializer encoder.ContinuationTokenSerializer) ReadChangesQueryOption {
return func(rq *ReadChangesQuery) {
rq.tokenSerializer = tokenSerializer
}
}
// NewReadChangesQuery creates a ReadChangesQuery with specified `ChangelogBackend`.
func NewReadChangesQuery(backend storage.ChangelogBackend, opts ...ReadChangesQueryOption) *ReadChangesQuery {
rq := &ReadChangesQuery{
backend: backend,
logger: logger.NewNoopLogger(),
encoder: encoder.NewBase64Encoder(),
horizonOffset: time.Duration(serverconfig.DefaultChangelogHorizonOffset) * time.Minute,
tokenSerializer: encoder.NewStringContinuationTokenSerializer(),
}
for _, opt := range opts {
opt(rq)
}
return rq
}
// Execute the ReadChangesQuery, returning paginated `openfga.TupleChange`(s) and a possibly non-empty continuation token.
func (q *ReadChangesQuery) Execute(ctx context.Context, req *openfgav1.ReadChangesRequest) (*openfgav1.ReadChangesResponse, error) {
decodedContToken, err := q.encoder.Decode(req.GetContinuationToken())
if err != nil {
return nil, serverErrors.ErrInvalidContinuationToken
}
token := string(decodedContToken)
var fromUlid string
var startTime time.Time
if req.GetStartTime() != nil {
startTime = req.GetStartTime().AsTime()
}
if token != "" {
var objType string
fromUlid, objType, err = q.tokenSerializer.Deserialize(token)
if err != nil {
return nil, serverErrors.ErrInvalidContinuationToken
}
if objType != req.GetType() {
return nil, serverErrors.ErrMismatchObjectType
}
} else if !startTime.IsZero() {
tokenUlid, ulidErr := ulid.New(ulid.Timestamp(startTime), nil)
if ulidErr != nil {
return nil, serverErrors.HandleError(ulidErr.Error(), storage.ErrInvalidStartTime)
}
fromUlid = tokenUlid.String()
}
opts := storage.ReadChangesOptions{
Pagination: storage.NewPaginationOptions(
req.GetPageSize().GetValue(),
fromUlid,
),
}
filter := storage.ReadChangesFilter{
ObjectType: req.GetType(),
HorizonOffset: q.horizonOffset,
}
changes, contUlid, err := q.backend.ReadChanges(ctx, req.GetStoreId(), filter, opts)
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
return &openfgav1.ReadChangesResponse{
ContinuationToken: req.GetContinuationToken(),
}, nil
}
return nil, serverErrors.HandleError("", err)
}
if len(contUlid) == 0 {
return &openfgav1.ReadChangesResponse{
Changes: changes,
ContinuationToken: "",
}, nil
}
contToken, err := q.tokenSerializer.Serialize(contUlid, req.GetType())
if err != nil {
return nil, serverErrors.HandleError("", err)
}
encodedContToken, err := q.encoder.Encode(contToken)
if err != nil {
return nil, serverErrors.HandleError("", err)
}
return &openfgav1.ReadChangesResponse{
Changes: changes,
ContinuationToken: encodedContToken,
}, nil
}
package reverseexpand
import (
"context"
"iter"
"sync"
)
type message[T any] struct {
Value T
finite func()
}
func (m *message[T]) done() {
if m.finite != nil {
m.finite()
}
}
type producer[T any] interface {
recv(context.Context) (message[T], bool)
seq(context.Context) iter.Seq[message[T]]
}
type consumer[T any] interface {
send(T)
close()
cancel()
}
const maxPipeSize int = 100
type pipe struct {
data [maxPipeSize]group
head int
tail int
count int
done bool
mu sync.Mutex
full *sync.Cond
empty *sync.Cond
closed *sync.Cond
trk tracker
}
func newPipe(trk tracker) *pipe {
p := pipe{
trk: trk,
}
p.full = sync.NewCond(&p.mu)
p.empty = sync.NewCond(&p.mu)
p.closed = sync.NewCond(&p.mu)
return &p
}
func (p *pipe) seq(ctx context.Context) iter.Seq[message[group]] {
return func(yield func(message[group]) bool) {
defer p.cancel()
for {
msg, ok := p.recv(ctx)
if !ok {
break
}
if !yield(msg) {
break
}
}
}
}
func (p *pipe) send(item group) {
p.mu.Lock()
defer p.mu.Unlock()
if p.done {
return
}
p.trk.Add(1)
// Wait if the buffer is full.
for p.count == maxPipeSize && !p.done {
p.full.Wait()
}
if p.done {
p.trk.Add(-1)
return
}
p.data[p.head] = item
p.head = (p.head + 1) % maxPipeSize
p.count++
// Signal that the buffer is no longer empty.
p.empty.Signal()
}
func (p *pipe) recv(ctx context.Context) (message[group], bool) {
p.mu.Lock()
// Wait while the buffer is empty and the pipe is not yet done.
for p.count == 0 && !p.done && ctx.Err() == nil {
p.empty.Wait()
}
if (p.count == 0 && p.done) || ctx.Err() != nil {
p.mu.Unlock()
return message[group]{}, false
}
item := p.data[p.tail]
p.tail = (p.tail + 1) % maxPipeSize
p.count--
// Signal that the buffer is no longer full.
p.full.Signal()
if p.count == 0 {
p.closed.Broadcast()
}
p.mu.Unlock()
fn := func() {
p.trk.Add(-1)
}
return message[group]{Value: item, finite: sync.OnceFunc(fn)}, true
}
func (p *pipe) close() {
p.mu.Lock()
defer p.mu.Unlock()
p.done = true
p.empty.Broadcast()
p.full.Broadcast()
for p.count > 0 {
p.closed.Wait()
}
}
func (p *pipe) cancel() {
p.mu.Lock()
if p.done {
p.mu.Unlock()
return
}
p.done = true
p.empty.Broadcast()
p.full.Broadcast()
p.mu.Unlock()
m, ok := p.recv(context.Background())
for ok {
m.done()
m, ok = p.recv(context.Background())
}
}
type staticProducer struct {
mu sync.Mutex
groups []group
pos int
trk tracker
}
func (p *staticProducer) recv(ctx context.Context) (message[group], bool) {
p.mu.Lock()
defer p.mu.Unlock()
if ctx.Err() != nil {
return message[group]{}, false
}
if p.pos == len(p.groups) {
return message[group]{}, false
}
value := p.groups[p.pos]
p.pos++
fn := func() {
p.trk.Add(-1)
}
return message[group]{Value: value, finite: sync.OnceFunc(fn)}, true
}
func (p *staticProducer) close() {
p.mu.Lock()
defer p.mu.Unlock()
p.pos = len(p.groups)
}
func (p *staticProducer) seq(ctx context.Context) iter.Seq[message[group]] {
return func(yield func(message[group]) bool) {
defer p.close()
for {
msg, ok := p.recv(ctx)
if !ok {
break
}
if !yield(msg) {
break
}
}
}
}
func newStaticProducer(trk tracker, groups ...group) producer[group] {
if trk != nil {
trk.Add(int64(len(groups)))
}
return &staticProducer{
groups: groups,
trk: trk,
}
}
package reverseexpand
import (
"context"
"errors"
"iter"
"maps"
"runtime"
"strings"
"sync"
"sync/atomic"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/types/known/structpb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
weightedGraph "github.com/openfga/language/pkg/go/graph"
"github.com/openfga/openfga/internal/checkutil"
"github.com/openfga/openfga/internal/seq"
"github.com/openfga/openfga/internal/validation"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/typesystem"
)
type (
Edge = weightedGraph.WeightedAuthorizationModelEdge
Graph = weightedGraph.WeightedAuthorizationModelGraph
Node = weightedGraph.WeightedAuthorizationModelNode
)
var (
pipelineTracer = otel.Tracer("pipeline")
edgeTypeComputed = weightedGraph.ComputedEdge
edgeTypeDirect = weightedGraph.DirectEdge
edgeTypeDirectLogical = weightedGraph.DirectLogicalEdge
edgeTypeRewrite = weightedGraph.RewriteEdge
edgeTypeTTU = weightedGraph.TTUEdge
edgeTypeTTULogical = weightedGraph.TTULogicalEdge
// EmptySequence represents an `iter.Seq[Item]` that does nothing.
emptySequence = func(yield func(Item) bool) {}
nodeTypeLogicalDirectGrouping = weightedGraph.LogicalDirectGrouping
nodeTypeLogicalTTUGrouping = weightedGraph.LogicalTTUGrouping
nodeTypeOperator = weightedGraph.OperatorNode
nodeTypeSpecificType = weightedGraph.SpecificType
nodeTypeSpecificTypeAndRelation = weightedGraph.SpecificTypeAndRelation
nodeTypeSpecificTypeWildcard = weightedGraph.SpecificTypeWildcard
)
func handleIdentity(_ context.Context, _ *Edge, items []Item) iter.Seq[Item] {
return seq.Sequence(items...)
}
func handleLeafNode(node *Node) edgeHandler {
return func(_ context.Context, _ *Edge, items []Item) iter.Seq[Item] {
objectParts := strings.Split(node.GetLabel(), "#")
if len(objectParts) < 1 {
return seq.Sequence(Item{Err: errors.New("empty label in node")})
}
objectType := objectParts[0]
results := seq.Transform(seq.Sequence(items...), func(item Item) Item {
var value string
switch node.GetNodeType() {
case nodeTypeSpecificTypeWildcard:
value = ""
case nodeTypeSpecificType, nodeTypeSpecificTypeAndRelation:
value = ":" + item.Value
default:
return Item{Err: errors.New("unsupported leaf node type")}
}
item.Value = objectType + value
return item
})
return results
}
}
func handleUnsupported(_ context.Context, _ *Edge, _ []Item) iter.Seq[Item] {
return seq.Sequence(Item{Err: errors.New("unsupported state")})
}
func NewPipeline(backend *Backend, options ...PipelineOption) *Pipeline {
p := &Pipeline{
backend: backend,
chunkSize: 100,
numProcs: 3,
}
for _, option := range options {
option(p)
}
return p
}
func WithChunkSize(size int) PipelineOption {
if size < 0 {
size = 0
}
return func(p *Pipeline) {
p.chunkSize = size
}
}
func WithNumProcs(num int) PipelineOption {
if num < 1 {
num = 1
}
return func(p *Pipeline) {
p.numProcs = num
}
}
// Backend is a struct that serves as a container for all backend elements
// necessary for creating and running a `Pipeline`.
type Backend struct {
Datastore storage.RelationshipTupleReader
StoreID string
TypeSystem *typesystem.TypeSystem
Context *structpb.Struct
Graph *Graph
Preference openfgav1.ConsistencyPreference
}
// handleDirectEdge is a function that interprets input on a direct edge and provides output from
// a query to the backend datastore.
func (b *Backend) handleDirectEdge(ctx context.Context, edge *Edge, items []Item) iter.Seq[Item] {
parts := strings.Split(edge.GetRelationDefinition(), "#")
nodeType := parts[0]
nodeRelation := parts[1]
userParts := strings.Split(edge.GetTo().GetLabel(), "#")
var userRelation string
if len(userParts) > 1 {
userRelation = userParts[1]
}
var userFilter []*openfgav1.ObjectRelation
var errs []Item
for _, item := range items {
if item.Err != nil {
errs = append(errs, item)
continue
}
userFilter = append(userFilter, &openfgav1.ObjectRelation{
Object: item.Value,
Relation: userRelation,
})
}
var results iter.Seq[Item]
if len(userFilter) > 0 {
input := queryInput{
objectType: nodeType,
objectRelation: nodeRelation,
userFilter: userFilter,
conditions: edge.GetConditions(),
}
results = b.query(ctx, input)
} else {
results = emptySequence
}
if len(errs) > 0 {
results = seq.Flatten(seq.Sequence(errs...), results)
}
return results
}
// handleTTUEdge is a function that interprets input on a TTU edge and provides output from
// a query to the backend datastore.
func (b *Backend) handleTTUEdge(ctx context.Context, edge *Edge, items []Item) iter.Seq[Item] {
parts := strings.Split(edge.GetTuplesetRelation(), "#")
if len(parts) < 2 {
return seq.Sequence(Item{Err: errors.New("invalid tupleset relation")})
}
tuplesetType := parts[0]
tuplesetRelation := parts[1]
tuplesetNode, ok := b.Graph.GetNodeByID(edge.GetTuplesetRelation())
if !ok {
return seq.Sequence(Item{Err: errors.New("tupleset node not in graph")})
}
edges, ok := b.Graph.GetEdgesFromNode(tuplesetNode)
if !ok {
return seq.Sequence(Item{Err: errors.New("no edges found for tupleset node")})
}
targetParts := strings.Split(edge.GetTo().GetLabel(), "#")
if len(targetParts) < 1 {
return seq.Sequence(Item{Err: errors.New("empty edge label")})
}
targetType := targetParts[0]
var targetEdge *Edge
for _, e := range edges {
if e.GetTo().GetLabel() == targetType {
targetEdge = e
break
}
}
if targetEdge == nil {
return seq.Sequence(Item{Err: errors.New("ttu target type is not an edge of tupleset")})
}
var userFilter []*openfgav1.ObjectRelation
var errs []Item
for _, item := range items {
if item.Err != nil {
errs = append(errs, item)
continue
}
userFilter = append(userFilter, &openfgav1.ObjectRelation{
Object: item.Value,
Relation: "",
})
}
var results iter.Seq[Item]
if len(userFilter) > 0 {
input := queryInput{
objectType: tuplesetType,
objectRelation: tuplesetRelation,
userFilter: userFilter,
conditions: targetEdge.GetConditions(),
}
results = b.query(ctx, input)
} else {
results = emptySequence
}
if len(errs) > 0 {
results = seq.Flatten(seq.Sequence(errs...), results)
}
return results
}
func (b *Backend) query(ctx context.Context, input queryInput) iter.Seq[Item] {
ctx, cancel := context.WithCancel(ctx)
it, err := b.Datastore.ReadStartingWithUser(
ctx,
b.StoreID,
storage.ReadStartingWithUserFilter{
ObjectType: input.objectType,
Relation: input.objectRelation,
UserFilter: input.userFilter,
Conditions: input.conditions,
},
storage.ReadStartingWithUserOptions{
Consistency: storage.ConsistencyOptions{
Preference: b.Preference,
},
},
)
if err != nil {
cancel()
return seq.Sequence(Item{Err: err})
}
// If more than one element exists, at least one element is guaranteed to be a condition.
// OR
// If only one element exists, and it is not `NoCond`, then it is guaranteed to be a condition.
hasConditions := len(input.conditions) > 1 || (len(input.conditions) > 0 && input.conditions[0] != weightedGraph.NoCond)
var itr storage.TupleKeyIterator
if hasConditions {
itr = storage.NewConditionsFilteredTupleKeyIterator(
storage.NewFilteredTupleKeyIterator(
storage.NewTupleKeyIteratorFromTupleIterator(it),
validation.FilterInvalidTuples(b.TypeSystem),
),
checkutil.BuildTupleKeyConditionFilter(ctx, b.Context, b.TypeSystem),
)
} else {
itr = storage.NewFilteredTupleKeyIterator(
storage.NewTupleKeyIteratorFromTupleIterator(it),
validation.FilterInvalidTuples(b.TypeSystem),
)
}
return func(yield func(Item) bool) {
defer cancel()
defer itr.Stop()
for ctx.Err() == nil {
t, err := itr.Next(ctx)
var item Item
if err != nil {
if err == storage.ErrIteratorDone {
break
}
item.Err = err
yield(item)
break
}
if t == nil {
continue
}
item.Value = t.GetObject()
if !yield(item) {
break
}
}
}
}
// baseResolver is a struct that implements the `resolver` interface and acts as the standard resolver for most
// workers. A baseResolver handles both recursive and non-recursive edges concurrently. The baseResolver's "ready"
// status will remain `true` until all of its senders that produce input from external sources have finished, and
// there exist no more in-flight messages for the parent worker. When recursive edges exist, the parent worker for
// this resolver type requires its internal watchdog process to initiate a shutdown.
type baseResolver struct {
// id is an identifier provided when registering with the baseResolver's StatusPool. The registration happens
// once, when the baseResolver is created, so this value will remain constant for the lifetime of the instance.
id int
ctx context.Context
// interpreter is an `interpreter` that transforms a sender's input into output which it broadcasts to all
// of the parent worker's listeners.
interpreter interpreter
// mutexes each protect map access for a buffer within inBuffers at the same index.
mutexes []sync.Mutex
// inBuffers contains a slice of maps, used as hash sets, for deduplicating each individual sender's
// input feed. Each buffer's index corresponds to its associated sender's index. Each sender needs
// a separate deduplication buffer because it is valid for the same object to be receieved on multiple
// edges producing to the same node. This is specifically true in the case of recursive edges, where
// a single resolver may have multiple recursive edges that must receive the same objects.
inBuffers []map[string]struct{}
errBuffers []map[string]struct{}
// outMu protects map access to outBuffer.
outMu sync.Mutex
// outBuffer contains a map, used as a hash set, for deduplicating each resolver's output. A single
// resolver should only output an object once.
outBuffer map[string]struct{}
// status is a *StatusPool instance that tracks the status of the baseResolver instance. This *StatusPool value
// may or may not be shared with other resolver instances and workers. When the current resolver is part of a
// recursive chain, then this *StatusPool value is shared with each of the participating resolvers. The status
// of a baseResolver is assumed to be `true` from the point of initialization, until all "standard" senders have completed.
// A baseResolver's status can be `false` while "recursive" senders are still actively processing messages. In that
// case, the parent worker is kept alive by the overall status of the *StatusPool instance, and the count of messages
// in-flight.
status *StatusPool
}
func (r *baseResolver) process(ndx int, snd *sender, listeners []*listener) loopFunc {
return func(msg message[group]) bool {
// Loop while the sender has a potential to yield a message.
var results iter.Seq[Item]
var outGroup group
var items, unseen []Item
attrs := []attribute.KeyValue{
attribute.String("resolver", "base"),
attribute.Int("item.count", len(msg.Value.Items)),
attribute.Int("sender.index", ndx),
}
edgeTo := "nil"
edgeFrom := "nil"
if snd.edge() != nil {
edgeTo = snd.edge().GetTo().GetUniqueLabel()
edgeFrom = snd.edge().GetFrom().GetUniqueLabel()
}
attrs = append(
attrs,
attribute.String("edge.to", edgeTo),
attribute.String("edge.from", edgeFrom),
)
ctx, span := pipelineTracer.Start(r.ctx, "message.received", trace.WithAttributes(attrs...))
defer span.End()
// Deduplicate items within this group based on the buffer for this sender
for _, item := range msg.Value.Items {
if item.Err != nil {
r.mutexes[ndx].Lock()
if _, ok := r.errBuffers[ndx][item.Err.Error()]; ok {
r.mutexes[ndx].Unlock()
continue
}
r.errBuffers[ndx][item.Err.Error()] = struct{}{}
r.mutexes[ndx].Unlock()
unseen = append(unseen, item)
continue
}
r.mutexes[ndx].Lock()
if _, ok := r.inBuffers[ndx][item.Value]; ok {
r.mutexes[ndx].Unlock()
continue
}
r.inBuffers[ndx][item.Value] = struct{}{}
r.mutexes[ndx].Unlock()
unseen = append(unseen, item)
}
// If there are no unseen items, skip processing
if len(unseen) == 0 {
msg.done()
return true
}
results = r.interpreter.interpret(ctx, snd.edge(), unseen)
// Deduplicate the output and potentially send in chunks.
for item := range results {
if r.ctx.Err() != nil {
break
}
if item.Err != nil {
goto AfterDedup
}
r.outMu.Lock()
if _, ok := r.outBuffer[item.Value]; ok {
r.outMu.Unlock()
continue
}
r.outBuffer[item.Value] = struct{}{}
r.outMu.Unlock()
AfterDedup:
items = append(items, item)
if len(items) < snd.chunks() || snd.chunks() == 0 {
continue
}
g := group{
Items: items,
}
items = nil
for _, lst := range listeners {
lst.send(g)
}
}
if len(items) == 0 {
msg.done()
return true
}
outGroup.Items = items
for _, lst := range listeners {
lst.send(outGroup)
}
msg.done()
return true
}
}
func (r *baseResolver) resolve(senders []*sender, listeners []*listener) {
r.mutexes = make([]sync.Mutex, len(senders))
r.inBuffers = make([]map[string]struct{}, len(senders))
r.errBuffers = make([]map[string]struct{}, len(senders))
r.outBuffer = make(map[string]struct{})
for ndx := range len(senders) {
r.inBuffers[ndx] = make(map[string]struct{})
r.errBuffers[ndx] = make(map[string]struct{})
}
// Any senders with a non-recursive edge will be processed in the "standard" queue.
var standard []func()
// Any senders with a recursive edge will be processed in the "recursive" queue.
var recursive []func()
for ndx := range len(senders) {
snd := senders[ndx]
// Any sender with an edge that has a value for its recursive relation will be treated
// as recursive, so long as it is not part of a tuple cycle. When the edge is part of
// a tuple cycle, treating it as recursive would cause the parent worker to be closed
// too early.
isRecursive := snd.edge() != nil && len(snd.edge().GetRecursiveRelation()) > 0 && !snd.edge().IsPartOfTupleCycle()
for range snd.procs() {
proc := func() {
snd.loop(r.process(ndx, snd, listeners))
}
if isRecursive {
recursive = append(recursive, proc)
continue
}
standard = append(standard, proc)
}
}
var wgStandard sync.WaitGroup
for _, proc := range standard {
wgStandard.Add(1)
go func() {
defer wgStandard.Done()
proc()
}()
}
var wgRecursive sync.WaitGroup
for _, proc := range recursive {
wgRecursive.Add(1)
go func() {
defer wgRecursive.Done()
proc()
}()
}
// All standard senders are guaranteed to end at some point, with the exception of the presence
// of a tuple cycle.
wgStandard.Wait()
// Once the standard senders have all finished processing, we set the resolver's status to `false`
// indicating that the parent is ready for cleanup once all messages have finished processing.
r.status.Set(r.id, false)
// Recursive senders will process infinitely until the parent worker's watchdog goroutine kills
// them.
wgRecursive.Wait()
}
type edgeHandler func(context.Context, *Edge, []Item) iter.Seq[Item]
type exclusionResolver struct {
id int
ctx context.Context
interpreter interpreter
status *StatusPool
trk tracker
}
func (r *exclusionResolver) resolve(senders []*sender, listeners []*listener) {
defer func() {
r.trk.Add(-1)
r.status.Set(r.id, false)
}()
r.trk.Add(1)
if len(senders) != 2 {
panic("exclusion resolver requires two senders")
}
var wg sync.WaitGroup
included := make(map[string]struct{})
excluded := make(map[string]struct{})
var includedErrs []Item
var excludedErrs []Item
var procIncluded loopFunc
var procExcluded loopFunc
procIncluded = func(msg message[group]) bool {
attrs := []attribute.KeyValue{
attribute.String("resolver", "exclusion"),
attribute.Int("item.count", len(msg.Value.Items)),
attribute.Int("sender.index", 0),
}
edgeTo := "nil"
edgeFrom := "nil"
if senders[0].edge() != nil {
edgeTo = senders[0].edge().GetTo().GetUniqueLabel()
edgeFrom = senders[0].edge().GetFrom().GetUniqueLabel()
}
attrs = append(
attrs,
attribute.String("edge.to", edgeTo),
attribute.String("edge.from", edgeFrom),
)
ctx, span := pipelineTracer.Start(r.ctx, "message.received", trace.WithAttributes(attrs...))
defer span.End()
results := r.interpreter.interpret(ctx, senders[0].edge(), msg.Value.Items)
for item := range results {
if r.ctx.Err() != nil {
break
}
if item.Err != nil {
includedErrs = append(includedErrs, item)
continue
}
included[item.Value] = struct{}{}
}
msg.done()
return true
}
procExcluded = func(msg message[group]) bool {
attrs := []attribute.KeyValue{
attribute.String("resolver", "exclusion"),
attribute.Int("item.count", len(msg.Value.Items)),
attribute.Int("sender.index", 1),
}
edgeTo := "nil"
edgeFrom := "nil"
if senders[1].edge() != nil {
edgeTo = senders[1].edge().GetTo().GetUniqueLabel()
edgeFrom = senders[1].edge().GetFrom().GetUniqueLabel()
}
attrs = append(
attrs,
attribute.String("edge.to", edgeTo),
attribute.String("edge.from", edgeFrom),
)
ctx, span := pipelineTracer.Start(r.ctx, "message.received", trace.WithAttributes(attrs...))
defer span.End()
results := r.interpreter.interpret(ctx, senders[1].edge(), msg.Value.Items)
for item := range results {
if item.Err != nil {
excludedErrs = append(excludedErrs, item)
continue
}
excluded[item.Value] = struct{}{}
}
msg.done()
return true
}
wg.Add(1)
go func() {
defer wg.Done()
senders[0].loop(procIncluded)
}()
wg.Add(1)
go func() {
defer wg.Done()
senders[1].loop(procExcluded)
}()
wg.Wait()
var allErrs []Item
allErrs = append(allErrs, includedErrs...)
allErrs = append(allErrs, excludedErrs...)
filteredSeq := seq.Filter(maps.Keys(included), func(v string) bool {
_, ok := excluded[v]
return !ok
})
flattenedSeq := seq.Flatten(seq.Sequence(allErrs...), seq.Transform(filteredSeq, strToItem))
var items []Item
for item := range flattenedSeq {
items = append(items, item)
}
outGroup := group{
Items: items,
}
for _, lst := range listeners {
lst.send(outGroup)
}
}
// group is a struct that acts as a container for a set of `Item` values.
type group struct {
Items []Item
}
// Item is a struct that contains an object `string` as its `Value` or an
// encountered error as its `Err`. Item is the primary container used to
// communicate values as they pass through a `Pipeline`.
type Item struct {
Value string
Err error
}
// interpreter is an interface that exposes a method for interpreting input for an edge into output.
type interpreter interface {
interpret(ctx context.Context, edge *Edge, items []Item) iter.Seq[Item]
}
type intersectionResolver struct {
id int
ctx context.Context
interpreter interpreter
status *StatusPool
trk tracker
}
func (r *intersectionResolver) resolve(senders []*sender, listeners []*listener) {
defer func() {
r.trk.Add(-1)
r.status.Set(r.id, false)
}()
r.trk.Add(1)
var wg sync.WaitGroup
objects := make(map[string]struct{})
buffers := make([]map[string]struct{}, len(senders))
output := make([]map[string]struct{}, len(senders))
for i := range senders {
buffers[i] = make(map[string]struct{})
output[i] = make(map[string]struct{})
}
errs := make([][]Item, len(senders))
for i, snd := range senders {
wg.Add(1)
go func(i int, snd *sender) {
defer wg.Done()
snd.loop(func(msg message[group]) bool {
attrs := []attribute.KeyValue{
attribute.String("resolver", "intersection"),
attribute.Int("item.count", len(msg.Value.Items)),
attribute.Int("sender.index", i),
}
edgeTo := "nil"
edgeFrom := "nil"
if snd.edge() != nil {
edgeTo = snd.edge().GetTo().GetUniqueLabel()
edgeFrom = snd.edge().GetFrom().GetUniqueLabel()
}
attrs = append(
attrs,
attribute.String("edge.to", edgeTo),
attribute.String("edge.from", edgeFrom),
)
ctx, span := pipelineTracer.Start(r.ctx, "message.received", trace.WithAttributes(attrs...))
defer span.End()
var unseen []Item
// Deduplicate items within this group based on the buffer for this sender
for _, item := range msg.Value.Items {
if item.Err != nil {
unseen = append(unseen, item)
continue
}
if _, ok := buffers[i][item.Value]; ok {
continue
}
unseen = append(unseen, item)
buffers[i][item.Value] = struct{}{}
}
// If there are no unseen items, skip processing
if len(unseen) == 0 {
msg.done()
return true
}
results := r.interpreter.interpret(ctx, snd.edge(), unseen)
for item := range results {
if r.ctx.Err() != nil {
break
}
if item.Err != nil {
errs[i] = append(errs[i], item)
}
output[i][item.Value] = struct{}{}
}
msg.done()
return true
})
}(i, snd)
}
wg.Wait()
OutputLoop:
for obj := range output[0] {
for i := 1; i < len(output); i++ {
if _, ok := output[i][obj]; !ok {
continue OutputLoop
}
}
objects[obj] = struct{}{}
}
var allErrs []Item
for _, errList := range errs {
allErrs = append(allErrs, errList...)
}
seq := seq.Flatten(seq.Sequence(allErrs...), seq.Transform(maps.Keys(objects), strToItem))
var items []Item
for item := range seq {
items = append(items, item)
}
outGroup := group{
Items: items,
}
for _, lst := range listeners {
lst.send(outGroup)
}
}
// listener is a struct that contains fields relevant to the listening
// end of a pipeline connection.
type listener struct {
cons consumer[group]
// node is the weighted graph node that is listening.
node *Node
}
func (lst *listener) send(g group) {
lst.cons.send(g)
}
func (lst *listener) cancel() {
lst.cons.cancel()
}
func (lst *listener) close() {
lst.cons.close()
}
type loopFunc func(message[group]) bool
type omniInterpreter struct {
hndNil edgeHandler
hndDirect edgeHandler
hndTTU edgeHandler
hndComputed edgeHandler
hndRewrite edgeHandler
hndDirectLogical edgeHandler
hndTTULogical edgeHandler
}
func (o *omniInterpreter) interpret(ctx context.Context, edge *Edge, items []Item) iter.Seq[Item] {
var results iter.Seq[Item]
if edge == nil {
results = o.hndNil(ctx, edge, items)
return results
}
switch edge.GetEdgeType() {
case edgeTypeDirect:
results = o.hndDirect(ctx, edge, items)
case edgeTypeTTU:
results = o.hndTTU(ctx, edge, items)
case edgeTypeComputed:
results = o.hndComputed(ctx, edge, items)
case edgeTypeRewrite:
results = o.hndRewrite(ctx, edge, items)
case edgeTypeDirectLogical:
results = o.hndDirectLogical(ctx, edge, items)
case edgeTypeTTULogical:
results = o.hndTTULogical(ctx, edge, items)
default:
return seq.Sequence(Item{Err: errors.New("unexpected edge type")})
}
return results
}
type Pipeline struct {
backend *Backend
chunkSize int
numProcs int
}
func (p *Pipeline) Build(ctx context.Context, source Source, target Target) iter.Seq[Item] {
ctxParent, span := pipelineTracer.Start(ctx, "Pipeline.Build")
defer span.End()
ctxNoCancel := context.WithoutCancel(ctxParent)
ctxCancel, cancel := context.WithCancel(ctxNoCancel)
pth := path{
// Removing the cancel on this context ensures that
// workers created during `path::resolve` will not
// end when the parent context cancels. This is important
// because it ensures that the watchdog goroutine is
// responsible for shutting down workers.
ctx: ctxNoCancel,
pipe: p,
workers: make(map[*Node]*worker),
}
pth.resolve((*Node)(source), target, nil, nil)
sourceWorker, ok := pth.workers[(*Node)(source)]
if !ok {
panic("no such source worker")
}
var wg sync.WaitGroup
results := sourceWorker.subscribe(nil)
for _, w := range pth.workers {
w.start()
}
wg.Add(1)
go func() {
defer wg.Done()
for {
var inactiveCount int
for _, w := range pth.workers {
if !w.active() {
w.close()
inactiveCount++
}
}
if inactiveCount == len(pth.workers) {
break
}
messageCount := pth.trk.Load()
if messageCount < 1 || ctxCancel.Err() != nil {
// cancel all running workers
for _, w := range pth.workers {
w.cancel()
}
// wait for all workers to finish
for _, w := range pth.workers {
w.wait()
}
break
}
runtime.Gosched()
}
}()
return func(yield func(Item) bool) {
defer wg.Wait()
defer cancel()
for msg := range results.seq(ctxParent) {
for _, item := range msg.Value.Items {
if !yield(item) {
msg.done()
return
}
}
msg.done()
}
}
}
func (p *Pipeline) Source(name, relation string) (Source, bool) {
sourceNode, ok := p.backend.Graph.GetNodeByID(name + "#" + relation)
return (Source)(sourceNode), ok
}
func (p *Pipeline) Target(name, identifier string) (Target, bool) {
if identifier == "*" {
name += ":*"
identifier = ""
}
targetNode, ok := p.backend.Graph.GetNodeByID(name)
return Target{
node: targetNode,
id: identifier,
}, ok
}
type PipelineOption func(*Pipeline)
type path struct {
ctx context.Context
pipe *Pipeline
workers map[*Node]*worker
trk atomic.Int64
}
func (p *path) resolve(source *Node, target Target, trk tracker, status *StatusPool) {
if _, ok := p.workers[source]; ok {
return
}
if trk == nil {
trk = newEchoTracker(&p.trk)
}
if status == nil {
status = new(StatusPool)
}
w := p.worker(source, trk, status)
p.workers[source] = w
switch source.GetNodeType() {
case nodeTypeSpecificType, nodeTypeSpecificTypeAndRelation:
if source == target.node {
// source node is the target node.
var grp group
grp.Items = []Item{{Value: target.id}}
w.listen(nil, newStaticProducer(&p.trk, grp), p.pipe.chunkSize, 1) // only one value to consume, so only one processor necessary.
}
case nodeTypeSpecificTypeWildcard:
label := source.GetLabel()
typePart := strings.Split(label, ":")[0]
if source == target.node || typePart == target.node.GetLabel() {
// source node is the target node or has the same type as the target.
var grp group
grp.Items = []Item{{Value: "*"}}
w.listen(nil, newStaticProducer(&p.trk, grp), p.pipe.chunkSize, 1) // only one value to consume, so only one processor necessary.
}
}
edges, ok := p.pipe.backend.Graph.GetEdgesFromNode(source)
if !ok {
return
}
for _, edge := range edges {
var track tracker
var stat *StatusPool
isRecursive := len(edge.GetRecursiveRelation()) > 0
if isRecursive {
track = w.trk
stat = status
}
p.resolve(edge.GetTo(), target, track, stat)
w.listen(edge, p.workers[edge.GetTo()].subscribe(source), p.pipe.chunkSize, p.pipe.numProcs)
}
}
func (p *path) worker(node *Node, trk tracker, status *StatusPool) *worker {
var w worker
ctx, cancel := context.WithCancel(p.ctx)
w.finite = sync.OnceFunc(func() {
cancel()
for _, lst := range w.listeners {
lst.close()
}
})
w.node = node
w.status = status
w.trk = trk
var r resolver
id := status.Register()
status.Set(id, true)
switch node.GetNodeType() {
case nodeTypeSpecificTypeAndRelation:
omni := &omniInterpreter{
hndNil: handleLeafNode(node),
hndDirect: p.pipe.backend.handleDirectEdge,
hndTTU: p.pipe.backend.handleTTUEdge,
hndComputed: handleIdentity,
hndRewrite: handleIdentity,
hndDirectLogical: handleUnsupported,
hndTTULogical: handleUnsupported,
}
r = &baseResolver{
id: id,
ctx: ctx,
interpreter: omni,
status: status,
}
case nodeTypeSpecificType:
omni := &omniInterpreter{
hndNil: handleLeafNode(node),
hndDirect: handleUnsupported,
hndTTU: handleUnsupported,
hndComputed: handleUnsupported,
hndRewrite: handleUnsupported,
hndDirectLogical: handleUnsupported,
hndTTULogical: handleUnsupported,
}
r = &baseResolver{
id: id,
ctx: ctx,
interpreter: omni,
status: status,
}
case nodeTypeSpecificTypeWildcard:
omni := &omniInterpreter{
hndNil: handleLeafNode(node),
hndDirect: handleUnsupported,
hndTTU: handleUnsupported,
hndComputed: handleUnsupported,
hndRewrite: handleUnsupported,
hndDirectLogical: handleUnsupported,
hndTTULogical: handleUnsupported,
}
r = &baseResolver{
id: id,
ctx: ctx,
interpreter: omni,
status: status,
}
case nodeTypeOperator:
switch node.GetLabel() {
case weightedGraph.IntersectionOperator:
omni := &omniInterpreter{
hndNil: handleUnsupported,
hndDirect: p.pipe.backend.handleDirectEdge,
hndTTU: p.pipe.backend.handleTTUEdge,
hndComputed: handleIdentity,
hndRewrite: handleIdentity,
hndDirectLogical: handleIdentity,
hndTTULogical: handleIdentity,
}
r = &intersectionResolver{
id: id,
ctx: ctx,
interpreter: omni,
status: status,
trk: trk,
}
case weightedGraph.UnionOperator:
omni := &omniInterpreter{
hndNil: handleUnsupported,
hndDirect: p.pipe.backend.handleDirectEdge,
hndTTU: p.pipe.backend.handleTTUEdge,
hndComputed: handleIdentity,
hndRewrite: handleIdentity,
hndDirectLogical: handleIdentity,
hndTTULogical: handleIdentity,
}
r = &baseResolver{
id: id,
ctx: ctx,
interpreter: omni,
status: status,
}
case weightedGraph.ExclusionOperator:
omni := &omniInterpreter{
hndNil: handleUnsupported,
hndDirect: p.pipe.backend.handleDirectEdge,
hndTTU: p.pipe.backend.handleTTUEdge,
hndComputed: handleIdentity,
hndRewrite: handleIdentity,
hndDirectLogical: handleIdentity,
hndTTULogical: handleIdentity,
}
r = &exclusionResolver{
id: id,
ctx: ctx,
interpreter: omni,
status: status,
trk: trk,
}
default:
panic("unsupported operator node for reverse expand worker")
}
case nodeTypeLogicalDirectGrouping:
omni := &omniInterpreter{
hndNil: handleUnsupported,
hndDirect: p.pipe.backend.handleDirectEdge,
hndTTU: handleUnsupported,
hndComputed: handleUnsupported,
hndRewrite: handleUnsupported,
hndDirectLogical: handleUnsupported,
hndTTULogical: handleUnsupported,
}
r = &baseResolver{
id: id,
ctx: ctx,
interpreter: omni,
status: status,
}
case nodeTypeLogicalTTUGrouping:
omni := &omniInterpreter{
hndNil: handleUnsupported,
hndDirect: handleUnsupported,
hndTTU: p.pipe.backend.handleTTUEdge,
hndComputed: handleUnsupported,
hndRewrite: handleUnsupported,
hndDirectLogical: handleUnsupported,
hndTTULogical: handleUnsupported,
}
r = &baseResolver{
id: id,
ctx: ctx,
interpreter: omni,
status: status,
}
default:
panic("unsupported node type for reverse expand worker")
}
w.resolver = r
pw := &w
p.workers[node] = pw
return pw
}
type queryInput struct {
objectType string
objectRelation string
userFilter []*openfgav1.ObjectRelation
conditions []string
}
// resolver is an interface that is consumed by a worker struct.
// A resolver is responsible for consuming messages from a worker's
// senders and broadcasting the result of processing the consumed
// messages to the worker's listeners.
type resolver interface {
// resolve is a function that consumes messages from the
// provided senders, and broadcasts the results of processing
// the consumed messages to the provided listeners.
resolve(senders []*sender, listeners []*listener)
}
type Source *Node
// sender is a struct that contains fields relevant to the producing
// end of a pipeline connection.
type sender struct {
// edge is the weighted graph edge that is producing.
e *Edge
prod producer[group]
// chunkSize is the target number of items to include in each
// outbound message. A value less than 1 indicates an unlimited
// number of items per message.
chunkSize int
numProcs int
}
func (s *sender) chunks() int {
return s.chunkSize
}
func (s *sender) edge() *Edge {
return s.e
}
func (s *sender) loop(fn loopFunc) {
// A background context is used here because the sender's underlying pipe
// relies on a cancelation from its consuming end for a shutdown signal.
// This sequence does not need an additional cancelation signal for that
// reason.
for msg := range s.prod.seq(context.Background()) {
if !fn(msg) {
break
}
}
}
func (s *sender) procs() int {
return s.numProcs
}
// strtoItem is a function that accepts a string input and returns an Item
// that contains the input as its `Value` value.
func strToItem(s string) Item {
return Item{Value: s}
}
type Target struct {
node *Node
id string
}
type worker struct {
node *Node
senders []*sender
listeners []*listener
resolver resolver
trk tracker
status *StatusPool
finite func()
wg sync.WaitGroup
}
func (w *worker) active() bool {
return w.status.Status() || w.trk.Load() != 0
}
func (w *worker) cancel() {
for _, lst := range w.listeners {
lst.cancel()
}
}
func (w *worker) close() {
w.finite()
}
func (w *worker) listen(edge *Edge, p producer[group], chunkSize int, numProcs int) {
w.senders = append(w.senders, &sender{
e: edge,
prod: p,
chunkSize: chunkSize,
numProcs: numProcs,
})
}
func (w *worker) start() {
w.wg.Add(1)
go func() {
defer w.wg.Done()
defer w.close()
w.resolver.resolve(w.senders, w.listeners)
}()
}
func (w *worker) subscribe(node *Node) producer[group] {
p := newPipe(w.trk)
w.listeners = append(w.listeners, &listener{
cons: p,
node: node,
})
return p
}
func (w *worker) wait() {
w.wg.Wait()
}
package reverseexpand
import (
"sync"
"sync/atomic"
)
type tracker interface {
Add(int64) int64
Load() int64
}
type echoTracker struct {
local atomic.Int64
parent tracker
}
func (t *echoTracker) Add(i int64) int64 {
value := t.local.Add(i)
if t.parent != nil {
t.parent.Add(i)
}
return value
}
func (t *echoTracker) Load() int64 {
return t.local.Load()
}
func newEchoTracker(parent tracker) tracker {
return &echoTracker{
parent: parent,
}
}
// StatusPool is a struct that aggregates status values, as booleans, from multiple sources
// into a single boolean status value. Each source must register itself using the `Register`
// method and supply the returned value in each call to `Set` when updating the source's status
// value. The default state of a StatusPool is `false` for all sources. All StatusPool methods
// are thread safe.
type StatusPool struct {
mu sync.Mutex
pool []uint64
top int
}
// Register is a function that creates a new entry in the StatusPool for a source and returns
// an identifier that is unique within the context of the StatusPool instance. The returned
// integer identifier values are predictable incrementing values beginning at 0. The `Register`
// method is thread safe.
func (sp *StatusPool) Register() int {
sp.mu.Lock()
defer sp.mu.Unlock()
capacity := len(sp.pool)
if sp.top/64 >= capacity {
sp.pool = append(sp.pool, 0)
}
id := sp.top
sp.top++
return id
}
// Set is a function that accepts a registered identifier and a boolean status. The caller must
// provide an integer identifier returned from an initial call to the `Register` function associated
// with the desired source. The `Set` function is thread safe.
func (sp *StatusPool) Set(id int, status bool) {
sp.mu.Lock()
defer sp.mu.Unlock()
ndx := id / 64
pos := uint64(1 << (id % 64))
if status {
sp.pool[ndx] |= pos
return
}
sp.pool[ndx] &^= pos
}
// Status is a function that returns the cummulative status of all sources registered within the pool.
// If any registered source's status is set to `true`, the return value of the `Status` function will
// be `true`. The default value is `false`. The `Status` function is thread safe.
func (sp *StatusPool) Status() bool {
sp.mu.Lock()
defer sp.mu.Unlock()
var status uint64
for _, s := range sp.pool {
status |= s
}
return status != 0
}
// Package reverseexpand contains the code that handles the ReverseExpand API
package reverseexpand
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"google.golang.org/protobuf/types/known/structpb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
weightedGraph "github.com/openfga/language/pkg/go/graph"
"github.com/openfga/openfga/internal/concurrency"
"github.com/openfga/openfga/internal/condition/eval"
"github.com/openfga/openfga/internal/graph"
"github.com/openfga/openfga/internal/stack"
"github.com/openfga/openfga/internal/throttler"
"github.com/openfga/openfga/internal/throttler/threshold"
"github.com/openfga/openfga/internal/validation"
"github.com/openfga/openfga/pkg/logger"
serverconfig "github.com/openfga/openfga/pkg/server/config"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/telemetry"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
var tracer = otel.Tracer("openfga/pkg/server/commands/reverse_expand")
type ReverseExpandRequest struct {
StoreID string
ObjectType string
Relation string
User IsUserRef
ContextualTuples []*openfgav1.TupleKey // TODO remove
Context *structpb.Struct
Consistency openfgav1.ConsistencyPreference
edge *graph.RelationshipEdge
skipWeightedGraph bool
weightedEdge *weightedGraph.WeightedAuthorizationModelEdge
relationStack stack.Stack[typeRelEntry]
}
func (r *ReverseExpandRequest) clone() *ReverseExpandRequest {
if r == nil {
return nil
}
copyRequest := *r
return ©Request
}
type IsUserRef interface {
isUserRef()
GetObjectType() string
String() string
}
type UserRefObject struct {
Object *openfgav1.Object
}
var _ IsUserRef = (*UserRefObject)(nil)
func (u *UserRefObject) isUserRef() {}
func (u *UserRefObject) GetObjectType() string {
return u.Object.GetType()
}
func (u *UserRefObject) String() string {
return tuple.BuildObject(u.Object.GetType(), u.Object.GetId())
}
type UserRefTypedWildcard struct {
Type string
}
var _ IsUserRef = (*UserRefTypedWildcard)(nil)
func (*UserRefTypedWildcard) isUserRef() {}
func (u *UserRefTypedWildcard) GetObjectType() string {
return u.Type
}
func (u *UserRefTypedWildcard) String() string {
return tuple.TypedPublicWildcard(u.Type)
}
type UserRefObjectRelation struct {
ObjectRelation *openfgav1.ObjectRelation
Condition *openfgav1.RelationshipCondition
}
func (*UserRefObjectRelation) isUserRef() {}
func (u *UserRefObjectRelation) GetObjectType() string {
return tuple.GetType(u.ObjectRelation.GetObject())
}
func (u *UserRefObjectRelation) String() string {
return tuple.ToObjectRelationString(
u.ObjectRelation.GetObject(),
u.ObjectRelation.GetRelation(),
)
}
type UserRef struct {
// Types that are assignable to Ref
// *UserRef_Object
// *UserRef_TypedWildcard
// *UserRef_ObjectRelation
Ref IsUserRef
}
type ReverseExpandQuery struct {
logger logger.Logger
datastore storage.RelationshipTupleReader
typesystem *typesystem.TypeSystem
resolveNodeLimit uint32
resolveNodeBreadthLimit uint32
dispatchThrottlerConfig threshold.Config
// visitedUsersetsMap map prevents visiting the same userset through the same edge twice
visitedUsersetsMap *sync.Map
// candidateObjectsMap map prevents returning the same object twice
candidateObjectsMap *sync.Map
// queryDedupeMap prevents multiple branches of exploration from running
// the same queries, since multiple leaf nodes can have a common ancestor
queryDedupeMap *sync.Map
// localCheckResolver allows reverse expand to call check locally
localCheckResolver graph.CheckRewriteResolver
optimizationsEnabled bool
}
type ReverseExpandQueryOption func(d *ReverseExpandQuery)
func WithResolveNodeLimit(limit uint32) ReverseExpandQueryOption {
return func(d *ReverseExpandQuery) {
d.resolveNodeLimit = limit
}
}
func WithDispatchThrottlerConfig(config threshold.Config) ReverseExpandQueryOption {
return func(d *ReverseExpandQuery) {
d.dispatchThrottlerConfig = config
}
}
func WithResolveNodeBreadthLimit(limit uint32) ReverseExpandQueryOption {
return func(d *ReverseExpandQuery) {
d.resolveNodeBreadthLimit = limit
}
}
func WithCheckResolver(resolver graph.CheckResolver) ReverseExpandQueryOption {
return func(d *ReverseExpandQuery) {
localCheckResolver, found := graph.LocalCheckResolver(resolver)
if found {
d.localCheckResolver = localCheckResolver
}
}
}
func WithListObjectOptimizationsEnabled(enabled bool) ReverseExpandQueryOption {
return func(d *ReverseExpandQuery) {
d.optimizationsEnabled = enabled
}
}
// TODO accept ReverseExpandRequest so we can build the datastore object right away.
func NewReverseExpandQuery(ds storage.RelationshipTupleReader, ts *typesystem.TypeSystem, opts ...ReverseExpandQueryOption) *ReverseExpandQuery {
query := &ReverseExpandQuery{
logger: logger.NewNoopLogger(),
datastore: ds,
typesystem: ts,
resolveNodeLimit: serverconfig.DefaultResolveNodeLimit,
resolveNodeBreadthLimit: serverconfig.DefaultResolveNodeBreadthLimit,
dispatchThrottlerConfig: threshold.Config{
Throttler: throttler.NewNoopThrottler(),
Enabled: serverconfig.DefaultListObjectsDispatchThrottlingEnabled,
Threshold: serverconfig.DefaultListObjectsDispatchThrottlingDefaultThreshold,
MaxThreshold: serverconfig.DefaultListObjectsDispatchThrottlingMaxThreshold,
},
candidateObjectsMap: new(sync.Map),
visitedUsersetsMap: new(sync.Map),
queryDedupeMap: new(sync.Map),
localCheckResolver: graph.NewLocalChecker(),
}
for _, opt := range opts {
opt(query)
}
return query
}
type ConditionalResultStatus int
const (
RequiresFurtherEvalStatus ConditionalResultStatus = iota
NoFurtherEvalStatus
)
type ReverseExpandResult struct {
Object string
ResultStatus ConditionalResultStatus
}
type ResolutionMetadata struct {
// The number of times we are expanding from each node to find set of objects
DispatchCounter *atomic.Uint32
// WasThrottled indicates whether the request was throttled
WasThrottled *atomic.Bool
// WasWeightedGraphUsed indicates whether the weighted graph was used as the algorithm for the ReverseExpand request.
WasWeightedGraphUsed *atomic.Bool
// The number of times internal check was called for the optimization path
CheckCounter *atomic.Uint32
}
func NewResolutionMetadata() *ResolutionMetadata {
return &ResolutionMetadata{
DispatchCounter: new(atomic.Uint32),
WasThrottled: new(atomic.Bool),
WasWeightedGraphUsed: new(atomic.Bool),
CheckCounter: new(atomic.Uint32),
}
}
func WithLogger(logger logger.Logger) ReverseExpandQueryOption {
return func(d *ReverseExpandQuery) {
d.logger = logger
}
}
// shallowClone creates an identical copy of reverseExpandQuery except
// candidateObjectsMap as list object candidates need to be validated
// via check.
func (c *ReverseExpandQuery) shallowClone() *ReverseExpandQuery {
if c == nil {
return nil
}
copy := *c
copy.candidateObjectsMap = new(sync.Map)
return ©
}
// Execute yields all the objects of the provided objectType that the
// given user possibly has, a specific relation with and sends those
// objects to resultChan. It MUST guarantee no duplicate objects sent.
//
// This function respects context timeouts and cancellations. If an
// error is encountered (e.g. context timeout) before resolving all
// objects, then the provided channel will NOT be closed, and it will
// return the error.
//
// If no errors occur, then Execute will yield all of the objects on
// the provided channel and then close the channel to signal that it
// is done.
func (c *ReverseExpandQuery) Execute(
ctx context.Context,
req *ReverseExpandRequest,
resultChan chan<- *ReverseExpandResult,
resolutionMetadata *ResolutionMetadata,
) error {
ctx = storage.ContextWithRelationshipTupleReader(ctx, c.datastore)
err := c.execute(ctx, req, resultChan, false, resolutionMetadata)
if err != nil {
return err
}
close(resultChan)
return nil
}
func (c *ReverseExpandQuery) dispatch(
ctx context.Context,
req *ReverseExpandRequest,
resultChan chan<- *ReverseExpandResult,
intersectionOrExclusionInPreviousEdges bool,
resolutionMetadata *ResolutionMetadata,
) error {
newcount := resolutionMetadata.DispatchCounter.Add(1)
if c.dispatchThrottlerConfig.Enabled {
c.throttle(ctx, newcount, resolutionMetadata)
}
return c.execute(ctx, req, resultChan, intersectionOrExclusionInPreviousEdges, resolutionMetadata)
}
func (c *ReverseExpandQuery) execute(
ctx context.Context,
req *ReverseExpandRequest,
resultChan chan<- *ReverseExpandResult,
intersectionOrExclusionInPreviousEdges bool,
resolutionMetadata *ResolutionMetadata,
) error {
if ctx.Err() != nil {
return ctx.Err()
}
ctx, span := tracer.Start(ctx, "reverseExpand.Execute", trace.WithAttributes(
attribute.String("target_type", req.ObjectType),
attribute.String("target_relation", req.Relation),
attribute.String("source", req.User.String()),
))
defer span.End()
if req.edge != nil {
span.SetAttributes(attribute.String("edge", req.edge.String()))
}
depth, ok := graph.ResolutionDepthFromContext(ctx)
if !ok {
ctx = graph.ContextWithResolutionDepth(ctx, 0)
} else {
if depth >= c.resolveNodeLimit {
return graph.ErrResolutionDepthExceeded
}
ctx = graph.ContextWithResolutionDepth(ctx, depth+1)
}
var sourceUserRef *openfgav1.RelationReference
var sourceUserType, sourceUserObj string
// e.g. 'user:bob'
if val, ok := req.User.(*UserRefObject); ok {
sourceUserType = val.Object.GetType()
sourceUserObj = tuple.BuildObject(sourceUserType, val.Object.GetId())
sourceUserRef = typesystem.DirectRelationReference(sourceUserType, "")
}
// e.g. 'user:*'
if val, ok := req.User.(*UserRefTypedWildcard); ok {
sourceUserType = val.Type
sourceUserRef = typesystem.WildcardRelationReference(sourceUserType)
}
// e.g. 'group:eng#member'
if userset, ok := req.User.(*UserRefObjectRelation); ok {
sourceUserType = tuple.GetType(userset.ObjectRelation.GetObject())
sourceUserObj = userset.ObjectRelation.GetObject()
sourceUserRef = typesystem.DirectRelationReference(sourceUserType, userset.ObjectRelation.GetRelation())
// Queries that come in explicitly looking for userset relations will skip weighted graph for now.
// e.g. ListObjects(document, viewer, team:fga#member)
req.skipWeightedGraph = true
if req.edge != nil {
key := fmt.Sprintf("%s#%s", sourceUserObj, req.edge.String())
if _, loaded := c.visitedUsersetsMap.LoadOrStore(key, struct{}{}); loaded {
// we've already visited this userset through this edge, exit to avoid an infinite cycle
return nil
}
}
// ReverseExpand(type=document, rel=viewer, user=document:1#viewer) will return "document:1"
if tuple.UsersetMatchTypeAndRelation(userset.String(), req.Relation, req.ObjectType) {
c.trySendCandidate(ctx, intersectionOrExclusionInPreviousEdges, sourceUserObj, resultChan)
}
}
targetObjRef := typesystem.DirectRelationReference(req.ObjectType, req.Relation)
if c.optimizationsEnabled && !req.skipWeightedGraph {
var typeRel string
if req.weightedEdge != nil {
typeRel = req.weightedEdge.GetTo().GetUniqueLabel()
} else { // true on first call to ReverseExpand
typeRel = tuple.ToObjectRelationString(targetObjRef.GetType(), targetObjRef.GetRelation())
node, ok := c.typesystem.GetNode(typeRel)
if !ok {
// The weighted graph is not guaranteed to be present.
// If there's no weighted graph, which can happen for models with disconnected types, we will log an error below
// and then fall back to the non-weighted version of reverse_expand
c.logger.InfoWithContext(ctx, "unable to find node in weighted graph", zap.String("node_id", typeRel))
req.skipWeightedGraph = true
} else {
weight, _ := node.GetWeight(sourceUserType)
if weight == weightedGraph.Infinite {
c.logger.InfoWithContext(ctx, "reverse_expand graph may contain cycle, skipping weighted graph", zap.String("node_id", typeRel))
req.skipWeightedGraph = true
}
}
}
if !req.skipWeightedGraph {
if req.weightedEdge == nil { // true on the first invocation only
req.relationStack = stack.Push(nil, typeRelEntry{typeRel: typeRel})
}
edges, _ := c.typesystem.GetConnectedEdges(
typeRel,
sourceUserType,
)
// error should never happen as if the weighted graph failed to build, req.skipWeightedGraph would
// have prevented us from entering this block
// Set value to indicate that the weighted graph was used
resolutionMetadata.WasWeightedGraphUsed.Store(true)
return c.loopOverEdges(
ctx,
req,
edges,
intersectionOrExclusionInPreviousEdges,
resolutionMetadata,
resultChan,
sourceUserType,
)
}
}
g := graph.New(c.typesystem)
edges, err := g.GetPrunedRelationshipEdges(targetObjRef, sourceUserRef)
if err != nil {
return err
}
pool := concurrency.NewPool(ctx, int(c.resolveNodeBreadthLimit))
var errs error
LoopOnEdges:
for _, edge := range edges {
innerLoopEdge := edge
intersectionOrExclusionInPreviousEdges := intersectionOrExclusionInPreviousEdges || innerLoopEdge.TargetReferenceInvolvesIntersectionOrExclusion
r := &ReverseExpandRequest{
StoreID: req.StoreID,
ObjectType: req.ObjectType,
Relation: req.Relation,
User: req.User,
ContextualTuples: req.ContextualTuples,
Context: req.Context,
edge: innerLoopEdge,
Consistency: req.Consistency,
skipWeightedGraph: req.skipWeightedGraph,
}
switch innerLoopEdge.Type {
case graph.DirectEdge:
pool.Go(func(ctx context.Context) error {
return c.reverseExpandDirect(ctx, r, resultChan, intersectionOrExclusionInPreviousEdges, resolutionMetadata)
})
case graph.ComputedUsersetEdge:
// follow the computed_userset edge, no new goroutine needed since it's not I/O intensive
r.User = &UserRefObjectRelation{
ObjectRelation: &openfgav1.ObjectRelation{
Object: sourceUserObj,
Relation: innerLoopEdge.TargetReference.GetRelation(),
},
}
err = c.dispatch(ctx, r, resultChan, intersectionOrExclusionInPreviousEdges, resolutionMetadata)
if err != nil {
errs = errors.Join(errs, err)
break LoopOnEdges
}
case graph.TupleToUsersetEdge:
pool.Go(func(ctx context.Context) error {
return c.reverseExpandTupleToUserset(ctx, r, resultChan, intersectionOrExclusionInPreviousEdges, resolutionMetadata)
})
default:
return fmt.Errorf("unsupported edge type: %v", innerLoopEdge.Type)
}
}
errs = errors.Join(errs, pool.Wait())
if errs != nil {
telemetry.TraceError(span, errs)
return errs
}
return nil
}
func (c *ReverseExpandQuery) reverseExpandTupleToUserset(
ctx context.Context,
req *ReverseExpandRequest,
resultChan chan<- *ReverseExpandResult,
intersectionOrExclusionInPreviousEdges bool,
resolutionMetadata *ResolutionMetadata,
) error {
ctx, span := tracer.Start(ctx, "reverseExpandTupleToUserset", trace.WithAttributes(
attribute.String("edge", req.edge.String()),
attribute.String("source.user", req.User.String()),
))
var err error
defer func() {
if err != nil {
telemetry.TraceError(span, err)
}
span.End()
}()
err = c.readTuplesAndExecute(ctx, req, resultChan, intersectionOrExclusionInPreviousEdges, resolutionMetadata)
return err
}
func (c *ReverseExpandQuery) reverseExpandDirect(
ctx context.Context,
req *ReverseExpandRequest,
resultChan chan<- *ReverseExpandResult,
intersectionOrExclusionInPreviousEdges bool,
resolutionMetadata *ResolutionMetadata,
) error {
ctx, span := tracer.Start(ctx, "reverseExpandDirect", trace.WithAttributes(
attribute.String("edge", req.edge.String()),
attribute.String("source.user", req.User.String()),
))
var err error
defer func() {
if err != nil {
telemetry.TraceError(span, err)
}
span.End()
}()
err = c.readTuplesAndExecute(ctx, req, resultChan, intersectionOrExclusionInPreviousEdges, resolutionMetadata)
return err
}
func (c *ReverseExpandQuery) shouldCheckPublicAssignable(targetReference *openfgav1.RelationReference, userRef IsUserRef) (bool, error) {
_, userIsUserset := userRef.(*UserRefObjectRelation)
if userIsUserset {
// if the user is an userset, by definition it is not public assignable
return false, nil
}
publiclyAssignable, err := c.typesystem.IsPubliclyAssignable(targetReference, userRef.GetObjectType())
if err != nil {
return false, err
}
return publiclyAssignable, nil
}
func (c *ReverseExpandQuery) readTuplesAndExecute(
ctx context.Context,
req *ReverseExpandRequest,
resultChan chan<- *ReverseExpandResult,
intersectionOrExclusionInPreviousEdges bool,
resolutionMetadata *ResolutionMetadata,
) error {
if ctx.Err() != nil {
return ctx.Err()
}
ctx, span := tracer.Start(ctx, "readTuplesAndExecute")
defer span.End()
var userFilter []*openfgav1.ObjectRelation
var relationFilter string
switch req.edge.Type {
case graph.DirectEdge:
relationFilter = req.edge.TargetReference.GetRelation()
targetUserObjectType := req.User.GetObjectType()
publiclyAssignable, err := c.shouldCheckPublicAssignable(req.edge.TargetReference, req.User)
if err != nil {
return err
}
if publiclyAssignable {
// e.g. 'user:*'
userFilter = append(userFilter, &openfgav1.ObjectRelation{
Object: tuple.TypedPublicWildcard(targetUserObjectType),
})
}
// e.g. 'user:bob'
if val, ok := req.User.(*UserRefObject); ok {
userFilter = append(userFilter, &openfgav1.ObjectRelation{
Object: tuple.BuildObject(val.Object.GetType(), val.Object.GetId()),
})
}
// e.g. 'group:eng#member'
if val, ok := req.User.(*UserRefObjectRelation); ok {
userFilter = append(userFilter, val.ObjectRelation)
}
case graph.TupleToUsersetEdge:
relationFilter = req.edge.TuplesetRelation
// a TTU edge can only have a userset as a source node
// e.g. 'group:eng#member'
if val, ok := req.User.(*UserRefObjectRelation); ok {
userFilter = append(userFilter, &openfgav1.ObjectRelation{
Object: val.ObjectRelation.GetObject(),
})
} else {
panic("unexpected source for reverse expansion of tuple to userset")
}
default:
panic("unsupported edge type")
}
// find all tuples of the form req.edge.TargetReference.Type:...#relationFilter@userFilter
iter, err := c.datastore.ReadStartingWithUser(ctx, req.StoreID, storage.ReadStartingWithUserFilter{
ObjectType: req.edge.TargetReference.GetType(),
Relation: relationFilter,
UserFilter: userFilter,
}, storage.ReadStartingWithUserOptions{
Consistency: storage.ConsistencyOptions{
Preference: req.Consistency,
},
})
if err != nil {
return err
}
// filter out invalid tuples yielded by the database iterator
filteredIter := storage.NewFilteredTupleKeyIterator(
storage.NewTupleKeyIteratorFromTupleIterator(iter),
validation.FilterInvalidTuples(c.typesystem),
)
defer filteredIter.Stop()
pool := concurrency.NewPool(ctx, int(c.resolveNodeBreadthLimit))
var errs error
LoopOnIterator:
for {
tk, err := filteredIter.Next(ctx)
if err != nil {
if errors.Is(err, storage.ErrIteratorDone) {
break
}
errs = errors.Join(errs, err)
break LoopOnIterator
}
cond, _ := c.typesystem.GetCondition(tk.GetCondition().GetName())
condMet, err := eval.EvaluateTupleCondition(ctx, tk, cond, req.Context)
if err != nil {
errs = errors.Join(errs, err)
continue
}
if !condMet {
continue
}
foundObject := tk.GetObject()
var newRelation string
switch req.edge.Type {
case graph.DirectEdge:
newRelation = tk.GetRelation()
case graph.TupleToUsersetEdge:
newRelation = req.edge.TargetReference.GetRelation()
default:
panic("unsupported edge type")
}
pool.Go(func(ctx context.Context) error {
return c.dispatch(ctx, &ReverseExpandRequest{
StoreID: req.StoreID,
ObjectType: req.ObjectType,
Relation: req.Relation,
User: &UserRefObjectRelation{
ObjectRelation: &openfgav1.ObjectRelation{
Object: foundObject,
Relation: newRelation,
},
Condition: tk.GetCondition(),
},
ContextualTuples: req.ContextualTuples,
Context: req.Context,
edge: req.edge,
Consistency: req.Consistency,
}, resultChan, intersectionOrExclusionInPreviousEdges, resolutionMetadata)
})
}
errs = errors.Join(errs, pool.Wait())
if errs != nil {
telemetry.TraceError(span, errs)
return errs
}
return nil
}
func (c *ReverseExpandQuery) trySendCandidate(
ctx context.Context,
intersectionOrExclusionInPreviousEdges bool,
candidateObject string,
candidateChan chan<- *ReverseExpandResult,
) {
_, span := tracer.Start(ctx, "trySendCandidate", trace.WithAttributes(
attribute.String("object", candidateObject),
attribute.Bool("sent", false),
))
defer span.End()
if _, ok := c.candidateObjectsMap.LoadOrStore(candidateObject, struct{}{}); !ok {
resultStatus := NoFurtherEvalStatus
if intersectionOrExclusionInPreviousEdges {
span.SetAttributes(attribute.Bool("requires_further_eval", true))
resultStatus = RequiresFurtherEvalStatus
}
result := &ReverseExpandResult{Object: candidateObject, ResultStatus: resultStatus}
ok = concurrency.TrySendThroughChannel(ctx, result, candidateChan)
if ok {
span.SetAttributes(attribute.Bool("sent", true))
}
}
}
func (c *ReverseExpandQuery) throttle(ctx context.Context, currentNumDispatch uint32, metadata *ResolutionMetadata) {
span := trace.SpanFromContext(ctx)
shouldThrottle := threshold.ShouldThrottle(
ctx,
currentNumDispatch,
c.dispatchThrottlerConfig.Threshold,
c.dispatchThrottlerConfig.MaxThreshold,
)
span.SetAttributes(
attribute.Int("dispatch_count", int(currentNumDispatch)),
attribute.Bool("is_throttled", shouldThrottle))
if shouldThrottle {
metadata.WasThrottled.Store(true)
c.dispatchThrottlerConfig.Throttler.Throttle(ctx)
}
}
package reverseexpand
import (
"context"
"errors"
"fmt"
"sync"
aq "github.com/emirpasic/gods/queues/arrayqueue"
"go.opentelemetry.io/otel/trace"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
weightedGraph "github.com/openfga/language/pkg/go/graph"
"github.com/openfga/openfga/internal/checkutil"
"github.com/openfga/openfga/internal/concurrency"
"github.com/openfga/openfga/internal/graph"
"github.com/openfga/openfga/internal/stack"
"github.com/openfga/openfga/internal/validation"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/telemetry"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
const (
listObjectsResultChannelLength = 100
)
var (
ErrEmptyStack = errors.New("unexpected empty stack")
ErrLowestWeightFail = errors.New("failed to get lowest weight edge")
ErrConstructUsersetFail = errors.New("failed to construct userset")
)
type ExecutionError struct {
operation string
object string
relation string
user string
cause error
}
func (e *ExecutionError) Error() string {
return fmt.Sprintf("failed to execute: operation: %s: object: %s: relation: %s: user: %s: cause: %s",
e.operation,
e.object,
e.relation,
e.user,
e.cause.Error(),
)
}
// typeRelEntry represents a step in the path taken to reach a leaf node.
// As reverseExpand traverses from a requested type#rel to its leaf nodes, it stack.Pushes typeRelEntry structs to a stack.
// After reaching a leaf, this stack is consumed by the `queryForTuples` function to build the precise chain of
// database queries needed to find the resulting objects.
type typeRelEntry struct {
typeRel string // e.g. "organization#admin"
// Only present for userset relations. Will be the userset relation string itself.
// For `rel admin: [team#member]`, usersetRelation is "member"
usersetRelation string
}
// queryJob represents a single task in the reverse expansion process.
// It holds the `foundObject` from a previous step in the traversal
// and the `ReverseExpandRequest` containing the current state of the request.
type queryJob struct {
foundObject string
req *ReverseExpandRequest
}
// jobQueue is a thread-safe queue for managing `queryJob` instances.
// It's used to hold jobs that need to be processed during the recursive
// `queryForTuples` operation, allowing concurrent processing of branches
// in the authorization graph.
type jobQueue struct {
queue aq.Queue
mu sync.Mutex
}
func newJobQueue() *jobQueue {
return &jobQueue{queue: *aq.New()}
}
func (q *jobQueue) Empty() bool {
q.mu.Lock()
defer q.mu.Unlock()
return q.queue.Empty()
}
func (q *jobQueue) enqueue(value ...queryJob) {
q.mu.Lock()
defer q.mu.Unlock()
for _, item := range value {
q.queue.Enqueue(item)
}
}
func (q *jobQueue) dequeue() (queryJob, bool) {
q.mu.Lock()
defer q.mu.Unlock()
val, ok := q.queue.Dequeue()
if !ok {
return queryJob{}, false
}
job, ok := val.(queryJob)
if !ok {
return queryJob{}, false
}
return job, true
}
// loopOverEdges iterates over a set of weightedGraphEdges and acts as a dispatcher,
// processing each edge according to its type to continue the reverse expansion process.
//
// While traversing, loopOverEdges appends relation entries to a stack for use in querying after traversal is complete.
// It will continue to dispatch and traverse the graph until it reaches a DirectEdge, which
// leads to a leaf node in the authorization graph. Once a DirectEdge is found, loopOverEdges invokes
// queryForTuples, passing it the stack of relations it constructed on the way to that particular leaf.
//
// For each edge, it creates a new ReverseExpandRequest, preserving the context of the overall query
// but updating the traversal state (the 'stack') based on the edge being processed.
//
// The behavior is determined by the edge type:
//
// - DirectEdge: This represents a direct path to data. Here we initiate a call to
// `queryForTuples` to query the datastore for tuples that match the relationship path
// accumulated in the stack. This is the end of the traversal.
//
// - ComputedEdge, RewriteEdge, and TTUEdge: These represent indirections in the authorization model.
// The function modifies the traversal 'stack' to reflect the next relationship that needs to be resolved.
// It then calls `dispatch` to continue traversing the graph with this new state until it reaches a DirectEdge.
func (c *ReverseExpandQuery) loopOverEdges(
ctx context.Context,
req *ReverseExpandRequest,
edges []*weightedGraph.WeightedAuthorizationModelEdge,
needsCheck bool,
resolutionMetadata *ResolutionMetadata,
resultChan chan<- *ReverseExpandResult,
sourceUserType string,
) error {
pool := concurrency.NewPool(ctx, int(c.resolveNodeBreadthLimit))
for _, edge := range edges {
newReq := req.clone()
newReq.weightedEdge = edge
toNode := edge.GetTo()
goingToUserset := toNode.GetNodeType() == weightedGraph.SpecificTypeAndRelation
// Going to a userset presents risk of infinite loop. Checking the edge and the traversal stack
// ensures we don't perform the same traversal multiple times.
if goingToUserset {
key := edge.GetFrom().GetUniqueLabel() + toNode.GetUniqueLabel() + edge.GetTuplesetRelation() + stack.String(newReq.relationStack)
_, loaded := c.visitedUsersetsMap.LoadOrStore(key, struct{}{})
if loaded {
// we've already visited this userset through this edge, exit to avoid an infinite cycle
continue
}
}
switch edge.GetEdgeType() {
case weightedGraph.DirectEdge:
if goingToUserset {
// Attach the userset relation to the previous stack entry
// type team:
// define member: [user]
// type org:
// define teammate: [team#member]
// A direct edge here is org#teammate --> team#member
// so if we find team:fga for this user, we need to know to check for
// team:fga#member when we check org#teammate
if newReq.relationStack == nil {
return ErrEmptyStack
}
entry, newStack := stack.Pop(newReq.relationStack)
entry.usersetRelation = tuple.GetRelation(toNode.GetUniqueLabel())
newStack = stack.Push(newStack, entry)
newStack = stack.Push(newStack, typeRelEntry{typeRel: toNode.GetUniqueLabel()})
newReq.relationStack = newStack
// Now continue traversing
pool.Go(func(ctx context.Context) error {
return c.dispatch(ctx, newReq, resultChan, needsCheck, resolutionMetadata)
})
continue
}
// We have reached a leaf node in the graph (e.g. `user` or `user:*`),
// and the traversal for this path is complete. Now we use the stack of relations
// we've built to query the datastore for matching tuples.
pool.Go(func(ctx context.Context) error {
return c.queryForTuples(
ctx,
newReq,
needsCheck,
resultChan,
"",
)
})
case weightedGraph.ComputedEdge:
// A computed edge is an alias (e.g., `define viewer: editor`).
// We replace the current relation on the stack (`viewer`) with the computed one (`editor`),
// as tuples are only written against `editor`.
if toNode.GetNodeType() != weightedGraph.OperatorNode {
if newReq.relationStack == nil {
return ErrEmptyStack
}
_, newStack := stack.Pop(newReq.relationStack)
newStack = stack.Push(newStack, typeRelEntry{typeRel: toNode.GetUniqueLabel()})
newReq.relationStack = newStack
}
pool.Go(func(ctx context.Context) error {
return c.dispatch(ctx, newReq, resultChan, needsCheck, resolutionMetadata)
})
case weightedGraph.TTUEdge:
// Replace the existing type#rel on the stack with the tuple-to-userset relation:
//
// type document
// define parent: [folder]
// define viewer: admin from parent
//
// We need to remove document#viewer from the stack and replace it with the tupleset relation (`document#parent`).
// Then we have to add the .To() relation `folder#admin`.
// The stack becomes `[document#parent, folder#admin]`, and on evaluation we will first
// query for folder#admin, then if folders exist we will see if they are related to
// any documents as #parent.
if newReq.relationStack == nil {
return ErrEmptyStack
}
_, newStack := stack.Pop(newReq.relationStack)
// stack.Push tupleset relation (`document#parent`)
tuplesetRel := typeRelEntry{typeRel: edge.GetTuplesetRelation()}
newStack = stack.Push(newStack, tuplesetRel)
// stack.Push target type#rel (`folder#admin`)
newStack = stack.Push(newStack, typeRelEntry{typeRel: toNode.GetUniqueLabel()})
newReq.relationStack = newStack
pool.Go(func(ctx context.Context) error {
return c.dispatch(ctx, newReq, resultChan, needsCheck, resolutionMetadata)
})
case weightedGraph.RewriteEdge:
// Behaves just like ComputedEdge above
// Operator nodes (union, intersection, exclusion) are not real types, they never get added
// to the stack.
if toNode.GetNodeType() != weightedGraph.OperatorNode {
if newReq.relationStack == nil {
return ErrEmptyStack
}
_, newStack := stack.Pop(newReq.relationStack)
newStack = stack.Push(newStack, typeRelEntry{typeRel: toNode.GetUniqueLabel()})
newReq.relationStack = newStack
pool.Go(func(ctx context.Context) error {
return c.dispatch(ctx, newReq, resultChan, needsCheck, resolutionMetadata)
})
// continue to the next edge
break
}
// If the edge is an operator node, we need to handle it differently.
switch toNode.GetLabel() {
case weightedGraph.IntersectionOperator:
err := c.intersectionHandler(pool, newReq, resultChan, toNode, sourceUserType, resolutionMetadata)
if err != nil {
return err
}
case weightedGraph.ExclusionOperator:
err := c.exclusionHandler(ctx, pool, newReq, resultChan, toNode, sourceUserType, resolutionMetadata)
if err != nil {
return err
}
case weightedGraph.UnionOperator:
pool.Go(func(ctx context.Context) error {
return c.dispatch(ctx, newReq, resultChan, needsCheck, resolutionMetadata)
})
default:
return fmt.Errorf("unsupported operator node: %s", toNode.GetLabel())
}
case weightedGraph.TTULogicalEdge, weightedGraph.DirectLogicalEdge:
pool.Go(func(ctx context.Context) error {
return c.dispatch(ctx, newReq, resultChan, needsCheck, resolutionMetadata)
})
default:
return fmt.Errorf("unsupported edge type: %v", edge.GetEdgeType())
}
}
// In order to maintain the current ListObjects behavior, in the case of timeout in reverse_expand_weighted
// we will return partial results.
// For more detail, see here: https://openfga.dev/api/service#/Relationship%20Queries/ListObjects
err := pool.Wait()
if err != nil {
var executionError *ExecutionError
if errors.As(err, &executionError) {
if errors.Is(executionError.cause, context.Canceled) || errors.Is(executionError.cause, context.DeadlineExceeded) {
return nil
}
}
}
return err
}
// queryForTuples performs all datastore-related reverse expansion logic. After a leaf node has been found in loopOverEdges,
// this function works backwards from a specified user (using the stack created in loopOverEdges)
// and an initial relationship edge to find all the objects that the given user has the given relationship with.
//
// This function orchestrates the concurrent execution of individual query jobs. It initializes a memoization
// map (`jobDedupeMap`) to prevent redundant database queries and a job queue to manage pending tasks.
// It kicks off the initial query and then continuously processes jobs from the queue using a concurrency pool
// until all branches leading up from the leaf have been explored.
func (c *ReverseExpandQuery) queryForTuples(
ctx context.Context,
req *ReverseExpandRequest,
needsCheck bool,
resultChan chan<- *ReverseExpandResult,
foundObject string,
) error {
span := trace.SpanFromContext(ctx)
queryJobQueue := newJobQueue()
// Now kick off the chain of queries
items, err := c.executeQueryJob(ctx, queryJob{req: req, foundObject: foundObject}, resultChan, needsCheck)
if err != nil {
telemetry.TraceError(span, err)
return err
}
// Populate the jobQueue with the initial jobs
queryJobQueue.enqueue(items...)
// We could potentially have c.resolveNodeBreadthLimit active routines reaching this point.
// Limit querying routines to avoid explosion of routines.
pool := concurrency.NewPool(ctx, int(c.resolveNodeBreadthLimit))
for !queryJobQueue.Empty() {
job, ok := queryJobQueue.dequeue()
if !ok {
// this shouldn't be possible
return nil
}
// Each goroutine will take its first job from the original queue above
// and then continue generating and processing jobs until there are no more.
pool.Go(func(ctx context.Context) error {
localQueue := newJobQueue()
localQueue.enqueue(job)
// While this goroutine's queue has items, keep looking for more
for !localQueue.Empty() {
nextJob, ok := localQueue.dequeue()
if !ok {
break
}
newItems, err := c.executeQueryJob(ctx, nextJob, resultChan, needsCheck)
if err != nil {
return err
}
localQueue.enqueue(newItems...)
}
return nil
})
}
err = pool.Wait()
if err != nil {
telemetry.TraceError(span, err)
return err
}
return nil
}
// executeQueryJob represents a single recursive step in the reverse expansion query process.
// It takes a `queryJob`, which encapsulates the current state of the traversal (found object,
// and the reverse expand request with its relation stack).
// The method constructs a database query based on the current relation at the top of the stack
// and the `foundObject` from the previous step. It queries the datastore, and for each result:
// - If the relation stack is empty, it means a candidate object has been found, which is then sent to `resultChan`.
// - If matching tuples are found, it prepares new `queryJob` instances to continue the traversal further up the graph,
// using the newly found object as the `foundObject` for the next step.
// - If no matching objects are found in the datastore, this branch of reverse expand is a dead end, and no more jobs are needed.
func (c *ReverseExpandQuery) executeQueryJob(
ctx context.Context,
job queryJob,
resultChan chan<- *ReverseExpandResult,
needsCheck bool,
) ([]queryJob, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
// Ensure we're always working with a copy
currentReq := job.req.clone()
userFilter, err := buildUserFilter(currentReq, job.foundObject)
if err != nil {
return nil, err
}
if currentReq.relationStack == nil {
return nil, ErrEmptyStack
}
// Now pop the top relation off of the stack for querying
entry, newStack := stack.Pop(currentReq.relationStack)
typeRel := entry.typeRel
currentReq.relationStack = newStack
objectType, relation := tuple.SplitObjectRelation(typeRel)
filteredIter, err := c.buildFilteredIterator(ctx, currentReq, objectType, relation, userFilter)
if err != nil {
return nil, err
}
defer filteredIter.Stop()
var nextJobs []queryJob
for {
tupleKey, err := filteredIter.Next(ctx)
if err != nil {
if errors.Is(err, storage.ErrIteratorDone) {
break
}
return nil, err
}
// This will be a "type:id" e.g. "document:roadmap"
foundObject := tupleKey.GetObject()
// If there are no more type#rel to look for in the stack that means we have hit the base case
// and this object is a candidate for return to the user.
if currentReq.relationStack == nil {
c.trySendCandidate(ctx, needsCheck, foundObject, resultChan)
continue
}
// For non-recursive relations (majority of cases), if there are more items on the stack, we continue
// the evaluation one level higher up the tree with the `foundObject`.
nextJobs = append(nextJobs, queryJob{foundObject: foundObject, req: currentReq})
}
return nextJobs, err
}
func buildUserFilter(
req *ReverseExpandRequest,
object string,
) ([]*openfgav1.ObjectRelation, error) {
var filter *openfgav1.ObjectRelation
// This is true on every call to queryFunc except the first, since we only trigger subsequent
// calls if we successfully found an object.
if object != "" {
if req.relationStack == nil {
return nil, ErrEmptyStack
}
entry := stack.Peek(req.relationStack)
filter = &openfgav1.ObjectRelation{Object: object}
if entry.usersetRelation != "" {
filter.Relation = entry.usersetRelation
}
} else {
// This else block ONLY hits on the first call to queryFunc.
toNode := req.weightedEdge.GetTo()
switch toNode.GetNodeType() {
case weightedGraph.SpecificType: // Direct User Reference. To() -> "user"
// req.User will always be either a UserRefObject or UserRefTypedWildcard here. Queries that come in for
// pure usersets do not take this code path. e.g. ListObjects(team:fga#member, document, viewer) will not make it here.
var userID string
val, ok := req.User.(*UserRefObject)
if ok {
userID = val.Object.GetId()
} else {
// It might be a wildcard user, which is ok
_, ok = req.User.(*UserRefTypedWildcard)
if !ok {
return nil, fmt.Errorf("unexpected user type when building User filter: %T", val)
}
return []*openfgav1.ObjectRelation{}, nil
}
filter = &openfgav1.ObjectRelation{Object: tuple.BuildObject(toNode.GetUniqueLabel(), userID)}
case weightedGraph.SpecificTypeWildcard: // Wildcard Referece To() -> "user:*"
filter = &openfgav1.ObjectRelation{Object: toNode.GetUniqueLabel()}
}
}
return []*openfgav1.ObjectRelation{filter}, nil
}
// buildFilteredIterator constructs the iterator used when reverse_expand queries for tuples.
// The returned iterator MUST have .Stop() called on it.
func (c *ReverseExpandQuery) buildFilteredIterator(
ctx context.Context,
req *ReverseExpandRequest,
objectType string,
relation string,
userFilter []*openfgav1.ObjectRelation,
) (storage.TupleKeyIterator, error) {
iter, err := c.datastore.ReadStartingWithUser(ctx, req.StoreID, storage.ReadStartingWithUserFilter{
ObjectType: objectType,
Relation: relation,
UserFilter: userFilter,
}, storage.ReadStartingWithUserOptions{
Consistency: storage.ConsistencyOptions{
Preference: req.Consistency,
},
})
if err != nil {
return nil, err
}
// filter out invalid tuples yielded by the database iterator
return storage.NewConditionsFilteredTupleKeyIterator(
storage.NewFilteredTupleKeyIterator(
storage.NewTupleKeyIteratorFromTupleIterator(iter),
validation.FilterInvalidTuples(c.typesystem),
),
checkutil.BuildTupleKeyConditionFilter(ctx, req.Context, c.typesystem),
), nil
}
// findCandidatesForLowestWeightEdge finds the candidate objects for the lowest weight edge for intersection or exclusion.
func (c *ReverseExpandQuery) findCandidatesForLowestWeightEdge(
pool *concurrency.Pool,
req *ReverseExpandRequest,
tmpResultChan chan<- *ReverseExpandResult,
edge *weightedGraph.WeightedAuthorizationModelEdge,
sourceUserType string,
resolutionMetadata *ResolutionMetadata,
) {
// We need to create a new stack with the top item from the original request's stack
// and use it to get the candidates for the lowest weight edge.
// If the edge is a tuple to userset edge, we need to later check the candidates against the
// original relationStack with the top item removed.
var topItemStack stack.Stack[typeRelEntry]
if req.relationStack != nil {
topItem, newStack := stack.Pop(req.relationStack)
req.relationStack = newStack
topItemStack = stack.Push(nil, topItem)
}
edges, err := c.typesystem.GetInternalEdges(edge, sourceUserType)
if err != nil {
return
}
// getting list object candidates from the lowest weight edge and have its result
// pass through tmpResultChan.
pool.Go(func(ctx context.Context) error {
defer close(tmpResultChan)
// stack with only the top item in it
newReq := req.clone()
newReq.relationStack = topItemStack
err := c.shallowClone().loopOverEdges(
ctx,
newReq,
edges,
false,
resolutionMetadata,
tmpResultChan,
sourceUserType,
)
return err
})
}
// checkCandidateInfo holds the information (req, userset, relation) needed to construct check request on a candidate object.
type checkCandidateInfo struct {
req *ReverseExpandRequest
userset *openfgav1.Userset
relation string
isAllowed bool
resolutionMetadata *ResolutionMetadata
}
// callCheckForCandidates calls check on the list objects candidate against non lowest weight edges.
func (c *ReverseExpandQuery) callCheckForCandidate(
ctx context.Context,
tmpResult *ReverseExpandResult,
resultChan chan<- *ReverseExpandResult,
info checkCandidateInfo,
) error {
info.resolutionMetadata.CheckCounter.Add(1)
handlerFunc := c.localCheckResolver.CheckRewrite(ctx,
&graph.ResolveCheckRequest{
StoreID: info.req.StoreID,
AuthorizationModelID: c.typesystem.GetAuthorizationModelID(),
TupleKey: tuple.NewTupleKey(tmpResult.Object, info.relation, info.req.User.String()),
ContextualTuples: info.req.ContextualTuples,
Context: info.req.Context,
Consistency: info.req.Consistency,
RequestMetadata: graph.NewCheckRequestMetadata(),
}, info.userset)
tmpCheckResult, err := handlerFunc(ctx)
if err != nil {
operation := "intersection"
if !info.isAllowed {
operation = "exclusion"
}
return &ExecutionError{
operation: operation,
object: tmpResult.Object,
relation: info.relation,
user: info.req.User.String(),
cause: err,
}
}
// If the allowed value does not match what we expect, we skip this candidate.
// eg, for intersection we expect the check result to be true
// and for exclusion we expect the check result to be false.
if tmpCheckResult.GetAllowed() != info.isAllowed {
return nil
}
// If the original stack only had 1 value, we can trySendCandidate right away (nothing more to check)
if stack.Len(info.req.relationStack) == 0 {
c.trySendCandidate(ctx, false, tmpResult.Object, resultChan)
return nil
}
// If the original stack had more than 1 value, we need to query the parent values
// new stack with top item in stack
err = c.queryForTuples(ctx, info.req, false, resultChan, tmpResult.Object)
if err != nil {
return err
}
return nil
}
// callCheckForCandidates calls check on the list objects candidates against non lowest weight edges.
func (c *ReverseExpandQuery) callCheckForCandidates(
pool *concurrency.Pool,
tmpResultChan <-chan *ReverseExpandResult,
resultChan chan<- *ReverseExpandResult,
info checkCandidateInfo,
) {
pool.Go(func(ctx context.Context) error {
// note that we create a separate goroutine pool instead of the main pool
// to avoid starvation on the main pool as there could be many candidates
// arriving concurrently.
tmpResultPool := concurrency.NewPool(ctx, int(c.resolveNodeBreadthLimit))
for tmpResult := range tmpResultChan {
tmpResultPool.Go(func(ctx context.Context) error {
return c.callCheckForCandidate(ctx, tmpResult, resultChan, info)
})
}
return tmpResultPool.Wait()
})
}
// invoke loopOverWeightedEdges to get list objects candidate. Check
// will then be invoked on the non-lowest weight edges against these
// list objects candidates. If check returns true, then the list
// object candidates are true candidates and will be returned via
// resultChan. If check returns false, then these list object candidates
// are invalid because it does not satisfy all paths for intersection.
func (c *ReverseExpandQuery) intersectionHandler(
pool *concurrency.Pool,
req *ReverseExpandRequest,
resultChan chan<- *ReverseExpandResult,
intersectionNode *weightedGraph.WeightedAuthorizationModelNode,
sourceUserType string,
resolutionMetadata *ResolutionMetadata,
) error {
if intersectionNode == nil || intersectionNode.GetNodeType() != weightedGraph.OperatorNode || intersectionNode.GetLabel() != weightedGraph.IntersectionOperator {
return fmt.Errorf("%w: operation: intersection: %s", errors.ErrUnsupported, "invalid intersection node")
}
// verify if the node has weight to the sourceUserType
edges, err := c.typesystem.GetEdgesFromNode(intersectionNode, sourceUserType)
if err != nil {
return err
}
// when the intersection node has a weight to the sourceUserType then it means all the group edges has weight to the sourceUserType
intersectionEdges, err := typesystem.GetEdgesForIntersection(edges, sourceUserType)
if err != nil {
return fmt.Errorf("%w: operation: intersection: %s", ErrLowestWeightFail, err.Error())
}
// note that we should never see a case where no edges to call LO
// i.e., len(intersectionEdges.LowestEdges) == 0 or we cannot call check (i.e., len(intersectionEdges.SiblingEdges) == 0)
// because typesystem.GetEdgesFromNode should have returned an error
tmpResultChan := make(chan *ReverseExpandResult, listObjectsResultChannelLength)
intersectEdges := intersectionEdges.SiblingEdges
usersets := make([]*openfgav1.Userset, 0, len(intersectEdges))
// the check's relation should be the same for all intersect edges.
// It is derived from the definition's relation of the intersect edge
checkRelation := ""
for _, intersectEdge := range intersectEdges {
// no matter how many direct edges we have, or ttu edges they for typesystem only required this
// no matter how many parent types have for the same ttu rel from parent will be only one created in the typesystem
// for any other case, does not have more than one edge, the logical groupings only occur in direct edges or ttu edges
userset, err := c.typesystem.ConstructUserset(intersectEdge, sourceUserType)
if err != nil {
// this should never happen
return fmt.Errorf("%w: operation: intersection: %s", ErrConstructUsersetFail, err.Error())
}
usersets = append(usersets, userset)
var intersectRelation string
_, intersectRelation = tuple.SplitObjectRelation(intersectEdge.GetRelationDefinition())
if checkRelation != "" && checkRelation != intersectRelation {
// this should never happen
return fmt.Errorf("%w: operation: intersection: %s", errors.ErrUnsupported, "multiple relations in intersection is not supported")
}
checkRelation = intersectRelation
}
var userset *openfgav1.Userset
switch len(usersets) {
case 0:
return fmt.Errorf("%w: empty connected edges", ErrConstructUsersetFail) // defensive; should be handled by the early return above
case 1:
userset = usersets[0]
default:
userset = typesystem.Intersection(usersets...)
}
// Concurrently find candidates and call check on them as they are found
c.findCandidatesForLowestWeightEdge(pool, req, tmpResultChan, intersectionEdges.LowestEdge, sourceUserType, resolutionMetadata)
c.callCheckForCandidates(pool, tmpResultChan, resultChan,
checkCandidateInfo{req: req, userset: userset, relation: checkRelation, isAllowed: true, resolutionMetadata: resolutionMetadata})
return nil
}
// invoke loopOverWeightedEdges to get list objects candidate. Check
// will then be invoked on the excluded edge against these
// list objects candidates. If check returns false, then the list
// object candidates are true candidates and will be returned via
// resultChan. If check returns true, then these list object candidates
// are invalid because it does not satisfy all paths for exclusion.
func (c *ReverseExpandQuery) exclusionHandler(
ctx context.Context,
pool *concurrency.Pool,
req *ReverseExpandRequest,
resultChan chan<- *ReverseExpandResult,
exclusionNode *weightedGraph.WeightedAuthorizationModelNode,
sourceUserType string,
resolutionMetadata *ResolutionMetadata,
) error {
if exclusionNode == nil || exclusionNode.GetNodeType() != weightedGraph.OperatorNode || exclusionNode.GetLabel() != weightedGraph.ExclusionOperator {
return fmt.Errorf("%w: operation: exclusion: %s", errors.ErrUnsupported, "invalid exclusion node")
}
// verify if the node has weight to the sourceUserType
exclusionEdges, err := c.typesystem.GetEdgesFromNode(exclusionNode, sourceUserType)
if err != nil {
return err
}
edges, err := typesystem.GetEdgesForExclusion(exclusionEdges, sourceUserType)
if err != nil {
return fmt.Errorf("%w: operation: exclusion: %s", ErrLowestWeightFail, err.Error())
}
// This means the exclusion edge does not have a path to the terminal type.
// e.g. `B` in `A but not B` is not relevant to this query.
if edges.ExcludedEdge == nil {
baseEdges, err := c.typesystem.GetInternalEdges(edges.BaseEdge, sourceUserType)
if err != nil {
return fmt.Errorf("%w: operation: exclusion: failed to get base edges: %s", ErrLowestWeightFail, err.Error())
}
newReq := req.clone()
return c.shallowClone().loopOverEdges(
ctx,
newReq,
baseEdges,
false,
resolutionMetadata,
resultChan,
sourceUserType,
)
}
tmpResultChan := make(chan *ReverseExpandResult, listObjectsResultChannelLength)
var checkRelation string
_, checkRelation = tuple.SplitObjectRelation(edges.ExcludedEdge.GetRelationDefinition())
userset, err := c.typesystem.ConstructUserset(edges.ExcludedEdge, sourceUserType)
if err != nil {
// This should never happen.
return fmt.Errorf("%w: operation: exclusion: %s", ErrConstructUsersetFail, err.Error())
}
// Concurrently find candidates and call check on them as they are found
c.findCandidatesForLowestWeightEdge(pool, req, tmpResultChan, edges.BaseEdge, sourceUserType, resolutionMetadata)
c.callCheckForCandidates(pool, tmpResultChan, resultChan,
checkCandidateInfo{req: req, userset: userset, relation: checkRelation, isAllowed: false, resolutionMetadata: resolutionMetadata})
return nil
}
package commands
import (
"context"
"errors"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/validation"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/server/config"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
tupleUtils "github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
// WriteCommand is used to Write and Delete tuples. Instances may be safely shared by multiple goroutines.
type WriteCommand struct {
logger logger.Logger
datastore storage.OpenFGADatastore
conditionContextByteLimit int
}
type WriteCommandOption func(*WriteCommand)
func WithWriteCmdLogger(l logger.Logger) WriteCommandOption {
return func(wc *WriteCommand) {
wc.logger = l
}
}
func WithConditionContextByteLimit(limit int) WriteCommandOption {
return func(wc *WriteCommand) {
wc.conditionContextByteLimit = limit
}
}
// NewWriteCommand creates a WriteCommand with specified storage.OpenFGADatastore to use for storage.
func NewWriteCommand(datastore storage.OpenFGADatastore, opts ...WriteCommandOption) *WriteCommand {
cmd := &WriteCommand{
datastore: datastore,
logger: logger.NewNoopLogger(),
conditionContextByteLimit: config.DefaultWriteContextByteLimit,
}
for _, opt := range opts {
opt(cmd)
}
return cmd
}
func parseOptionOnDuplicate(wr *openfgav1.WriteRequestWrites) (storage.OnDuplicateInsert, error) {
switch wr.GetOnDuplicate() {
case "", "error":
return storage.OnDuplicateInsertError, nil
case "ignore":
return storage.OnDuplicateInsertIgnore, nil
default:
return storage.OnDuplicateInsertError, serverErrors.ValidationError(fmt.Errorf("invalid on_duplicate option: %s", wr.GetOnDuplicate()))
}
}
func parseOptionOnMissing(wr *openfgav1.WriteRequestDeletes) (storage.OnMissingDelete, error) {
switch wr.GetOnMissing() {
case "", "error":
return storage.OnMissingDeleteError, nil
case "ignore":
return storage.OnMissingDeleteIgnore, nil
default:
return storage.OnMissingDeleteError, serverErrors.ValidationError(fmt.Errorf("invalid on_missing option: %s", wr.GetOnMissing()))
}
}
// Execute deletes and writes the specified tuples. Deletes are applied first, then writes.
func (c *WriteCommand) Execute(ctx context.Context, req *openfgav1.WriteRequest) (*openfgav1.WriteResponse, error) {
if err := c.validateWriteRequest(ctx, req); err != nil {
return nil, err
}
onDuplicateInsert, err := parseOptionOnDuplicate(req.GetWrites())
if err != nil {
return nil, err
}
onEmptyDelete, err := parseOptionOnMissing(req.GetDeletes())
if err != nil {
return nil, err
}
err = c.datastore.Write(
ctx,
req.GetStoreId(),
req.GetDeletes().GetTupleKeys(),
req.GetWrites().GetTupleKeys(),
storage.WithOnMissingDelete(onEmptyDelete),
storage.WithOnDuplicateInsert(onDuplicateInsert),
)
if err != nil {
if errors.Is(err, storage.ErrTransactionalWriteFailed) {
return nil, status.Error(codes.Aborted, err.Error())
}
if errors.Is(err, storage.ErrInvalidWriteInput) {
return nil, serverErrors.WriteFailedDueToInvalidInput(err)
}
return nil, serverErrors.HandleError("", err)
}
return &openfgav1.WriteResponse{}, nil
}
func (c *WriteCommand) validateWriteRequest(ctx context.Context, req *openfgav1.WriteRequest) error {
ctx, span := tracer.Start(ctx, "validateWriteRequest")
defer span.End()
store := req.GetStoreId()
modelID := req.GetAuthorizationModelId()
deletes := req.GetDeletes().GetTupleKeys()
writes := req.GetWrites().GetTupleKeys()
if len(deletes) == 0 && len(writes) == 0 {
return serverErrors.ErrInvalidWriteInput
}
if len(writes) > 0 {
authModel, err := c.datastore.ReadAuthorizationModel(ctx, store, modelID)
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
return serverErrors.AuthorizationModelNotFound(modelID)
}
return serverErrors.HandleError("", err)
}
if !typesystem.IsSchemaVersionSupported(authModel.GetSchemaVersion()) {
return serverErrors.ValidationError(typesystem.ErrInvalidSchemaVersion)
}
typesys, err := typesystem.New(authModel)
if err != nil {
return err
}
for _, tk := range writes {
err := validation.ValidateTupleForWrite(typesys, tk)
if err != nil {
return serverErrors.ValidationError(err)
}
err = c.validateNotImplicit(tk)
if err != nil {
return err
}
contextSize := proto.Size(tk.GetCondition().GetContext())
if contextSize > c.conditionContextByteLimit {
return serverErrors.ValidationError(&tupleUtils.InvalidTupleError{
Cause: fmt.Errorf("condition context size limit exceeded: %d bytes exceeds %d bytes", contextSize, c.conditionContextByteLimit),
TupleKey: tk,
})
}
}
}
for _, tk := range deletes {
// TODO validate relation format and object format
if ok := tupleUtils.IsValidUser(tk.GetUser()); !ok {
return serverErrors.ValidationError(
&tupleUtils.InvalidTupleError{
Cause: fmt.Errorf("the 'user' field is malformed"),
TupleKey: tk,
},
)
}
}
if err := c.validateNoDuplicatesAndCorrectSize(deletes, writes); err != nil {
return err
}
return nil
}
// validateNoDuplicatesAndCorrectSize ensures the deletes and writes contain no duplicates and length fits.
func (c *WriteCommand) validateNoDuplicatesAndCorrectSize(
deletes []*openfgav1.TupleKeyWithoutCondition,
writes []*openfgav1.TupleKey,
) error {
tuples := map[string]struct{}{}
for _, tk := range deletes {
key := tupleUtils.TupleKeyToString(tk)
if _, ok := tuples[key]; ok {
return serverErrors.DuplicateTupleInWrite(tk)
}
tuples[key] = struct{}{}
}
for _, tk := range writes {
key := tupleUtils.TupleKeyToString(tk)
if _, ok := tuples[key]; ok {
return serverErrors.DuplicateTupleInWrite(tk)
}
tuples[key] = struct{}{}
}
if len(tuples) > c.datastore.MaxTuplesPerWrite() {
return serverErrors.ExceededEntityLimit("write operations", c.datastore.MaxTuplesPerWrite())
}
return nil
}
// validateNotImplicit ensures the tuple to be written (not deleted) is not of the form `object:id # relation @ object:id#relation`.
func (c *WriteCommand) validateNotImplicit(
tk *openfgav1.TupleKey,
) error {
userObject, userRelation := tupleUtils.SplitObjectRelation(tk.GetUser())
if tk.GetRelation() == userRelation && tk.GetObject() == userObject {
return serverErrors.ValidationError(&tupleUtils.InvalidTupleError{
Cause: fmt.Errorf("cannot write a tuple that is implicit"),
TupleKey: tk,
})
}
return nil
}
package commands
import (
"context"
"errors"
"google.golang.org/protobuf/proto"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/validation"
"github.com/openfga/openfga/pkg/logger"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
tupleUtils "github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
// DefaultMaxAssertionSizeInBytes is 64KB because MySQL supports up to 64 KB in one BLOB.
// In the future we may want to make it a LONGBLOB (4 GB) and/or make this value configurable
// based on the datastore.
var DefaultMaxAssertionSizeInBytes = 64000 // 64KB
type WriteAssertionsCommand struct {
datastore storage.OpenFGADatastore
logger logger.Logger
maxAssertionSizeInBytes int
}
type WriteAssertionsCmdOption func(*WriteAssertionsCommand)
func WithWriteAssertCmdLogger(l logger.Logger) WriteAssertionsCmdOption {
return func(c *WriteAssertionsCommand) {
c.logger = l
}
}
func NewWriteAssertionsCommand(
datastore storage.OpenFGADatastore, opts ...WriteAssertionsCmdOption) *WriteAssertionsCommand {
cmd := &WriteAssertionsCommand{
datastore: datastore,
logger: logger.NewNoopLogger(),
maxAssertionSizeInBytes: DefaultMaxAssertionSizeInBytes,
}
for _, opt := range opts {
opt(cmd)
}
return cmd
}
func (w *WriteAssertionsCommand) Execute(ctx context.Context, req *openfgav1.WriteAssertionsRequest) (*openfgav1.WriteAssertionsResponse, error) {
store := req.GetStoreId()
modelID := req.GetAuthorizationModelId()
assertions := req.GetAssertions()
model, err := w.datastore.ReadAuthorizationModel(ctx, store, modelID)
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
return nil, serverErrors.AuthorizationModelNotFound(req.GetAuthorizationModelId())
}
return nil, serverErrors.HandleError("", err)
}
if !typesystem.IsSchemaVersionSupported(model.GetSchemaVersion()) {
return nil, serverErrors.ValidationError(typesystem.ErrInvalidSchemaVersion)
}
typesys, err := typesystem.New(model)
if err != nil {
return nil, serverErrors.HandleError("", err)
}
assertionSizeInBytes := 0
for _, assertion := range assertions {
assertionSizeInBytes += proto.Size(assertion)
}
if assertionSizeInBytes > w.maxAssertionSizeInBytes {
return nil, serverErrors.ExceededEntityLimit("bytes", w.maxAssertionSizeInBytes)
}
for _, assertion := range assertions {
// an assertion should be validated the same as the input tuple key to a Check request
if err := validation.ValidateUserObjectRelation(typesys, tupleUtils.ConvertAssertionTupleKeyToTupleKey(assertion.GetTupleKey())); err != nil {
return nil, serverErrors.ValidationError(err)
}
for _, ct := range assertion.GetContextualTuples() {
// but contextual tuples need to be validated the same as an input to a Write Tuple request
if err = validation.ValidateTupleForWrite(typesys, ct); err != nil {
return nil, serverErrors.ValidationError(err)
}
}
}
err = w.datastore.WriteAssertions(ctx, store, modelID, assertions)
if err != nil {
return nil, serverErrors.HandleError("", err)
}
return &openfgav1.WriteAssertionsResponse{}, nil
}
package commands
import (
"context"
"fmt"
"github.com/oklog/ulid/v2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/logger"
serverconfig "github.com/openfga/openfga/pkg/server/config"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/typesystem"
)
// WriteAuthorizationModelCommand performs updates of the store authorization model.
type WriteAuthorizationModelCommand struct {
backend storage.TypeDefinitionWriteBackend
logger logger.Logger
maxAuthorizationModelSizeInBytes int
}
type WriteAuthModelOption func(*WriteAuthorizationModelCommand)
func WithWriteAuthModelLogger(l logger.Logger) WriteAuthModelOption {
return func(m *WriteAuthorizationModelCommand) {
m.logger = l
}
}
func WithWriteAuthModelMaxSizeInBytes(size int) WriteAuthModelOption {
return func(m *WriteAuthorizationModelCommand) {
m.maxAuthorizationModelSizeInBytes = size
}
}
func NewWriteAuthorizationModelCommand(backend storage.TypeDefinitionWriteBackend, opts ...WriteAuthModelOption) *WriteAuthorizationModelCommand {
model := &WriteAuthorizationModelCommand{
backend: backend,
logger: logger.NewNoopLogger(),
maxAuthorizationModelSizeInBytes: serverconfig.DefaultMaxAuthorizationModelSizeInBytes,
}
for _, opt := range opts {
opt(model)
}
return model
}
// Execute the command using the supplied request.
func (w *WriteAuthorizationModelCommand) Execute(ctx context.Context, req *openfgav1.WriteAuthorizationModelRequest) (*openfgav1.WriteAuthorizationModelResponse, error) {
// Until this is solved: https://github.com/envoyproxy/protoc-gen-validate/issues/74
if len(req.GetTypeDefinitions()) > w.backend.MaxTypesPerAuthorizationModel() {
return nil, serverErrors.ExceededEntityLimit("type definitions in an authorization model", w.backend.MaxTypesPerAuthorizationModel())
}
// Fill in the schema version for old requests, which don't contain it, while we migrate to the new schema version.
if req.GetSchemaVersion() == "" {
req.SchemaVersion = typesystem.SchemaVersion1_1
}
model := &openfgav1.AuthorizationModel{
Id: ulid.Make().String(),
SchemaVersion: req.GetSchemaVersion(),
TypeDefinitions: req.GetTypeDefinitions(),
Conditions: req.GetConditions(),
}
// Validate the size in bytes of the wire-format encoding of the authorization model.
modelSize := proto.Size(model)
if modelSize > w.maxAuthorizationModelSizeInBytes {
// Consider using serverErrors.ExceededEntityLimit.
return nil, status.Error(
codes.Code(openfgav1.ErrorCode_exceeded_entity_limit),
fmt.Sprintf("model exceeds size limit: %d bytes vs %d bytes", modelSize, w.maxAuthorizationModelSizeInBytes),
)
}
_, err := typesystem.NewAndValidate(ctx, model)
if err != nil {
return nil, serverErrors.InvalidAuthorizationModelInput(err)
}
err = w.backend.WriteAuthorizationModel(ctx, req.GetStoreId(), model)
if err != nil {
return nil, serverErrors.
HandleError("Error writing authorization model configuration", err)
}
return &openfgav1.WriteAuthorizationModelResponse{
AuthorizationModelId: model.GetId(),
}, nil
}
package config
import (
"time"
)
type CacheSettings struct {
CheckCacheLimit uint32
CacheControllerEnabled bool
CacheControllerTTL time.Duration
CheckQueryCacheEnabled bool
CheckQueryCacheTTL time.Duration
CheckIteratorCacheEnabled bool
CheckIteratorCacheMaxResults uint32
CheckIteratorCacheTTL time.Duration
ListObjectsIteratorCacheEnabled bool
ListObjectsIteratorCacheMaxResults uint32
ListObjectsIteratorCacheTTL time.Duration
SharedIteratorEnabled bool
SharedIteratorLimit uint32
SharedIteratorTTL time.Duration
}
func NewDefaultCacheSettings() CacheSettings {
return CacheSettings{
CheckCacheLimit: DefaultCheckCacheLimit,
CacheControllerEnabled: DefaultCacheControllerEnabled,
CacheControllerTTL: DefaultCacheControllerTTL,
CheckQueryCacheEnabled: DefaultCheckQueryCacheEnabled,
CheckQueryCacheTTL: DefaultCheckQueryCacheTTL,
CheckIteratorCacheEnabled: DefaultCheckIteratorCacheEnabled,
CheckIteratorCacheMaxResults: DefaultCheckIteratorCacheMaxResults,
CheckIteratorCacheTTL: DefaultCheckIteratorCacheTTL,
ListObjectsIteratorCacheEnabled: DefaultListObjectsIteratorCacheEnabled,
ListObjectsIteratorCacheMaxResults: DefaultListObjectsIteratorCacheMaxResults,
ListObjectsIteratorCacheTTL: DefaultListObjectsIteratorCacheTTL,
SharedIteratorEnabled: DefaultSharedIteratorEnabled,
SharedIteratorLimit: DefaultSharedIteratorLimit,
SharedIteratorTTL: DefaultSharedIteratorTTL,
}
}
func (c CacheSettings) ShouldCreateNewCache() bool {
return c.ShouldCacheCheckQueries() || c.ShouldCacheCheckIterators() || c.ShouldCacheListObjectsIterators()
}
func (c CacheSettings) ShouldCreateCacheController() bool {
return c.ShouldCreateNewCache() && c.CacheControllerEnabled
}
func (c CacheSettings) ShouldCacheCheckQueries() bool {
return c.CheckCacheLimit > 0 && c.CheckQueryCacheEnabled
}
func (c CacheSettings) ShouldCacheCheckIterators() bool {
return c.CheckCacheLimit > 0 && c.CheckIteratorCacheEnabled
}
func (c CacheSettings) ShouldCacheListObjectsIterators() bool {
return c.ListObjectsIteratorCacheEnabled && c.ListObjectsIteratorCacheMaxResults > 0
}
func (c CacheSettings) ShouldCreateShadowNewCache() bool {
return c.ShouldCreateNewCache()
}
// ShouldCreateShadowCacheController determines if a new shadow cache controller should be created.
// A shadow cache controller is created if the cache controller is enabled.
func (c CacheSettings) ShouldCreateShadowCacheController() bool {
return c.ShouldCreateCacheController()
}
// ShouldShadowCacheListObjectsIterators returns true if a shadow cache for list objects iterators should be created.
// A shadow cache for list objects iterators is created if list objects iterators caching is enabled.
func (c CacheSettings) ShouldShadowCacheListObjectsIterators() bool {
return c.ShouldCacheListObjectsIterators()
}
// Package config contains all knobs and defaults used to configure features of
// OpenFGA when running as a standalone server.
package config
import (
"errors"
"fmt"
"math"
"strconv"
"time"
"github.com/spf13/viper"
)
const (
DefaultMaxRPCMessageSizeInBytes = 512 * 1_204 // 512 KB
DefaultMaxTuplesPerWrite = 100
DefaultMaxTypesPerAuthorizationModel = 100
DefaultMaxAuthorizationModelSizeInBytes = 256 * 1_024
DefaultMaxAuthorizationModelCacheSize = 100000
DefaultChangelogHorizonOffset = 0
DefaultResolveNodeLimit = 25
DefaultResolveNodeBreadthLimit = 10
DefaultListObjectsDeadline = 3 * time.Second
DefaultListObjectsMaxResults = 1000
DefaultMaxConcurrentReadsForCheck = math.MaxUint32
DefaultMaxConcurrentReadsForListObjects = math.MaxUint32
DefaultListUsersDeadline = 3 * time.Second
DefaultListUsersMaxResults = 1000
DefaultMaxConcurrentReadsForListUsers = math.MaxUint32
DefaultWriteContextByteLimit = 32 * 1_024 // 32KB
DefaultCheckCacheLimit = 10000
DefaultCacheControllerEnabled = false
DefaultCacheControllerTTL = 10 * time.Second
DefaultCheckQueryCacheEnabled = false
DefaultCheckQueryCacheTTL = 10 * time.Second
DefaultCheckIteratorCacheEnabled = false
DefaultCheckIteratorCacheMaxResults = 10000
DefaultCheckIteratorCacheTTL = 10 * time.Second
DefaultListObjectsIteratorCacheEnabled = false
DefaultListObjectsIteratorCacheMaxResults = 10000
DefaultListObjectsIteratorCacheTTL = 10 * time.Second
DefaultListObjectsOptimizationsEnabled = false
DefaultCacheControllerConfigEnabled = false
DefaultCacheControllerConfigTTL = 10 * time.Second
DefaultShadowCheckResolverTimeout = 1 * time.Second
DefaultShadowListObjectsQueryTimeout = 1 * time.Second
DefaultShadowListObjectsQueryMaxDeltaItems = 100
// Care should be taken here - decreasing can cause API compatibility problems with Conditions.
DefaultMaxConditionEvaluationCost = 100
DefaultInterruptCheckFrequency = 100
DefaultCheckDispatchThrottlingEnabled = false
DefaultCheckDispatchThrottlingFrequency = 10 * time.Microsecond
DefaultCheckDispatchThrottlingDefaultThreshold = 100
DefaultCheckDispatchThrottlingMaxThreshold = 0 // 0 means use the default threshold as max
// Batch Check.
DefaultMaxChecksPerBatchCheck = 50
DefaultMaxConcurrentChecksPerBatchCheck = 50
DefaultListObjectsDispatchThrottlingEnabled = false
DefaultListObjectsDispatchThrottlingFrequency = 10 * time.Microsecond
DefaultListObjectsDispatchThrottlingDefaultThreshold = 100
DefaultListObjectsDispatchThrottlingMaxThreshold = 0 // 0 means use the default threshold as max
DefaultListUsersDispatchThrottlingEnabled = false
DefaultListUsersDispatchThrottlingFrequency = 10 * time.Microsecond
DefaultListUsersDispatchThrottlingDefaultThreshold = 100
DefaultListUsersDispatchThrottlingMaxThreshold = 0 // 0 means use the default threshold as max
DefaultRequestTimeout = 3 * time.Second
additionalUpstreamTimeout = 3 * time.Second
DefaultSharedIteratorEnabled = false
DefaultSharedIteratorLimit = 1000000
DefaultSharedIteratorTTL = 4 * time.Minute
DefaultSharedIteratorMaxAdmissionTime = 10 * time.Second
DefaultSharedIteratorMaxIdleTime = 1 * time.Second
DefaultPlannerEvictionThreshold = 0
DefaultPlannerCleanupInterval = 0
ExperimentalCheckOptimizations = "enable-check-optimizations"
ExperimentalListObjectsOptimizations = "enable-list-objects-optimizations"
ExperimentalAccessControlParams = "enable-access-control"
// Moving forward, all experimental flags should follow the naming convention below:
// 1. Avoid using enable/disable prefixes.
// 2. Flag names should have only numbers, letters and underscores.
ExperimentalShadowCheck = "shadow_check"
ExperimentalShadowListObjects = "shadow_list_objects"
ExperimentalDatastoreThrottling = "datastore_throttling"
)
type DatastoreMetricsConfig struct {
// Enabled enables export of the Datastore metrics.
Enabled bool
}
// DatastoreConfig defines OpenFGA server configurations for datastore specific settings.
type DatastoreConfig struct {
// Engine is the datastore engine to use (e.g. 'memory', 'postgres', 'mysql', 'sqlite')
Engine string
URI string `json:"-"` // private field, won't be logged
SecondaryURI string `json:"-"` // private field, won't be logged
Username string
Password string `json:"-"` // private field, won't be logged
SecondaryUsername string
SecondaryPassword string `json:"-"` // private field, won't be logged
// MaxCacheSize is the maximum number of authorization models that will be cached in memory.
MaxCacheSize int
// MaxOpenConns is the maximum number of open connections to the database.
MaxOpenConns int
// MinOpenConns is the minimum number of open connections to the database.
// This is only available in Postgresql.
MinOpenConns int
// MaxIdleConns is the maximum number of connections to the datastore in the idle connection
// pool. This is only used for some datastore engines (non-PostgresSQL that uses sql.DB).
MaxIdleConns int
// MinIdleConns is the minimum number of connections to the datastore in the idle connection
// pool. This is only available in Postgresql..
MinIdleConns int
// ConnMaxIdleTime is the maximum amount of time a connection to the datastore may be idle.
ConnMaxIdleTime time.Duration
// ConnMaxLifetime is the maximum amount of time a connection to the datastore may be reused.
ConnMaxLifetime time.Duration
// Metrics is configuration for the Datastore metrics.
Metrics DatastoreMetricsConfig
}
// GRPCConfig defines OpenFGA server configurations for grpc server specific settings.
type GRPCConfig struct {
Addr string
TLS *TLSConfig
}
// HTTPConfig defines OpenFGA server configurations for HTTP server specific settings.
type HTTPConfig struct {
Enabled bool
Addr string
TLS *TLSConfig
// UpstreamTimeout is the timeout duration for proxying HTTP requests upstream
// to the grpc endpoint. It cannot be smaller than Config.ListObjectsDeadline.
UpstreamTimeout time.Duration
CORSAllowedOrigins []string
CORSAllowedHeaders []string
}
// TLSConfig defines configuration specific to Transport Layer Security (TLS) settings.
type TLSConfig struct {
Enabled bool
CertPath string `mapstructure:"cert"`
KeyPath string `mapstructure:"key"`
}
// AuthnConfig defines OpenFGA server configurations for authentication specific settings.
type AuthnConfig struct {
// Method is the authentication method that should be enforced (e.g. 'none', 'preshared',
// 'oidc')
Method string
*AuthnOIDCConfig `mapstructure:"oidc"`
*AuthnPresharedKeyConfig `mapstructure:"preshared"`
}
// AuthnOIDCConfig defines configurations for the 'oidc' method of authentication.
type AuthnOIDCConfig struct {
Issuer string
IssuerAliases []string
Subjects []string
Audience string
ClientIDClaims []string
}
// AuthnPresharedKeyConfig defines configurations for the 'preshared' method of authentication.
type AuthnPresharedKeyConfig struct {
// Keys define the preshared keys to verify authn tokens against.
Keys []string `json:"-"` // private field, won't be logged
}
// LogConfig defines OpenFGA server configurations for log specific settings. For production, we
// recommend using the 'json' log format.
type LogConfig struct {
// Format is the log format to use in the log output (e.g. 'text' or 'json')
Format string
// Level is the log level to use in the log output (e.g. 'none', 'debug', or 'info')
Level string
// Format of the timestamp in the log output (e.g. 'Unix'(default) or 'ISO8601')
TimestampFormat string
}
type TraceConfig struct {
Enabled bool
OTLP OTLPTraceConfig `mapstructure:"otlp"`
SampleRatio float64
ServiceName string
}
type OTLPTraceConfig struct {
Endpoint string
TLS OTLPTraceTLSConfig
}
type OTLPTraceTLSConfig struct {
Enabled bool
}
// PlaygroundConfig defines OpenFGA server configurations for the Playground specific settings.
type PlaygroundConfig struct {
Enabled bool
Port int
}
// ProfilerConfig defines server configurations specific to pprof profiling.
type ProfilerConfig struct {
Enabled bool
Addr string
}
// MetricConfig defines configurations for serving custom metrics from OpenFGA.
type MetricConfig struct {
Enabled bool
Addr string
EnableRPCHistograms bool
}
// CheckQueryCache defines configuration for caching when resolving check.
type CheckQueryCache struct {
Enabled bool
TTL time.Duration
}
// CheckCacheConfig defines configuration for a cache that is shared across Check requests.
type CheckCacheConfig struct {
Limit uint32
}
// IteratorCacheConfig defines configuration to cache storage iterator results.
type IteratorCacheConfig struct {
Enabled bool
MaxResults uint32
TTL time.Duration
}
// SharedIteratorConfig defines configuration to share storage iterator.
type SharedIteratorConfig struct {
Enabled bool
Limit uint32
}
// CacheControllerConfig defines configuration to manage cache invalidation dynamically by observing whether
// there are recent tuple changes to specified store.
type CacheControllerConfig struct {
Enabled bool
TTL time.Duration
}
// DispatchThrottlingConfig defines configurations for dispatch throttling.
type DispatchThrottlingConfig struct {
Enabled bool
Frequency time.Duration
Threshold uint32
MaxThreshold uint32
}
// DatastoreThrottleConfig defines configurations for database throttling.
// A threshold <= 0 means DatastoreThrottling is not enabled.
type DatastoreThrottleConfig struct {
Threshold int
Duration time.Duration
}
// AccessControlConfig is the configuration for the access control feature.
type AccessControlConfig struct {
Enabled bool
StoreID string
ModelID string
}
type PlannerConfig struct {
EvictionThreshold time.Duration
CleanupInterval time.Duration
}
type Config struct {
// If you change any of these settings, please update the documentation at
// https://github.com/openfga/openfga.dev/blob/main/docs/content/intro/setup-openfga.mdx
// ListObjectsDeadline defines the maximum amount of time to accumulate ListObjects results
// before the server will respond. This is to protect the server from misuse of the
// ListObjects endpoints. It cannot be larger than HTTPConfig.UpstreamTimeout.
ListObjectsDeadline time.Duration
// ListObjectsMaxResults defines the maximum number of results to accumulate
// before the non-streaming ListObjects API will respond to the client.
// This is to protect the server from misuse of the ListObjects endpoints.
ListObjectsMaxResults uint32
// ListUsersDeadline defines the maximum amount of time to accumulate ListUsers results
// before the server will respond. This is to protect the server from misuse of the
// ListUsers endpoints. It cannot be larger than the configured server's request timeout (RequestTimeout or HTTPConfig.UpstreamTimeout).
ListUsersDeadline time.Duration
// ListUsersMaxResults defines the maximum number of results to accumulate
// before the non-streaming ListUsers API will respond to the client.
// This is to protect the server from misuse of the ListUsers endpoints.
ListUsersMaxResults uint32
// MaxTuplesPerWrite defines the maximum number of tuples per Write endpoint.
MaxTuplesPerWrite int
// MaxChecksPerBatchCheck defines the maximum number of tuples
// that can be passed in each BatchCheck request.
MaxChecksPerBatchCheck uint32
// MaxConcurrentChecksPerBatchCheck defines the maximum number of checks
// that can be run in simultaneously
MaxConcurrentChecksPerBatchCheck uint32
// MaxTypesPerAuthorizationModel defines the maximum number of type definitions per
// authorization model for the WriteAuthorizationModel endpoint.
MaxTypesPerAuthorizationModel int
// MaxAuthorizationModelSizeInBytes defines the maximum size in bytes allowed for
// persisting an Authorization Model.
MaxAuthorizationModelSizeInBytes int
// MaxConcurrentReadsForListObjects defines the maximum number of concurrent database reads
// allowed in ListObjects queries
MaxConcurrentReadsForListObjects uint32
// MaxConcurrentReadsForCheck defines the maximum number of concurrent database reads allowed in
// Check queries
MaxConcurrentReadsForCheck uint32
// MaxConcurrentReadsForListUsers defines the maximum number of concurrent database reads
// allowed in ListUsers queries
MaxConcurrentReadsForListUsers uint32
// MaxConditionEvaluationCost defines the maximum cost for CEL condition evaluation before a request returns an error
MaxConditionEvaluationCost uint64
// ChangelogHorizonOffset is an offset in minutes from the current time. Changes that occur
// after this offset will not be included in the response of ReadChanges.
ChangelogHorizonOffset int
// Experimentals is a list of the experimental features to enable in the OpenFGA server.
Experimentals []string
// AccessControl is the configuration for the access control feature.
AccessControl AccessControlConfig
// ResolveNodeLimit indicates how deeply nested an authorization model can be before a query
// errors out.
ResolveNodeLimit uint32
// ResolveNodeBreadthLimit indicates how many nodes on a given level can be evaluated
// concurrently in a query
ResolveNodeBreadthLimit uint32
// RequestTimeout configures request timeout. If both HTTP upstream timeout and request timeout are specified,
// request timeout will be prioritized
RequestTimeout time.Duration
// ContextPropagationToDatastore enables propagation of a requests context to the datastore,
// thereby receiving API cancellation signals
ContextPropagationToDatastore bool
Datastore DatastoreConfig
GRPC GRPCConfig
HTTP HTTPConfig
Authn AuthnConfig
Log LogConfig
Trace TraceConfig
Playground PlaygroundConfig
Profiler ProfilerConfig
Metrics MetricConfig
CheckCache CheckCacheConfig
CheckIteratorCache IteratorCacheConfig
CheckQueryCache CheckQueryCache
CacheController CacheControllerConfig
CheckDispatchThrottling DispatchThrottlingConfig
ListObjectsDispatchThrottling DispatchThrottlingConfig
ListUsersDispatchThrottling DispatchThrottlingConfig
CheckDatastoreThrottle DatastoreThrottleConfig
ListObjectsDatastoreThrottle DatastoreThrottleConfig
ListUsersDatastoreThrottle DatastoreThrottleConfig
ListObjectsIteratorCache IteratorCacheConfig
SharedIterator SharedIteratorConfig
Planner PlannerConfig
RequestDurationDatastoreQueryCountBuckets []string
RequestDurationDispatchCountBuckets []string
}
func (cfg *Config) Verify() error {
if err := cfg.VerifyServerSettings(); err != nil {
return err
}
return cfg.VerifyBinarySettings()
}
func (cfg *Config) VerifyServerSettings() error {
if err := cfg.verifyDeadline(); err != nil {
return err
}
if cfg.MaxConcurrentReadsForListUsers == 0 {
return fmt.Errorf("config 'maxConcurrentReadsForListUsers' cannot be 0")
}
if err := cfg.verifyRequestDurationDatastoreQueryCountBuckets(); err != nil {
return err
}
if err := cfg.verifyCacheConfig(); err != nil {
return err
}
if len(cfg.RequestDurationDispatchCountBuckets) == 0 {
return errors.New("request duration datastore dispatch count buckets must not be empty")
}
for _, val := range cfg.RequestDurationDispatchCountBuckets {
valInt, err := strconv.Atoi(val)
if err != nil || valInt < 0 {
return errors.New(
"request duration dispatch count bucket items must be non-negative integer",
)
}
}
err := cfg.VerifyDispatchThrottlingConfig()
if err != nil {
return err
}
err = cfg.VerifyDatastoreThrottlesConfig()
if err != nil {
return err
}
if cfg.ListObjectsDeadline < 0 {
return errors.New("listObjectsDeadline must be non-negative time duration")
}
if cfg.ListUsersDeadline < 0 {
return errors.New("listUsersDeadline must be non-negative time duration")
}
if cfg.MaxConditionEvaluationCost < 100 {
return errors.New("maxConditionsEvaluationCosts less than 100 can cause API compatibility problems with Conditions")
}
if cfg.Datastore.MaxOpenConns < cfg.Datastore.MinOpenConns {
return errors.New("datastore MaxOpenConns must not be less than datastore MinOpenConns")
}
if cfg.Datastore.MinOpenConns < cfg.Datastore.MinIdleConns {
return errors.New("datastore MinOpenConns must not be less than datastore MinIdleConns")
}
return nil
}
func (cfg *Config) VerifyBinarySettings() error {
if cfg.Log.Format != "text" && cfg.Log.Format != "json" {
return fmt.Errorf("config 'log.format' must be one of ['text', 'json']")
}
if cfg.Log.Level != "none" &&
cfg.Log.Level != "debug" &&
cfg.Log.Level != "info" &&
cfg.Log.Level != "warn" &&
cfg.Log.Level != "error" &&
cfg.Log.Level != "panic" &&
cfg.Log.Level != "fatal" {
return fmt.Errorf(
"config 'log.level' must be one of ['none', 'debug', 'info', 'warn', 'error', 'panic', 'fatal']",
)
}
if cfg.Log.Level == "none" {
fmt.Println("WARNING: Logging is not enabled. It is highly recommended to enable logging in production environments to avoid masking attacker operations.")
}
if cfg.Log.TimestampFormat != "Unix" && cfg.Log.TimestampFormat != "ISO8601" {
return fmt.Errorf("config 'log.TimestampFormat' must be one of ['Unix', 'ISO8601']")
}
if cfg.Playground.Enabled {
if !cfg.HTTP.Enabled {
return errors.New("the HTTP server must be enabled to run the openfga playground")
}
if cfg.Authn.Method != "none" && cfg.Authn.Method != "preshared" {
return errors.New("the playground only supports authn methods 'none' and 'preshared'")
}
}
if cfg.HTTP.TLS.Enabled {
if cfg.HTTP.TLS.CertPath == "" || cfg.HTTP.TLS.KeyPath == "" {
return errors.New("'http.tls.cert' and 'http.tls.key' configs must be set")
}
}
if cfg.GRPC.TLS.Enabled {
if cfg.GRPC.TLS.CertPath == "" || cfg.GRPC.TLS.KeyPath == "" {
return errors.New("'grpc.tls.cert' and 'grpc.tls.key' configs must be set")
}
}
if cfg.RequestTimeout < 0 {
return errors.New("requestTimeout must be a non-negative time duration")
}
if cfg.RequestTimeout == 0 && cfg.HTTP.Enabled && cfg.HTTP.UpstreamTimeout < 0 {
return errors.New("http.upstreamTimeout must be a non-negative time duration")
}
if viper.IsSet("cache.limit") && !viper.IsSet("checkCache.limit") {
fmt.Println("WARNING: flag `check-query-cache-limit` is deprecated. Please set --check-cache-limit instead.")
}
return nil
}
// DefaultContextTimeout returns the runtime DefaultContextTimeout.
// If requestTimeout > 0, we should let the middleware take care of the timeout and the
// runtime.DefaultContextTimeout is used as last resort.
// Otherwise, use the http upstream timeout if http is enabled.
func DefaultContextTimeout(config *Config) time.Duration {
if config.RequestTimeout > 0 {
return config.RequestTimeout + additionalUpstreamTimeout
}
if config.HTTP.Enabled && config.HTTP.UpstreamTimeout > 0 {
return config.HTTP.UpstreamTimeout
}
return 0
}
// VerifyDispatchThrottlingConfig ensures DispatchThrottlingConfigs are valid.
func (cfg *Config) VerifyDispatchThrottlingConfig() error {
if cfg.CheckDispatchThrottling.Enabled {
if cfg.CheckDispatchThrottling.Frequency <= 0 {
return errors.New("'checkDispatchThrottling.frequency' must be non-negative time duration")
}
if cfg.CheckDispatchThrottling.Threshold <= 0 {
return errors.New("'checkDispatchThrottling.threshold' must be non-negative integer")
}
if cfg.CheckDispatchThrottling.MaxThreshold != 0 && cfg.CheckDispatchThrottling.Threshold > cfg.CheckDispatchThrottling.MaxThreshold {
return errors.New("'checkDispatchThrottling.threshold' must be less than or equal to 'checkDispatchThrottling.maxThreshold' respectively")
}
}
if cfg.ListObjectsDispatchThrottling.Enabled {
if cfg.ListObjectsDispatchThrottling.Frequency <= 0 {
return errors.New("'listObjectsDispatchThrottling.frequency' must be non-negative time duration")
}
if cfg.ListObjectsDispatchThrottling.Threshold <= 0 {
return errors.New("'listObjectsDispatchThrottling.threshold' must be non-negative integer")
}
if cfg.ListObjectsDispatchThrottling.MaxThreshold != 0 && cfg.ListObjectsDispatchThrottling.Threshold > cfg.ListObjectsDispatchThrottling.MaxThreshold {
return errors.New("'listObjectsDispatchThrottling.threshold' must be less than or equal to 'listObjectsDispatchThrottling.maxThreshold'")
}
}
if cfg.ListUsersDispatchThrottling.Enabled {
if cfg.ListUsersDispatchThrottling.Frequency <= 0 {
return errors.New("'listUsersDispatchThrottling.frequency' must be non-negative time duration")
}
if cfg.ListUsersDispatchThrottling.Threshold <= 0 {
return errors.New("'listUsersDispatchThrottling.threshold' must be non-negative integer")
}
if cfg.ListUsersDispatchThrottling.MaxThreshold != 0 && cfg.ListUsersDispatchThrottling.Threshold > cfg.ListUsersDispatchThrottling.MaxThreshold {
return errors.New("'listUsersDispatchThrottling.threshold' must be less than or equal to 'listUsersDispatchThrottling.maxThreshold'")
}
}
return nil
}
// VerifyDatastoreThrottlesConfig ensures VerifyDatastoreThrottlesConfig is called so that the right values are verified.
func (cfg *Config) VerifyDatastoreThrottlesConfig() error {
if cfg.CheckDatastoreThrottle.Threshold > 0 && cfg.CheckDatastoreThrottle.Duration <= 0 {
return errors.New("'checkDatastoreThrottler.duration' must be greater than zero if threshold > 0")
}
if cfg.ListObjectsDatastoreThrottle.Threshold > 0 && cfg.ListObjectsDatastoreThrottle.Duration <= 0 {
return errors.New("'listObjectsDatastoreThrottler.duration' must be greater than zero if threshold > 0")
}
if cfg.ListUsersDatastoreThrottle.Threshold > 0 && cfg.ListUsersDatastoreThrottle.Duration <= 0 {
return errors.New("'listUsersDatastoreThrottler.duration' must be greater than zero if threshold > 0")
}
return nil
}
func (cfg *Config) verifyDeadline() error {
configuredTimeout := DefaultContextTimeout(cfg)
if cfg.ListObjectsDeadline > configuredTimeout {
return fmt.Errorf(
"configured request timeout (%s) cannot be lower than 'listObjectsDeadline' config (%s)",
configuredTimeout,
cfg.ListObjectsDeadline,
)
}
if cfg.ListUsersDeadline > configuredTimeout {
return fmt.Errorf(
"configured request timeout (%s) cannot be lower than 'listUsersDeadline' config (%s)",
configuredTimeout,
cfg.ListUsersDeadline,
)
}
return nil
}
func (cfg *Config) verifyRequestDurationDatastoreQueryCountBuckets() error {
if len(cfg.RequestDurationDatastoreQueryCountBuckets) == 0 {
return errors.New("request duration datastore query count buckets must not be empty")
}
for _, val := range cfg.RequestDurationDatastoreQueryCountBuckets {
valInt, err := strconv.Atoi(val)
if err != nil || valInt < 0 {
return errors.New(
"request duration datastore query count bucket items must be non-negative integer",
)
}
}
return nil
}
func (cfg *Config) verifyCacheConfig() error {
if cfg.CheckQueryCache.Enabled && cfg.CheckQueryCache.TTL <= 0 {
return errors.New("'checkQueryCache.ttl' must be greater than zero")
}
if cfg.CheckIteratorCache.Enabled {
if cfg.CheckIteratorCache.TTL <= 0 {
return errors.New("'checkIteratorCache.ttl' must be greater than zero")
}
if cfg.CheckIteratorCache.MaxResults <= 0 {
return errors.New("'checkIteratorCache.maxResults' must be greater than zero")
}
}
if cfg.ListObjectsIteratorCache.Enabled {
if cfg.ListObjectsIteratorCache.TTL <= 0 {
return errors.New("'listObjectsIteratorCache.ttl' must be greater than zero")
}
if cfg.ListObjectsIteratorCache.MaxResults <= 0 {
return errors.New("'listObjectsIteratorCache.maxResults' must be greater than zero")
}
}
if cfg.CacheController.Enabled && cfg.CacheController.TTL <= 0 {
return errors.New("'cacheController.ttl' must be greater than zero")
}
return nil
}
// MaxConditionEvaluationCost ensures a safe value for CEL evaluation cost.
func MaxConditionEvaluationCost() uint64 {
return max(DefaultMaxConditionEvaluationCost, viper.GetUint64("maxConditionEvaluationCost"))
}
// DefaultConfig is the OpenFGA server default configurations.
func DefaultConfig() *Config {
return &Config{
MaxTuplesPerWrite: DefaultMaxTuplesPerWrite,
MaxTypesPerAuthorizationModel: DefaultMaxTypesPerAuthorizationModel,
MaxAuthorizationModelSizeInBytes: DefaultMaxAuthorizationModelSizeInBytes,
MaxChecksPerBatchCheck: DefaultMaxChecksPerBatchCheck,
MaxConcurrentChecksPerBatchCheck: DefaultMaxConcurrentChecksPerBatchCheck,
MaxConcurrentReadsForCheck: DefaultMaxConcurrentReadsForCheck,
MaxConcurrentReadsForListObjects: DefaultMaxConcurrentReadsForListObjects,
MaxConcurrentReadsForListUsers: DefaultMaxConcurrentReadsForListUsers,
MaxConditionEvaluationCost: DefaultMaxConditionEvaluationCost,
ChangelogHorizonOffset: DefaultChangelogHorizonOffset,
ResolveNodeLimit: DefaultResolveNodeLimit,
ResolveNodeBreadthLimit: DefaultResolveNodeBreadthLimit,
Experimentals: []string{},
AccessControl: AccessControlConfig{Enabled: false, StoreID: "", ModelID: ""},
ListObjectsDeadline: DefaultListObjectsDeadline,
ListObjectsMaxResults: DefaultListObjectsMaxResults,
ListUsersMaxResults: DefaultListUsersMaxResults,
ListUsersDeadline: DefaultListUsersDeadline,
RequestDurationDatastoreQueryCountBuckets: []string{"50", "200"},
RequestDurationDispatchCountBuckets: []string{"50", "200"},
Datastore: DatastoreConfig{
Engine: "memory",
MaxCacheSize: DefaultMaxAuthorizationModelCacheSize,
MinIdleConns: 0,
MaxIdleConns: 10,
MinOpenConns: 0,
MaxOpenConns: 30,
},
GRPC: GRPCConfig{
Addr: "0.0.0.0:8081",
TLS: &TLSConfig{Enabled: false},
},
HTTP: HTTPConfig{
Enabled: true,
Addr: "0.0.0.0:8080",
TLS: &TLSConfig{Enabled: false},
UpstreamTimeout: 5 * time.Second,
CORSAllowedOrigins: []string{"*"},
CORSAllowedHeaders: []string{"*"},
},
Authn: AuthnConfig{
Method: "none",
AuthnPresharedKeyConfig: &AuthnPresharedKeyConfig{},
AuthnOIDCConfig: &AuthnOIDCConfig{},
},
Log: LogConfig{
Format: "text",
Level: "info",
TimestampFormat: "Unix",
},
Trace: TraceConfig{
Enabled: false,
OTLP: OTLPTraceConfig{
Endpoint: "0.0.0.0:4317",
TLS: OTLPTraceTLSConfig{
Enabled: false,
},
},
SampleRatio: 0.2,
ServiceName: "openfga",
},
Playground: PlaygroundConfig{
Enabled: true,
Port: 3000,
},
Profiler: ProfilerConfig{
Enabled: false,
Addr: ":3001",
},
Metrics: MetricConfig{
Enabled: true,
Addr: "0.0.0.0:2112",
EnableRPCHistograms: false,
},
CheckIteratorCache: IteratorCacheConfig{
Enabled: DefaultCheckIteratorCacheEnabled,
MaxResults: DefaultCheckIteratorCacheMaxResults,
TTL: DefaultCheckIteratorCacheTTL,
},
CheckQueryCache: CheckQueryCache{
Enabled: DefaultCheckQueryCacheEnabled,
TTL: DefaultCheckQueryCacheTTL,
},
CheckCache: CheckCacheConfig{
Limit: DefaultCheckCacheLimit,
},
SharedIterator: SharedIteratorConfig{
Enabled: DefaultSharedIteratorEnabled,
Limit: DefaultSharedIteratorLimit,
},
CacheController: CacheControllerConfig{
Enabled: DefaultCacheControllerConfigEnabled,
TTL: DefaultCacheControllerConfigTTL,
},
CheckDispatchThrottling: DispatchThrottlingConfig{
Enabled: DefaultCheckDispatchThrottlingEnabled,
Frequency: DefaultCheckDispatchThrottlingFrequency,
Threshold: DefaultCheckDispatchThrottlingDefaultThreshold,
MaxThreshold: DefaultCheckDispatchThrottlingMaxThreshold,
},
ListObjectsDispatchThrottling: DispatchThrottlingConfig{
Enabled: DefaultListObjectsDispatchThrottlingEnabled,
Frequency: DefaultListObjectsDispatchThrottlingFrequency,
Threshold: DefaultListObjectsDispatchThrottlingDefaultThreshold,
MaxThreshold: DefaultListObjectsDispatchThrottlingMaxThreshold,
},
ListUsersDispatchThrottling: DispatchThrottlingConfig{
Enabled: DefaultListUsersDispatchThrottlingEnabled,
Frequency: DefaultListUsersDispatchThrottlingFrequency,
Threshold: DefaultListUsersDispatchThrottlingDefaultThreshold,
MaxThreshold: DefaultListUsersDispatchThrottlingMaxThreshold,
},
ListObjectsIteratorCache: IteratorCacheConfig{
Enabled: DefaultListObjectsIteratorCacheEnabled,
MaxResults: DefaultListObjectsIteratorCacheMaxResults,
TTL: DefaultListObjectsIteratorCacheTTL,
},
CheckDatastoreThrottle: DatastoreThrottleConfig{
Threshold: 0,
Duration: 0,
},
ListObjectsDatastoreThrottle: DatastoreThrottleConfig{
Threshold: 0,
Duration: 0,
},
ListUsersDatastoreThrottle: DatastoreThrottleConfig{
Threshold: 0,
Duration: 0,
},
RequestTimeout: DefaultRequestTimeout,
ContextPropagationToDatastore: false,
Planner: PlannerConfig{
EvictionThreshold: DefaultPlannerEvictionThreshold,
CleanupInterval: DefaultPlannerCleanupInterval,
},
}
}
// MustDefaultConfig returns default server config with the playground, tracing and metrics turned off.
func MustDefaultConfig() *Config {
config := DefaultConfig()
config.Playground.Enabled = false
config.Metrics.Enabled = false
return config
}
package errors
import (
"net/http"
"regexp"
"strings"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
)
const (
cFirstAuthenticationErrorCode int32 = 1000
cFirstValidationErrorCode int32 = 2000
cFirstThrottlingErrorCode int32 = 3500
cFirstInternalErrorCode int32 = 4000
cFirstUnknownEndpointErrorCode int32 = 5000
)
type ErrorResponse struct {
Code string `json:"code"`
Message string `json:"message"`
codeInt int32
}
// EncodedError allows customized error with code in string and specified http status field.
type EncodedError struct {
HTTPStatusCode int
GRPCStatusCode codes.Code
ActualError ErrorResponse
}
// Error returns the encoded message.
func (e *EncodedError) Error() string {
return e.ActualError.Message
}
// CodeValue returns the encoded code in integer.
func (e *EncodedError) CodeValue() int32 {
return e.ActualError.codeInt
}
// HTTPStatus returns the HTTP Status code.
func (e *EncodedError) HTTPStatus() int {
return e.HTTPStatusCode
}
func (e *EncodedError) GRPCStatus() *status.Status {
return status.New(e.GRPCStatusCode, e.Error())
}
// Code returns the encoded code in string.
func (e *EncodedError) Code() string {
return e.ActualError.Code
}
func sanitizedMessage(message string) string {
parsedMessages := strings.Split(message, "| caused by:")
lastMessage := parsedMessages[len(parsedMessages)-1]
lastMessage = strings.TrimSpace(lastMessage)
sanitizedErrorMessage := regexp.MustCompile(`unexpected EOF`).ReplaceAllString(lastMessage, "malformed JSON")
sanitizedErrorMessage = regexp.MustCompile(`rpc error: code = [a-zA-Z0-9\(\)]* desc = `).ReplaceAllString(sanitizedErrorMessage, "")
return strings.TrimSpace(strings.TrimPrefix(sanitizedErrorMessage, "proto:"))
}
// NewEncodedError returns the encoded error with the correct http status code etc.
func NewEncodedError(errorCode int32, message string) *EncodedError {
if !IsValidEncodedError(errorCode) {
if errorCode == int32(codes.Aborted) {
return &EncodedError{
HTTPStatusCode: http.StatusConflict,
GRPCStatusCode: codes.Aborted,
ActualError: ErrorResponse{
Code: codes.Aborted.String(),
Message: sanitizedMessage(message),
codeInt: errorCode,
},
}
}
return &EncodedError{
HTTPStatusCode: http.StatusInternalServerError,
GRPCStatusCode: codes.Internal,
ActualError: ErrorResponse{
Code: openfgav1.InternalErrorCode(errorCode).String(),
Message: sanitizedMessage(message),
codeInt: errorCode,
},
}
}
var httpStatusCode int
var grpcStatusCode codes.Code
var code string
switch {
case errorCode >= cFirstAuthenticationErrorCode && errorCode < cFirstValidationErrorCode:
httpStatusCode = http.StatusUnauthorized
code = openfgav1.AuthErrorCode(errorCode).String()
grpcStatusCode = codes.Unauthenticated
case errorCode >= cFirstValidationErrorCode && errorCode < cFirstThrottlingErrorCode:
httpStatusCode = http.StatusBadRequest
code = openfgav1.ErrorCode(errorCode).String()
grpcStatusCode = codes.InvalidArgument
case errorCode >= cFirstThrottlingErrorCode && errorCode < cFirstInternalErrorCode:
httpStatusCode = http.StatusUnprocessableEntity
code = openfgav1.UnprocessableContentErrorCode(errorCode).String()
grpcStatusCode = codes.ResourceExhausted
case errorCode >= cFirstInternalErrorCode && errorCode < cFirstUnknownEndpointErrorCode:
httpStatusCode = http.StatusInternalServerError
code = openfgav1.InternalErrorCode(errorCode).String()
grpcStatusCode = codes.Internal
default:
httpStatusCode = http.StatusNotFound
code = openfgav1.NotFoundErrorCode(errorCode).String()
grpcStatusCode = codes.NotFound
}
return &EncodedError{
HTTPStatusCode: httpStatusCode,
GRPCStatusCode: grpcStatusCode,
ActualError: ErrorResponse{
Code: code,
Message: sanitizedMessage(message),
codeInt: errorCode,
},
}
}
// IsValidEncodedError returns whether the error code is a valid encoded error.
func IsValidEncodedError(errorCode int32) bool {
return errorCode >= cFirstAuthenticationErrorCode
}
func getCustomizedErrorCode(field string, reason string) int32 {
switch field {
case "Assertions":
if strings.HasPrefix(reason, "value must contain no more than") {
return int32(openfgav1.ErrorCode_assertions_too_many_items)
}
case "AuthorizationModelId":
if strings.HasPrefix(reason, "value length must be at most") {
return int32(openfgav1.ErrorCode_authorization_model_id_too_long)
}
case "Base":
if strings.HasPrefix(reason, "value is required") {
return int32(openfgav1.ErrorCode_difference_base_missing_value)
}
case "Id":
if strings.HasPrefix(reason, "value length must be at most") {
return int32(openfgav1.ErrorCode_id_too_long)
}
case "Object":
if strings.HasPrefix(reason, "value length must be at most") {
return int32(openfgav1.ErrorCode_object_too_long)
}
case "PageSize":
if strings.HasPrefix(reason, "value must be inside range") {
return int32(openfgav1.ErrorCode_page_size_invalid)
}
case "Params":
if strings.HasPrefix(reason, "value is required") {
return int32(openfgav1.ErrorCode_param_missing_value)
}
case "Relation":
if strings.HasPrefix(reason, "value length must be at most") {
return int32(openfgav1.ErrorCode_relation_too_long)
}
case "Relations":
if strings.HasPrefix(reason, "value must contain at least") {
return int32(openfgav1.ErrorCode_relations_too_few_items)
}
case "Subtract":
if strings.HasPrefix(reason, "value is required") {
return int32(openfgav1.ErrorCode_subtract_base_missing_value)
}
case "StoreId":
if strings.HasPrefix(reason, "value length must be") {
return int32(openfgav1.ErrorCode_store_id_invalid_length)
}
case "TupleKey":
if strings.HasPrefix(reason, "value is required") {
return int32(openfgav1.ErrorCode_tuple_key_value_not_specified)
}
case "TupleKeys":
if strings.HasPrefix(reason, "value must contain between") {
return int32(openfgav1.ErrorCode_tuple_keys_too_many_or_too_few_items)
}
case "Type":
if strings.HasPrefix(reason, "value length must be at") {
return int32(openfgav1.ErrorCode_type_invalid_length)
}
if strings.HasPrefix(reason, "value does not match regex pattern") {
return int32(openfgav1.ErrorCode_type_invalid_pattern)
}
case "TypeDefinitions":
if strings.HasPrefix(reason, "value must contain at least") {
return int32(openfgav1.ErrorCode_type_definitions_too_few_items)
}
}
// We will need to check for regex pattern
if strings.HasPrefix(field, "Relations[") {
if strings.HasPrefix(reason, "value length must be at most") {
return int32(openfgav1.ErrorCode_relations_too_long)
}
if strings.HasPrefix(reason, "value does not match regex pattern") {
return int32(openfgav1.ErrorCode_relations_invalid_pattern)
}
}
// When we get to here, this is not a type or message that we know well.
// We needs to return the generic error type
return int32(openfgav1.ErrorCode_validation_error)
}
func ConvertToEncodedErrorCode(statusError *status.Status) int32 {
code := int32(statusError.Code())
if code >= cFirstAuthenticationErrorCode {
return code
}
switch statusError.Code() {
case codes.OK:
return int32(codes.OK)
case codes.Unauthenticated:
return int32(openfgav1.AuthErrorCode_unauthenticated)
case codes.Canceled:
return int32(openfgav1.ErrorCode_cancelled)
case codes.Unknown:
// we will return InternalError as our implementation of
// InternalError does not have a status code - which will result
// in unknown error
return int32(openfgav1.InternalErrorCode_internal_error)
case codes.DeadlineExceeded:
return int32(openfgav1.InternalErrorCode_deadline_exceeded)
case codes.NotFound:
return int32(openfgav1.NotFoundErrorCode_undefined_endpoint)
case codes.AlreadyExists:
return int32(openfgav1.InternalErrorCode_already_exists)
case codes.ResourceExhausted:
return int32(openfgav1.InternalErrorCode_resource_exhausted)
case codes.FailedPrecondition:
return int32(openfgav1.InternalErrorCode_failed_precondition)
case codes.Aborted:
return int32(codes.Aborted)
case codes.OutOfRange:
return int32(openfgav1.InternalErrorCode_out_of_range)
case codes.Unimplemented:
return int32(openfgav1.NotFoundErrorCode_unimplemented)
case codes.Internal:
return int32(openfgav1.InternalErrorCode_internal_error)
case codes.Unavailable:
return int32(openfgav1.InternalErrorCode_unavailable)
case codes.DataLoss:
return int32(openfgav1.InternalErrorCode_data_loss)
case codes.InvalidArgument:
break
default:
// Unknown code - internal error
return int32(openfgav1.InternalErrorCode_internal_error)
}
// When we get to here, the cause is InvalidArgument (likely flagged by the framework's validator).
// We will try to find out the actual cause if possible. Otherwise, the default response will
// be openfgav1.ErrorCode_validation_error
lastMessage := sanitizedMessage(statusError.Message())
lastMessageSplitted := strings.SplitN(lastMessage, ": ", 2)
if len(lastMessageSplitted) < 2 {
// I don't know how to process this message.
// The safest thing is to return the generic validation error
return int32(openfgav1.ErrorCode_validation_error)
}
errorObjectSplitted := strings.Split(lastMessageSplitted[0], ".")
if len(errorObjectSplitted) != 2 {
// I don't know is the type.
// Return generic error type
return int32(openfgav1.ErrorCode_validation_error)
}
return getCustomizedErrorCode(errorObjectSplitted[1], lastMessageSplitted[1])
}
// Package errors contains custom error codes that are sent to clients.
package errors
import (
"context"
"errors"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/tuple"
)
const InternalServerErrorMsg = "Internal Server Error"
var (
// ErrAuthorizationModelResolutionTooComplex is used to avoid stack overflows.
ErrAuthorizationModelResolutionTooComplex = status.Error(codes.Code(openfgav1.ErrorCode_authorization_model_resolution_too_complex), "Authorization Model resolution required too many rewrite rules to be resolved. Check your authorization model for infinite recursion or too much nesting")
ErrInvalidWriteInput = status.Error(codes.Code(openfgav1.ErrorCode_invalid_write_input), "Invalid input. Make sure you provide at least one write, or at least one delete")
ErrInvalidContinuationToken = status.Error(codes.Code(openfgav1.ErrorCode_invalid_continuation_token), "Invalid continuation token")
ErrInvalidStartTime = status.Error(codes.Code(openfgav1.ErrorCode_invalid_start_time), "Invalid start time")
ErrInvalidExpandInput = status.Error(codes.Code(openfgav1.ErrorCode_invalid_expand_input), "Invalid input. Make sure you provide an object and a relation")
ErrUnsupportedUserSet = status.Error(codes.Code(openfgav1.ErrorCode_unsupported_user_set), "Userset is not supported (right now)")
ErrStoreIDNotFound = status.Error(codes.Code(openfgav1.NotFoundErrorCode_store_id_not_found), "Store ID not found")
ErrMismatchObjectType = status.Error(codes.Code(openfgav1.ErrorCode_query_string_type_continuation_token_mismatch), "The type in the querystring and the continuation token don't match")
ErrRequestCancelled = status.Error(codes.Code(openfgav1.ErrorCode_cancelled), "Request Cancelled")
ErrRequestDeadlineExceeded = status.Error(codes.Code(openfgav1.InternalErrorCode_deadline_exceeded), "Request Deadline Exceeded")
ErrThrottledTimeout = status.Error(codes.Code(openfgav1.UnprocessableContentErrorCode_throttled_timeout_error), "timeout due to throttling on complex request")
// ErrTransactionThrottled can apply when a limit is hit at the database level.
ErrTransactionThrottled = status.Error(codes.ResourceExhausted, "transaction was throttled by the datastore")
)
type InternalError struct {
public error
internal error
}
func (e InternalError) Error() string {
// hide the internal error in the message
return e.public.Error()
}
// Unwrap is called by errors.Is. It returns the underlying issue.
func (e InternalError) Unwrap() error {
return e.internal
}
func (e InternalError) GRPCStatus() *status.Status {
st, ok := status.FromError(e.public)
if ok {
return st
}
return status.New(codes.Unknown, e.public.Error())
}
// NewInternalError returns an error that is decorated with a public-facing error message.
// It is only meant to be called by HandleError.
func NewInternalError(public string, internal error) InternalError {
if public == "" {
public = InternalServerErrorMsg
}
return InternalError{
public: status.Error(codes.Code(openfgav1.InternalErrorCode_internal_error), public),
internal: internal,
}
}
func ValidationError(cause error) error {
return status.Error(codes.Code(openfgav1.ErrorCode_validation_error), cause.Error())
}
func AssertionsNotForAuthorizationModelFound(modelID string) error {
return status.Error(codes.Code(openfgav1.ErrorCode_authorization_model_assertions_not_found), fmt.Sprintf("No assertions found for authorization model '%s'", modelID))
}
func AuthorizationModelNotFound(modelID string) error {
return status.Error(codes.Code(openfgav1.ErrorCode_authorization_model_not_found), fmt.Sprintf("Authorization Model '%s' not found", modelID))
}
func LatestAuthorizationModelNotFound(store string) error {
return status.Error(codes.Code(openfgav1.ErrorCode_latest_authorization_model_not_found), fmt.Sprintf("No authorization models found for store '%s'", store))
}
func TypeNotFound(objectType string) error {
return status.Error(codes.Code(openfgav1.ErrorCode_type_not_found), fmt.Sprintf("type '%s' not found", objectType))
}
func RelationNotFound(relation string, objectType string, tk *openfgav1.TupleKey) error {
msg := fmt.Sprintf("relation '%s#%s' not found", objectType, relation)
if tk != nil {
msg += fmt.Sprintf(" for tuple '%s'", tuple.TupleKeyToString(tk))
}
return status.Error(codes.Code(openfgav1.ErrorCode_relation_not_found), msg)
}
func ExceededEntityLimit(entity string, limit int) error {
return status.Error(codes.Code(openfgav1.ErrorCode_exceeded_entity_limit),
fmt.Sprintf("The number of %s exceeds the allowed limit of %d", entity, limit))
}
func DuplicateTupleInWrite(tk tuple.TupleWithoutCondition) error {
return status.Error(codes.Code(openfgav1.ErrorCode_cannot_allow_duplicate_tuples_in_one_request), fmt.Sprintf("duplicate tuple in write: user: '%s', relation: '%s', object: '%s'", tk.GetUser(), tk.GetRelation(), tk.GetObject()))
}
func WriteFailedDueToInvalidInput(err error) error {
return status.Error(codes.Code(openfgav1.ErrorCode_write_failed_due_to_invalid_input), err.Error())
}
func InvalidAuthorizationModelInput(err error) error {
return status.Error(codes.Code(openfgav1.ErrorCode_invalid_authorization_model), err.Error())
}
// HandleError is used to surface some errors, and hide others.
// Use `public` if you want to return a useful error message to the user.
func HandleError(public string, err error) error {
switch {
case errors.Is(err, storage.ErrTransactionThrottled):
return ErrTransactionThrottled
case errors.Is(err, context.Canceled):
// cancel by a client is not an "internal server error"
return ErrRequestCancelled
case errors.Is(err, context.DeadlineExceeded):
return ErrRequestDeadlineExceeded
case errors.Is(err, storage.ErrInvalidStartTime):
return ErrInvalidStartTime
case errors.Is(err, storage.ErrInvalidContinuationToken):
return ErrInvalidContinuationToken
default:
return NewInternalError(public, err)
}
}
// HandleTupleValidateError provide common routines for handling tuples validation error.
func HandleTupleValidateError(err error) error {
switch t := err.(type) {
case *tuple.InvalidTupleError:
return status.Error(
codes.Code(openfgav1.ErrorCode_invalid_tuple),
fmt.Sprintf("Invalid tuple '%s'. Reason: %s", t.TupleKey, t.Cause.Error()),
)
case *tuple.TypeNotFoundError:
return TypeNotFound(t.TypeName)
case *tuple.RelationNotFoundError:
return RelationNotFound(t.Relation, t.TypeName, t.TupleKey)
case *tuple.InvalidConditionalTupleError:
return status.Error(
codes.Code(openfgav1.ErrorCode_validation_error),
err.Error(),
)
}
return HandleError("", err)
}
package server
import (
"context"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/utils/apimethod"
"github.com/openfga/openfga/pkg/middleware/validator"
"github.com/openfga/openfga/pkg/server/commands"
"github.com/openfga/openfga/pkg/telemetry"
"github.com/openfga/openfga/pkg/typesystem"
)
func (s *Server) Expand(ctx context.Context, req *openfgav1.ExpandRequest) (*openfgav1.ExpandResponse, error) {
tk := req.GetTupleKey()
ctx, span := tracer.Start(ctx, apimethod.Expand.String(), trace.WithAttributes(
attribute.KeyValue{Key: "store_id", Value: attribute.StringValue(req.GetStoreId())},
attribute.KeyValue{Key: "object", Value: attribute.StringValue(tk.GetObject())},
attribute.KeyValue{Key: "relation", Value: attribute.StringValue(tk.GetRelation())},
attribute.KeyValue{Key: "consistency", Value: attribute.StringValue(req.GetConsistency().String())},
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: apimethod.Expand.String(),
})
err := s.checkAuthz(ctx, req.GetStoreId(), apimethod.Expand)
if err != nil {
return nil, err
}
storeID := req.GetStoreId()
typesys, err := s.resolveTypesystem(ctx, storeID, req.GetAuthorizationModelId())
if err != nil {
return nil, err
}
q := commands.NewExpandQuery(s.datastore, commands.WithExpandQueryLogger(s.logger))
return q.Execute(
typesystem.ContextWithTypesystem(ctx, typesys),
&openfgav1.ExpandRequest{
StoreId: storeID,
TupleKey: tk,
Consistency: req.GetConsistency(),
ContextualTuples: req.GetContextualTuples(),
})
}
// Package health contains the service that check the health of an OpenFGA server.
package health
import (
"context"
grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
"google.golang.org/grpc/codes"
healthv1pb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/status"
)
// TargetService defines an interface that services can implement for server health checks.
type TargetService interface {
IsReady(ctx context.Context) (bool, error)
}
type Checker struct {
healthv1pb.UnimplementedHealthServer
TargetService
TargetServiceName string
}
var _ grpcauth.ServiceAuthFuncOverride = (*Checker)(nil)
// AuthFuncOverride implements the grpc_auth.ServiceAuthFuncOverride interface by bypassing authn middleware.
func (o *Checker) AuthFuncOverride(ctx context.Context, fullMethodName string) (context.Context, error) {
return ctx, nil
}
func (o *Checker) Check(ctx context.Context, req *healthv1pb.HealthCheckRequest) (*healthv1pb.HealthCheckResponse, error) {
requestedService := req.GetService()
if requestedService == "" || requestedService == o.TargetServiceName {
ready, err := o.IsReady(ctx)
if err != nil {
return &healthv1pb.HealthCheckResponse{Status: healthv1pb.HealthCheckResponse_NOT_SERVING}, err
}
if !ready {
return &healthv1pb.HealthCheckResponse{Status: healthv1pb.HealthCheckResponse_NOT_SERVING}, nil
}
return &healthv1pb.HealthCheckResponse{Status: healthv1pb.HealthCheckResponse_SERVING}, nil
}
return nil, status.Errorf(codes.NotFound, "service '%s' is not registered with the Health server", requestedService)
}
func (o *Checker) Watch(req *healthv1pb.HealthCheckRequest, server healthv1pb.Health_WatchServer) error {
return status.Error(codes.Unimplemented, "unimplemented streaming endpoint")
}
package server
import (
"context"
"errors"
"time"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/condition"
"github.com/openfga/openfga/internal/graph"
"github.com/openfga/openfga/internal/throttler/threshold"
"github.com/openfga/openfga/internal/utils"
"github.com/openfga/openfga/internal/utils/apimethod"
"github.com/openfga/openfga/pkg/middleware/validator"
"github.com/openfga/openfga/pkg/server/commands"
serverconfig "github.com/openfga/openfga/pkg/server/config"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/telemetry"
"github.com/openfga/openfga/pkg/typesystem"
)
func (s *Server) ListObjects(ctx context.Context, req *openfgav1.ListObjectsRequest) (*openfgav1.ListObjectsResponse, error) {
start := time.Now()
targetObjectType := req.GetType()
ctx, span := tracer.Start(ctx, apimethod.ListObjects.String(), trace.WithAttributes(
attribute.String("store_id", req.GetStoreId()),
attribute.String("object_type", targetObjectType),
attribute.String("relation", req.GetRelation()),
attribute.String("user", req.GetUser()),
attribute.String("consistency", req.GetConsistency().String()),
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
// TODO: This should be apimethod.ListObjects, but is it considered a breaking change to move?
const methodName = "listobjects"
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: methodName,
})
err := s.checkAuthz(ctx, req.GetStoreId(), apimethod.ListObjects)
if err != nil {
return nil, err
}
storeID := req.GetStoreId()
typesys, err := s.resolveTypesystem(ctx, storeID, req.GetAuthorizationModelId())
if err != nil {
return nil, err
}
builder := s.getListObjectsCheckResolverBuilder(req.GetStoreId())
checkResolver, checkResolverCloser, err := builder.Build()
if err != nil {
return nil, err
}
defer checkResolverCloser()
q, err := commands.NewListObjectsQueryWithShadowConfig(
s.datastore,
checkResolver,
commands.NewShadowListObjectsQueryConfig(
commands.WithShadowListObjectsQueryEnabled(s.featureFlagClient.Boolean(serverconfig.ExperimentalShadowListObjects, req.GetStoreId())),
commands.WithShadowListObjectsQueryTimeout(s.shadowListObjectsQueryTimeout),
commands.WithShadowListObjectsQueryMaxDeltaItems(s.shadowListObjectsQueryMaxDeltaItems),
commands.WithShadowListObjectsQueryLogger(s.logger),
),
req.GetStoreId(),
commands.WithLogger(s.logger),
commands.WithListObjectsDeadline(s.listObjectsDeadline),
commands.WithListObjectsMaxResults(s.listObjectsMaxResults),
commands.WithDispatchThrottlerConfig(threshold.Config{
Throttler: s.listObjectsDispatchThrottler,
Enabled: s.listObjectsDispatchThrottlingEnabled,
Threshold: s.listObjectsDispatchDefaultThreshold,
MaxThreshold: s.listObjectsDispatchThrottlingMaxThreshold,
}),
commands.WithResolveNodeLimit(s.resolveNodeLimit),
commands.WithResolveNodeBreadthLimit(s.resolveNodeBreadthLimit),
commands.WithMaxConcurrentReads(s.maxConcurrentReadsForListObjects),
commands.WithListObjectsCache(s.sharedDatastoreResources, s.cacheSettings),
commands.WithListObjectsDatastoreThrottler(
s.featureFlagClient.Boolean(serverconfig.ExperimentalDatastoreThrottling, storeID),
s.listObjectsDatastoreThrottleThreshold,
s.listObjectsDatastoreThrottleDuration,
),
commands.WithFeatureFlagClient(s.featureFlagClient),
)
if err != nil {
return nil, serverErrors.NewInternalError("", err)
}
result, err := q.Execute(
typesystem.ContextWithTypesystem(ctx, typesys),
&openfgav1.ListObjectsRequest{
StoreId: storeID,
ContextualTuples: req.GetContextualTuples(),
AuthorizationModelId: typesys.GetAuthorizationModelID(), // the resolved model id
Type: targetObjectType,
Relation: req.GetRelation(),
User: req.GetUser(),
Context: req.GetContext(),
Consistency: req.GetConsistency(),
},
)
if err != nil {
telemetry.TraceError(span, err)
if errors.Is(err, condition.ErrEvaluationFailed) {
return nil, serverErrors.ValidationError(err)
}
return nil, err
}
datastoreQueryCount := float64(result.ResolutionMetadata.DatastoreQueryCount.Load())
grpc_ctxtags.Extract(ctx).Set(datastoreQueryCountHistogramName, datastoreQueryCount)
span.SetAttributes(attribute.Float64(datastoreQueryCountHistogramName, datastoreQueryCount))
datastoreQueryCountHistogram.WithLabelValues(
s.serviceName,
methodName,
).Observe(datastoreQueryCount)
datastoreItemCount := float64(result.ResolutionMetadata.DatastoreItemCount.Load())
grpc_ctxtags.Extract(ctx).Set(datastoreItemCountHistogramName, datastoreItemCount)
span.SetAttributes(attribute.Float64(datastoreItemCountHistogramName, datastoreItemCount))
datastoreItemCountHistogram.WithLabelValues(
s.serviceName,
methodName,
).Observe(datastoreItemCount)
dispatchCount := float64(result.ResolutionMetadata.DispatchCounter.Load())
grpc_ctxtags.Extract(ctx).Set(dispatchCountHistogramName, dispatchCount)
span.SetAttributes(attribute.Float64(dispatchCountHistogramName, dispatchCount))
dispatchCountHistogram.WithLabelValues(
s.serviceName,
methodName,
).Observe(dispatchCount)
requestDurationHistogram.WithLabelValues(
s.serviceName,
methodName,
utils.Bucketize(uint(datastoreQueryCount), s.requestDurationByQueryHistogramBuckets),
utils.Bucketize(uint(result.ResolutionMetadata.DispatchCounter.Load()), s.requestDurationByDispatchCountHistogramBuckets),
req.GetConsistency().String(),
).Observe(float64(time.Since(start).Milliseconds()))
wasRequestThrottled := result.ResolutionMetadata.WasThrottled.Load()
if wasRequestThrottled {
throttledRequestCounter.WithLabelValues(s.serviceName, methodName).Inc()
}
listObjectsOptimzationLabel := "non-weighted"
if result.ResolutionMetadata.WasWeightedGraphUsed.Load() {
listObjectsOptimzationLabel = "weighted"
}
listObjectsOptimizationCounter.WithLabelValues(listObjectsOptimzationLabel).Inc()
checkCounter := float64(result.ResolutionMetadata.CheckCounter.Load())
grpc_ctxtags.Extract(ctx).Set(listObjectsCheckCountName, checkCounter)
return &openfgav1.ListObjectsResponse{
Objects: result.Objects,
}, nil
}
func (s *Server) StreamedListObjects(req *openfgav1.StreamedListObjectsRequest, srv openfgav1.OpenFGAService_StreamedListObjectsServer) error {
start := time.Now()
ctx := srv.Context()
ctx, span := tracer.Start(ctx, apimethod.StreamedListObjects.String(), trace.WithAttributes(
attribute.String("store_id", req.GetStoreId()),
attribute.String("object_type", req.GetType()),
attribute.String("relation", req.GetRelation()),
attribute.String("user", req.GetUser()),
attribute.String("consistency", req.GetConsistency().String()),
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return status.Error(codes.InvalidArgument, err.Error())
}
}
// TODO: This should be apimethod.StreamedListObjects, but is it considered a breaking change to move?
const methodName = "streamedlistobjects"
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: methodName,
})
err := s.checkAuthz(ctx, req.GetStoreId(), apimethod.StreamedListObjects)
if err != nil {
return err
}
storeID := req.GetStoreId()
typesys, err := s.resolveTypesystem(ctx, storeID, req.GetAuthorizationModelId())
if err != nil {
return err
}
builder := s.getListObjectsCheckResolverBuilder(req.GetStoreId())
checkResolver, checkResolverCloser, err := builder.Build()
if err != nil {
return err
}
defer checkResolverCloser()
q, err := commands.NewListObjectsQueryWithShadowConfig(
s.datastore,
checkResolver,
commands.NewShadowListObjectsQueryConfig(
commands.WithShadowListObjectsQueryEnabled(s.featureFlagClient.Boolean(serverconfig.ExperimentalShadowListObjects, req.GetStoreId())),
commands.WithShadowListObjectsQueryTimeout(s.shadowListObjectsQueryTimeout),
commands.WithShadowListObjectsQueryMaxDeltaItems(s.shadowListObjectsQueryMaxDeltaItems),
commands.WithShadowListObjectsQueryLogger(s.logger),
),
req.GetStoreId(),
commands.WithLogger(s.logger),
commands.WithListObjectsDeadline(s.listObjectsDeadline),
commands.WithDispatchThrottlerConfig(threshold.Config{
Throttler: s.listObjectsDispatchThrottler,
Enabled: s.listObjectsDispatchThrottlingEnabled,
Threshold: s.listObjectsDispatchDefaultThreshold,
MaxThreshold: s.listObjectsDispatchThrottlingMaxThreshold,
}),
commands.WithListObjectsMaxResults(s.listObjectsMaxResults),
commands.WithResolveNodeLimit(s.resolveNodeLimit),
commands.WithResolveNodeBreadthLimit(s.resolveNodeBreadthLimit),
commands.WithMaxConcurrentReads(s.maxConcurrentReadsForListObjects),
commands.WithFeatureFlagClient(s.featureFlagClient),
)
if err != nil {
return serverErrors.NewInternalError("", err)
}
req.AuthorizationModelId = typesys.GetAuthorizationModelID() // the resolved model id
resolutionMetadata, err := q.ExecuteStreamed(
typesystem.ContextWithTypesystem(ctx, typesys),
req,
srv,
)
if err != nil {
telemetry.TraceError(span, err)
return err
}
datastoreQueryCount := float64(resolutionMetadata.DatastoreQueryCount.Load())
grpc_ctxtags.Extract(ctx).Set(datastoreQueryCountHistogramName, datastoreQueryCount)
span.SetAttributes(attribute.Float64(datastoreQueryCountHistogramName, datastoreQueryCount))
datastoreQueryCountHistogram.WithLabelValues(
s.serviceName,
methodName,
).Observe(datastoreQueryCount)
datastoreItemCount := float64(resolutionMetadata.DatastoreItemCount.Load())
grpc_ctxtags.Extract(ctx).Set(datastoreItemCountHistogramName, datastoreItemCount)
span.SetAttributes(attribute.Float64(datastoreItemCountHistogramName, datastoreItemCount))
datastoreItemCountHistogram.WithLabelValues(
s.serviceName,
methodName,
).Observe(datastoreItemCount)
dispatchCount := float64(resolutionMetadata.DispatchCounter.Load())
grpc_ctxtags.Extract(ctx).Set(dispatchCountHistogramName, dispatchCount)
span.SetAttributes(attribute.Float64(dispatchCountHistogramName, dispatchCount))
dispatchCountHistogram.WithLabelValues(
s.serviceName,
methodName,
).Observe(dispatchCount)
requestDurationHistogram.WithLabelValues(
s.serviceName,
methodName,
utils.Bucketize(uint(datastoreQueryCount), s.requestDurationByQueryHistogramBuckets),
utils.Bucketize(uint(resolutionMetadata.DispatchCounter.Load()), s.requestDurationByDispatchCountHistogramBuckets),
req.GetConsistency().String(),
).Observe(float64(time.Since(start).Milliseconds()))
wasRequestThrottled := resolutionMetadata.WasThrottled.Load()
if wasRequestThrottled {
throttledRequestCounter.WithLabelValues(s.serviceName, methodName).Inc()
}
return nil
}
func (s *Server) getListObjectsCheckResolverBuilder(storeID string) *graph.CheckResolverOrderedBuilder {
checkCacheOptions, checkDispatchThrottlingOptions := s.getCheckResolverOptions()
return graph.NewOrderedCheckResolvers([]graph.CheckResolverOrderedBuilderOpt{
graph.WithLocalCheckerOpts([]graph.LocalCheckerOption{
graph.WithResolveNodeBreadthLimit(s.resolveNodeBreadthLimit),
graph.WithOptimizations(s.featureFlagClient.Boolean(serverconfig.ExperimentalCheckOptimizations, storeID)),
graph.WithMaxResolutionDepth(s.resolveNodeLimit),
}...),
graph.WithCachedCheckResolverOpts(s.cacheSettings.ShouldCacheCheckQueries(), checkCacheOptions...),
graph.WithDispatchThrottlingCheckResolverOpts(s.checkDispatchThrottlingEnabled, checkDispatchThrottlingOptions...),
}...)
}
package server
import (
"context"
"errors"
"strings"
"time"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/condition"
"github.com/openfga/openfga/internal/graph"
"github.com/openfga/openfga/internal/throttler/threshold"
"github.com/openfga/openfga/internal/utils"
"github.com/openfga/openfga/internal/utils/apimethod"
"github.com/openfga/openfga/pkg/middleware/validator"
"github.com/openfga/openfga/pkg/server/commands/listusers"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/telemetry"
"github.com/openfga/openfga/pkg/tuple"
"github.com/openfga/openfga/pkg/typesystem"
)
// ListUsers returns all users (e.g. subjects) matching a specific user filter criteria
// that have a specific relation with some object.
func (s *Server) ListUsers(
ctx context.Context,
req *openfgav1.ListUsersRequest,
) (*openfgav1.ListUsersResponse, error) {
start := time.Now()
ctx, span := tracer.Start(ctx, apimethod.ListUsers.String(), trace.WithAttributes(
attribute.String("store_id", req.GetStoreId()),
attribute.String("object", tuple.BuildObject(req.GetObject().GetType(), req.GetObject().GetId())),
attribute.String("relation", req.GetRelation()),
attribute.String("user_filters", userFiltersToString(req.GetUserFilters())),
attribute.String("consistency", req.GetConsistency().String()),
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
// TODO: This should be apimethod.ListUsers, but is it considered a breaking change to move?
const methodName = "listusers"
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: methodName,
})
err := s.checkAuthz(ctx, req.GetStoreId(), apimethod.ListUsers)
if err != nil {
return nil, err
}
typesys, err := s.resolveTypesystem(ctx, req.GetStoreId(), req.GetAuthorizationModelId())
if err != nil {
return nil, err
}
err = listusers.ValidateListUsersRequest(ctx, req, typesys)
if err != nil {
return nil, err
}
ctx = typesystem.ContextWithTypesystem(ctx, typesys)
listUsersQuery := listusers.NewListUsersQuery(s.datastore,
req.GetContextualTuples(),
listusers.WithResolveNodeLimit(s.resolveNodeLimit),
listusers.WithResolveNodeBreadthLimit(s.resolveNodeBreadthLimit),
listusers.WithListUsersQueryLogger(s.logger),
listusers.WithListUsersMaxResults(s.listUsersMaxResults),
listusers.WithListUsersDeadline(s.listUsersDeadline),
listusers.WithListUsersMaxConcurrentReads(s.maxConcurrentReadsForListUsers),
listusers.WithDispatchThrottlerConfig(threshold.Config{
Throttler: s.listUsersDispatchThrottler,
Enabled: s.listUsersDispatchThrottlingEnabled,
Threshold: s.listUsersDispatchDefaultThreshold,
MaxThreshold: s.listUsersDispatchThrottlingMaxThreshold,
}),
listusers.WithListUsersDatastoreThrottler(s.listUsersDatastoreThrottleThreshold, s.listUsersDatastoreThrottleDuration),
)
resp, err := listUsersQuery.ListUsers(ctx, req)
if err != nil {
telemetry.TraceError(span, err)
switch {
case errors.Is(err, graph.ErrResolutionDepthExceeded):
return nil, serverErrors.ErrAuthorizationModelResolutionTooComplex
case errors.Is(err, condition.ErrEvaluationFailed):
return nil, serverErrors.ValidationError(err)
default:
return nil, serverErrors.HandleError("", err)
}
}
datastoreQueryCount := float64(resp.Metadata.DatastoreQueryCount)
grpc_ctxtags.Extract(ctx).Set(datastoreQueryCountHistogramName, datastoreQueryCount)
span.SetAttributes(attribute.Float64(datastoreQueryCountHistogramName, datastoreQueryCount))
datastoreQueryCountHistogram.WithLabelValues(
s.serviceName,
methodName,
).Observe(datastoreQueryCount)
datastoreItemCount := float64(resp.Metadata.DatastoreItemCount)
grpc_ctxtags.Extract(ctx).Set(datastoreItemCountHistogramName, datastoreItemCount)
span.SetAttributes(attribute.Float64(datastoreItemCountHistogramName, datastoreItemCount))
datastoreItemCountHistogram.WithLabelValues(
s.serviceName,
methodName,
).Observe(datastoreItemCount)
dispatchCount := float64(resp.Metadata.DispatchCounter.Load())
grpc_ctxtags.Extract(ctx).Set(dispatchCountHistogramName, dispatchCount)
span.SetAttributes(attribute.Float64(dispatchCountHistogramName, dispatchCount))
dispatchCountHistogram.WithLabelValues(
s.serviceName,
methodName,
).Observe(dispatchCount)
requestDurationHistogram.WithLabelValues(
s.serviceName,
methodName,
utils.Bucketize(uint(datastoreQueryCount), s.requestDurationByQueryHistogramBuckets),
utils.Bucketize(uint(dispatchCount), s.requestDurationByDispatchCountHistogramBuckets),
req.GetConsistency().String(),
).Observe(float64(time.Since(start).Milliseconds()))
wasRequestThrottled := resp.GetMetadata().WasThrottled.Load()
if wasRequestThrottled {
throttledRequestCounter.WithLabelValues(s.serviceName, methodName).Inc()
}
return &openfgav1.ListUsersResponse{
Users: resp.GetUsers(),
}, nil
}
func userFiltersToString(filter []*openfgav1.UserTypeFilter) string {
var s strings.Builder
for _, f := range filter {
s.WriteString(f.GetType())
if f.GetRelation() != "" {
s.WriteString("#" + f.GetRelation())
}
}
return s.String()
}
package server
import (
"context"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/utils/apimethod"
"github.com/openfga/openfga/pkg/middleware/validator"
"github.com/openfga/openfga/pkg/server/commands"
"github.com/openfga/openfga/pkg/telemetry"
)
func (s *Server) Read(ctx context.Context, req *openfgav1.ReadRequest) (*openfgav1.ReadResponse, error) {
tk := req.GetTupleKey()
ctx, span := tracer.Start(ctx, apimethod.Read.String(), trace.WithAttributes(
attribute.String("store_id", req.GetStoreId()),
attribute.KeyValue{Key: "object", Value: attribute.StringValue(tk.GetObject())},
attribute.KeyValue{Key: "relation", Value: attribute.StringValue(tk.GetRelation())},
attribute.KeyValue{Key: "user", Value: attribute.StringValue(tk.GetUser())},
attribute.KeyValue{Key: "consistency", Value: attribute.StringValue(req.GetConsistency().String())},
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: apimethod.Read.String(),
})
err := s.checkAuthz(ctx, req.GetStoreId(), apimethod.Read)
if err != nil {
return nil, err
}
q := commands.NewReadQuery(s.datastore,
commands.WithReadQueryLogger(s.logger),
commands.WithReadQueryEncoder(s.encoder),
commands.WithReadQueryTokenSerializer(s.tokenSerializer),
)
return q.Execute(ctx, &openfgav1.ReadRequest{
StoreId: req.GetStoreId(),
TupleKey: tk,
PageSize: req.GetPageSize(),
ContinuationToken: req.GetContinuationToken(),
Consistency: req.GetConsistency(),
})
}
package server
import (
"context"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/utils/apimethod"
"github.com/openfga/openfga/pkg/middleware/validator"
"github.com/openfga/openfga/pkg/server/commands"
"github.com/openfga/openfga/pkg/telemetry"
)
func (s *Server) ReadChanges(ctx context.Context, req *openfgav1.ReadChangesRequest) (*openfgav1.ReadChangesResponse, error) {
ctx, span := tracer.Start(ctx, apimethod.ReadChanges.String(), trace.WithAttributes(
attribute.String("store_id", req.GetStoreId()),
attribute.KeyValue{Key: "type", Value: attribute.StringValue(req.GetType())},
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: apimethod.ReadChanges.String(),
})
err := s.checkAuthz(ctx, req.GetStoreId(), apimethod.ReadChanges)
if err != nil {
return nil, err
}
q := commands.NewReadChangesQuery(s.datastore,
commands.WithReadChangesQueryLogger(s.logger),
commands.WithReadChangesQueryEncoder(s.encoder),
commands.WithContinuationTokenSerializer(s.tokenSerializer),
commands.WithReadChangeQueryHorizonOffset(s.changelogHorizonOffset),
)
return q.Execute(ctx, req)
}
// Package server contains the endpoint handlers.
package server
import (
"context"
"errors"
"fmt"
"sort"
"time"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"github.com/oklog/ulid/v2"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"golang.org/x/sync/singleflight"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/authz"
"github.com/openfga/openfga/internal/build"
"github.com/openfga/openfga/internal/graph"
"github.com/openfga/openfga/internal/planner"
"github.com/openfga/openfga/internal/shared"
"github.com/openfga/openfga/internal/throttler"
"github.com/openfga/openfga/internal/utils"
"github.com/openfga/openfga/internal/utils/apimethod"
"github.com/openfga/openfga/pkg/authclaims"
"github.com/openfga/openfga/pkg/encoder"
"github.com/openfga/openfga/pkg/featureflags"
"github.com/openfga/openfga/pkg/gateway"
"github.com/openfga/openfga/pkg/logger"
serverconfig "github.com/openfga/openfga/pkg/server/config"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/storagewrappers"
"github.com/openfga/openfga/pkg/telemetry"
"github.com/openfga/openfga/pkg/typesystem"
)
const (
AuthorizationModelIDHeader = "Openfga-Authorization-Model-Id"
authorizationModelIDKey = "authorization_model_id"
allowedLabel = "allowed"
)
var tracer = otel.Tracer("openfga/pkg/server")
var (
dispatchCountHistogramName = "dispatch_count"
dispatchCountHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: dispatchCountHistogramName,
Help: "The number of dispatches required to resolve a query (e.g. Check).",
Buckets: []float64{1, 5, 20, 50, 100, 150, 225, 400, 500, 750, 1000},
NativeHistogramBucketFactor: 1.1,
NativeHistogramMaxBucketNumber: 100,
NativeHistogramMinResetDuration: time.Hour,
}, []string{"grpc_service", "grpc_method"})
datastoreQueryCountHistogramName = "datastore_query_count"
datastoreQueryCountHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: datastoreQueryCountHistogramName,
Help: "The number of database queries required to resolve a query (e.g. Check, ListObjects or ListUsers).",
Buckets: []float64{1, 5, 20, 50, 100, 150, 225, 400, 500, 750, 1000},
NativeHistogramBucketFactor: 1.1,
NativeHistogramMaxBucketNumber: 100,
NativeHistogramMinResetDuration: time.Hour,
}, []string{"grpc_service", "grpc_method"})
datastoreItemCountHistogramName = "datastore_item_count"
datastoreItemCountHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: datastoreItemCountHistogramName,
Help: "The number of items returned from the database required to resolve a query (e.g. Check, ListObjects or ListUsers).",
Buckets: []float64{1, 5, 20, 50, 100, 150, 225, 400, 500, 750, 1000},
NativeHistogramBucketFactor: 1.1,
NativeHistogramMaxBucketNumber: 100,
NativeHistogramMinResetDuration: time.Hour,
}, []string{"grpc_service", "grpc_method"})
requestDurationHistogramName = "request_duration_ms"
requestDurationHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: requestDurationHistogramName,
Help: "The request duration (in ms) labeled by method and buckets of datastore query counts and number of dispatches. This allows for reporting percentiles based on the number of datastore queries and number of dispatches required to resolve the request.",
Buckets: []float64{1, 5, 10, 25, 50, 80, 100, 150, 200, 300, 1000, 2000, 5000},
NativeHistogramBucketFactor: 1.1,
NativeHistogramMaxBucketNumber: 100,
NativeHistogramMinResetDuration: time.Hour,
}, []string{"grpc_service", "grpc_method", "datastore_query_count", "dispatch_count", "consistency"})
listObjectsOptimizationCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: "list_objects_optimization_count",
Help: "The total number of requests that have been processed by the weighted graph vs non-weighted graph.",
}, []string{"strategy"})
listObjectsCheckCountName = "check_count"
throttledRequestCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: "throttled_requests_count",
Help: "The total number of requests that have been throttled.",
}, []string{"grpc_service", "grpc_method"})
checkResultCounterName = "check_result_count"
checkResultCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: checkResultCounterName,
Help: "The total number of check requests by response result",
}, []string{allowedLabel})
accessControlStoreCheckDurationHistogramName = "access_control_store_check_request_duration_ms"
accessControlStoreCheckDurationHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: accessControlStoreCheckDurationHistogramName,
Help: "The request duration (in ms) for access control store's check duration labeled by method and buckets of datastore query counts and number of dispatches.",
Buckets: []float64{1, 5, 10, 25, 50, 80, 100, 150, 200, 300, 1000, 2000, 5000},
NativeHistogramBucketFactor: 1.1,
NativeHistogramMaxBucketNumber: 100,
NativeHistogramMinResetDuration: time.Hour,
}, []string{"datastore_query_count", "dispatch_count", "consistency"})
writeDurationHistogramName = "write_duration_ms"
writeDurationHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: writeDurationHistogramName,
Help: "The request duration (in ms) for write API labeled by whether an authorizer check is required or not.",
Buckets: []float64{1, 5, 10, 25, 50, 80, 100, 150, 200, 300, 1000, 2000, 5000},
NativeHistogramBucketFactor: 1.1,
NativeHistogramMaxBucketNumber: 100,
NativeHistogramMinResetDuration: time.Hour,
}, []string{"require_authorize_check", "on_duplicate_write", "on_missing_delete"})
checkDurationHistogramName = "check_duration_ms"
checkDurationHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: checkDurationHistogramName,
Help: "The duration of check command resolution, labeled by parent_method and datastore_query_count (in buckets)",
Buckets: []float64{1, 5, 10, 25, 50, 80, 100, 150, 200, 300, 1000, 2000, 5000},
NativeHistogramBucketFactor: 1.1,
NativeHistogramMaxBucketNumber: 100,
NativeHistogramMinResetDuration: time.Hour,
}, []string{"datastore_query_count", "caller"})
)
// A Server implements the OpenFGA service backend as both
// a GRPC and HTTP server.
type Server struct {
openfgav1.UnimplementedOpenFGAServiceServer
logger logger.Logger
datastore storage.OpenFGADatastore
tokenSerializer encoder.ContinuationTokenSerializer
encoder encoder.Encoder
transport gateway.Transport
resolveNodeLimit uint32
resolveNodeBreadthLimit uint32
changelogHorizonOffset int
listObjectsDeadline time.Duration
listObjectsMaxResults uint32
listUsersDeadline time.Duration
listUsersMaxResults uint32
maxChecksPerBatchCheck uint32
maxConcurrentChecksPerBatch uint32
maxConcurrentReadsForListObjects uint32
maxConcurrentReadsForCheck uint32
maxConcurrentReadsForListUsers uint32
maxAuthorizationModelCacheSize int
maxAuthorizationModelSizeInBytes int
experimentals []string
AccessControl serverconfig.AccessControlConfig
AuthnMethod string
serviceName string
featureFlagClient featureflags.Client
// NOTE don't use this directly, use function resolveTypesystem. See https://github.com/openfga/openfga/issues/1527
typesystemResolver typesystem.TypesystemResolverFunc
typesystemResolverStop func()
// cacheSettings are given by the user
cacheSettings serverconfig.CacheSettings
// sharedDatastoreResources are created by the server
sharedDatastoreResources *shared.SharedDatastoreResources
shadowCheckResolverTimeout time.Duration
shadowListObjectsQueryTimeout time.Duration
shadowListObjectsQueryMaxDeltaItems int
requestDurationByQueryHistogramBuckets []uint
requestDurationByDispatchCountHistogramBuckets []uint
checkDispatchThrottlingEnabled bool
checkDispatchThrottlingFrequency time.Duration
checkDispatchThrottlingDefaultThreshold uint32
checkDispatchThrottlingMaxThreshold uint32
listObjectsDispatchThrottlingEnabled bool
listObjectsDispatchThrottlingFrequency time.Duration
listObjectsDispatchDefaultThreshold uint32
listObjectsDispatchThrottlingMaxThreshold uint32
listUsersDispatchThrottlingEnabled bool
listUsersDispatchThrottlingFrequency time.Duration
listUsersDispatchDefaultThreshold uint32
listUsersDispatchThrottlingMaxThreshold uint32
listObjectsDispatchThrottler throttler.Throttler
listUsersDispatchThrottler throttler.Throttler
checkDatastoreThrottleThreshold int
checkDatastoreThrottleDuration time.Duration
listObjectsDatastoreThrottleThreshold int
listObjectsDatastoreThrottleDuration time.Duration
listUsersDatastoreThrottleThreshold int
listUsersDatastoreThrottleDuration time.Duration
authorizer authz.AuthorizerInterface
ctx context.Context
contextPropagationToDatastore bool
// singleflightGroup can be shared across caches, deduplicators, etc.
singleflightGroup *singleflight.Group
planner *planner.Planner
requestTimeout time.Duration
}
type OpenFGAServiceV1Option func(s *Server)
// WithDatastore passes a datastore to the Server.
// You must call [storage.OpenFGADatastore.Close] on it after you have stopped using it.
func WithDatastore(ds storage.OpenFGADatastore) OpenFGAServiceV1Option {
return func(s *Server) {
s.datastore = ds
}
}
func WithContinuationTokenSerializer(ds encoder.ContinuationTokenSerializer) OpenFGAServiceV1Option {
return func(s *Server) {
s.tokenSerializer = ds
}
}
// WithContext passes the server context to allow for graceful shutdowns.
func WithContext(ctx context.Context) OpenFGAServiceV1Option {
return func(s *Server) {
s.ctx = ctx
}
}
// WithAuthorizationModelCacheSize sets the maximum number of authorization models that will be cached in memory.
func WithAuthorizationModelCacheSize(maxAuthorizationModelCacheSize int) OpenFGAServiceV1Option {
return func(s *Server) {
s.maxAuthorizationModelCacheSize = maxAuthorizationModelCacheSize
}
}
func WithLogger(l logger.Logger) OpenFGAServiceV1Option {
return func(s *Server) {
s.logger = l
}
}
func WithTokenEncoder(encoder encoder.Encoder) OpenFGAServiceV1Option {
return func(s *Server) {
s.encoder = encoder
}
}
// WithTransport sets the connection transport.
func WithTransport(t gateway.Transport) OpenFGAServiceV1Option {
return func(s *Server) {
s.transport = t
}
}
// WithResolveNodeLimit sets a limit on the number of recursive calls that one Check, ListObjects or ListUsers call will allow.
// Thinking of a request as a tree of evaluations, this option controls
// how many levels we will evaluate before throwing an error that the authorization model is too complex.
func WithResolveNodeLimit(limit uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.resolveNodeLimit = limit
}
}
// WithResolveNodeBreadthLimit sets a limit on the number of goroutines that can be created
// when evaluating a subtree of a Check, ListObjects or ListUsers call.
// Thinking of a Check request as a tree of evaluations, this option controls,
// on a given level of the tree, the maximum number of nodes that can be evaluated concurrently (the breadth).
// If your authorization models are very complex (e.g. one relation is a union of many relations, or one relation
// is deeply nested), or if you have lots of users for (object, relation) pairs,
// you should set this option to be a low number (e.g. 1000).
func WithResolveNodeBreadthLimit(limit uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.resolveNodeBreadthLimit = limit
}
}
// WithChangelogHorizonOffset sets an offset (in minutes) from the current time.
// Changes that occur after this offset will not be included in the response of ReadChanges API.
// If your datastore is eventually consistent or if you have a database with replication delay, we recommend setting this (e.g. 1 minute).
func WithChangelogHorizonOffset(offset int) OpenFGAServiceV1Option {
return func(s *Server) {
s.changelogHorizonOffset = offset
}
}
// WithListObjectsDeadline affect the ListObjects API and Streamed ListObjects API only.
// It sets the maximum amount of time that the server will spend gathering results.
func WithListObjectsDeadline(deadline time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.listObjectsDeadline = deadline
}
}
// WithListObjectsMaxResults affects the ListObjects API only.
// It sets the maximum number of results that this API will return.
func WithListObjectsMaxResults(limit uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.listObjectsMaxResults = limit
}
}
// WithListUsersDeadline affect the ListUsers API only.
// It sets the maximum amount of time that the server will spend gathering results.
func WithListUsersDeadline(deadline time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.listUsersDeadline = deadline
}
}
// WithListUsersMaxResults affects the ListUsers API only.
// It sets the maximum number of results that this API will return.
// If it's zero, all results will be attempted to be returned.
func WithListUsersMaxResults(limit uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.listUsersMaxResults = limit
}
}
// WithMaxConcurrentReadsForListObjects sets a limit on the number of datastore reads that can be in flight for a given ListObjects call.
// This number should be set depending on the RPS expected for Check and ListObjects APIs, the number of OpenFGA replicas running,
// and the number of connections the datastore allows.
// E.g. If Datastore.MaxOpenConns = 100 and assuming that each ListObjects call takes 1 second and no traffic to Check API:
// - One OpenFGA replica and expected traffic of 100 RPS => set it to 1.
// - One OpenFGA replica and expected traffic of 1 RPS => set it to 100.
// - Two OpenFGA replicas and expected traffic of 1 RPS => set it to 50.
func WithMaxConcurrentReadsForListObjects(maxConcurrentReads uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.maxConcurrentReadsForListObjects = maxConcurrentReads
}
}
// WithMaxConcurrentReadsForCheck sets a limit on the number of datastore reads that can be in flight for a given Check call.
// This number should be set depending on the RPS expected for Check and ListObjects APIs, the number of OpenFGA replicas running,
// and the number of connections the datastore allows.
// E.g. If Datastore.MaxOpenConns = 100 and assuming that each Check call takes 1 second and no traffic to ListObjects API:
// - One OpenFGA replica and expected traffic of 100 RPS => set it to 1.
// - One OpenFGA replica and expected traffic of 1 RPS => set it to 100.
// - Two OpenFGA replicas and expected traffic of 1 RPS => set it to 50.
func WithMaxConcurrentReadsForCheck(maxConcurrentReadsForCheck uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.maxConcurrentReadsForCheck = maxConcurrentReadsForCheck
}
}
// WithMaxConcurrentReadsForListUsers sets a limit on the number of datastore reads that can be in flight for a given ListUsers call.
// This number should be set depending on the RPS expected for all query APIs, the number of OpenFGA replicas running,
// and the number of connections the datastore allows.
// E.g. If Datastore.MaxOpenConns = 100 and assuming that each ListUsers call takes 1 second and no traffic to other query APIs:
// - One OpenFGA replica and expected traffic of 100 RPS => set it to 1.
// - One OpenFGA replica and expected traffic of 1 RPS => set it to 100.
// - Two OpenFGA replicas and expected traffic of 1 RPS => set it to 50.
func WithMaxConcurrentReadsForListUsers(maxConcurrentReadsForListUsers uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.maxConcurrentReadsForListUsers = maxConcurrentReadsForListUsers
}
}
func WithExperimentals(experimentals ...string) OpenFGAServiceV1Option {
return func(s *Server) {
s.experimentals = experimentals
}
}
func WithFeatureFlagClient(client featureflags.Client) OpenFGAServiceV1Option {
return func(s *Server) {
if client != nil {
s.featureFlagClient = client
return
}
s.featureFlagClient = featureflags.NewNoopFeatureFlagClient()
}
}
// WithAccessControlParams sets enabled, the storeID, and modelID for the access control feature.
func WithAccessControlParams(enabled bool, storeID string, modelID string, authnMethod string) OpenFGAServiceV1Option {
return func(s *Server) {
s.AccessControl = serverconfig.AccessControlConfig{
Enabled: enabled,
StoreID: storeID,
ModelID: modelID,
}
s.AuthnMethod = authnMethod
}
}
// WithCheckQueryCacheEnabled enables caching of Check results for the Check and List objects APIs.
// This cache is shared for all requests.
// See also WithCheckCacheLimit and WithCheckQueryCacheTTL.
func WithCheckQueryCacheEnabled(enabled bool) OpenFGAServiceV1Option {
return func(s *Server) {
s.cacheSettings.CheckQueryCacheEnabled = enabled
}
}
// WithCheckCacheLimit sets the check cache size limit (in items).
func WithCheckCacheLimit(limit uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.cacheSettings.CheckCacheLimit = limit
}
}
// WithCacheControllerEnabled enables cache invalidation of different cache entities.
func WithCacheControllerEnabled(enabled bool) OpenFGAServiceV1Option {
return func(s *Server) {
s.cacheSettings.CacheControllerEnabled = enabled
}
}
// WithCacheControllerTTL sets the frequency for the controller to execute.
func WithCacheControllerTTL(ttl time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.cacheSettings.CacheControllerTTL = ttl
}
}
// WithCheckQueryCacheTTL sets the TTL of cached checks and list objects partial results
// Needs WithCheckQueryCacheEnabled set to true.
func WithCheckQueryCacheTTL(ttl time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.cacheSettings.CheckQueryCacheTTL = ttl
}
}
// WithCheckIteratorCacheEnabled enables caching of iterators produced within Check for subsequent requests.
func WithCheckIteratorCacheEnabled(enabled bool) OpenFGAServiceV1Option {
return func(s *Server) {
s.cacheSettings.CheckIteratorCacheEnabled = enabled
}
}
// WithCheckIteratorCacheMaxResults sets the limit of an iterator size to cache (in items)
// Needs WithCheckIteratorCacheEnabled set to true.
func WithCheckIteratorCacheMaxResults(limit uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.cacheSettings.CheckIteratorCacheMaxResults = limit
}
}
// WithCheckIteratorCacheTTL sets the TTL of iterator caches.
// Needs WithCheckIteratorCacheEnabled set to true.
func WithCheckIteratorCacheTTL(ttl time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.cacheSettings.CheckIteratorCacheTTL = ttl
}
}
// WithListObjectsIteratorCacheEnabled enables caching of iterators produced within Check for subsequent requests.
func WithListObjectsIteratorCacheEnabled(enabled bool) OpenFGAServiceV1Option {
return func(s *Server) {
s.cacheSettings.ListObjectsIteratorCacheEnabled = enabled
}
}
// WithListObjectsIteratorCacheMaxResults sets the limit of an iterator size to cache (in items)
// Needs WithListObjectsIteratorCacheEnabled set to true.
func WithListObjectsIteratorCacheMaxResults(limit uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.cacheSettings.ListObjectsIteratorCacheMaxResults = limit
}
}
// WithListObjectsIteratorCacheTTL sets the TTL of iterator caches.
// Needs WithListObjectsCheckIteratorCacheEnabled set to true.
func WithListObjectsIteratorCacheTTL(ttl time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.cacheSettings.ListObjectsIteratorCacheTTL = ttl
}
}
// WithRequestDurationByQueryHistogramBuckets sets the buckets used in labelling the requestDurationByQueryAndDispatchHistogram.
func WithRequestDurationByQueryHistogramBuckets(buckets []uint) OpenFGAServiceV1Option {
return func(s *Server) {
sort.Slice(buckets, func(i, j int) bool { return buckets[i] < buckets[j] })
s.requestDurationByQueryHistogramBuckets = buckets
}
}
// WithRequestDurationByDispatchCountHistogramBuckets sets the buckets used in labelling the requestDurationByQueryAndDispatchHistogram.
func WithRequestDurationByDispatchCountHistogramBuckets(buckets []uint) OpenFGAServiceV1Option {
return func(s *Server) {
sort.Slice(buckets, func(i, j int) bool { return buckets[i] < buckets[j] })
s.requestDurationByDispatchCountHistogramBuckets = buckets
}
}
func WithMaxAuthorizationModelSizeInBytes(size int) OpenFGAServiceV1Option {
return func(s *Server) {
s.maxAuthorizationModelSizeInBytes = size
}
}
// WithDispatchThrottlingCheckResolverEnabled sets whether dispatch throttling is enabled for Check requests.
// Enabling this feature will prioritize dispatched requests requiring less than the configured dispatch
// threshold over requests whose dispatch count exceeds the configured threshold.
func WithDispatchThrottlingCheckResolverEnabled(enabled bool) OpenFGAServiceV1Option {
return func(s *Server) {
s.checkDispatchThrottlingEnabled = enabled
}
}
// WithDispatchThrottlingCheckResolverFrequency defines how frequent dispatch throttling
// will be evaluated for Check requests.
// Frequency controls how frequently throttled dispatch requests are evaluated to determine whether
// it can be processed.
// This value should not be too small (i.e., in the ns ranges) as i) there are limitation in timer resolution
// and ii) very small value will result in a higher frequency of processing dispatches,
// which diminishes the value of the throttling.
func WithDispatchThrottlingCheckResolverFrequency(frequency time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.checkDispatchThrottlingFrequency = frequency
}
}
// WithDispatchThrottlingCheckResolverThreshold define the number of dispatches to be throttled.
// In addition, it will update checkDispatchThrottlingMaxThreshold if required.
func WithDispatchThrottlingCheckResolverThreshold(defaultThreshold uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.checkDispatchThrottlingDefaultThreshold = defaultThreshold
}
}
// WithDispatchThrottlingCheckResolverMaxThreshold define the maximum threshold values allowed
// It will ensure checkDispatchThrottlingMaxThreshold will never be smaller than threshold.
func WithDispatchThrottlingCheckResolverMaxThreshold(maxThreshold uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.checkDispatchThrottlingMaxThreshold = maxThreshold
}
}
// WithContextPropagationToDatastore determines whether the request context is propagated to the datastore.
// When enabled, the datastore receives cancellation signals when an API request is cancelled.
// When disabled, datastore operations continue even if the original request context is cancelled.
// Disabling context propagation is normally desirable to avoid unnecessary database connection churn.
// If not specified, the default value is false (separate storage and request contexts).
func WithContextPropagationToDatastore(enable bool) OpenFGAServiceV1Option {
return func(s *Server) {
s.contextPropagationToDatastore = enable
}
}
func WithPlanner(planner *planner.Planner) OpenFGAServiceV1Option {
return func(s *Server) {
s.planner = planner
}
}
func WithRequestTimeout(timeout time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.requestTimeout = timeout
}
}
// MustNewServerWithOpts see NewServerWithOpts.
func MustNewServerWithOpts(opts ...OpenFGAServiceV1Option) *Server {
s, err := NewServerWithOpts(opts...)
if err != nil {
panic(fmt.Errorf("failed to construct the OpenFGA server: %w", err))
}
return s
}
// IsAccessControlEnabled returns true if the access control feature is enabled.
func (s *Server) IsAccessControlEnabled() bool {
isEnabled := s.featureFlagClient.Boolean(serverconfig.ExperimentalAccessControlParams, "")
return isEnabled && s.AccessControl.Enabled
}
// WithListObjectsDispatchThrottlingEnabled sets whether dispatch throttling is enabled for List Objects requests.
// Enabling this feature will prioritize dispatched requests requiring less than the configured dispatch
// threshold over requests whose dispatch count exceeds the configured threshold.
func WithListObjectsDispatchThrottlingEnabled(enabled bool) OpenFGAServiceV1Option {
return func(s *Server) {
s.listObjectsDispatchThrottlingEnabled = enabled
}
}
// WithListObjectsDispatchThrottlingFrequency defines how frequent dispatch throttling
// will be evaluated for List Objects requests.
// Frequency controls how frequently throttled dispatch requests are evaluated to determine whether
// it can be processed.
// This value should not be too small (i.e., in the ns ranges) as i) there are limitation in timer resolution
// and ii) very small value will result in a higher frequency of processing dispatches,
// which diminishes the value of the throttling.
func WithListObjectsDispatchThrottlingFrequency(frequency time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.listObjectsDispatchThrottlingFrequency = frequency
}
}
// WithListObjectsDispatchThrottlingThreshold define the number of dispatches to be throttled
// for List Objects requests.
func WithListObjectsDispatchThrottlingThreshold(threshold uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.listObjectsDispatchDefaultThreshold = threshold
}
}
// WithListObjectsDispatchThrottlingMaxThreshold define the maximum threshold values allowed
// It will ensure listObjectsDispatchThrottlingMaxThreshold will never be smaller than threshold.
func WithListObjectsDispatchThrottlingMaxThreshold(maxThreshold uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.listObjectsDispatchThrottlingMaxThreshold = maxThreshold
}
}
// WithListUsersDispatchThrottlingEnabled sets whether dispatch throttling is enabled for ListUsers requests.
// Enabling this feature will prioritize dispatched requests requiring less than the configured dispatch
// threshold over requests whose dispatch count exceeds the configured threshold.
func WithListUsersDispatchThrottlingEnabled(enabled bool) OpenFGAServiceV1Option {
return func(s *Server) {
s.listUsersDispatchThrottlingEnabled = enabled
}
}
// WithListUsersDispatchThrottlingFrequency defines how frequent dispatch throttling
// will be evaluated for ListUsers requests.
// Frequency controls how frequently throttled dispatch requests are evaluated to determine whether
// it can be processed.
// This value should not be too small (i.e., in the ns ranges) as i) there are limitation in timer resolution
// and ii) very small value will result in a higher frequency of processing dispatches,
// which diminishes the value of the throttling.
func WithListUsersDispatchThrottlingFrequency(frequency time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.listUsersDispatchThrottlingFrequency = frequency
}
}
// WithListUsersDispatchThrottlingThreshold define the number of dispatches to be throttled
// for ListUsers requests.
func WithListUsersDispatchThrottlingThreshold(threshold uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.listUsersDispatchDefaultThreshold = threshold
}
}
// WithListUsersDispatchThrottlingMaxThreshold define the maximum threshold values allowed
// It will ensure listUsersDispatchThrottlingMaxThreshold will never be smaller than threshold.
func WithListUsersDispatchThrottlingMaxThreshold(maxThreshold uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.listUsersDispatchThrottlingMaxThreshold = maxThreshold
}
}
// WithMaxConcurrentChecksPerBatchCheck defines the maximum number of checks
// allowed to be processed concurrently in a single batch request.
func WithMaxConcurrentChecksPerBatchCheck(maxConcurrentChecks uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.maxConcurrentChecksPerBatch = maxConcurrentChecks
}
}
// WithMaxChecksPerBatchCheck defines the maximum number of checks allowed to be sent
// in a single BatchCheck request.
func WithMaxChecksPerBatchCheck(maxChecks uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.maxChecksPerBatchCheck = maxChecks
}
}
func WithCheckDatabaseThrottle(threshold int, duration time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.checkDatastoreThrottleThreshold = threshold
s.checkDatastoreThrottleDuration = duration
}
}
func WithListObjectsDatabaseThrottle(threshold int, duration time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.listObjectsDatastoreThrottleThreshold = threshold
s.listObjectsDatastoreThrottleDuration = duration
}
}
func WithListUsersDatabaseThrottle(threshold int, duration time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.listUsersDatastoreThrottleThreshold = threshold
s.listUsersDatastoreThrottleDuration = duration
}
}
// WithShadowCheckResolverTimeout is the amount of time to wait for the shadow Check evaluation response.
func WithShadowCheckResolverTimeout(threshold time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.shadowCheckResolverTimeout = threshold
}
}
// WithShadowListObjectsQueryTimeout is the amount of time to wait for the shadow ListObjects evaluation response.
func WithShadowListObjectsQueryTimeout(threshold time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.shadowListObjectsQueryTimeout = threshold
}
}
func WithShadowListObjectsQueryMaxDeltaItems(maxDeltaItems int) OpenFGAServiceV1Option {
return func(s *Server) {
s.shadowListObjectsQueryMaxDeltaItems = maxDeltaItems
}
}
// WithSharedIteratorEnabled enables iterator to be shared across different consumer.
func WithSharedIteratorEnabled(enabled bool) OpenFGAServiceV1Option {
return func(s *Server) {
s.cacheSettings.SharedIteratorEnabled = enabled
}
}
// WithSharedIteratorLimit sets the number of items that can be shared.
func WithSharedIteratorLimit(limit uint32) OpenFGAServiceV1Option {
return func(s *Server) {
s.cacheSettings.SharedIteratorLimit = limit
}
}
func WithSharedIteratorTTL(ttl time.Duration) OpenFGAServiceV1Option {
return func(s *Server) {
s.cacheSettings.SharedIteratorTTL = ttl
}
}
// NewServerWithOpts returns a new server.
// You must call Close on it after you are done using it.
func NewServerWithOpts(opts ...OpenFGAServiceV1Option) (*Server, error) {
s := &Server{
ctx: context.Background(),
logger: logger.NewNoopLogger(),
encoder: encoder.NewBase64Encoder(),
transport: gateway.NewNoopTransport(),
changelogHorizonOffset: serverconfig.DefaultChangelogHorizonOffset,
resolveNodeLimit: serverconfig.DefaultResolveNodeLimit,
resolveNodeBreadthLimit: serverconfig.DefaultResolveNodeBreadthLimit,
listObjectsDeadline: serverconfig.DefaultListObjectsDeadline,
listObjectsMaxResults: serverconfig.DefaultListObjectsMaxResults,
listUsersDeadline: serverconfig.DefaultListUsersDeadline,
listUsersMaxResults: serverconfig.DefaultListUsersMaxResults,
maxChecksPerBatchCheck: serverconfig.DefaultMaxChecksPerBatchCheck,
maxConcurrentChecksPerBatch: serverconfig.DefaultMaxConcurrentChecksPerBatchCheck,
maxConcurrentReadsForCheck: serverconfig.DefaultMaxConcurrentReadsForCheck,
maxConcurrentReadsForListObjects: serverconfig.DefaultMaxConcurrentReadsForListObjects,
maxConcurrentReadsForListUsers: serverconfig.DefaultMaxConcurrentReadsForListUsers,
maxAuthorizationModelSizeInBytes: serverconfig.DefaultMaxAuthorizationModelSizeInBytes,
maxAuthorizationModelCacheSize: serverconfig.DefaultMaxAuthorizationModelCacheSize,
experimentals: make([]string, 0, 10),
AccessControl: serverconfig.AccessControlConfig{Enabled: false, StoreID: "", ModelID: ""},
cacheSettings: serverconfig.NewDefaultCacheSettings(),
shadowCheckResolverTimeout: serverconfig.DefaultShadowCheckResolverTimeout,
shadowListObjectsQueryTimeout: serverconfig.DefaultShadowListObjectsQueryTimeout,
shadowListObjectsQueryMaxDeltaItems: serverconfig.DefaultShadowListObjectsQueryMaxDeltaItems,
requestDurationByQueryHistogramBuckets: []uint{50, 200},
requestDurationByDispatchCountHistogramBuckets: []uint{50, 200},
serviceName: openfgav1.OpenFGAService_ServiceDesc.ServiceName,
checkDispatchThrottlingEnabled: serverconfig.DefaultCheckDispatchThrottlingEnabled,
checkDispatchThrottlingFrequency: serverconfig.DefaultCheckDispatchThrottlingFrequency,
checkDispatchThrottlingDefaultThreshold: serverconfig.DefaultCheckDispatchThrottlingDefaultThreshold,
listObjectsDispatchThrottlingEnabled: serverconfig.DefaultListObjectsDispatchThrottlingEnabled,
listObjectsDispatchThrottlingFrequency: serverconfig.DefaultListObjectsDispatchThrottlingFrequency,
listObjectsDispatchDefaultThreshold: serverconfig.DefaultListObjectsDispatchThrottlingDefaultThreshold,
listObjectsDispatchThrottlingMaxThreshold: serverconfig.DefaultListObjectsDispatchThrottlingMaxThreshold,
listUsersDispatchThrottlingEnabled: serverconfig.DefaultListUsersDispatchThrottlingEnabled,
listUsersDispatchThrottlingFrequency: serverconfig.DefaultListUsersDispatchThrottlingFrequency,
listUsersDispatchDefaultThreshold: serverconfig.DefaultListUsersDispatchThrottlingDefaultThreshold,
listUsersDispatchThrottlingMaxThreshold: serverconfig.DefaultListUsersDispatchThrottlingMaxThreshold,
tokenSerializer: encoder.NewStringContinuationTokenSerializer(),
singleflightGroup: &singleflight.Group{},
authorizer: authz.NewAuthorizerNoop(),
planner: planner.New(&planner.Config{
EvictionThreshold: serverconfig.DefaultPlannerEvictionThreshold,
CleanupInterval: serverconfig.DefaultPlannerCleanupInterval,
}),
requestTimeout: serverconfig.DefaultRequestTimeout,
}
for _, opt := range opts {
opt(s)
}
if s.datastore == nil {
return nil, fmt.Errorf("a datastore option must be provided")
}
// ctx can be nil despite the default above if WithContext() was called
if s.ctx == nil {
return nil, fmt.Errorf("server cannot be started with nil context")
}
if len(s.requestDurationByQueryHistogramBuckets) == 0 {
return nil, fmt.Errorf("request duration datastore count buckets must not be empty")
}
if len(s.requestDurationByDispatchCountHistogramBuckets) == 0 {
return nil, fmt.Errorf("request duration by dispatch count buckets must not be empty")
}
if s.checkDispatchThrottlingEnabled && s.checkDispatchThrottlingMaxThreshold != 0 && s.checkDispatchThrottlingDefaultThreshold > s.checkDispatchThrottlingMaxThreshold {
return nil, fmt.Errorf("check default dispatch throttling threshold must be equal or smaller than max dispatch threshold for Check")
}
if s.listObjectsDispatchThrottlingMaxThreshold != 0 && s.listObjectsDispatchDefaultThreshold > s.listObjectsDispatchThrottlingMaxThreshold {
return nil, fmt.Errorf("ListObjects default dispatch throttling threshold must be equal or smaller than max dispatch threshold for ListObjects")
}
if s.listUsersDispatchThrottlingMaxThreshold != 0 && s.listUsersDispatchDefaultThreshold > s.listUsersDispatchThrottlingMaxThreshold {
return nil, fmt.Errorf("ListUsers default dispatch throttling threshold must be equal or smaller than max dispatch threshold for ListUsers")
}
if s.featureFlagClient == nil {
s.featureFlagClient = featureflags.NewDefaultClient(s.experimentals)
}
err := s.validateAccessControlEnabled()
if err != nil {
return nil, err
}
// below this point, don't throw errors or we may leak resources in tests
if !s.contextPropagationToDatastore {
// Creates a new [storagewrappers.ContextTracerWrapper] that will execute datastore queries using
// a new background context with the current trace context.
s.datastore = storagewrappers.NewContextWrapper(s.datastore)
}
s.datastore, err = storagewrappers.NewCachedOpenFGADatastore(s.datastore, s.maxAuthorizationModelCacheSize)
if err != nil {
return nil, err
}
s.sharedDatastoreResources, err = shared.NewSharedDatastoreResources(s.ctx, s.singleflightGroup, s.datastore, s.cacheSettings, []shared.SharedDatastoreResourcesOpt{shared.WithLogger(s.logger)}...)
if err != nil {
return nil, err
}
if s.listObjectsDispatchThrottlingEnabled {
s.listObjectsDispatchThrottler = throttler.NewConstantRateThrottler(s.listObjectsDispatchThrottlingFrequency, "list_objects_dispatch_throttle")
}
if s.listUsersDispatchThrottlingEnabled {
s.listUsersDispatchThrottler = throttler.NewConstantRateThrottler(s.listUsersDispatchThrottlingFrequency, "list_users_dispatch_throttle")
}
s.typesystemResolver, s.typesystemResolverStop, err = typesystem.MemoizedTypesystemResolverFunc(s.datastore)
if err != nil {
return nil, err
}
if s.IsAccessControlEnabled() {
s.authorizer = authz.NewAuthorizer(&authz.Config{StoreID: s.AccessControl.StoreID, ModelID: s.AccessControl.ModelID}, s, s.logger)
}
return s, nil
}
// Close releases the server resources.
func (s *Server) Close() {
if s.planner != nil {
s.planner.Stop()
}
s.typesystemResolverStop()
if s.listObjectsDispatchThrottler != nil {
s.listObjectsDispatchThrottler.Close()
}
if s.listUsersDispatchThrottler != nil {
s.listUsersDispatchThrottler.Close()
}
s.sharedDatastoreResources.Close()
s.datastore.Close()
}
// IsReady reports whether the datastore is ready. Please see the implementation of [[storage.OpenFGADatastore.IsReady]]
// for your datastore.
func (s *Server) IsReady(ctx context.Context) (bool, error) {
// for now we only depend on the datastore being ready, but in the future
// server readiness may also depend on other criteria in addition to the
// datastore being ready.
status, err := s.datastore.IsReady(ctx)
if err != nil {
return false, err
}
if status.IsReady {
return true, nil
}
s.logger.WarnWithContext(ctx, "datastore is not ready", zap.Any("status", status.Message))
return false, nil
}
// resolveTypesystem resolves the underlying TypeSystem given the storeID and modelID and
// it sets some response metadata based on the model resolution.
func (s *Server) resolveTypesystem(ctx context.Context, storeID, modelID string) (*typesystem.TypeSystem, error) {
parentSpan := trace.SpanFromContext(ctx)
typesys, err := s.typesystemResolver(ctx, storeID, modelID)
if err != nil {
if errors.Is(err, typesystem.ErrModelNotFound) {
if modelID == "" {
return nil, serverErrors.LatestAuthorizationModelNotFound(storeID)
}
return nil, serverErrors.AuthorizationModelNotFound(modelID)
}
if errors.Is(err, typesystem.ErrInvalidModel) {
return nil, serverErrors.ValidationError(err)
}
telemetry.TraceError(parentSpan, err)
err = serverErrors.HandleError("", err)
return nil, err
}
resolvedModelID := typesys.GetAuthorizationModelID()
parentSpan.SetAttributes(attribute.String(authorizationModelIDKey, resolvedModelID))
grpc_ctxtags.Extract(ctx).Set(authorizationModelIDKey, resolvedModelID)
s.transport.SetHeader(ctx, AuthorizationModelIDHeader, resolvedModelID)
return typesys, nil
}
// validateAccessControlEnabled validates the access control parameters.
func (s *Server) validateAccessControlEnabled() error {
if s.IsAccessControlEnabled() {
if (s.AccessControl == serverconfig.AccessControlConfig{} || s.AccessControl.StoreID == "" || s.AccessControl.ModelID == "") {
return fmt.Errorf("access control parameters are not enabled. They can be enabled for experimental use by passing the `--experimentals enable-access-control` configuration option when running OpenFGA server. Additionally, the `--access-control-store-id` and `--access-control-model-id` parameters must not be empty")
}
if s.AuthnMethod != "oidc" {
return fmt.Errorf("access control is enabled, but the authentication method is not OIDC. Access control is only supported with OIDC authentication")
}
_, err := ulid.Parse(s.AccessControl.StoreID)
if err != nil {
return fmt.Errorf("config '--access-control-store-id' must be a valid ULID")
}
_, err = ulid.Parse(s.AccessControl.ModelID)
if err != nil {
return fmt.Errorf("config '--access-control-model-id' must be a valid ULID")
}
}
return nil
}
// checkAuthz checks the authorization for calling an API method.
func (s *Server) checkAuthz(ctx context.Context, storeID string, apiMethod apimethod.APIMethod, modules ...string) error {
if authclaims.SkipAuthzCheckFromContext(ctx) {
return nil
}
err := s.authorizer.Authorize(ctx, storeID, apiMethod, modules...)
if err != nil {
s.logger.Info("authorization failed", zap.Error(err))
return authz.ErrUnauthorizedResponse
}
return nil
}
// checkCreateStoreAuthz checks the authorization for creating a store.
func (s *Server) checkCreateStoreAuthz(ctx context.Context) error {
if authclaims.SkipAuthzCheckFromContext(ctx) {
return nil
}
err := s.authorizer.AuthorizeCreateStore(ctx)
if err != nil {
s.logger.Info("authorization failed", zap.Error(err))
return authz.ErrUnauthorizedResponse
}
return nil
}
// getAccessibleStores checks whether the caller has permission to list stores and if so,
// returns the list of stores that the user has access to.
func (s *Server) getAccessibleStores(ctx context.Context) ([]string, error) {
if authclaims.SkipAuthzCheckFromContext(ctx) {
return nil, nil
}
err := s.authorizer.AuthorizeListStores(ctx)
if err != nil {
s.logger.Info("authorization failed", zap.Error(err))
return nil, authz.ErrUnauthorizedResponse
}
stores, err := s.authorizer.ListAuthorizedStores(ctx)
if err != nil {
s.logger.Info("authorization failed", zap.Error(err))
return nil, authz.ErrUnauthorizedResponse
}
return stores, nil
}
func (s *Server) getCheckResolverOptions() ([]graph.CachedCheckResolverOpt, []graph.DispatchThrottlingCheckResolverOpt) {
var checkCacheOptions []graph.CachedCheckResolverOpt
if s.cacheSettings.ShouldCacheCheckQueries() {
checkCacheOptions = append(checkCacheOptions,
graph.WithExistingCache(s.sharedDatastoreResources.CheckCache),
graph.WithLogger(s.logger),
graph.WithCacheTTL(s.cacheSettings.CheckQueryCacheTTL),
)
}
var checkDispatchThrottlingOptions []graph.DispatchThrottlingCheckResolverOpt
if s.checkDispatchThrottlingEnabled {
checkDispatchThrottlingOptions = []graph.DispatchThrottlingCheckResolverOpt{
graph.WithDispatchThrottlingCheckResolverConfig(graph.DispatchThrottlingCheckResolverConfig{
DefaultThreshold: s.checkDispatchThrottlingDefaultThreshold,
MaxThreshold: s.checkDispatchThrottlingMaxThreshold,
}),
// only create the throttler if the feature is enabled, so that we can clean it afterward
graph.WithConstantRateThrottler(s.checkDispatchThrottlingFrequency,
"check_dispatch_throttle"),
}
}
return checkCacheOptions, checkDispatchThrottlingOptions
}
// checkWriteAuthz checks the authorization for modules if they exist, otherwise the store on write requests.
func (s *Server) checkWriteAuthz(ctx context.Context, req *openfgav1.WriteRequest, typesys *typesystem.TypeSystem) error {
if authclaims.SkipAuthzCheckFromContext(ctx) {
return nil
}
modules, err := s.authorizer.GetModulesForWriteRequest(ctx, req, typesys)
if err != nil {
s.logger.Info("authorization failed", zap.Error(err))
return authz.ErrUnauthorizedResponse
}
return s.checkAuthz(ctx, req.GetStoreId(), apimethod.Write, modules...)
}
func (s *Server) emitCheckDurationMetric(checkMetadata graph.ResolveCheckResponseMetadata, caller string) {
checkDurationHistogram.WithLabelValues(
utils.Bucketize(uint(checkMetadata.DatastoreQueryCount), s.requestDurationByQueryHistogramBuckets),
caller,
).Observe(float64(checkMetadata.Duration.Milliseconds()))
}
package server
import (
"context"
"net/http"
"strconv"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/utils/apimethod"
httpmiddleware "github.com/openfga/openfga/pkg/middleware/http"
"github.com/openfga/openfga/pkg/middleware/validator"
"github.com/openfga/openfga/pkg/server/commands"
"github.com/openfga/openfga/pkg/telemetry"
)
func (s *Server) CreateStore(ctx context.Context, req *openfgav1.CreateStoreRequest) (*openfgav1.CreateStoreResponse, error) {
ctx, span := tracer.Start(ctx, apimethod.CreateStore.String())
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: apimethod.CreateStore.String(),
})
err := s.checkCreateStoreAuthz(ctx)
if err != nil {
return nil, err
}
c := commands.NewCreateStoreCommand(s.datastore, commands.WithCreateStoreCmdLogger(s.logger))
res, err := c.Execute(ctx, req)
if err != nil {
return nil, err
}
s.transport.SetHeader(ctx, httpmiddleware.XHttpCode, strconv.Itoa(http.StatusCreated))
return res, nil
}
func (s *Server) DeleteStore(ctx context.Context, req *openfgav1.DeleteStoreRequest) (*openfgav1.DeleteStoreResponse, error) {
ctx, span := tracer.Start(ctx, apimethod.DeleteStore.String(), trace.WithAttributes(
attribute.String("store_id", req.GetStoreId()),
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: apimethod.DeleteStore.String(),
})
err := s.checkAuthz(ctx, req.GetStoreId(), apimethod.DeleteStore)
if err != nil {
return nil, err
}
cmd := commands.NewDeleteStoreCommand(s.datastore, commands.WithDeleteStoreCmdLogger(s.logger))
res, err := cmd.Execute(ctx, req)
if err != nil {
return nil, err
}
s.transport.SetHeader(ctx, httpmiddleware.XHttpCode, strconv.Itoa(http.StatusNoContent))
return res, nil
}
func (s *Server) GetStore(ctx context.Context, req *openfgav1.GetStoreRequest) (*openfgav1.GetStoreResponse, error) {
ctx, span := tracer.Start(ctx, apimethod.GetStore.String(), trace.WithAttributes(
attribute.String("store_id", req.GetStoreId()),
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: apimethod.GetStore.String(),
})
err := s.checkAuthz(ctx, req.GetStoreId(), apimethod.GetStore)
if err != nil {
return nil, err
}
q := commands.NewGetStoreQuery(s.datastore, commands.WithGetStoreQueryLogger(s.logger))
return q.Execute(ctx, req)
}
func (s *Server) ListStores(ctx context.Context, req *openfgav1.ListStoresRequest) (*openfgav1.ListStoresResponse, error) {
method := "ListStores"
ctx, span := tracer.Start(ctx, method)
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: method,
})
storeIDs, err := s.getAccessibleStores(ctx)
if err != nil {
return nil, err
}
// even though we have the list of store IDs, we need to call ListStoresQuery to fetch the entire metadata of the store.
q := commands.NewListStoresQuery(s.datastore,
commands.WithListStoresQueryLogger(s.logger),
commands.WithListStoresQueryEncoder(s.encoder),
)
return q.Execute(ctx, req, storeIDs)
}
package server
import (
"context"
"strconv"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/utils/apimethod"
"github.com/openfga/openfga/pkg/authclaims"
"github.com/openfga/openfga/pkg/middleware/validator"
"github.com/openfga/openfga/pkg/server/commands"
"github.com/openfga/openfga/pkg/telemetry"
)
func (s *Server) Write(ctx context.Context, req *openfgav1.WriteRequest) (*openfgav1.WriteResponse, error) {
start := time.Now()
ctx, span := tracer.Start(ctx, apimethod.Write.String(), trace.WithAttributes(
attribute.String("store_id", req.GetStoreId()),
))
defer span.End()
if !validator.RequestIsValidatedFromContext(ctx) {
if err := req.Validate(); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
ctx = telemetry.ContextWithRPCInfo(ctx, telemetry.RPCInfo{
Service: s.serviceName,
Method: apimethod.Write.String(),
})
storeID := req.GetStoreId()
typesys, err := s.resolveTypesystem(ctx, storeID, req.GetAuthorizationModelId())
if err != nil {
return nil, err
}
err = s.checkWriteAuthz(ctx, req, typesys)
if err != nil {
return nil, err
}
cmd := commands.NewWriteCommand(
s.datastore,
commands.WithWriteCmdLogger(s.logger),
)
resp, err := cmd.Execute(ctx, &openfgav1.WriteRequest{
StoreId: storeID,
AuthorizationModelId: typesys.GetAuthorizationModelID(), // the resolved model id
Writes: req.GetWrites(),
Deletes: req.GetDeletes(),
})
// For now, we only measure the duration if it passes the authz step to make the comparison
// apple to apple.
writeDurationHistogram.WithLabelValues(
strconv.FormatBool(s.IsAccessControlEnabled() && !authclaims.SkipAuthzCheckFromContext(ctx)),
req.GetWrites().GetOnDuplicate(),
req.GetDeletes().GetOnMissing(),
).Observe(float64(time.Since(start).Milliseconds()))
return resp, err
}
//go:generate mockgen -source cache.go -destination ../../internal/mocks/mock_cache.go -package mocks cache
package storage
import (
"errors"
"fmt"
"io"
"sort"
"strconv"
"sync"
"time"
"github.com/Yiling-J/theine-go"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"google.golang.org/protobuf/types/known/structpb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/build"
"github.com/openfga/openfga/pkg/tuple"
)
var (
cacheItemCount = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: build.ProjectName,
Name: "cache_item_count",
Help: "The total number of items stored in the cache",
}, []string{"entity"})
cacheItemRemovedCount = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: "cache_item_removed_count",
Help: "The total number of items removed from the cache",
}, []string{"entity", "reason"})
)
const (
SubproblemCachePrefix = "sp."
iteratorCachePrefix = "ic."
changelogCachePrefix = "cc."
invalidIteratorCachePrefix = "iq."
defaultMaxCacheSize = 10000
oneYear = time.Hour * 24 * 365
removedLabel = "removed"
evictedLabel = "evicted"
expiredLabel = "expired"
unspecifiedLabel = "unspecified"
)
type CacheItem interface {
CacheEntityType() string
}
// InMemoryCache is a general purpose cache to store things in memory.
type InMemoryCache[T any] interface {
// Get If the key exists, returns the value. If the key didn't exist, returns nil.
Get(key string) T
Set(key string, value T, ttl time.Duration)
Delete(key string)
// Stop cleans resources.
Stop()
}
// Specific implementation
type InMemoryLRUCache[T any] struct {
client *theine.Cache[string, T]
maxElements int64
stopOnce *sync.Once
}
type InMemoryLRUCacheOpt[T any] func(i *InMemoryLRUCache[T])
func WithMaxCacheSize[T any](maxElements int64) InMemoryLRUCacheOpt[T] {
return func(i *InMemoryLRUCache[T]) {
i.maxElements = maxElements
}
}
var _ InMemoryCache[any] = (*InMemoryLRUCache[any])(nil)
func NewInMemoryLRUCache[T any](opts ...InMemoryLRUCacheOpt[T]) (*InMemoryLRUCache[T], error) {
t := &InMemoryLRUCache[T]{
maxElements: defaultMaxCacheSize,
stopOnce: &sync.Once{},
}
for _, opt := range opts {
opt(t)
}
cacheBuilder := theine.NewBuilder[string, T](t.maxElements)
cacheBuilder.RemovalListener(func(key string, value T, reason theine.RemoveReason) {
var (
reasonLabel string
entityLabel string
)
switch reason {
case theine.EVICTED:
reasonLabel = evictedLabel
case theine.EXPIRED:
reasonLabel = expiredLabel
case theine.REMOVED:
reasonLabel = removedLabel
default:
reasonLabel = unspecifiedLabel
}
if item, ok := any(value).(CacheItem); ok {
entityLabel = item.CacheEntityType()
} else {
entityLabel = unspecifiedLabel
}
cacheItemCount.WithLabelValues(entityLabel).Dec()
cacheItemRemovedCount.WithLabelValues(entityLabel, reasonLabel).Inc()
})
var err error
t.client, err = cacheBuilder.Build()
if err != nil {
return nil, err
}
return t, nil
}
func (i InMemoryLRUCache[T]) Get(key string) T {
var zero T
item, ok := i.client.Get(key)
if !ok {
return zero
}
return item
}
// Set will store the value during the ttl.
// Note that ttl is truncated to one year to avoid misinterpreted as negative value.
// Negative ttl are noop.
func (i InMemoryLRUCache[T]) Set(key string, value T, ttl time.Duration) {
if ttl >= oneYear {
ttl = oneYear
}
i.client.SetWithTTL(key, value, 1, ttl)
if item, ok := any(value).(CacheItem); ok {
cacheItemCount.WithLabelValues(item.CacheEntityType()).Inc()
} else {
cacheItemCount.WithLabelValues(unspecifiedLabel).Inc()
}
}
func (i InMemoryLRUCache[T]) Delete(key string) {
i.client.Delete(key)
}
func (i InMemoryLRUCache[T]) Stop() {
i.stopOnce.Do(func() {
i.client.Close()
})
}
var (
_ CacheItem = (*ChangelogCacheEntry)(nil)
_ CacheItem = (*InvalidEntityCacheEntry)(nil)
_ CacheItem = (*TupleIteratorCacheEntry)(nil)
)
type ChangelogCacheEntry struct {
LastModified time.Time
}
func (c *ChangelogCacheEntry) CacheEntityType() string {
return "changelog"
}
func GetChangelogCacheKey(storeID string) string {
return changelogCachePrefix + storeID
}
type InvalidEntityCacheEntry struct {
LastModified time.Time
}
func (i *InvalidEntityCacheEntry) CacheEntityType() string {
return "invalid_entity"
}
func GetInvalidIteratorCacheKey(storeID string) string {
return invalidIteratorCachePrefix + storeID
}
func GetInvalidIteratorByObjectRelationCacheKey(storeID, object, relation string) string {
return invalidIteratorCachePrefix + storeID + "-or/" + object + "#" + relation
}
func GetInvalidIteratorByUserObjectTypeCacheKeys(storeID string, users []string, objectType string) []string {
res := make([]string, len(users))
var i int
for _, user := range users {
res[i] = invalidIteratorCachePrefix + storeID + "-otr/" + user + "|" + objectType
i++
}
return res
}
type TupleIteratorCacheEntry struct {
Tuples []*TupleRecord
LastModified time.Time
}
func (t *TupleIteratorCacheEntry) CacheEntityType() string {
return "tuple_iterator"
}
func GetReadUsersetTuplesCacheKeyPrefix(store, object, relation string) string {
return iteratorCachePrefix + "rut/" + store + "/" + object + "#" + relation
}
func GetReadStartingWithUserCacheKeyPrefix(store, objectType, relation string) string {
return iteratorCachePrefix + "rtwu/" + store + "/" + objectType + "#" + relation
}
func GetReadCacheKey(store, tuple string) string {
return iteratorCachePrefix + "r/" + store + "/" + tuple
}
// ErrUnexpectedStructValue is an error used to indicate that
// an unexpected structpb.Value kind was encountered.
var ErrUnexpectedStructValue = errors.New("unexpected structpb value encountered")
// writeValue writes value v to the writer w. An error
// is returned only when the underlying writer returns
// an error or an unexpected value kind is encountered.
func writeValue(w io.StringWriter, v *structpb.Value) (err error) {
switch val := v.GetKind().(type) {
case *structpb.Value_BoolValue:
_, err = w.WriteString(strconv.FormatBool(val.BoolValue))
case *structpb.Value_NullValue:
_, err = w.WriteString("null")
case *structpb.Value_StringValue:
_, err = w.WriteString(val.StringValue)
case *structpb.Value_NumberValue:
_, err = w.WriteString(strconv.FormatFloat(val.NumberValue, 'f', -1, 64)) // -1 precision ensures we represent the 64-bit value with the maximum precision needed to represent it, see strconv#FormatFloat for more info.
case *structpb.Value_ListValue:
values := val.ListValue.GetValues()
for n, vv := range values {
if err = writeValue(w, vv); err != nil {
return
}
if n < len(values)-1 {
if _, err = w.WriteString(","); err != nil {
return
}
}
}
case *structpb.Value_StructValue:
err = writeStruct(w, val.StructValue)
default:
err = ErrUnexpectedStructValue
}
return
}
// keys accepts a map m and returns a slice of its keys.
// When this project is updated to Go version 1.23 or greater,
// `maps.Keys` should be preferred.
func keys[T comparable, U any](m map[T]U) []T {
n := make([]T, len(m))
var i int
for k := range m {
n[i] = k
i++
}
return n
}
// writeStruct writes Struct value s to writer w. When s is nil, a
// nil error is returned. An error is returned only when the underlying
// writer returns an error. The struct fields are written in the sorted
// order of their names. A comma separates fields.
func writeStruct(w io.StringWriter, s *structpb.Struct) (err error) {
if s == nil {
return
}
fields := s.GetFields()
keys := keys(fields)
sort.Strings(keys)
for _, key := range keys {
if _, err = w.WriteString(fmt.Sprintf("'%s:'", key)); err != nil {
return
}
if err = writeValue(w, fields[key]); err != nil {
return
}
if _, err = w.WriteString(","); err != nil {
return
}
}
return
}
// writeTuples writes the set of tuples to writer w in ascending sorted order.
// The intention of this function is to write the tuples as a unique string.
// Tuples are separated by commas, and when present, conditions are included
// in the tuple string representation. Returns an error only when
// the underlying writer returns an error.
func writeTuples(w io.StringWriter, tuples ...*openfgav1.TupleKey) (err error) {
sortedTuples := make(tuple.TupleKeys, len(tuples))
// copy tuples slice to avoid mutating the original slice during sorting.
copy(sortedTuples, tuples)
// sort tulpes for a deterministic write
sort.Sort(sortedTuples)
// prefix to avoid overlap with previous strings written
_, err = w.WriteString("/")
if err != nil {
return
}
for n, tupleKey := range sortedTuples {
_, err = w.WriteString(tupleKey.GetObject() + "#" + tupleKey.GetRelation())
if err != nil {
return
}
cond := tupleKey.GetCondition()
if cond != nil {
// " with " is separated by spaces as those are invalid in relation names
// and we need to ensure this cache key is unique
// resultant cache key format is "object:object_id#relation with {condition} {context}@user:user_id"
_, err = w.WriteString(" with " + cond.GetName())
if err != nil {
return
}
// if the condition also has context, we need an additional separator
// which cannot be present in condition names
if cond.GetContext() != nil {
_, err = w.WriteString(" ")
if err != nil {
return
}
}
// now write context to hash. Is a noop if context is nil.
if err = writeStruct(w, cond.GetContext()); err != nil {
return
}
}
if _, err = w.WriteString("@" + tupleKey.GetUser()); err != nil {
return
}
if n < len(tuples)-1 {
if _, err = w.WriteString(","); err != nil {
return
}
}
}
return
}
// CheckCacheKeyParams is all the necessary pieces to create a unique-per-check cache key.
type CheckCacheKeyParams struct {
StoreID string
AuthorizationModelID string
TupleKey *openfgav1.TupleKey
ContextualTuples []*openfgav1.TupleKey
Context *structpb.Struct
}
// WriteCheckCacheKey converts the elements of a Check into a canonical cache key that can be
// used for Check resolution cache key lookups in a stable way, and writes it to the provided writer.
//
// For one store and model ID, the same tuple provided with the same contextual tuples and context
// should produce the same cache key. Contextual tuple order and context parameter order is ignored,
// only the contents are compared.
func WriteCheckCacheKey(w io.StringWriter, params *CheckCacheKeyParams) error {
t := tuple.From(params.TupleKey)
_, err := w.WriteString(t.String())
if err != nil {
return err
}
err = WriteInvariantCheckCacheKey(w, params)
if err != nil {
return err
}
return nil
}
func WriteInvariantCheckCacheKey(w io.StringWriter, params *CheckCacheKeyParams) error {
_, err := w.WriteString(
" " + // space to separate from user in the TupleCacheKey, where spaces cannot be present
SubproblemCachePrefix +
params.StoreID +
"/" +
params.AuthorizationModelID,
)
if err != nil {
return err
}
// here, and for context below, avoid hashing if we don't need to
if len(params.ContextualTuples) > 0 {
if err = writeTuples(w, params.ContextualTuples...); err != nil {
return err
}
}
if params.Context != nil {
if err = writeStruct(w, params.Context); err != nil {
return err
}
}
return nil
}
package storage
import (
"errors"
"fmt"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/tuple"
)
var (
// ErrCollision is returned when an item already exists within the store.
ErrCollision = errors.New("item already exists")
// ErrInvalidContinuationToken is returned when the continuation token is invalid.
ErrInvalidContinuationToken = errors.New("invalid continuation token")
// ErrInvalidStartTime is returned when start time param for ReadChanges API is invalid.
ErrInvalidStartTime = errors.New("invalid start time")
// ErrInvalidWriteInput is returned when the tuple to be written
// already existed or the tuple to be deleted did not exist.
ErrInvalidWriteInput = errors.New("tuple to be written already existed or the tuple to be deleted did not exist")
// ErrWriteConflictOnInsert is returned when two writes attempt to insert the same tuple at the same time.
ErrWriteConflictOnInsert = fmt.Errorf("%w: one or more tuples to write were inserted by another transaction", ErrTransactionalWriteFailed)
// ErrWriteConflictOnDelete is returned when two writes attempt to delete the same tuple at the same time.
ErrWriteConflictOnDelete = fmt.Errorf("%w: one or more tuples to delete were deleted by another transaction", ErrTransactionalWriteFailed)
// ErrTransactionalWriteFailed is returned when two writes attempt to write the same tuple at the same time.
ErrTransactionalWriteFailed = errors.New("transactional write failed due to conflict")
// ErrTransactionThrottled is returned when throttling is applied at the datastore level.
ErrTransactionThrottled = errors.New("transaction throttled")
// ErrNotFound is returned when the object does not exist.
ErrNotFound = errors.New("not found")
)
// InvalidWriteInputError generates an error for invalid operations in a tuple store.
// This function is invoked when an attempt is made to write or delete a tuple with invalid conditions.
// Specifically, it addresses two scenarios:
// 1. Attempting to delete a non-existent tuple.
// 2. Attempting to write a tuple that already exists.
func InvalidWriteInputError(tk tuple.TupleWithoutCondition, operation openfgav1.TupleOperation) error {
switch operation {
case openfgav1.TupleOperation_TUPLE_OPERATION_DELETE:
return fmt.Errorf(
"cannot delete a tuple which does not exist: user: '%s', relation: '%s', object: '%s': %w",
tk.GetUser(),
tk.GetRelation(),
tk.GetObject(),
ErrInvalidWriteInput,
)
case openfgav1.TupleOperation_TUPLE_OPERATION_WRITE:
return fmt.Errorf(
"cannot write a tuple which already exists: user: '%s', relation: '%s', object: '%s': %w",
tk.GetUser(),
tk.GetRelation(),
tk.GetObject(),
ErrInvalidWriteInput,
)
default:
return nil
}
}
func TupleConditionConflictError(tk tuple.TupleWithoutCondition) error {
return fmt.Errorf(
"%w: attempted to write a tuple which already exists with a different condition: user: '%s', relation: '%s', object: '%s'",
ErrTransactionalWriteFailed, // mapped to 409 Conflict in the API layer
tk.GetUser(),
tk.GetRelation(),
tk.GetObject(),
)
}
package memory
import (
"context"
"fmt"
"slices"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/oklog/ulid/v2"
"go.opentelemetry.io/otel"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/telemetry"
tupleUtils "github.com/openfga/openfga/pkg/tuple"
)
var tracer = otel.Tracer("openfga/pkg/storage/memory")
type staticIterator struct {
records []*storage.TupleRecord
continuationToken string
mu sync.Mutex
}
// match returns true if all the fields in t [*storage.TupleRecord] are equal to
// the same field in the target [*openfgav1.TupleKey]. If the input Object
// doesn't specify an ID, only the Object Types are compared. If a field
// in the input parameter is empty, it is ignored in the comparison.
func match(t *storage.TupleRecord, target *openfgav1.TupleKey) bool {
if target.GetObject() != "" {
td, objectid := tupleUtils.SplitObject(target.GetObject())
if objectid == "" {
if td != t.ObjectType {
return false
}
} else {
if td != t.ObjectType || objectid != t.ObjectID {
return false
}
}
}
if target.GetRelation() != "" && t.Relation != target.GetRelation() {
return false
}
if target.GetUser() != "" {
userType, userID, _ := tupleUtils.ToUserParts(target.GetUser())
if userID != "" && t.User != target.GetUser() {
return false
} else if userID == "" && !strings.HasPrefix(t.User, userType+":") {
return false
}
}
return true
}
// Next see [storage.Iterator].Next.
func (s *staticIterator) Next(ctx context.Context) (*openfgav1.Tuple, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mu.Lock()
defer s.mu.Unlock()
if len(s.records) == 0 {
return nil, storage.ErrIteratorDone
}
next, rest := s.records[0], s.records[1:]
s.records = rest
return next.AsTuple(), nil
}
// Stop does not do anything for staticIterator.
func (s *staticIterator) Stop() {}
// Head see [storage.Iterator].Next.
func (s *staticIterator) Head(ctx context.Context) (*openfgav1.Tuple, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mu.Lock()
defer s.mu.Unlock()
if len(s.records) == 0 {
return nil, storage.ErrIteratorDone
}
rec := s.records[0]
return rec.AsTuple(), nil
}
// ToArray converts the entire sequence of tuples in the staticIterator to an array format.
func (s *staticIterator) ToArray(ctx context.Context) ([]*openfgav1.Tuple, string, error) {
var res []*openfgav1.Tuple
for range s.records {
t, err := s.Next(ctx)
if err != nil {
return nil, "", err
}
res = append(res, t)
}
return res, s.continuationToken, nil
}
// StorageOption defines a function type used for configuring a [MemoryBackend] instance.
type StorageOption func(dataStore *MemoryBackend)
const (
defaultMaxTuplesPerWrite = 100
defaultMaxTypesPerAuthorizationModel = 100
)
// MemoryBackend provides an ephemeral memory-backed implementation of [storage.OpenFGADatastore].
// These instances may be safely shared by multiple go-routines.
type MemoryBackend struct {
maxTuplesPerWrite int
maxTypesPerAuthorizationModel int
// TupleBackend
// map: store => set of tuples
tuples map[string][]*storage.TupleRecord // GUARDED_BY(mutexTuples).
mutexTuples sync.RWMutex
// ChangelogBackend
// map: store => set of changes
changes map[string][]*tupleChangeRec // GUARDED_BY(mutexTuples).
// AuthorizationModelBackend
// map: store = > map: type definition id => type definition
authorizationModels map[string]map[string]*AuthorizationModelEntry // GUARDED_BY(mutexModels).
mutexModels sync.RWMutex
// map: store id => store data
stores map[string]*openfgav1.Store // GUARDED_BY(mutexStores).
mutexStores sync.RWMutex
// map: store id | authz model id => assertions
assertions map[string][]*openfgav1.Assertion // GUARDED_BY(mutexAssertions).
mutexAssertions sync.RWMutex
}
// Ensures that [MemoryBackend] implements the [storage.OpenFGADatastore] interface.
var _ storage.OpenFGADatastore = (*MemoryBackend)(nil)
// AuthorizationModelEntry represents an entry in a storage system
// that holds information about an authorization model.
type AuthorizationModelEntry struct {
model *openfgav1.AuthorizationModel
latest bool
}
// New creates a new [MemoryBackend] given the options.
func New(opts ...StorageOption) storage.OpenFGADatastore {
ds := &MemoryBackend{
maxTuplesPerWrite: defaultMaxTuplesPerWrite,
maxTypesPerAuthorizationModel: defaultMaxTypesPerAuthorizationModel,
tuples: make(map[string][]*storage.TupleRecord, 0),
changes: make(map[string][]*tupleChangeRec, 0),
authorizationModels: make(map[string]map[string]*AuthorizationModelEntry),
stores: make(map[string]*openfgav1.Store, 0),
assertions: make(map[string][]*openfgav1.Assertion, 0),
}
for _, opt := range opts {
opt(ds)
}
return ds
}
// WithMaxTuplesPerWrite returns a [StorageOption] that sets the maximum number of tuples allowed in a single write operation.
// This option is used to configure a [MemoryBackend] instance, providing a limit to the number of tuples that can be written at once.
// This helps in managing and controlling the load and performance of the memory storage during bulk write operations.
func WithMaxTuplesPerWrite(n int) StorageOption {
return func(ds *MemoryBackend) { ds.maxTuplesPerWrite = n }
}
// WithMaxTypesPerAuthorizationModel returns a [StorageOption] that sets the maximum number of types allowed per authorization model.
// This configuration is particularly useful for limiting the complexity or size of an authorization model in a MemoryBackend instance,
// ensuring that models remain manageable and within predefined resource constraints.
func WithMaxTypesPerAuthorizationModel(n int) StorageOption {
return func(ds *MemoryBackend) { ds.maxTypesPerAuthorizationModel = n }
}
// Close does not do anything for [MemoryBackend].
func (s *MemoryBackend) Close() {}
// Read see [storage.RelationshipTupleReader].Read.
func (s *MemoryBackend) Read(ctx context.Context, store string, filter storage.ReadFilter, _ storage.ReadOptions) (storage.TupleIterator, error) {
ctx, span := tracer.Start(ctx, "memory.Read")
defer span.End()
return s.read(ctx, store, filter, nil)
}
// ReadPage see [storage.RelationshipTupleReader].ReadPage.
func (s *MemoryBackend) ReadPage(ctx context.Context, store string, filter storage.ReadFilter, options storage.ReadPageOptions) ([]*openfgav1.Tuple, string, error) {
ctx, span := tracer.Start(ctx, "memory.ReadPage")
defer span.End()
it, err := s.read(ctx, store, filter, &options)
if err != nil {
return nil, "", err
}
return it.ToArray(ctx)
}
// ReadChanges see [storage.ChangelogBackend].ReadChanges.
func (s *MemoryBackend) ReadChanges(ctx context.Context, store string, filter storage.ReadChangesFilter, options storage.ReadChangesOptions) ([]*openfgav1.TupleChange, string, error) {
_, span := tracer.Start(ctx, "memory.ReadChanges")
defer span.End()
s.mutexTuples.RLock()
defer s.mutexTuples.RUnlock()
var from *ulid.ULID
if options.Pagination.From != "" {
parsed, err := ulid.Parse(options.Pagination.From)
if err != nil {
return nil, "", storage.ErrInvalidContinuationToken
}
from = &parsed
}
objectType := filter.ObjectType
horizonOffset := filter.HorizonOffset
var allChanges []*tupleChangeRec
now := time.Now().UTC()
for _, changeRec := range s.changes[store] {
if objectType == "" || (strings.HasPrefix(changeRec.Change.GetTupleKey().GetObject(), objectType+":")) {
if changeRec.Change.GetTimestamp().AsTime().After(now.Add(-horizonOffset)) {
break
}
if from != nil {
if !options.SortDesc && changeRec.Ulid.Compare(*from) <= 0 {
continue
} else if options.SortDesc && changeRec.Ulid.Compare(*from) >= 0 {
continue
}
}
allChanges = append(allChanges, changeRec)
}
}
if len(allChanges) == 0 {
return nil, "", storage.ErrNotFound
}
pageSize := storage.DefaultPageSize
if options.Pagination.PageSize > 0 {
pageSize = options.Pagination.PageSize
}
if options.SortDesc {
slices.Reverse(allChanges)
}
to := pageSize
if len(allChanges) < to {
to = len(allChanges)
}
if to == 0 {
return nil, "", storage.ErrNotFound
}
res := make([]*openfgav1.TupleChange, 0, to)
var last ulid.ULID
for _, change := range allChanges[:to] {
res = append(res, change.Change)
last = change.Ulid
}
return res, last.String(), nil
}
// read returns an iterator of a store's tuples with a given tuple as filter.
// A nil paginationOptions input means the returned iterator will iterate through all values.
func (s *MemoryBackend) read(ctx context.Context, store string, filter storage.ReadFilter, options *storage.ReadPageOptions) (*staticIterator, error) {
_, span := tracer.Start(ctx, "memory.read")
defer span.End()
s.mutexTuples.RLock()
defer s.mutexTuples.RUnlock()
var matches []*storage.TupleRecord
if filter.Object == "" && filter.Relation == "" && filter.User == "" {
matches = make([]*storage.TupleRecord, len(s.tuples[store]))
copy(matches, s.tuples[store])
} else {
for _, t := range s.tuples[store] {
if match(t, &openfgav1.TupleKey{
Object: filter.Object,
Relation: filter.Relation,
User: filter.User,
}) && (len(filter.Conditions) == 0 || slices.Contains(filter.Conditions, t.ConditionName)) {
matches = append(matches, t)
}
}
}
var err error
var from int
if options != nil && options.Pagination.From != "" {
from, err = strconv.Atoi(options.Pagination.From)
if err != nil {
telemetry.TraceError(span, err)
return nil, err
}
}
if from <= len(matches) {
matches = matches[from:]
}
to := 0 // fetch everything
if options != nil {
to = options.Pagination.PageSize
}
if to != 0 && to < len(matches) {
return &staticIterator{records: matches[:to], continuationToken: strconv.Itoa(from + to)}, nil
}
return &staticIterator{records: matches}, nil
}
type tupleChangeRec struct {
Change *openfgav1.TupleChange
Ulid ulid.ULID
}
// Write see [storage.RelationshipTupleWriter].Write.
func (s *MemoryBackend) Write(ctx context.Context, store string, deletes storage.Deletes, writes storage.Writes, opts ...storage.TupleWriteOption) error {
_, span := tracer.Start(ctx, "memory.Write")
defer span.End()
s.mutexTuples.Lock()
defer s.mutexTuples.Unlock()
now := timestamppb.Now()
duplicateDeletes, _, err := sanitizeTuplesWriteDelete(s.tuples[store], deletes, writes, storage.NewTupleWriteOptions(opts...))
if err != nil {
return err
}
var records []*storage.TupleRecord
entropy := ulid.DefaultEntropy()
Delete:
for _, tr := range s.tuples[store] {
t := tr.AsTuple()
tk := t.GetKey()
for i, k := range deletes {
if match(tr, tupleUtils.TupleKeyWithoutConditionToTupleKey(k)) {
if slices.Contains(duplicateDeletes, i) {
// noop for duplicate delete
continue
}
s.changes[store] = append(
s.changes[store],
&tupleChangeRec{
Change: &openfgav1.TupleChange{
TupleKey: tupleUtils.NewTupleKey(tk.GetObject(), tk.GetRelation(), tk.GetUser()), // Redact the condition info.
Operation: openfgav1.TupleOperation_TUPLE_OPERATION_DELETE,
Timestamp: now,
},
Ulid: ulid.MustNew(ulid.Timestamp(now.AsTime()), entropy),
},
)
continue Delete
}
}
records = append(records, tr)
}
Write:
for _, t := range writes {
for _, et := range records {
if match(et, t) {
// notice we don't need to assert for duplicateWrites because the fact that we match,
// and it satisfies sanitizeTuplesWriteDelete means that it is a valid duplicate write.
continue Write
}
}
var conditionName string
var conditionContext *structpb.Struct
if condition := t.GetCondition(); condition != nil {
conditionName = condition.GetName()
conditionContext = condition.GetContext()
}
objectType, objectID := tupleUtils.SplitObject(t.GetObject())
records = append(records, &storage.TupleRecord{
Store: store,
ObjectType: objectType,
ObjectID: objectID,
Relation: t.GetRelation(),
User: t.GetUser(),
ConditionName: conditionName,
ConditionContext: conditionContext,
Ulid: ulid.MustNew(ulid.Timestamp(now.AsTime()), ulid.DefaultEntropy()).String(),
InsertedAt: now.AsTime(),
})
tk := tupleUtils.NewTupleKeyWithCondition(
tupleUtils.BuildObject(objectType, objectID),
t.GetRelation(),
t.GetUser(),
conditionName,
conditionContext,
)
s.changes[store] = append(s.changes[store], &tupleChangeRec{
Change: &openfgav1.TupleChange{
TupleKey: tk,
Operation: openfgav1.TupleOperation_TUPLE_OPERATION_WRITE,
Timestamp: now,
},
Ulid: ulid.MustNew(ulid.Timestamp(now.AsTime()), entropy),
})
}
s.tuples[store] = records
return nil
}
func sanitizeTuplesWriteDelete(
records []*storage.TupleRecord,
deletes []*openfgav1.TupleKeyWithoutCondition,
writes []*openfgav1.TupleKey,
opts storage.TupleWriteOptions,
) ([]int, []int, error) {
var duplicateDeletes []int
var duplicateWrites []int
for i, tk := range deletes {
if find(records, tupleUtils.TupleKeyWithoutConditionToTupleKey(tk)) == nil {
if opts.OnMissingDelete == storage.OnMissingDeleteIgnore {
duplicateDeletes = append(duplicateDeletes, i)
continue
}
return nil, nil, storage.InvalidWriteInputError(tk, openfgav1.TupleOperation_TUPLE_OPERATION_DELETE)
}
}
for i, tk := range writes {
record := find(records, tk)
if record != nil {
if opts.OnDuplicateInsert == storage.OnDuplicateInsertIgnore {
// need to validate against condition and context
if record.ConditionName == tk.GetCondition().GetName() && record.ConditionContext.String() == tk.GetCondition().GetContext().String() {
duplicateWrites = append(duplicateWrites, i)
continue
}
return nil, nil, storage.TupleConditionConflictError(tk)
}
return nil, nil, storage.InvalidWriteInputError(tk, openfgav1.TupleOperation_TUPLE_OPERATION_WRITE)
}
}
return duplicateDeletes, duplicateWrites, nil
}
// find returns tuple if *storage.TupleRecord [*storage.TupleRecord] returns true.
// Return nil otherwise.
func find(records []*storage.TupleRecord, tupleKey *openfgav1.TupleKey) *storage.TupleRecord {
for _, tr := range records {
if match(tr, tupleKey) {
return tr
}
}
return nil
}
// ReadUserTuple see [storage.RelationshipTupleReader].ReadUserTuple.
func (s *MemoryBackend) ReadUserTuple(ctx context.Context, store string, key *openfgav1.TupleKey, _ storage.ReadUserTupleOptions) (*openfgav1.Tuple, error) {
_, span := tracer.Start(ctx, "memory.ReadUserTuple")
defer span.End()
s.mutexTuples.RLock()
defer s.mutexTuples.RUnlock()
for _, t := range s.tuples[store] {
if match(t, key) {
return t.AsTuple(), nil
}
}
telemetry.TraceError(span, storage.ErrNotFound)
return nil, storage.ErrNotFound
}
// ReadUsersetTuples see [storage.RelationshipTupleReader].ReadUsersetTuples.
func (s *MemoryBackend) ReadUsersetTuples(
ctx context.Context,
store string,
filter storage.ReadUsersetTuplesFilter,
_ storage.ReadUsersetTuplesOptions,
) (storage.TupleIterator, error) {
_, span := tracer.Start(ctx, "memory.ReadUsersetTuples")
defer span.End()
s.mutexTuples.RLock()
defer s.mutexTuples.RUnlock()
var matches []*storage.TupleRecord
for _, t := range s.tuples[store] {
if match(t, &openfgav1.TupleKey{
Object: filter.Object,
Relation: filter.Relation,
}) && tupleUtils.GetUserTypeFromUser(t.User) == tupleUtils.UserSet {
if len(filter.AllowedUserTypeRestrictions) == 0 { // 1.0 model.
matches = append(matches, t)
continue
}
// 1.1 model: see if the tuple found is of an allowed type.
userType := tupleUtils.GetType(t.User)
_, userRelation := tupleUtils.SplitObjectRelation(t.User)
for _, allowedType := range filter.AllowedUserTypeRestrictions {
if allowedType.GetType() == userType && allowedType.GetRelation() == userRelation {
matches = append(matches, t)
continue
}
}
if len(filter.Conditions) > 0 && !slices.Contains(filter.Conditions, t.ConditionName) {
continue
}
}
}
return &staticIterator{records: matches}, nil
}
// ReadStartingWithUser see [storage.RelationshipTupleReader].ReadStartingWithUser.
func (s *MemoryBackend) ReadStartingWithUser(
ctx context.Context,
store string,
filter storage.ReadStartingWithUserFilter,
options storage.ReadStartingWithUserOptions,
) (storage.TupleIterator, error) {
_, span := tracer.Start(ctx, "memory.ReadStartingWithUser")
defer span.End()
s.mutexTuples.RLock()
defer s.mutexTuples.RUnlock()
var matches []*storage.TupleRecord
for _, t := range s.tuples[store] {
if t.ObjectType != filter.ObjectType {
continue
}
if t.Relation != filter.Relation {
continue
}
if filter.ObjectIDs != nil && !filter.ObjectIDs.Exists(t.ObjectID) {
continue
}
if len(filter.Conditions) > 0 && !slices.Contains(filter.Conditions, t.ConditionName) {
continue
}
for _, userFilter := range filter.UserFilter {
targetUser := userFilter.GetObject()
if userFilter.GetRelation() != "" {
targetUser = tupleUtils.GetObjectRelationAsString(userFilter)
}
if targetUser != t.User {
continue
}
matches = append(matches, t)
}
}
sort.Slice(matches, func(i, j int) bool {
return matches[i].ObjectID < matches[j].ObjectID
})
return &staticIterator{records: matches}, nil
}
func findAuthorizationModelByID(
id string,
configurations map[string]*AuthorizationModelEntry,
) (*openfgav1.AuthorizationModel, bool) {
if id != "" {
if entry, ok := configurations[id]; ok {
return entry.model, true
}
return nil, false
}
for _, entry := range configurations {
if entry.latest {
return entry.model, true
}
}
return nil, false
}
// ReadAuthorizationModel see [storage.AuthorizationModelReadBackend].ReadAuthorizationModel.
func (s *MemoryBackend) ReadAuthorizationModel(
ctx context.Context,
store string,
id string,
) (*openfgav1.AuthorizationModel, error) {
_, span := tracer.Start(ctx, "memory.ReadAuthorizationModel")
defer span.End()
s.mutexModels.RLock()
defer s.mutexModels.RUnlock()
tm, ok := s.authorizationModels[store]
if !ok {
telemetry.TraceError(span, storage.ErrNotFound)
return nil, storage.ErrNotFound
}
if model, ok := findAuthorizationModelByID(id, tm); ok {
if model.GetTypeDefinitions() == nil || len(model.GetTypeDefinitions()) == 0 {
return nil, storage.ErrNotFound
}
return model, nil
}
telemetry.TraceError(span, storage.ErrNotFound)
return nil, storage.ErrNotFound
}
// ReadAuthorizationModels see [storage.AuthorizationModelReadBackend].ReadAuthorizationModels.
func (s *MemoryBackend) ReadAuthorizationModels(ctx context.Context, store string, options storage.ReadAuthorizationModelsOptions) ([]*openfgav1.AuthorizationModel, string, error) {
_, span := tracer.Start(ctx, "memory.ReadAuthorizationModels")
defer span.End()
s.mutexModels.RLock()
defer s.mutexModels.RUnlock()
models := make([]*openfgav1.AuthorizationModel, 0, len(s.authorizationModels[store]))
for _, entry := range s.authorizationModels[store] {
models = append(models, entry.model)
}
// From newest to oldest.
sort.Slice(models, func(i, j int) bool {
return models[i].GetId() > models[j].GetId()
})
var from int64
continuationToken := ""
var err error
pageSize := storage.DefaultPageSize
if options.Pagination.PageSize > 0 {
pageSize = options.Pagination.PageSize
}
if options.Pagination.From != "" {
from, err = strconv.ParseInt(options.Pagination.From, 10, 32)
if err != nil {
return nil, "", err
}
}
to := int(from) + pageSize
if len(models) < to {
to = len(models)
}
res := models[from:to]
if to != len(models) {
continuationToken = strconv.Itoa(to)
}
return res, continuationToken, nil
}
// FindLatestAuthorizationModel see [storage.AuthorizationModelReadBackend].FindLatestAuthorizationModel.
func (s *MemoryBackend) FindLatestAuthorizationModel(ctx context.Context, store string) (*openfgav1.AuthorizationModel, error) {
_, span := tracer.Start(ctx, "memory.FindLatestAuthorizationModel")
defer span.End()
s.mutexModels.RLock()
defer s.mutexModels.RUnlock()
tm, ok := s.authorizationModels[store]
if !ok {
telemetry.TraceError(span, storage.ErrNotFound)
return nil, storage.ErrNotFound
}
// Find latest model.
nsc, ok := findAuthorizationModelByID("", tm)
if !ok {
telemetry.TraceError(span, storage.ErrNotFound)
return nil, storage.ErrNotFound
}
return nsc, nil
}
// WriteAuthorizationModel see [storage.TypeDefinitionWriteBackend].WriteAuthorizationModel.
func (s *MemoryBackend) WriteAuthorizationModel(ctx context.Context, store string, model *openfgav1.AuthorizationModel) error {
_, span := tracer.Start(ctx, "memory.WriteAuthorizationModel")
defer span.End()
s.mutexModels.Lock()
defer s.mutexModels.Unlock()
if _, ok := s.authorizationModels[store]; !ok {
s.authorizationModels[store] = make(map[string]*AuthorizationModelEntry)
}
for _, entry := range s.authorizationModels[store] {
entry.latest = false
}
s.authorizationModels[store][model.GetId()] = &AuthorizationModelEntry{
model: model,
latest: true,
}
return nil
}
// CreateStore adds a new store to the [MemoryBackend].
func (s *MemoryBackend) CreateStore(ctx context.Context, newStore *openfgav1.Store) (*openfgav1.Store, error) {
_, span := tracer.Start(ctx, "memory.CreateStore")
defer span.End()
s.mutexStores.Lock()
defer s.mutexStores.Unlock()
if _, ok := s.stores[newStore.GetId()]; ok {
return nil, storage.ErrCollision
}
now := timestamppb.New(time.Now().UTC())
s.stores[newStore.GetId()] = &openfgav1.Store{
Id: newStore.GetId(),
Name: newStore.GetName(),
CreatedAt: now,
UpdatedAt: now,
}
return s.stores[newStore.GetId()], nil
}
// DeleteStore removes a store from the [MemoryBackend].
func (s *MemoryBackend) DeleteStore(ctx context.Context, id string) error {
_, span := tracer.Start(ctx, "memory.DeleteStore")
defer span.End()
s.mutexStores.Lock()
defer s.mutexStores.Unlock()
delete(s.stores, id)
return nil
}
// WriteAssertions see [storage.AssertionsBackend].WriteAssertions.
func (s *MemoryBackend) WriteAssertions(ctx context.Context, store, modelID string, assertions []*openfgav1.Assertion) error {
_, span := tracer.Start(ctx, "memory.WriteAssertions")
defer span.End()
s.mutexAssertions.Lock()
defer s.mutexAssertions.Unlock()
assertionsID := fmt.Sprintf("%s|%s", store, modelID)
s.assertions[assertionsID] = assertions
return nil
}
// ReadAssertions see [storage.AssertionsBackend].ReadAssertions.
func (s *MemoryBackend) ReadAssertions(ctx context.Context, store, modelID string) ([]*openfgav1.Assertion, error) {
_, span := tracer.Start(ctx, "memory.ReadAssertions")
defer span.End()
s.mutexAssertions.RLock()
defer s.mutexAssertions.RUnlock()
assertionsID := fmt.Sprintf("%s|%s", store, modelID)
assertions, ok := s.assertions[assertionsID]
if !ok {
return []*openfgav1.Assertion{}, nil
}
return assertions, nil
}
// MaxTuplesPerWrite see [storage.RelationshipTupleWriter].MaxTuplesPerWrite.
func (s *MemoryBackend) MaxTuplesPerWrite() int {
return s.maxTuplesPerWrite
}
// MaxTypesPerAuthorizationModel see [storage.TypeDefinitionWriteBackend].MaxTypesPerAuthorizationModel.
func (s *MemoryBackend) MaxTypesPerAuthorizationModel() int {
return s.maxTypesPerAuthorizationModel
}
// GetStore retrieves the details of a specific store from the MemoryBackend using its storeID.
func (s *MemoryBackend) GetStore(ctx context.Context, storeID string) (*openfgav1.Store, error) {
_, span := tracer.Start(ctx, "memory.GetStore")
defer span.End()
s.mutexStores.RLock()
defer s.mutexStores.RUnlock()
if s.stores[storeID] == nil {
return nil, storage.ErrNotFound
}
return s.stores[storeID], nil
}
// ListStores provides a paginated list of all stores present in the MemoryBackend.
func (s *MemoryBackend) ListStores(ctx context.Context, options storage.ListStoresOptions) ([]*openfgav1.Store, string, error) {
_, span := tracer.Start(ctx, "memory.ListStores")
defer span.End()
s.mutexStores.RLock()
defer s.mutexStores.RUnlock()
stores := make([]*openfgav1.Store, 0, len(s.stores))
for _, t := range s.stores {
stores = append(stores, t)
}
if len(options.IDs) > 0 {
filteredStores := make([]*openfgav1.Store, 0, len(stores))
for _, storeID := range options.IDs {
for _, store := range stores {
if store.GetId() == storeID {
filteredStores = append(filteredStores, store)
}
}
}
stores = filteredStores
}
if options.Name != "" {
filteredStores := make([]*openfgav1.Store, 0, len(stores))
for _, store := range stores {
if store.GetName() == options.Name {
filteredStores = append(filteredStores, store)
}
}
stores = filteredStores
}
// From oldest to newest.
sort.SliceStable(stores, func(i, j int) bool {
return stores[i].GetId() < stores[j].GetId()
})
var err error
var from int64
if options.Pagination.From != "" {
from, err = strconv.ParseInt(options.Pagination.From, 10, 32)
if err != nil {
return nil, "", err
}
}
pageSize := storage.DefaultPageSize
if options.Pagination.PageSize > 0 {
pageSize = options.Pagination.PageSize
}
to := int(from) + pageSize
if len(stores) < to {
to = len(stores)
}
res := stores[from:to]
if len(res) == 0 {
return nil, "", nil
}
continuationToken := ""
if to != len(stores) {
continuationToken = strconv.Itoa(to)
}
return res, continuationToken, nil
}
// IsReady see [storage.OpenFGADatastore].IsReady.
func (s *MemoryBackend) IsReady(context.Context) (storage.ReadinessStatus, error) {
return storage.ReadinessStatus{IsReady: true}, nil
}
package mysql
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
sq "github.com/Masterminds/squirrel"
"github.com/cenkalti/backoff/v4"
"github.com/go-sql-driver/mysql"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/sqlcommon"
tupleUtils "github.com/openfga/openfga/pkg/tuple"
)
var tracer = otel.Tracer("openfga/pkg/storage/mysql")
func startTrace(ctx context.Context, name string) (context.Context, trace.Span) {
return tracer.Start(ctx, "mysql."+name)
}
// Datastore provides a MySQL based implementation of [storage.OpenFGADatastore].
type Datastore struct {
stbl sq.StatementBuilderType
db *sql.DB
dbInfo *sqlcommon.DBInfo
logger logger.Logger
dbStatsCollector prometheus.Collector
maxTuplesPerWriteField int
maxTypesPerModelField int
versionReady bool
}
// Ensures that Datastore implements the OpenFGADatastore interface.
var _ storage.OpenFGADatastore = (*Datastore)(nil)
// New creates a new [Datastore] storage.
func New(uri string, cfg *sqlcommon.Config) (*Datastore, error) {
if cfg.Username != "" || cfg.Password != "" {
dsnCfg, err := mysql.ParseDSN(uri)
if err != nil {
return nil, fmt.Errorf("parse mysql connection dsn: %w", err)
}
if cfg.Username != "" {
dsnCfg.User = cfg.Username
}
if cfg.Password != "" {
dsnCfg.Passwd = cfg.Password
}
dsnCfg.AllowNativePasswords = true
uri = dsnCfg.FormatDSN()
}
db, err := sql.Open("mysql", uri)
if err != nil {
return nil, fmt.Errorf("initialize mysql connection: %w", err)
}
return NewWithDB(db, cfg)
}
// NewWithDB creates a new [Datastore] storage with the provided database connection.
func NewWithDB(db *sql.DB, cfg *sqlcommon.Config) (*Datastore, error) {
if cfg.MaxIdleConns != 0 {
db.SetMaxIdleConns(cfg.MaxIdleConns) // default is 2, not retaining connections(0) would be detrimental for performance
}
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetConnMaxIdleTime(cfg.ConnMaxIdleTime)
db.SetConnMaxLifetime(cfg.ConnMaxLifetime)
policy := backoff.NewExponentialBackOff()
policy.MaxElapsedTime = 1 * time.Minute
attempt := 1
err := backoff.Retry(func() error {
err := db.PingContext(context.Background())
if err != nil {
cfg.Logger.Info("waiting for database", zap.Int("attempt", attempt))
attempt++
return err
}
return nil
}, policy)
if err != nil {
return nil, fmt.Errorf("ping db: %w", err)
}
var collector prometheus.Collector
if cfg.ExportMetrics {
collector = collectors.NewDBStatsCollector(db, "openfga")
if err := prometheus.Register(collector); err != nil {
return nil, fmt.Errorf("initialize metrics: %w", err)
}
}
stbl := sq.StatementBuilder.RunWith(db)
dbInfo := sqlcommon.NewDBInfo(stbl, HandleSQLError, "mysql")
return &Datastore{
stbl: stbl,
db: db,
dbInfo: dbInfo,
logger: cfg.Logger,
dbStatsCollector: collector,
maxTuplesPerWriteField: cfg.MaxTuplesPerWriteField,
maxTypesPerModelField: cfg.MaxTypesPerModelField,
versionReady: false,
}, nil
}
// Close see [storage.OpenFGADatastore].Close.
func (s *Datastore) Close() {
if s.dbStatsCollector != nil {
prometheus.Unregister(s.dbStatsCollector)
}
s.db.Close()
}
// Read see [storage.RelationshipTupleReader].Read.
func (s *Datastore) Read(
ctx context.Context,
store string,
filter storage.ReadFilter,
_ storage.ReadOptions,
) (storage.TupleIterator, error) {
ctx, span := startTrace(ctx, "Read")
defer span.End()
return s.read(ctx, store, filter, nil)
}
// ReadPage see [storage.RelationshipTupleReader].ReadPage.
func (s *Datastore) ReadPage(ctx context.Context, store string, filter storage.ReadFilter, options storage.ReadPageOptions) ([]*openfgav1.Tuple, string, error) {
ctx, span := startTrace(ctx, "ReadPage")
defer span.End()
iter, err := s.read(ctx, store, filter, &options)
if err != nil {
return nil, "", err
}
defer iter.Stop()
return iter.ToArray(ctx, options.Pagination)
}
func (s *Datastore) read(ctx context.Context, store string, filter storage.ReadFilter, options *storage.ReadPageOptions) (*sqlcommon.SQLTupleIterator, error) {
_, span := startTrace(ctx, "read")
defer span.End()
sb := s.stbl.
Select(
"store", "object_type", "object_id", "relation",
"_user",
"condition_name", "condition_context", "ulid", "inserted_at",
).
From("tuple").
Where(sq.Eq{"store": store})
if options != nil {
sb = sb.OrderBy("ulid")
}
objectType, objectID := tupleUtils.SplitObject(filter.Object)
if objectType != "" {
sb = sb.Where(sq.Eq{"object_type": objectType})
}
if objectID != "" {
sb = sb.Where(sq.Eq{"object_id": objectID})
}
if filter.Relation != "" {
sb = sb.Where(sq.Eq{"relation": filter.Relation})
}
if filter.User != "" {
userType, userID, _ := tupleUtils.ToUserParts(filter.User)
if userID != "" {
sb = sb.Where(sq.Eq{"_user": filter.User})
} else {
sb = sb.Where(sq.Like{"_user": userType + ":%"})
}
}
if len(filter.Conditions) > 0 {
// Use COALESCE to treat NULL and '' as the same value (empty string).
// This allows filtering for "no condition" (e.g., filter.Conditions = [""])
// to correctly match rows where condition_name is either '' OR NULL.
sb = sb.Where(sq.Eq{"COALESCE(condition_name, '')": filter.Conditions})
}
if options != nil && options.Pagination.From != "" {
token := options.Pagination.From
sb = sb.Where(sq.GtOrEq{"ulid": token})
}
if options != nil && options.Pagination.PageSize != 0 {
sb = sb.Limit(uint64(options.Pagination.PageSize + 1)) // + 1 is used to determine whether to return a continuation token.
}
return sqlcommon.NewSQLTupleIterator(sqlcommon.NewSBIteratorQuery(sb), HandleSQLError), nil
}
// Write see [storage.RelationshipTupleWriter].Write.
func (s *Datastore) Write(
ctx context.Context,
store string,
deletes storage.Deletes,
writes storage.Writes,
opts ...storage.TupleWriteOption,
) error {
ctx, span := startTrace(ctx, "Write")
defer span.End()
return sqlcommon.Write(ctx, s.dbInfo, s.db, store,
sqlcommon.WriteData{
Deletes: deletes,
Writes: writes,
Opts: storage.NewTupleWriteOptions(opts...),
Now: time.Now().UTC(),
})
}
// ReadUserTuple see [storage.RelationshipTupleReader].ReadUserTuple.
func (s *Datastore) ReadUserTuple(ctx context.Context, store string, tupleKey *openfgav1.TupleKey, _ storage.ReadUserTupleOptions) (*openfgav1.Tuple, error) {
ctx, span := startTrace(ctx, "ReadUserTuple")
defer span.End()
objectType, objectID := tupleUtils.SplitObject(tupleKey.GetObject())
userType := tupleUtils.GetUserTypeFromUser(tupleKey.GetUser())
var conditionName sql.NullString
var conditionContext []byte
var record storage.TupleRecord
err := s.stbl.
Select(
"object_type", "object_id", "relation",
"_user",
"condition_name", "condition_context",
).
From("tuple").
Where(sq.Eq{
"store": store,
"object_type": objectType,
"object_id": objectID,
"relation": tupleKey.GetRelation(),
"_user": tupleKey.GetUser(),
"user_type": userType,
}).
QueryRowContext(ctx).
Scan(
&record.ObjectType,
&record.ObjectID,
&record.Relation,
&record.User,
&conditionName,
&conditionContext,
)
if err != nil {
return nil, HandleSQLError(err)
}
if conditionName.String != "" {
record.ConditionName = conditionName.String
if conditionContext != nil {
var conditionContextStruct structpb.Struct
if err := proto.Unmarshal(conditionContext, &conditionContextStruct); err != nil {
return nil, err
}
record.ConditionContext = &conditionContextStruct
}
}
return record.AsTuple(), nil
}
// ReadUsersetTuples see [storage.RelationshipTupleReader].ReadUsersetTuples.
func (s *Datastore) ReadUsersetTuples(
ctx context.Context,
store string,
filter storage.ReadUsersetTuplesFilter,
_ storage.ReadUsersetTuplesOptions,
) (storage.TupleIterator, error) {
_, span := startTrace(ctx, "ReadUsersetTuples")
defer span.End()
sb := s.stbl.
Select(
"store", "object_type", "object_id", "relation",
"_user",
"condition_name", "condition_context", "ulid", "inserted_at",
).
From("tuple").
Where(sq.Eq{"store": store}).
Where(sq.Eq{"user_type": tupleUtils.UserSet})
objectType, objectID := tupleUtils.SplitObject(filter.Object)
if objectType != "" {
sb = sb.Where(sq.Eq{"object_type": objectType})
}
if objectID != "" {
sb = sb.Where(sq.Eq{"object_id": objectID})
}
if filter.Relation != "" {
sb = sb.Where(sq.Eq{"relation": filter.Relation})
}
if len(filter.AllowedUserTypeRestrictions) > 0 {
orConditions := sq.Or{}
for _, userset := range filter.AllowedUserTypeRestrictions {
if _, ok := userset.GetRelationOrWildcard().(*openfgav1.RelationReference_Relation); ok {
orConditions = append(orConditions, sq.Like{
"_user": userset.GetType() + ":%#" + userset.GetRelation(),
})
}
if _, ok := userset.GetRelationOrWildcard().(*openfgav1.RelationReference_Wildcard); ok {
orConditions = append(orConditions, sq.Eq{
"_user": userset.GetType() + ":*",
})
}
}
sb = sb.Where(orConditions)
}
if len(filter.Conditions) > 0 {
sb = sb.Where(sq.Eq{"COALESCE(condition_name, '')": filter.Conditions})
}
return sqlcommon.NewSQLTupleIterator(sqlcommon.NewSBIteratorQuery(sb), HandleSQLError), nil
}
// ReadStartingWithUser see [storage.RelationshipTupleReader].ReadStartingWithUser.
func (s *Datastore) ReadStartingWithUser(
ctx context.Context,
store string,
filter storage.ReadStartingWithUserFilter,
_ storage.ReadStartingWithUserOptions,
) (storage.TupleIterator, error) {
_, span := startTrace(ctx, "ReadStartingWithUser")
defer span.End()
var targetUsersArg []string
for _, u := range filter.UserFilter {
targetUser := u.GetObject()
if u.GetRelation() != "" {
targetUser = strings.Join([]string{u.GetObject(), u.GetRelation()}, "#")
}
targetUsersArg = append(targetUsersArg, targetUser)
}
builder := s.stbl.
Select(
"store", "object_type", "object_id", "relation",
"_user",
"condition_name", "condition_context", "ulid", "inserted_at",
).
From("tuple").
Where(sq.Eq{
"store": store,
"object_type": filter.ObjectType,
"relation": filter.Relation,
"_user": targetUsersArg,
}).OrderBy("object_id")
if filter.ObjectIDs != nil && filter.ObjectIDs.Size() > 0 {
builder = builder.Where(sq.Eq{"object_id": filter.ObjectIDs.Values()})
}
if len(filter.Conditions) > 0 {
builder = builder.Where(sq.Eq{"COALESCE(condition_name, '')": filter.Conditions})
}
return sqlcommon.NewSQLTupleIterator(sqlcommon.NewSBIteratorQuery(builder), HandleSQLError), nil
}
// MaxTuplesPerWrite see [storage.RelationshipTupleWriter].MaxTuplesPerWrite.
func (s *Datastore) MaxTuplesPerWrite() int {
return s.maxTuplesPerWriteField
}
// ReadAuthorizationModel see [storage.AuthorizationModelReadBackend].ReadAuthorizationModel.
func (s *Datastore) ReadAuthorizationModel(ctx context.Context, store string, modelID string) (*openfgav1.AuthorizationModel, error) {
ctx, span := startTrace(ctx, "ReadAuthorizationModel")
defer span.End()
return sqlcommon.ReadAuthorizationModel(ctx, s.dbInfo, store, modelID)
}
// ReadAuthorizationModels see [storage.AuthorizationModelReadBackend].ReadAuthorizationModels.
func (s *Datastore) ReadAuthorizationModels(ctx context.Context, store string, options storage.ReadAuthorizationModelsOptions) ([]*openfgav1.AuthorizationModel, string, error) {
ctx, span := startTrace(ctx, "ReadAuthorizationModels")
defer span.End()
sb := s.stbl.
Select("authorization_model_id").
Distinct().
From("authorization_model").
Where(sq.Eq{"store": store}).
OrderBy("authorization_model_id desc")
if options.Pagination.From != "" {
token := options.Pagination.From
sb = sb.Where(sq.LtOrEq{"authorization_model_id": token})
}
if options.Pagination.PageSize > 0 {
sb = sb.Limit(uint64(options.Pagination.PageSize + 1)) // + 1 is used to determine whether to return a continuation token.
}
rows, err := sb.QueryContext(ctx)
if err != nil {
return nil, "", HandleSQLError(err)
}
defer rows.Close()
var modelIDs []string
var modelID string
for rows.Next() {
err = rows.Scan(&modelID)
if err != nil {
return nil, "", HandleSQLError(err)
}
modelIDs = append(modelIDs, modelID)
}
if err := rows.Err(); err != nil {
return nil, "", HandleSQLError(err)
}
var token string
numModelIDs := len(modelIDs)
if len(modelIDs) > options.Pagination.PageSize {
numModelIDs = options.Pagination.PageSize
token = modelID
}
// TODO: make this concurrent with a maximum of 5 goroutines. This may be helpful:
// https://stackoverflow.com/questions/25306073/always-have-x-number-of-goroutines-running-at-any-time
models := make([]*openfgav1.AuthorizationModel, 0, numModelIDs)
// We use numModelIDs here to avoid retrieving possibly one extra model.
for i := 0; i < numModelIDs; i++ {
model, err := s.ReadAuthorizationModel(ctx, store, modelIDs[i])
if err != nil {
return nil, "", err
}
models = append(models, model)
}
return models, token, nil
}
// FindLatestAuthorizationModel see [storage.AuthorizationModelReadBackend].FindLatestAuthorizationModel.
func (s *Datastore) FindLatestAuthorizationModel(ctx context.Context, store string) (*openfgav1.AuthorizationModel, error) {
ctx, span := startTrace(ctx, "FindLatestAuthorizationModel")
defer span.End()
return sqlcommon.FindLatestAuthorizationModel(ctx, s.dbInfo, store)
}
// MaxTypesPerAuthorizationModel see [storage.TypeDefinitionWriteBackend].MaxTypesPerAuthorizationModel.
func (s *Datastore) MaxTypesPerAuthorizationModel() int {
return s.maxTypesPerModelField
}
// WriteAuthorizationModel see [storage.TypeDefinitionWriteBackend].WriteAuthorizationModel.
func (s *Datastore) WriteAuthorizationModel(ctx context.Context, store string, model *openfgav1.AuthorizationModel) error {
ctx, span := startTrace(ctx, "WriteAuthorizationModel")
defer span.End()
return sqlcommon.WriteAuthorizationModel(ctx, s.dbInfo, store, model)
}
// CreateStore adds a new store to storage.
func (s *Datastore) CreateStore(ctx context.Context, store *openfgav1.Store) (*openfgav1.Store, error) {
ctx, span := startTrace(ctx, "CreateStore")
defer span.End()
var id, name string
var createdAt, updatedAt time.Time
txn, err := s.db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return nil, HandleSQLError(err)
}
defer func() {
_ = txn.Rollback()
}()
_, err = s.stbl.
Insert("store").
Columns("id", "name", "created_at", "updated_at").
Values(store.GetId(), store.GetName(), sq.Expr("NOW()"), sq.Expr("NOW()")).
RunWith(txn).
ExecContext(ctx)
if err != nil {
return nil, HandleSQLError(err)
}
err = s.stbl.
Select("id", "name", "created_at", "updated_at").
From("store").
Where(sq.Eq{"id": store.GetId()}).
RunWith(txn).
QueryRowContext(ctx).
Scan(&id, &name, &createdAt, &updatedAt)
if err != nil {
return nil, HandleSQLError(err)
}
err = txn.Commit()
if err != nil {
return nil, HandleSQLError(err)
}
return &openfgav1.Store{
Id: id,
Name: name,
CreatedAt: timestamppb.New(createdAt),
UpdatedAt: timestamppb.New(updatedAt),
}, nil
}
// GetStore retrieves the details of a specific store using its storeID.
func (s *Datastore) GetStore(ctx context.Context, id string) (*openfgav1.Store, error) {
ctx, span := startTrace(ctx, "GetStore")
defer span.End()
row := s.stbl.
Select("id", "name", "created_at", "updated_at").
From("store").
Where(sq.Eq{
"id": id,
"deleted_at": nil,
}).
QueryRowContext(ctx)
var storeID, name string
var createdAt, updatedAt time.Time
err := row.Scan(&storeID, &name, &createdAt, &updatedAt)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, storage.ErrNotFound
}
return nil, HandleSQLError(err)
}
return &openfgav1.Store{
Id: storeID,
Name: name,
CreatedAt: timestamppb.New(createdAt),
UpdatedAt: timestamppb.New(updatedAt),
}, nil
}
// ListStores provides a paginated list of all stores present in the storage.
func (s *Datastore) ListStores(ctx context.Context, options storage.ListStoresOptions) ([]*openfgav1.Store, string, error) {
ctx, span := startTrace(ctx, "ListStores")
defer span.End()
whereClause := sq.And{
sq.Eq{"deleted_at": nil},
}
if len(options.IDs) > 0 {
whereClause = append(whereClause, sq.Eq{"id": options.IDs})
}
if options.Name != "" {
whereClause = append(whereClause, sq.Eq{"name": options.Name})
}
if options.Pagination.From != "" {
whereClause = append(whereClause, sq.GtOrEq{"id": options.Pagination.From})
}
sb := s.stbl.
Select("id", "name", "created_at", "updated_at").
From("store").
Where(whereClause).
OrderBy("id")
if options.Pagination.PageSize > 0 {
sb = sb.Limit(uint64(options.Pagination.PageSize + 1)) // + 1 is used to determine whether to return a continuation token.
}
rows, err := sb.QueryContext(ctx)
if err != nil {
return nil, "", HandleSQLError(err)
}
defer rows.Close()
var stores []*openfgav1.Store
var id string
for rows.Next() {
var name string
var createdAt, updatedAt time.Time
err := rows.Scan(&id, &name, &createdAt, &updatedAt)
if err != nil {
return nil, "", HandleSQLError(err)
}
stores = append(stores, &openfgav1.Store{
Id: id,
Name: name,
CreatedAt: timestamppb.New(createdAt),
UpdatedAt: timestamppb.New(updatedAt),
})
}
if err := rows.Err(); err != nil {
return nil, "", HandleSQLError(err)
}
if len(stores) > options.Pagination.PageSize {
return stores[:options.Pagination.PageSize], id, nil
}
return stores, "", nil
}
// DeleteStore removes a store from storage.
func (s *Datastore) DeleteStore(ctx context.Context, id string) error {
ctx, span := startTrace(ctx, "DeleteStore")
defer span.End()
_, err := s.stbl.
Update("store").
Set("deleted_at", sq.Expr("NOW()")).
Where(sq.Eq{"id": id}).
ExecContext(ctx)
if err != nil {
return HandleSQLError(err)
}
return nil
}
// WriteAssertions see [storage.AssertionsBackend].WriteAssertions.
func (s *Datastore) WriteAssertions(ctx context.Context, store, modelID string, assertions []*openfgav1.Assertion) error {
ctx, span := startTrace(ctx, "WriteAssertions")
defer span.End()
marshalledAssertions, err := proto.Marshal(&openfgav1.Assertions{Assertions: assertions})
if err != nil {
return err
}
_, err = s.stbl.
Insert("assertion").
Columns("store", "authorization_model_id", "assertions").
Values(store, modelID, marshalledAssertions).
Suffix("ON DUPLICATE KEY UPDATE assertions = ?", marshalledAssertions).
ExecContext(ctx)
if err != nil {
return HandleSQLError(err)
}
return nil
}
// ReadAssertions see [storage.AssertionsBackend].ReadAssertions.
func (s *Datastore) ReadAssertions(ctx context.Context, store, modelID string) ([]*openfgav1.Assertion, error) {
ctx, span := startTrace(ctx, "ReadAssertions")
defer span.End()
var marshalledAssertions []byte
err := s.stbl.
Select("assertions").
From("assertion").
Where(sq.Eq{
"store": store,
"authorization_model_id": modelID,
}).
QueryRowContext(ctx).
Scan(&marshalledAssertions)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return []*openfgav1.Assertion{}, nil
}
return nil, HandleSQLError(err)
}
var assertions openfgav1.Assertions
err = proto.Unmarshal(marshalledAssertions, &assertions)
if err != nil {
return nil, err
}
return assertions.GetAssertions(), nil
}
// ReadChanges see [storage.ChangelogBackend].ReadChanges.
func (s *Datastore) ReadChanges(ctx context.Context, store string, filter storage.ReadChangesFilter, options storage.ReadChangesOptions) ([]*openfgav1.TupleChange, string, error) {
ctx, span := startTrace(ctx, "ReadChanges")
defer span.End()
objectTypeFilter := filter.ObjectType
horizonOffset := filter.HorizonOffset
orderBy := "ulid asc"
if options.SortDesc {
orderBy = "ulid desc"
}
sb := s.stbl.
Select(
"ulid", "object_type", "object_id", "relation",
"_user",
"operation",
"condition_name", "condition_context", "inserted_at",
).
From("changelog").
Where(sq.Eq{"store": store}).
Where(fmt.Sprintf("inserted_at <= NOW() - INTERVAL %d MICROSECOND", horizonOffset.Microseconds())).
OrderBy(orderBy)
if objectTypeFilter != "" {
sb = sb.Where(sq.Eq{"object_type": objectTypeFilter})
}
if options.Pagination.From != "" {
sb = sqlcommon.AddFromUlid(sb, options.Pagination.From, options.SortDesc)
}
if options.Pagination.PageSize > 0 {
sb = sb.Limit(uint64(options.Pagination.PageSize)) // + 1 is NOT used here as we always return a continuation token.
}
rows, err := sb.QueryContext(ctx)
if err != nil {
return nil, "", HandleSQLError(err)
}
defer rows.Close()
var changes []*openfgav1.TupleChange
var ulid string
for rows.Next() {
var objectType, objectID, relation, user string
var operation int
var insertedAt time.Time
var conditionName sql.NullString
var conditionContext []byte
err = rows.Scan(
&ulid,
&objectType,
&objectID,
&relation,
&user,
&operation,
&conditionName,
&conditionContext,
&insertedAt,
)
if err != nil {
return nil, "", HandleSQLError(err)
}
var conditionContextStruct structpb.Struct
if conditionName.String != "" {
if conditionContext != nil {
if err := proto.Unmarshal(conditionContext, &conditionContextStruct); err != nil {
return nil, "", err
}
}
}
tk := tupleUtils.NewTupleKeyWithCondition(
tupleUtils.BuildObject(objectType, objectID),
relation,
user,
conditionName.String,
&conditionContextStruct,
)
changes = append(changes, &openfgav1.TupleChange{
TupleKey: tk,
Operation: openfgav1.TupleOperation(operation),
Timestamp: timestamppb.New(insertedAt.UTC()),
})
}
if len(changes) == 0 {
return nil, "", storage.ErrNotFound
}
return changes, ulid, nil
}
// IsReady see [sqlcommon.IsReady].
func (s *Datastore) IsReady(ctx context.Context) (storage.ReadinessStatus, error) {
versionReady, err := sqlcommon.IsReady(ctx, s.versionReady, s.db)
if err != nil {
return versionReady, err
}
s.versionReady = versionReady.IsReady
return versionReady, nil
}
// HandleSQLError processes an SQL error and converts it into a more
// specific error type based on the nature of the SQL error.
func HandleSQLError(err error, args ...interface{}) error {
if errors.Is(err, sql.ErrNoRows) {
return storage.ErrNotFound
}
var me *mysql.MySQLError
if errors.As(err, &me) && me.Number == 1062 {
if len(args) > 0 {
if tk, ok := args[0].(*openfgav1.TupleKey); ok {
return storage.InvalidWriteInputError(tk, openfgav1.TupleOperation_TUPLE_OPERATION_WRITE)
}
}
return storage.ErrCollision
}
return fmt.Errorf("sql error: %w", err)
}
package postgres
import (
"context"
sq "github.com/Masterminds/squirrel"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/openfga/openfga/pkg/storage/sqlcommon"
)
// PgxQuery interface allows Query that returns pgx.Rows.
type PgxQuery interface {
Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error)
}
// PgxExec interface allows pgx Exec functionality.
type PgxExec interface {
Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error)
}
// PgxTxnIterQuery is a helper to run queries using pgxpool when used in sqlcommon iterator.
type PgxTxnIterQuery struct {
txn PgxQuery
query string
args []interface{}
}
var _ sqlcommon.SQLIteratorRowGetter = (*PgxTxnIterQuery)(nil)
// NewPgxTxnGetRows creates a PgxPoolIterQuery which allows the GetRows functionality via the specified PgxQuery txn.
func NewPgxTxnGetRows(txn PgxQuery, sb sq.SelectBuilder) (*PgxTxnIterQuery, error) {
stmt, args, err := sb.ToSql()
if err != nil {
return nil, err
}
return &PgxTxnIterQuery{
txn: txn,
query: stmt,
args: args,
}, nil
}
// GetRows executes the txn query and returns the sqlcommon.Rows.
func (p *PgxTxnIterQuery) GetRows(ctx context.Context) (sqlcommon.Rows, error) {
rows, err := p.txn.Query(ctx, p.query, p.args...)
if err != nil {
return nil, HandleSQLError(err)
}
return &pgxRowsWrapper{rows: rows}, nil
}
// pgxRowsWrapper wraps pgx.Rows to implement sqlcommon.Rows interface.
type pgxRowsWrapper struct {
rows pgx.Rows
}
func (r *pgxRowsWrapper) Err() error {
return r.rows.Err()
}
func (r *pgxRowsWrapper) Next() bool {
return r.rows.Next()
}
func (r *pgxRowsWrapper) Scan(dest ...any) error {
return r.rows.Scan(dest...)
}
func (r *pgxRowsWrapper) Close() error {
r.rows.Close()
return nil
}
var _ sqlcommon.Rows = (*pgxRowsWrapper)(nil)
package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"net/url"
"strings"
"time"
"github.com/IBM/pgxpoolprometheus"
sq "github.com/Masterminds/squirrel"
"github.com/cenkalti/backoff/v4"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/sqlcommon"
tupleUtils "github.com/openfga/openfga/pkg/tuple"
)
var tracer = otel.Tracer("openfga/pkg/storage/postgres")
func startTrace(ctx context.Context, name string) (context.Context, trace.Span) {
return tracer.Start(ctx, "postgres."+name)
}
// Datastore provides a PostgreSQL based implementation of [storage.OpenFGADatastore].
type Datastore struct {
primaryDB *pgxpool.Pool
secondaryDB *pgxpool.Pool
logger logger.Logger
primaryDBStatsCollector prometheus.Collector
secondaryDBStatsCollector prometheus.Collector
maxTuplesPerWriteField int
maxTypesPerModelField int
versionReady bool
}
// Ensures that Datastore implements the OpenFGADatastore interface.
var _ storage.OpenFGADatastore = (*Datastore)(nil)
func parseConfig(uri string, override bool, cfg *sqlcommon.Config) (*pgxpool.Config, error) {
c, err := pgxpool.ParseConfig(uri)
if err != nil {
return nil, fmt.Errorf("pgxpool parse postgres connection uri: %w", err)
}
if override {
parsed, err := url.Parse(uri)
if err != nil {
return nil, fmt.Errorf("url parse postgres connection uri: %w", err)
}
if cfg.Username != "" {
c.ConnConfig.User = cfg.Username
} else if parsed.User != nil {
c.ConnConfig.User = parsed.User.Username()
}
switch {
case cfg.Password != "":
c.ConnConfig.Password = cfg.Password
case parsed.User != nil:
if password, ok := parsed.User.Password(); ok {
c.ConnConfig.Password = password
}
}
}
if cfg.MaxOpenConns != 0 {
c.MaxConns = int32(cfg.MaxOpenConns)
}
if cfg.MinIdleConns != 0 {
c.MinIdleConns = int32(cfg.MinIdleConns)
}
if cfg.MinOpenConns != 0 {
c.MinConns = int32(cfg.MinOpenConns)
}
if cfg.ConnMaxLifetime != 0 {
c.MaxConnLifetime = cfg.ConnMaxLifetime
c.MaxConnLifetimeJitter = cfg.ConnMaxLifetime / 10 // Add 10% jitter to avoid thundering herd
}
if cfg.ConnMaxIdleTime != 0 {
c.MaxConnIdleTime = cfg.ConnMaxIdleTime
}
return c, nil
}
// initDB initializes a new postgres database connection.
func initDB(uri string, override bool, cfg *sqlcommon.Config) (*pgxpool.Pool, error) {
c, err := parseConfig(uri, override, cfg)
if err != nil {
return nil, err
}
db, err := pgxpool.NewWithConfig(context.Background(), c)
if err != nil {
return nil, fmt.Errorf("failed to establish connection: %w", err)
}
return db, nil
}
// New creates a new [Datastore] storage.
func New(uri string, cfg *sqlcommon.Config) (*Datastore, error) {
primaryDB, err := initDB(uri, cfg.Username != "" || cfg.Password != "", cfg)
if err != nil {
return nil, fmt.Errorf("initialize postgres connection: %w", err)
}
var secondaryDB *pgxpool.Pool
if cfg.SecondaryURI != "" {
secondaryDB, err = initDB(cfg.SecondaryURI, cfg.SecondaryUsername != "" || cfg.SecondaryPassword != "", cfg)
if err != nil {
return nil, fmt.Errorf("initialize postgres connection: %w", err)
}
}
return NewWithDB(primaryDB, secondaryDB, cfg)
}
func configureDB(db *pgxpool.Pool, cfg *sqlcommon.Config, dbName string) (prometheus.Collector, error) {
policy := backoff.NewExponentialBackOff()
policy.MaxElapsedTime = 1 * time.Minute
attempt := 1
err := backoff.Retry(func() error {
err := db.Ping(context.Background())
if err != nil {
cfg.Logger.Info("waiting for database", zap.Int("attempt", attempt))
attempt++
return err
}
return nil
}, policy)
if err != nil {
return nil, fmt.Errorf("ping db: %w", err)
}
var collector prometheus.Collector
if cfg.ExportMetrics {
collector = pgxpoolprometheus.NewCollector(db, map[string]string{"db_name": dbName})
if err := prometheus.Register(collector); err != nil {
return nil, fmt.Errorf("initialize metrics: %w", err)
}
}
return collector, nil
}
// NewWithDB creates a new [Datastore] storage with the provided database connection.
func NewWithDB(primaryDB, secondaryDB *pgxpool.Pool, cfg *sqlcommon.Config) (*Datastore, error) {
primaryCollector, err := configureDB(primaryDB, cfg, "openfga")
if err != nil {
return nil, fmt.Errorf("configure primary db: %w", err)
}
var secondaryCollector prometheus.Collector
if secondaryDB != nil {
secondaryCollector, err = configureDB(secondaryDB, cfg, "openfga_secondary")
if err != nil {
return nil, fmt.Errorf("configure secondary db: %w", err)
}
}
return &Datastore{
primaryDB: primaryDB,
secondaryDB: secondaryDB,
logger: cfg.Logger,
primaryDBStatsCollector: primaryCollector,
secondaryDBStatsCollector: secondaryCollector,
maxTuplesPerWriteField: cfg.MaxTuplesPerWriteField,
maxTypesPerModelField: cfg.MaxTypesPerModelField,
versionReady: false,
}, nil
}
func (s *Datastore) isSecondaryConfigured() bool {
return s.secondaryDB != nil
}
// Close see [storage.OpenFGADatastore].Close.
func (s *Datastore) Close() {
if s.primaryDBStatsCollector != nil {
prometheus.Unregister(s.primaryDBStatsCollector)
}
s.primaryDB.Close()
if s.isSecondaryConfigured() {
if s.secondaryDBStatsCollector != nil {
prometheus.Unregister(s.secondaryDBStatsCollector)
}
s.secondaryDB.Close()
}
}
// getPgxPool returns the pgxpool.Pool based on consistency options.
func (s *Datastore) getPgxPool(consistency openfgav1.ConsistencyPreference) *pgxpool.Pool {
if consistency == openfgav1.ConsistencyPreference_HIGHER_CONSISTENCY {
// If we are using higher consistency, we need to use the write database.
return s.primaryDB
}
if s.isSecondaryConfigured() {
// If we are using lower consistency, we can use the read database.
return s.secondaryDB
}
// If we are not using a secondary database, we can only use the primary database.
return s.primaryDB
}
// Read see [storage.RelationshipTupleReader].Read.
func (s *Datastore) Read(
ctx context.Context,
store string,
filter storage.ReadFilter,
options storage.ReadOptions,
) (storage.TupleIterator, error) {
ctx, span := startTrace(ctx, "Read")
defer span.End()
readPool := s.getPgxPool(options.Consistency.Preference)
return s.read(ctx, store, filter, nil, readPool)
}
// ReadPage see [storage.RelationshipTupleReader].ReadPage.
func (s *Datastore) ReadPage(ctx context.Context, store string, filter storage.ReadFilter, options storage.ReadPageOptions) ([]*openfgav1.Tuple, string, error) {
ctx, span := startTrace(ctx, "ReadPage")
defer span.End()
readPool := s.getPgxPool(options.Consistency.Preference)
iter, err := s.read(ctx, store, filter, &options, readPool)
if err != nil {
return nil, "", err
}
defer iter.Stop()
return iter.ToArray(ctx, options.Pagination)
}
func (s *Datastore) read(ctx context.Context, store string, filter storage.ReadFilter, options *storage.ReadPageOptions, db *pgxpool.Pool) (*sqlcommon.SQLTupleIterator, error) {
_, span := startTrace(ctx, "read")
defer span.End()
sb := sq.StatementBuilder.PlaceholderFormat(sq.Dollar).
Select(
"store", "object_type", "object_id", "relation",
"_user",
"condition_name", "condition_context", "ulid", "inserted_at",
).
From("tuple").
Where(sq.Eq{"store": store})
if options != nil {
sb = sb.OrderBy("ulid")
}
objectType, objectID := tupleUtils.SplitObject(filter.Object)
if objectType != "" {
sb = sb.Where(sq.Eq{"object_type": objectType})
}
if objectID != "" {
sb = sb.Where(sq.Eq{"object_id": objectID})
}
if filter.Relation != "" {
sb = sb.Where(sq.Eq{"relation": filter.Relation})
}
if filter.User != "" {
userType, userID, _ := tupleUtils.ToUserParts(filter.User)
if userID != "" {
sb = sb.Where(sq.Eq{"_user": filter.User})
} else {
sb = sb.Where(sq.Like{"_user": userType + ":%"})
}
}
if len(filter.Conditions) > 0 {
// Use COALESCE to treat NULL and '' as the same value (empty string).
// This allows filtering for "no condition" (e.g., filter.Conditions = [""])
// to correctly match rows where condition_name is either '' OR NULL.
sb = sb.Where(sq.Eq{"COALESCE(condition_name, '')": filter.Conditions})
}
if options != nil && options.Pagination.From != "" {
sb = sb.Where(sq.GtOrEq{"ulid": options.Pagination.From})
}
if options != nil && options.Pagination.PageSize != 0 {
sb = sb.Limit(uint64(options.Pagination.PageSize + 1)) // + 1 is used to determine whether to return a continuation token.
}
poolGetRows, err := NewPgxTxnGetRows(db, sb)
if err != nil {
return nil, HandleSQLError(err)
}
return sqlcommon.NewSQLTupleIterator(poolGetRows, HandleSQLError), nil
}
// Write see [storage.RelationshipTupleWriter].Write.
func (s *Datastore) Write(
ctx context.Context,
store string,
deletes storage.Deletes,
writes storage.Writes,
opts ...storage.TupleWriteOption,
) error {
ctx, span := startTrace(ctx, "Write")
defer span.End()
return s.write(ctx, store, deletes, writes, storage.NewTupleWriteOptions(opts...), time.Now().UTC())
}
// execute SELECT … FOR UPDATE statement for all the rows indicated by the lockKeys
// return a map of all the existing keys.
func selectAllExistingRowsForUpdate(ctx context.Context,
lockKeys []sqlcommon.TupleLockKey,
txn PgxQuery,
store string) (map[string]*openfgav1.Tuple, error) {
total := len(lockKeys)
stbl := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
existing := make(map[string]*openfgav1.Tuple, total)
for start := 0; start < total; start += storage.DefaultMaxTuplesPerWrite {
end := start + storage.DefaultMaxTuplesPerWrite
if end > total {
end = total
}
keys := lockKeys[start:end]
if err := selectExistingRowsForWrite(ctx, stbl, txn, store, keys, existing); err != nil {
return nil, err
}
}
return existing, nil
}
// For the prepared deleteConditions, execute delete tuples.
func executeDeleteTuples(ctx context.Context, txn PgxExec, store string, deleteConditions sq.Or) error {
stbl := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
for start, totalDeletes := 0, len(deleteConditions); start < totalDeletes; start += storage.DefaultMaxTuplesPerWrite {
end := start + storage.DefaultMaxTuplesPerWrite
if end > totalDeletes {
end = totalDeletes
}
deleteConditionsBatch := deleteConditions[start:end]
stmt, args, err := stbl.Delete("tuple").Where(sq.Eq{"store": store}).
Where(deleteConditionsBatch).ToSql()
if err != nil {
// Should never happen because we craft the delete statement
return HandleSQLError(err)
}
res, err := txn.Exec(ctx, stmt, args...)
if err != nil {
return HandleSQLError(err)
}
rowsAffected := res.RowsAffected()
if rowsAffected != int64(len(deleteConditionsBatch)) {
// If we deleted fewer rows than planned (after read before write), means we hit a race condition - someone else deleted the same row(s).
return storage.ErrWriteConflictOnDelete
}
}
return nil
}
// For the prepared writeItems, execute insert writeItems.
func executeWriteTuples(ctx context.Context, txn PgxExec, writeItems [][]interface{}) error {
stbl := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
for start, totalWrites := 0, len(writeItems); start < totalWrites; start += storage.DefaultMaxTuplesPerWrite {
end := start + storage.DefaultMaxTuplesPerWrite
if end > totalWrites {
end = totalWrites
}
writesBatch := writeItems[start:end]
insertBuilder := stbl.
Insert("tuple").
Columns(
"store",
"object_type",
"object_id",
"relation",
"_user",
"user_type",
"condition_name",
"condition_context",
"ulid",
"inserted_at",
)
for _, item := range writesBatch {
insertBuilder = insertBuilder.Values(item...)
}
stmt, args, err := insertBuilder.ToSql()
if err != nil {
// Should never happen because we craft the insert statement
return HandleSQLError(err)
}
_, err = txn.Exec(ctx, stmt, args...)
if err != nil {
dberr := HandleSQLError(err)
if errors.Is(dberr, storage.ErrCollision) {
// ErrCollision is returned on duplicate write (constraint violation), meaning we hit a race condition - someone else inserted the same row(s).
return storage.ErrWriteConflictOnInsert
}
return dberr
}
}
return nil
}
func executeInsertChanges(ctx context.Context, txn PgxExec, changeLogItems [][]interface{}) error {
stbl := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
for start, totalItems := 0, len(changeLogItems); start < totalItems; start += storage.DefaultMaxTuplesPerWrite {
end := start + storage.DefaultMaxTuplesPerWrite
if end > totalItems {
end = totalItems
}
changeLogBatch := changeLogItems[start:end]
changelogBuilder := stbl.
Insert("changelog").
Columns(
"store",
"object_type",
"object_id",
"relation",
"_user",
"condition_name",
"condition_context",
"operation",
"ulid",
"inserted_at",
)
for _, item := range changeLogBatch {
changelogBuilder = changelogBuilder.Values(item...)
}
stmt, args, err := changelogBuilder.ToSql()
if err != nil {
// Should never happen because we craft the insert statement
return HandleSQLError(err)
}
_, err = txn.Exec(ctx, stmt, args...)
if err != nil {
return HandleSQLError(err)
}
}
return nil
}
func (s *Datastore) write(
ctx context.Context,
store string,
deletes storage.Deletes,
writes storage.Writes,
opts storage.TupleWriteOptions,
now time.Time,
) error {
txn, err := s.primaryDB.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.ReadCommitted})
if err != nil {
return HandleSQLError(err)
}
// Important - use the same txn (instead of via db) to ensure all works are done as a transaction
defer func() { _ = txn.Rollback(ctx) }()
// 2. Compile a SELECT … FOR UPDATE statement to read the tuples for writes and lock tuples for deletes
// Build a deduped, sorted list of keys to lock.
lockKeys := sqlcommon.MakeTupleLockKeys(deletes, writes)
if len(lockKeys) == 0 {
// Nothing to do.
return nil
}
// 3. If list compiled in step 2 is not empty, execute SELECT … FOR UPDATE statement
existing, err := selectAllExistingRowsForUpdate(ctx, lockKeys, txn, store)
if err != nil {
return err
}
// 4. Construct the deleteConditions, write and changelog items to be written
deleteConditions, writeItems, changeLogItems, err := sqlcommon.GetDeleteWriteChangelogItems(store, existing,
sqlcommon.WriteData{
Deletes: deletes,
Writes: writes,
Opts: opts,
Now: now,
})
if err != nil {
return err
}
err = executeDeleteTuples(ctx, txn, store, deleteConditions)
if err != nil {
return err
}
err = executeWriteTuples(ctx, txn, writeItems)
if err != nil {
return err
}
// 5. Execute INSERT changelog statements
err = executeInsertChanges(ctx, txn, changeLogItems)
if err != nil {
return err
}
// 6. Commit Transaction
if err := txn.Commit(ctx); err != nil {
return HandleSQLError(err)
}
return nil
}
// ReadUserTuple see [storage.RelationshipTupleReader].ReadUserTuple.
func (s *Datastore) ReadUserTuple(ctx context.Context, store string, tupleKey *openfgav1.TupleKey, options storage.ReadUserTupleOptions) (*openfgav1.Tuple, error) {
ctx, span := startTrace(ctx, "ReadUserTuple")
defer span.End()
readStbl := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
objectType, objectID := tupleUtils.SplitObject(tupleKey.GetObject())
userType := tupleUtils.GetUserTypeFromUser(tupleKey.GetUser())
var conditionName sql.NullString
var conditionContext []byte
var record storage.TupleRecord
stbl := readStbl.
Select(
"object_type", "object_id", "relation",
"_user",
"condition_name", "condition_context",
).
From("tuple").
Where(sq.Eq{
"store": store,
"object_type": objectType,
"object_id": objectID,
"relation": tupleKey.GetRelation(),
"_user": tupleKey.GetUser(),
"user_type": userType,
})
stmt, args, err := stbl.ToSql()
if err != nil {
return nil, HandleSQLError(err)
}
db := s.getPgxPool(options.Consistency.Preference)
row := db.QueryRow(ctx, stmt, args...)
err = row.Scan(
&record.ObjectType,
&record.ObjectID,
&record.Relation,
&record.User,
&conditionName,
&conditionContext,
)
if err != nil {
return nil, HandleSQLError(err)
}
if conditionName.String != "" {
record.ConditionName = conditionName.String
if conditionContext != nil {
var conditionContextStruct structpb.Struct
if err := proto.Unmarshal(conditionContext, &conditionContextStruct); err != nil {
return nil, err
}
record.ConditionContext = &conditionContextStruct
}
}
return record.AsTuple(), nil
}
// ReadUsersetTuples see [storage.RelationshipTupleReader].ReadUsersetTuples.
func (s *Datastore) ReadUsersetTuples(
ctx context.Context,
store string,
filter storage.ReadUsersetTuplesFilter,
options storage.ReadUsersetTuplesOptions,
) (storage.TupleIterator, error) {
_, span := startTrace(ctx, "ReadUsersetTuples")
defer span.End()
db := s.getPgxPool(options.Consistency.Preference)
sb := sq.StatementBuilder.PlaceholderFormat(sq.Dollar).
Select(
"store", "object_type", "object_id", "relation",
"_user",
"condition_name", "condition_context", "ulid", "inserted_at",
).
From("tuple").
Where(sq.Eq{"store": store}).
Where(sq.Eq{"user_type": tupleUtils.UserSet})
objectType, objectID := tupleUtils.SplitObject(filter.Object)
if objectType != "" {
sb = sb.Where(sq.Eq{"object_type": objectType})
}
if objectID != "" {
sb = sb.Where(sq.Eq{"object_id": objectID})
}
if filter.Relation != "" {
sb = sb.Where(sq.Eq{"relation": filter.Relation})
}
if len(filter.AllowedUserTypeRestrictions) > 0 {
orConditions := sq.Or{}
for _, userset := range filter.AllowedUserTypeRestrictions {
if _, ok := userset.GetRelationOrWildcard().(*openfgav1.RelationReference_Relation); ok {
orConditions = append(orConditions, sq.Like{
"_user": userset.GetType() + ":%#" + userset.GetRelation(),
})
}
if _, ok := userset.GetRelationOrWildcard().(*openfgav1.RelationReference_Wildcard); ok {
orConditions = append(orConditions, sq.Eq{
"_user": userset.GetType() + ":*",
})
}
}
sb = sb.Where(orConditions)
}
if len(filter.Conditions) > 0 {
sb = sb.Where(sq.Eq{"COALESCE(condition_name, '')": filter.Conditions})
}
poolGetRows, err := NewPgxTxnGetRows(db, sb)
if err != nil {
return nil, HandleSQLError(err)
}
return sqlcommon.NewSQLTupleIterator(poolGetRows, HandleSQLError), nil
}
// ReadStartingWithUser see [storage.RelationshipTupleReader].ReadStartingWithUser.
func (s *Datastore) ReadStartingWithUser(
ctx context.Context,
store string,
filter storage.ReadStartingWithUserFilter,
options storage.ReadStartingWithUserOptions,
) (storage.TupleIterator, error) {
_, span := startTrace(ctx, "ReadStartingWithUser")
defer span.End()
db := s.getPgxPool(options.Consistency.Preference)
readStbl := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
var targetUsersArg []string
for _, u := range filter.UserFilter {
targetUser := u.GetObject()
if u.GetRelation() != "" {
targetUser = strings.Join([]string{u.GetObject(), u.GetRelation()}, "#")
}
targetUsersArg = append(targetUsersArg, targetUser)
}
builder := readStbl.
Select(
"store", "object_type", "object_id", "relation",
"_user",
"condition_name", "condition_context", "ulid", "inserted_at",
).
From("tuple").
Where(sq.Eq{
"store": store,
"object_type": filter.ObjectType,
"relation": filter.Relation,
"_user": targetUsersArg,
}).OrderBy("object_id collate \"C\"")
if filter.ObjectIDs != nil && filter.ObjectIDs.Size() > 0 {
builder = builder.Where(sq.Eq{"object_id": filter.ObjectIDs.Values()})
}
if len(filter.Conditions) > 0 {
builder = builder.Where(sq.Eq{"COALESCE(condition_name, '')": filter.Conditions})
}
poolGetRows, err := NewPgxTxnGetRows(db, builder)
if err != nil {
return nil, HandleSQLError(err)
}
return sqlcommon.NewSQLTupleIterator(poolGetRows, HandleSQLError), nil
}
// MaxTuplesPerWrite see [storage.RelationshipTupleWriter].MaxTuplesPerWrite.
func (s *Datastore) MaxTuplesPerWrite() int {
return s.maxTuplesPerWriteField
}
// ReadAuthorizationModel see [storage.AuthorizationModelReadBackend].ReadAuthorizationModel.
func (s *Datastore) ReadAuthorizationModel(ctx context.Context, store string, modelID string) (*openfgav1.AuthorizationModel, error) {
ctx, span := startTrace(ctx, "ReadAuthorizationModel")
defer span.End()
db := s.getPgxPool(openfgav1.ConsistencyPreference_MINIMIZE_LATENCY)
stmt, args, err := sq.StatementBuilder.PlaceholderFormat(sq.Dollar).
Select("authorization_model_id", "schema_version", "type", "type_definition", "serialized_protobuf").
From("authorization_model").
Where(sq.Eq{
"store": store,
"authorization_model_id": modelID,
}).ToSql()
if err != nil {
return nil, HandleSQLError(err)
}
rows, err := db.Query(ctx, stmt, args...)
if err != nil {
return nil, HandleSQLError(err)
}
defer rows.Close()
ret, err := sqlcommon.ConstructAuthorizationModelFromSQLRows(&pgxRowsWrapper{rows})
if err != nil {
return nil, HandleSQLError(err)
}
return ret, nil
}
// ReadAuthorizationModels see [storage.AuthorizationModelReadBackend].ReadAuthorizationModels.
func (s *Datastore) ReadAuthorizationModels(ctx context.Context, store string, options storage.ReadAuthorizationModelsOptions) ([]*openfgav1.AuthorizationModel, string, error) {
ctx, span := startTrace(ctx, "ReadAuthorizationModels")
defer span.End()
db := s.getPgxPool(openfgav1.ConsistencyPreference_MINIMIZE_LATENCY)
sb := sq.StatementBuilder.PlaceholderFormat(sq.Dollar).
Select("authorization_model_id").
Distinct().
From("authorization_model").
Where(sq.Eq{"store": store}).
OrderBy("authorization_model_id desc")
if options.Pagination.From != "" {
sb = sb.Where(sq.LtOrEq{"authorization_model_id": options.Pagination.From})
}
if options.Pagination.PageSize > 0 {
sb = sb.Limit(uint64(options.Pagination.PageSize + 1)) // + 1 is used to determine whether to return a continuation token.
}
stmt, args, err := sb.ToSql()
if err != nil {
return nil, "", HandleSQLError(err)
}
rows, err := db.Query(ctx, stmt, args...)
if err != nil {
return nil, "", HandleSQLError(err)
}
defer rows.Close()
var modelIDs []string
var modelID string
for rows.Next() {
err = rows.Scan(&modelID)
if err != nil {
return nil, "", HandleSQLError(err)
}
modelIDs = append(modelIDs, modelID)
}
if err := rows.Err(); err != nil {
return nil, "", HandleSQLError(err)
}
var token string
numModelIDs := len(modelIDs)
if len(modelIDs) > options.Pagination.PageSize {
numModelIDs = options.Pagination.PageSize
token = modelID
}
// TODO: make this concurrent with a maximum of 5 goroutines. This may be helpful:
// https://stackoverflow.com/questions/25306073/always-have-x-number-of-goroutines-running-at-any-time
models := make([]*openfgav1.AuthorizationModel, 0, numModelIDs)
// We use numModelIDs here to avoid retrieving possibly one extra model.
for i := 0; i < numModelIDs; i++ {
model, err := s.ReadAuthorizationModel(ctx, store, modelIDs[i])
if err != nil {
return nil, "", err
}
models = append(models, model)
}
return models, token, nil
}
// FindLatestAuthorizationModel see [storage.AuthorizationModelReadBackend].FindLatestAuthorizationModel.
func (s *Datastore) FindLatestAuthorizationModel(ctx context.Context, store string) (*openfgav1.AuthorizationModel, error) {
ctx, span := startTrace(ctx, "FindLatestAuthorizationModel")
defer span.End()
db := s.getPgxPool(openfgav1.ConsistencyPreference_MINIMIZE_LATENCY)
stmt, args, err := sq.StatementBuilder.PlaceholderFormat(sq.Dollar).
Select("authorization_model_id", "schema_version", "type", "type_definition", "serialized_protobuf").
From("authorization_model").
Where(sq.Eq{"store": store}).
OrderBy("authorization_model_id desc").ToSql()
if err != nil {
return nil, HandleSQLError(err)
}
rows, err := db.Query(ctx, stmt, args...)
if err != nil {
return nil, HandleSQLError(err)
}
defer rows.Close()
ret, err := sqlcommon.ConstructAuthorizationModelFromSQLRows(&pgxRowsWrapper{rows})
if err != nil {
return nil, HandleSQLError(err)
}
return ret, nil
}
// MaxTypesPerAuthorizationModel see [storage.TypeDefinitionWriteBackend].MaxTypesPerAuthorizationModel.
func (s *Datastore) MaxTypesPerAuthorizationModel() int {
return s.maxTypesPerModelField
}
// WriteAuthorizationModel see [storage.TypeDefinitionWriteBackend].WriteAuthorizationModel.
func (s *Datastore) WriteAuthorizationModel(ctx context.Context, store string, model *openfgav1.AuthorizationModel) error {
ctx, span := startTrace(ctx, "WriteAuthorizationModel")
defer span.End()
schemaVersion := model.GetSchemaVersion()
typeDefinitions := model.GetTypeDefinitions()
if len(typeDefinitions) < 1 {
return nil
}
pbdata, err := proto.Marshal(model)
if err != nil {
return err
}
stmt, args, err := sq.StatementBuilder.PlaceholderFormat(sq.Dollar).
Insert("authorization_model").
Columns("store", "authorization_model_id", "schema_version", "type", "type_definition", "serialized_protobuf").
Values(store, model.GetId(), schemaVersion, "", nil, pbdata).
ToSql()
if err != nil {
return HandleSQLError(err)
}
_, err = s.primaryDB.Exec(ctx, stmt, args...)
if err != nil {
return HandleSQLError(err)
}
return nil
}
// CreateStore adds a new store to storage.
func (s *Datastore) CreateStore(ctx context.Context, store *openfgav1.Store) (*openfgav1.Store, error) {
ctx, span := startTrace(ctx, "CreateStore")
defer span.End()
var id, name string
var createdAt, updatedAt time.Time
stmt, args, err := sq.StatementBuilder.PlaceholderFormat(sq.Dollar).
Insert("store").
Columns("id", "name", "created_at", "updated_at").
Values(store.GetId(), store.GetName(), sq.Expr("NOW()"), sq.Expr("NOW()")).
Suffix("returning id, name, created_at, updated_at").ToSql()
if err != nil {
return nil, HandleSQLError(err)
}
row := s.primaryDB.QueryRow(ctx, stmt, args...)
err = row.Scan(&id, &name, &createdAt, &updatedAt)
if err != nil {
return nil, HandleSQLError(err)
}
return &openfgav1.Store{
Id: id,
Name: name,
CreatedAt: timestamppb.New(createdAt),
UpdatedAt: timestamppb.New(updatedAt),
}, nil
}
// GetStore retrieves the details of a specific store using its storeID.
func (s *Datastore) GetStore(ctx context.Context, id string) (*openfgav1.Store, error) {
ctx, span := startTrace(ctx, "GetStore")
defer span.End()
db := s.getPgxPool(openfgav1.ConsistencyPreference_MINIMIZE_LATENCY)
stmt, args, err := sq.StatementBuilder.PlaceholderFormat(sq.Dollar).
Select("id", "name", "created_at", "updated_at").
From("store").
Where(sq.Eq{
"id": id,
"deleted_at": nil,
}).ToSql()
if err != nil {
return nil, HandleSQLError(err)
}
row := db.QueryRow(ctx, stmt, args...)
var storeID, name string
var createdAt, updatedAt time.Time
err = row.Scan(&storeID, &name, &createdAt, &updatedAt)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, storage.ErrNotFound
}
return nil, HandleSQLError(err)
}
return &openfgav1.Store{
Id: storeID,
Name: name,
CreatedAt: timestamppb.New(createdAt),
UpdatedAt: timestamppb.New(updatedAt),
}, nil
}
// ListStores provides a paginated list of all stores present in the storage.
func (s *Datastore) ListStores(ctx context.Context, options storage.ListStoresOptions) ([]*openfgav1.Store, string, error) {
ctx, span := startTrace(ctx, "ListStores")
defer span.End()
whereClause := sq.And{
sq.Eq{"deleted_at": nil},
}
if len(options.IDs) > 0 {
whereClause = append(whereClause, sq.Eq{"id": options.IDs})
}
if options.Name != "" {
whereClause = append(whereClause, sq.Eq{"name": options.Name})
}
if options.Pagination.From != "" {
whereClause = append(whereClause, sq.GtOrEq{"id": options.Pagination.From})
}
db := s.getPgxPool(openfgav1.ConsistencyPreference_MINIMIZE_LATENCY)
sb := sq.StatementBuilder.PlaceholderFormat(sq.Dollar).
Select("id", "name", "created_at", "updated_at").
From("store").
Where(whereClause).
OrderBy("id")
if options.Pagination.PageSize > 0 {
sb = sb.Limit(uint64(options.Pagination.PageSize + 1)) // + 1 is used to determine whether to return a continuation token.
}
stmt, args, err := sb.ToSql()
if err != nil {
return nil, "", HandleSQLError(err)
}
rows, err := db.Query(ctx, stmt, args...)
if err != nil {
return nil, "", HandleSQLError(err)
}
defer rows.Close()
var stores []*openfgav1.Store
var id string
for rows.Next() {
var name string
var createdAt, updatedAt time.Time
err := rows.Scan(&id, &name, &createdAt, &updatedAt)
if err != nil {
return nil, "", HandleSQLError(err)
}
stores = append(stores, &openfgav1.Store{
Id: id,
Name: name,
CreatedAt: timestamppb.New(createdAt),
UpdatedAt: timestamppb.New(updatedAt),
})
}
if err := rows.Err(); err != nil {
return nil, "", HandleSQLError(err)
}
if len(stores) > options.Pagination.PageSize {
return stores[:options.Pagination.PageSize], id, nil
}
return stores, "", nil
}
// DeleteStore removes a store from storage.
func (s *Datastore) DeleteStore(ctx context.Context, id string) error {
ctx, span := startTrace(ctx, "DeleteStore")
defer span.End()
db := s.primaryDB
stmt, args, err := sq.StatementBuilder.PlaceholderFormat(sq.Dollar).
Update("store").
Set("deleted_at", sq.Expr("NOW()")).
Where(sq.Eq{"id": id}).ToSql()
if err != nil {
return HandleSQLError(err)
}
_, err = db.Exec(ctx, stmt, args...)
if err != nil {
return HandleSQLError(err)
}
return nil
}
// WriteAssertions see [storage.AssertionsBackend].WriteAssertions.
func (s *Datastore) WriteAssertions(ctx context.Context, store, modelID string, assertions []*openfgav1.Assertion) error {
ctx, span := startTrace(ctx, "WriteAssertions")
defer span.End()
marshalledAssertions, err := proto.Marshal(&openfgav1.Assertions{Assertions: assertions})
if err != nil {
return err
}
db := s.primaryDB
stmt, args, err := sq.StatementBuilder.PlaceholderFormat(sq.Dollar).
Insert("assertion").
Columns("store", "authorization_model_id", "assertions").
Values(store, modelID, marshalledAssertions).
Suffix("ON CONFLICT (store, authorization_model_id) DO UPDATE SET assertions = ?", marshalledAssertions).
ToSql()
if err != nil {
return HandleSQLError(err)
}
_, err = db.Exec(ctx, stmt, args...)
if err != nil {
return HandleSQLError(err)
}
return nil
}
// ReadAssertions see [storage.AssertionsBackend].ReadAssertions.
func (s *Datastore) ReadAssertions(ctx context.Context, store, modelID string) ([]*openfgav1.Assertion, error) {
ctx, span := startTrace(ctx, "ReadAssertions")
defer span.End()
var marshalledAssertions []byte
db := s.getPgxPool(openfgav1.ConsistencyPreference_MINIMIZE_LATENCY)
stmt, args, err := sq.StatementBuilder.PlaceholderFormat(sq.Dollar).
Select("assertions").
From("assertion").
Where(sq.Eq{
"store": store,
"authorization_model_id": modelID,
}).ToSql()
if err != nil {
return nil, HandleSQLError(err)
}
err = db.QueryRow(ctx, stmt, args...).Scan(&marshalledAssertions)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return []*openfgav1.Assertion{}, nil
}
return nil, HandleSQLError(err)
}
var assertions openfgav1.Assertions
err = proto.Unmarshal(marshalledAssertions, &assertions)
if err != nil {
return nil, err
}
return assertions.GetAssertions(), nil
}
// ReadChanges see [storage.ChangelogBackend].ReadChanges.
func (s *Datastore) ReadChanges(ctx context.Context, store string, filter storage.ReadChangesFilter, options storage.ReadChangesOptions) ([]*openfgav1.TupleChange, string, error) {
ctx, span := startTrace(ctx, "ReadChanges")
defer span.End()
objectTypeFilter := filter.ObjectType
horizonOffset := filter.HorizonOffset
orderBy := "ulid asc"
if options.SortDesc {
orderBy = "ulid desc"
}
db := s.getPgxPool(openfgav1.ConsistencyPreference_MINIMIZE_LATENCY)
sb := sq.StatementBuilder.PlaceholderFormat(sq.Dollar).
Select(
"ulid", "object_type", "object_id", "relation",
"_user",
"operation",
"condition_name", "condition_context", "inserted_at",
).
From("changelog").
Where(sq.Eq{"store": store}).
Where(fmt.Sprintf("inserted_at < NOW() - interval '%dms'", horizonOffset.Milliseconds())).
OrderBy(orderBy)
if objectTypeFilter != "" {
sb = sb.Where(sq.Eq{"object_type": objectTypeFilter})
}
if options.Pagination.From != "" {
sb = sqlcommon.AddFromUlid(sb, options.Pagination.From, options.SortDesc)
}
if options.Pagination.PageSize > 0 {
sb = sb.Limit(uint64(options.Pagination.PageSize)) // + 1 is NOT used here as we always return a continuation token.
}
stmt, args, err := sb.ToSql()
if err != nil {
return nil, "", HandleSQLError(err)
}
rows, err := db.Query(ctx, stmt, args...)
if err != nil {
return nil, "", HandleSQLError(err)
}
defer rows.Close()
var changes []*openfgav1.TupleChange
var ulid string
for rows.Next() {
var objectType, objectID, relation, user string
var operation int
var insertedAt time.Time
var conditionName sql.NullString
var conditionContext []byte
err = rows.Scan(
&ulid,
&objectType,
&objectID,
&relation,
&user,
&operation,
&conditionName,
&conditionContext,
&insertedAt,
)
if err != nil {
return nil, "", HandleSQLError(err)
}
var conditionContextStruct structpb.Struct
if conditionName.String != "" {
if conditionContext != nil {
if err := proto.Unmarshal(conditionContext, &conditionContextStruct); err != nil {
return nil, "", err
}
}
}
tk := tupleUtils.NewTupleKeyWithCondition(
tupleUtils.BuildObject(objectType, objectID),
relation,
user,
conditionName.String,
&conditionContextStruct,
)
changes = append(changes, &openfgav1.TupleChange{
TupleKey: tk,
Operation: openfgav1.TupleOperation(operation),
Timestamp: timestamppb.New(insertedAt.UTC()),
})
}
if len(changes) == 0 {
return nil, "", storage.ErrNotFound
}
return changes, ulid, nil
}
func isDBReady(ctx context.Context, versionReady bool, db *pgxpool.Pool) (storage.ReadinessStatus, error) {
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
err := db.Ping(ctx)
if err != nil {
return storage.ReadinessStatus{}, err
}
sqlDB := stdlib.OpenDBFromPool(db)
defer func() {
_ = sqlDB.Close()
}()
return sqlcommon.IsVersionReady(ctx, versionReady, sqlDB)
}
// IsReady see [sqlcommon.IsReady].
func (s *Datastore) IsReady(ctx context.Context) (storage.ReadinessStatus, error) {
primaryStatus, err := isDBReady(ctx, s.versionReady, s.primaryDB)
if err != nil {
return primaryStatus, err
}
// if secondary is not configured, return primary status only
if !s.isSecondaryConfigured() {
s.versionReady = primaryStatus.IsReady
return primaryStatus, nil
}
if primaryStatus.IsReady && primaryStatus.Message == "" {
primaryStatus.Message = "ready"
}
secondaryStatus, err := isDBReady(ctx, s.versionReady, s.secondaryDB)
if err != nil {
secondaryStatus.Message = err.Error()
secondaryStatus.IsReady = false
}
if secondaryStatus.IsReady && secondaryStatus.Message == "" {
secondaryStatus.Message = "ready"
}
multipleReadyStatus := storage.ReadinessStatus{}
messageTpl := "primary: %s, secondary: %s"
multipleReadyStatus.IsReady = primaryStatus.IsReady && secondaryStatus.IsReady
multipleReadyStatus.Message = fmt.Sprintf(messageTpl, primaryStatus.Message, secondaryStatus.Message)
s.versionReady = multipleReadyStatus.IsReady
return multipleReadyStatus, nil
}
// HandleSQLError processes an SQL error and converts it into a more
// specific error type based on the nature of the SQL error.
func HandleSQLError(err error, args ...interface{}) error {
if errors.Is(err, sql.ErrNoRows) {
return storage.ErrNotFound
}
if strings.Contains(err.Error(), "duplicate key value") {
if len(args) > 0 {
if tk, ok := args[0].(*openfgav1.TupleKey); ok {
return storage.InvalidWriteInputError(tk, openfgav1.TupleOperation_TUPLE_OPERATION_WRITE)
}
}
return storage.ErrCollision
}
return fmt.Errorf("sql error: %w", err)
}
// selectExistingRowsForWrite selects existing rows for the given keys and locks them FOR UPDATE.
// The existing rows are added to the existing map.
func selectExistingRowsForWrite(ctx context.Context, stbl sq.StatementBuilderType, txn PgxQuery, store string, keys []sqlcommon.TupleLockKey, existing map[string]*openfgav1.Tuple) error {
inExpr, args := sqlcommon.BuildRowConstructorIN(keys)
sb := stbl.
Select(sqlcommon.SQLIteratorColumns()...).
From("tuple").
Where(sq.Eq{"store": store}).
// Row-constructor IN on full composite key for precise point locks.
Where(sq.Expr("(object_type, object_id, relation, _user, user_type) IN "+inExpr, args...)).
Suffix("FOR UPDATE")
poolGetRows, err := NewPgxTxnGetRows(txn, sb)
if err != nil {
return HandleSQLError(err)
}
iter := sqlcommon.NewSQLTupleIterator(poolGetRows, HandleSQLError)
defer iter.Stop()
items, _, err := iter.ToArray(ctx, storage.PaginationOptions{PageSize: len(keys)})
if err != nil {
return err
}
for _, tuple := range items {
existing[tupleUtils.TupleKeyToString(tuple.GetKey())] = tuple
}
return nil
}
package storage
import (
"time"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
tupleutils "github.com/openfga/openfga/pkg/tuple"
)
// TupleRecord represents a record structure used
// to store information about a specific tuple.
type TupleRecord struct {
Store string
ObjectType string
ObjectID string
Relation string
User string // Deprecated: Use UserObjectType, UserObjectID & UserRelation instead.
UserObjectType string
UserObjectID string
UserRelation string
ConditionName string
ConditionContext *structpb.Struct
Ulid string
InsertedAt time.Time
}
// AsTuple converts a [TupleRecord] into a [*openfgav1.Tuple].
func (t *TupleRecord) AsTuple() *openfgav1.Tuple {
user := t.User
if t.User == "" {
user = tupleutils.FromUserParts(t.UserObjectType, t.UserObjectID, t.UserRelation)
}
return &openfgav1.Tuple{
Key: tupleutils.NewTupleKeyWithCondition(
tupleutils.BuildObject(t.ObjectType, t.ObjectID),
t.Relation,
user,
t.ConditionName,
t.ConditionContext,
),
Timestamp: timestamppb.New(t.InsertedAt),
}
}
package storage
import "github.com/emirpasic/gods/trees/redblacktree"
// SortedSet stores a set (no duplicates allowed) of string IDs in memory
// in a way that also provides fast sorted access.
type SortedSet interface {
Size() int
// Min returns an empty string if the set is empty.
Min() string
// Max returns an empty string if the set is empty.
Max() string
Add(value string)
Exists(value string) bool
// Values returns the elements in the set in sorted order (ascending).
Values() []string
}
type RedBlackTreeSet struct {
inner *redblacktree.Tree
}
var _ SortedSet = (*RedBlackTreeSet)(nil)
func NewSortedSet(vals ...string) *RedBlackTreeSet {
c := &RedBlackTreeSet{
inner: redblacktree.NewWithStringComparator(),
}
for _, val := range vals {
c.Add(val)
}
return c
}
func (r *RedBlackTreeSet) Min() string {
if r.Size() == 0 {
return ""
}
return r.inner.Left().Key.(string)
}
func (r *RedBlackTreeSet) Max() string {
if r.Size() == 0 {
return ""
}
return r.inner.Right().Key.(string)
}
func (r *RedBlackTreeSet) Add(value string) {
r.inner.Put(value, nil)
}
func (r *RedBlackTreeSet) Exists(value string) bool {
_, ok := r.inner.Get(value)
return ok
}
func (r *RedBlackTreeSet) Size() int {
return r.inner.Size()
}
func (r *RedBlackTreeSet) Values() []string {
values := make([]string, 0, r.inner.Size())
for _, v := range r.inner.Keys() {
values = append(values, v.(string))
}
return values
}
package sqlcommon
import (
"google.golang.org/protobuf/proto"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
)
func MarshalRelationshipCondition(
rel *openfgav1.RelationshipCondition,
) (name string, context []byte, err error) {
if rel != nil {
// Normalize empty context to nil.
if rel.GetContext() != nil && len(rel.GetContext().GetFields()) > 0 {
context, err = proto.Marshal(rel.GetContext())
if err != nil {
return name, context, err
}
}
return rel.GetName(), context, err
}
return name, context, err
}
package sqlcommon
import (
"context"
"database/sql"
"encoding/json"
"errors"
"sort"
"strconv"
"strings"
"sync"
"time"
sq "github.com/Masterminds/squirrel"
"github.com/oklog/ulid/v2"
"github.com/pressly/goose/v3"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/build"
"github.com/openfga/openfga/pkg/encoder"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/storage"
tupleUtils "github.com/openfga/openfga/pkg/tuple"
)
var tracer = otel.Tracer("pkg/storage/sqlcommon")
// Config defines the configuration parameters
// for setting up and managing a sql connection.
type Config struct {
SecondaryURI string
Username string
Password string
SecondaryUsername string
SecondaryPassword string
Logger logger.Logger
MaxTuplesPerWriteField int
MaxTypesPerModelField int
MaxOpenConns int
MinOpenConns int
MaxIdleConns int
MinIdleConns int
ConnMaxIdleTime time.Duration
ConnMaxLifetime time.Duration
ExportMetrics bool
}
// DatastoreOption defines a function type
// used for configuring a Config object.
type DatastoreOption func(*Config)
// WithSecondaryURI returns a DatastoreOption that sets the secondary URI in the Config.
func WithSecondaryURI(uri string) DatastoreOption {
return func(config *Config) {
config.SecondaryURI = uri
}
}
// WithUsername returns a DatastoreOption that sets the username in the Config.
func WithUsername(username string) DatastoreOption {
return func(config *Config) {
config.Username = username
}
}
// WithPassword returns a DatastoreOption that sets the password in the Config.
func WithPassword(password string) DatastoreOption {
return func(config *Config) {
config.Password = password
}
}
// WithSecondaryUsername returns a DatastoreOption that sets the secondary username in the Config.
func WithSecondaryUsername(username string) DatastoreOption {
return func(config *Config) {
config.SecondaryUsername = username
}
}
// WithSecondaryPassword returns a DatastoreOption that sets the secondary password in the Config.
func WithSecondaryPassword(password string) DatastoreOption {
return func(config *Config) {
config.SecondaryPassword = password
}
}
// WithLogger returns a DatastoreOption that sets the Logger in the Config.
func WithLogger(l logger.Logger) DatastoreOption {
return func(cfg *Config) {
cfg.Logger = l
}
}
// WithMaxTuplesPerWrite returns a DatastoreOption that sets
// the maximum number of tuples per write in the Config.
func WithMaxTuplesPerWrite(maxTuples int) DatastoreOption {
return func(cfg *Config) {
cfg.MaxTuplesPerWriteField = maxTuples
}
}
// WithMaxTypesPerAuthorizationModel returns a DatastoreOption that sets
// the maximum number of types per authorization model in the Config.
func WithMaxTypesPerAuthorizationModel(maxTypes int) DatastoreOption {
return func(cfg *Config) {
cfg.MaxTypesPerModelField = maxTypes
}
}
// WithMaxOpenConns returns a DatastoreOption that sets the
// maximum number of open connections in the Config.
func WithMaxOpenConns(c int) DatastoreOption {
return func(cfg *Config) {
cfg.MaxOpenConns = c
}
}
// WithMinOpenConns returns a DatastoreOption that sets the
// minimum number of open connections in the Config.
// This is only used by some SQL drivers (e.g., pgxpool that is used
// in PostgresSQL).
func WithMinOpenConns(c int) DatastoreOption {
return func(cfg *Config) {
cfg.MinOpenConns = c
}
}
// WithMaxIdleConns returns a DatastoreOption that sets the
// maximum number of idle connections in the Config.
func WithMaxIdleConns(c int) DatastoreOption {
return func(cfg *Config) {
cfg.MaxIdleConns = c
}
}
// WithMinIdleConns returns a DatastoreOption that sets the
// minimum number of idle connections in the Config.
// This is only used by some SQL drivers (e.g., pgxpool that is used
// in PostgresSQL).
func WithMinIdleConns(c int) DatastoreOption {
return func(cfg *Config) {
cfg.MinIdleConns = c
}
}
// WithConnMaxIdleTime returns a DatastoreOption that sets
// the maximum idle time for a connection in the Config.
func WithConnMaxIdleTime(d time.Duration) DatastoreOption {
return func(cfg *Config) {
cfg.ConnMaxIdleTime = d
}
}
// WithConnMaxLifetime returns a DatastoreOption that sets
// the maximum lifetime for a connection in the Config.
func WithConnMaxLifetime(d time.Duration) DatastoreOption {
return func(cfg *Config) {
cfg.ConnMaxLifetime = d
}
}
// WithMetrics returns a DatastoreOption that
// enables the export of metrics in the Config.
func WithMetrics() DatastoreOption {
return func(cfg *Config) {
cfg.ExportMetrics = true
}
}
// NewConfig creates a new Config instance with default values
// and applies any provided DatastoreOption modifications.
func NewConfig(opts ...DatastoreOption) *Config {
cfg := &Config{}
for _, opt := range opts {
opt(cfg)
}
if cfg.Logger == nil {
cfg.Logger = logger.NewNoopLogger()
}
if cfg.MaxTuplesPerWriteField == 0 {
cfg.MaxTuplesPerWriteField = storage.DefaultMaxTuplesPerWrite
}
if cfg.MaxTypesPerModelField == 0 {
cfg.MaxTypesPerModelField = storage.DefaultMaxTypesPerAuthorizationModel
}
return cfg
}
// ContToken represents a continuation token structure used in pagination.
type ContToken struct {
Ulid string `json:"ulid"`
ObjectType string `json:"ObjectType"`
}
// NewContToken creates a new instance of ContToken
// with the provided ULID and object type.
func NewContToken(ulid, objectType string) *ContToken {
return &ContToken{
Ulid: ulid,
ObjectType: objectType,
}
}
// MarshallContToken takes a ContToken struct and attempts to marshal it into a string.
func NewSQLContinuationTokenSerializer() encoder.ContinuationTokenSerializer {
return &SQLContinuationTokenSerializer{}
}
type SQLContinuationTokenSerializer struct{}
func (s *SQLContinuationTokenSerializer) Serialize(ulid string, objType string) ([]byte, error) {
if ulid == "" {
return nil, errors.New("empty ulid provided for continuation token")
}
return json.Marshal(NewContToken(ulid, objType))
}
func (s *SQLContinuationTokenSerializer) Deserialize(continuationToken string) (ulid string, objType string, err error) {
var token ContToken
if err := json.Unmarshal([]byte(continuationToken), &token); err != nil {
return "", "", storage.ErrInvalidContinuationToken
}
return token.Ulid, token.ObjectType, nil
}
// SQLIteratorRowGetter is an interface for retrieving rows from a SQL query.
// Implementations should provide the GetRows method, which executes a query
// and returns a Rows object for iteration.
//
// GetRows executes the query and returns the resulting Rows or an error.
type SQLIteratorRowGetter interface {
GetRows(context.Context) (Rows, error)
}
type SBIteratorQuery struct {
sb sq.SelectBuilder
}
func NewSBIteratorQuery(sb sq.SelectBuilder) *SBIteratorQuery {
return &SBIteratorQuery{sb: sb}
}
func (q *SBIteratorQuery) GetRows(ctx context.Context) (Rows, error) {
return q.sb.QueryContext(ctx)
}
// Rows is an interface that abstracts the iteration over SQL query results.
// It provides methods to close the result set, check for errors, advance to the next row,
// and scan the current row's columns into provided destinations.
// It is intended as a subset of *sql.Rows to facilitate compatibility with *pgx.Rows as well.
//
// Methods:
// - Close(): Closes the rows iterator and releases any resources.
// - Err(): Returns the error, if any, that was encountered during iteration.
// - Next(): Advances to the next row, returning true if there is another row available.
// - Scan(dest ...any): Scans the columns of the current row into the provided destination variables.
type Rows interface {
Close() error
Err() error
Next() bool
Scan(dest ...any) error
}
// SQLTupleIterator is a struct that implements the storage.TupleIterator
// interface for iterating over tuples fetched from a SQL database.
type SQLTupleIterator struct {
rows Rows // GUARDED_BY(mu)
handleSQLError errorHandlerFn
// firstRow is used as a temporary storage place if head is called.
// If firstRow is nil and Head is called, rows.Next() will return the first item and advance
// the iterator. Thus, we will need to store this first item so that future Head() and Next()
// will use this item instead. Otherwise, the first item will be lost.
firstRow *storage.TupleRecord // GUARDED_BY(mu)
mu sync.Mutex
rowGetter SQLIteratorRowGetter
}
// Ensures that SQLTupleIterator implements the TupleIterator interface.
var _ storage.TupleIterator = (*SQLTupleIterator)(nil)
// sqlIteratorColumns required for the SQL tuple iterator scanner.
var sqlIteratorColumns = []string{
"store",
"object_type",
"object_id",
"relation",
"_user",
"condition_name",
"condition_context",
"ulid",
"inserted_at",
}
// SQLIteratorColumns returns the columns used in the SQL tuple iterator.
func SQLIteratorColumns() []string {
return sqlIteratorColumns
}
// NewSQLTupleIterator returns a SQL tuple iterator.
func NewSQLTupleIterator(rowGetter SQLIteratorRowGetter, errHandler errorHandlerFn) *SQLTupleIterator {
return &SQLTupleIterator{
rows: nil,
handleSQLError: errHandler,
firstRow: nil,
mu: sync.Mutex{},
rowGetter: rowGetter,
}
}
func (t *SQLTupleIterator) fetchBuffer(ctx context.Context) error {
ctx, span := tracer.Start(ctx, "sqlcommon.fetchBuffer", trace.WithAttributes())
defer span.End()
ctx = context.WithoutCancel(ctx)
curRows, err := t.rowGetter.GetRows(ctx)
if err != nil {
return t.handleSQLError(err)
}
t.rows = curRows
return nil
}
func (t *SQLTupleIterator) next(ctx context.Context) (*storage.TupleRecord, error) {
t.mu.Lock()
if t.rows == nil {
if err := t.fetchBuffer(ctx); err != nil {
t.mu.Unlock()
return nil, err
}
}
if t.firstRow != nil {
// If head was called previously, we don't need to scan / next
// again as the data is already there and the internal iterator would be advanced via `t.rows.Next()`.
// Calling t.rows.Next() in this case would lose the first row data.
//
// For example, let's say there are 3 items [1,2,3]
// If we called Head() and t.firstRow is empty, the rows will only be left with [2,3].
// Thus, we will need to save item [1] in firstRow. This allows future next() and head() to consume
// [1] first.
// If head() was not called, t.firstRow would be nil and we can follow the t.rows.Next() logic below.
firstRow := t.firstRow
t.firstRow = nil
t.mu.Unlock()
return firstRow, nil
}
if !t.rows.Next() {
err := t.rows.Err()
t.mu.Unlock()
if err != nil {
return nil, t.handleSQLError(err)
}
return nil, storage.ErrIteratorDone
}
var conditionName sql.NullString
var conditionContext []byte
var record storage.TupleRecord
err := t.rows.Scan(
&record.Store,
&record.ObjectType,
&record.ObjectID,
&record.Relation,
&record.User,
&conditionName,
&conditionContext,
&record.Ulid,
&record.InsertedAt,
)
t.mu.Unlock()
if err != nil {
return nil, t.handleSQLError(err)
}
record.ConditionName = conditionName.String
if conditionContext != nil {
var conditionContextStruct structpb.Struct
if err := proto.Unmarshal(conditionContext, &conditionContextStruct); err != nil {
return nil, err
}
record.ConditionContext = &conditionContextStruct
}
return &record, nil
}
func (t *SQLTupleIterator) head(ctx context.Context) (*storage.TupleRecord, error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.rows == nil {
if err := t.fetchBuffer(ctx); err != nil {
return nil, err
}
}
if t.firstRow != nil {
// If head was called previously, we don't need to scan / next
// again as the data is already there and the internal iterator would be advanced via `t.rows.Next()`.
// Calling t.rows.Next() in this case would lose the first row data.
//
// For example, let's say there are 3 items [1,2,3]
// If we called Head() and t.firstRow is empty, the rows will only be left with [2,3].
// Thus, we will need to save item [1] in firstRow. This allows future next() and head() to return
// [1] first. Note that for head(), we will not unset t.firstRow. Therefore, calling head() multiple times
// will yield the same result.
// If head() was not called, t.firstRow would be nil, and we can follow the t.rows.Next() logic below.
return t.firstRow, nil
}
if !t.rows.Next() {
if err := t.rows.Err(); err != nil {
return nil, t.handleSQLError(err)
}
return nil, storage.ErrIteratorDone
}
var conditionName sql.NullString
var conditionContext []byte
var record storage.TupleRecord
err := t.rows.Scan(
&record.Store,
&record.ObjectType,
&record.ObjectID,
&record.Relation,
&record.User,
&conditionName,
&conditionContext,
&record.Ulid,
&record.InsertedAt,
)
if err != nil {
return nil, t.handleSQLError(err)
}
record.ConditionName = conditionName.String
if conditionContext != nil {
var conditionContextStruct structpb.Struct
if err := proto.Unmarshal(conditionContext, &conditionContextStruct); err != nil {
return nil, err
}
record.ConditionContext = &conditionContextStruct
}
t.firstRow = &record
return &record, nil
}
// ToArray converts the tupleIterator to an []*openfgav1.Tuple and a possibly empty continuation token.
// If the continuation token exists it is the ulid of the last element of the returned array.
func (t *SQLTupleIterator) ToArray(ctx context.Context,
opts storage.PaginationOptions,
) ([]*openfgav1.Tuple, string, error) {
var res []*openfgav1.Tuple
for i := 0; i < opts.PageSize; i++ {
tupleRecord, err := t.next(ctx)
if err != nil {
if errors.Is(err, storage.ErrIteratorDone) {
return res, "", nil
}
return nil, "", err
}
res = append(res, tupleRecord.AsTuple())
}
// Check if we are at the end of the iterator.
// If we are then we do not need to return a continuation token.
// This is why we have LIMIT+1 in the query.
tupleRecord, err := t.next(ctx)
if err != nil {
if errors.Is(err, storage.ErrIteratorDone) {
return res, "", nil
}
return nil, "", err
}
return res, tupleRecord.Ulid, nil
}
// Next will return the next available item.
func (t *SQLTupleIterator) Next(ctx context.Context) (*openfgav1.Tuple, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
record, err := t.next(ctx)
if err != nil {
return nil, err
}
return record.AsTuple(), nil
}
// Head will return the first available item.
func (t *SQLTupleIterator) Head(ctx context.Context) (*openfgav1.Tuple, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
record, err := t.head(ctx)
if err != nil {
return nil, err
}
return record.AsTuple(), nil
}
// Stop terminates iteration.
func (t *SQLTupleIterator) Stop() {
t.mu.Lock()
defer t.mu.Unlock()
if t.rows != nil {
_ = t.rows.Close()
}
}
// DBInfo encapsulates DB information for use in common method.
type DBInfo struct {
stbl sq.StatementBuilderType
HandleSQLError errorHandlerFn
}
type errorHandlerFn func(error, ...interface{}) error
// NewDBInfo constructs a [DBInfo] object.
func NewDBInfo(stbl sq.StatementBuilderType, errorHandler errorHandlerFn, dialect string) *DBInfo {
if err := goose.SetDialect(dialect); err != nil {
panic("failed to set database dialect: " + err.Error())
}
return &DBInfo{
stbl: stbl,
HandleSQLError: errorHandler,
}
}
// TupleLockKey represents the composite key we lock on.
type TupleLockKey struct {
objectType string
objectID string
relation string
user string
userType string
}
// MakeTupleLockKeys flattens Deletes+writes into a deduped, sorted slice to ensure stable lock order.
func MakeTupleLockKeys(deletes storage.Deletes, writes storage.Writes) []TupleLockKey {
keys := make([]TupleLockKey, 0, len(deletes)+len(writes))
seen := make(map[string]struct{}, cap(keys))
add := func(tk *openfgav1.TupleKey) {
ot, oid := tupleUtils.SplitObject(tk.GetObject())
k := TupleLockKey{
objectType: ot,
objectID: oid,
relation: tk.GetRelation(),
user: tk.GetUser(),
userType: string(tupleUtils.GetUserTypeFromUser(tk.GetUser())),
}
s := strings.Join([]string{
k.objectType,
k.objectID,
k.relation,
k.user,
k.userType,
}, "\x00")
if _, ok := seen[s]; ok {
return
}
seen[s] = struct{}{}
keys = append(keys, k)
}
for _, tk := range deletes {
add(tupleUtils.TupleKeyWithoutConditionToTupleKey(tk))
}
for _, tk := range writes {
add(tk)
}
// Sort deterministically by the composite key to keep lock order stable.
sort.Slice(keys, func(i, j int) bool {
a, b := keys[i], keys[j]
if a.objectType != b.objectType {
return a.objectType < b.objectType
}
if a.objectID != b.objectID {
return a.objectID < b.objectID
}
if a.relation != b.relation {
return a.relation < b.relation
}
if a.user != b.user {
return a.user < b.user
}
return a.userType < b.userType
})
return keys
}
// BuildRowConstructorIN builds "((?,?,?,?,?),(?,?,?,?,?),...)" and arg list for row-constructor IN.
func BuildRowConstructorIN(keys []TupleLockKey) (string, []interface{}) {
if len(keys) == 0 {
return "", nil
}
var sb strings.Builder
args := make([]interface{}, 0, len(keys)*5)
sb.WriteByte('(')
for i, k := range keys {
if i > 0 {
sb.WriteByte(',')
}
sb.WriteString("(?,?,?,?,?)")
args = append(args, k.objectType, k.objectID, k.relation, k.user, k.userType)
}
sb.WriteByte(')')
return sb.String(), args
}
// selectExistingRowsForWrite selects existing rows for the given keys and locks them FOR UPDATE.
// The existing rows are added to the existing map.
func selectExistingRowsForWrite(ctx context.Context, dbInfo *DBInfo, store string, keys []TupleLockKey, txn *sql.Tx, existing map[string]*openfgav1.Tuple) error {
inExpr, args := BuildRowConstructorIN(keys)
selectBuilder := dbInfo.stbl.
Select(SQLIteratorColumns()...).
From("tuple").
Where(sq.Eq{"store": store}).
// Row-constructor IN on full composite key for precise point locks.
Where(sq.Expr("(object_type, object_id, relation, _user, user_type) IN "+inExpr, args...)).
Suffix("FOR UPDATE").
RunWith(txn) // make sure to run in the same transaction
iter := NewSQLTupleIterator(NewSBIteratorQuery(selectBuilder), dbInfo.HandleSQLError)
defer iter.Stop()
items, _, err := iter.ToArray(ctx, storage.PaginationOptions{PageSize: len(keys)})
if err != nil {
return err
}
for _, tuple := range items {
existing[tupleUtils.TupleKeyToString(tuple.GetKey())] = tuple
}
return nil
}
// GetDeleteWriteChangelogItems constructs the delete conditions, write items, and changelog items.
func GetDeleteWriteChangelogItems(
store string,
existing map[string]*openfgav1.Tuple,
writeData WriteData) (sq.Or, [][]interface{}, [][]interface{}, error) {
changeLogItems := make([][]interface{}, 0, len(writeData.Deletes)+len(writeData.Writes))
// ensures increasingly unique values within a single thread
entropy := ulid.DefaultEntropy()
deleteConditions := sq.Or{}
// 1. For Deletes
// a. If on_missing: error ( default behavior ):
// - Execute DELETEs as a single statement.
// On conflict ( row count != delete count ) - rollback & return an error
// b. If on_missing: ignore use the result from Step 3.a.
// - Based on the results from step 3.a, which identified and locked existing rows,
// the system will generate DELETE tuple and INSERT changelog statements only for those specific tuples
// - For rows that don’t exist in DB - ignore, no-op
// - Execute DELETEs as a single statement.
// On conflict ( row count != delete count ) - rollback & return a HTTP 409 Conflict error
for _, tk := range writeData.Deletes {
if _, ok := existing[tupleUtils.TupleKeyToString(tk)]; !ok {
// If the tuple does not exist, we can not delete it.
switch writeData.Opts.OnMissingDelete {
case storage.OnMissingDeleteIgnore:
continue
case storage.OnMissingDeleteError:
fallthrough
default:
return nil, nil, nil, storage.InvalidWriteInputError(
tk,
openfgav1.TupleOperation_TUPLE_OPERATION_DELETE,
)
}
}
id := ulid.MustNew(ulid.Timestamp(writeData.Now), entropy).String()
objectType, objectID := tupleUtils.SplitObject(tk.GetObject())
deleteConditions = append(deleteConditions, sq.Eq{
"object_type": objectType,
"object_id": objectID,
"relation": tk.GetRelation(),
"_user": tk.GetUser(),
"user_type": tupleUtils.GetUserTypeFromUser(tk.GetUser()),
})
changeLogItems = append(changeLogItems, []interface{}{
store,
objectType,
objectID,
tk.GetRelation(),
tk.GetUser(),
"",
nil, // Redact condition info for Deletes since we only need the base triplet (object, relation, user).
openfgav1.TupleOperation_TUPLE_OPERATION_DELETE,
id,
sq.Expr("NOW()"),
})
}
writeItems := make([][]interface{}, 0, len(writeData.Writes))
// 2. For writes
// a. If on_duplicate: error ( default behavior )
// - Execute INSERTs as a single statement.
// On duplicate insert we’d get a CONSTRAINT VIOLATION error, return 400 Bad Request
// b. If on_duplicate: ignore
// - Based on the results from step 3.a, which identified and locked existing rows, the system will compare values to the ones we’re trying to insert
// - On conflict ( values not identical ) - return an error 409 Conflict
// - For rows that DO NOT exist in DB - create both INSERT tuple & INSERT changelog statements
// c. Execute INSERTs as a single statement
// On error, return 409 Conflict
for _, tk := range writeData.Writes {
if existingTuple, ok := existing[tupleUtils.TupleKeyToString(tk)]; ok {
// If the tuple exists, we can not write it.
switch writeData.Opts.OnDuplicateInsert {
case storage.OnDuplicateInsertIgnore:
// If the tuple exists and the condition is the same, we can ignore it.
// We need to use its serialized text instead of reflect.DeepEqual to avoid comparing internal values.
if proto.Equal(existingTuple.GetKey().GetCondition(), tk.GetCondition()) {
continue
}
// If tuple conditions are different, we throw an error.
return nil, nil, nil, storage.TupleConditionConflictError(tk)
case storage.OnDuplicateInsertError:
fallthrough
default:
return nil, nil, nil, storage.InvalidWriteInputError(
tk,
openfgav1.TupleOperation_TUPLE_OPERATION_WRITE,
)
}
}
id := ulid.MustNew(ulid.Timestamp(writeData.Now), entropy).String()
objectType, objectID := tupleUtils.SplitObject(tk.GetObject())
conditionName, conditionContext, err := MarshalRelationshipCondition(tk.GetCondition())
if err != nil {
return nil, nil, nil, err
}
writeItems = append(writeItems, []interface{}{
store,
objectType,
objectID,
tk.GetRelation(),
tk.GetUser(),
tupleUtils.GetUserTypeFromUser(tk.GetUser()),
conditionName,
conditionContext,
id,
sq.Expr("NOW()"),
})
changeLogItems = append(changeLogItems, []interface{}{
store,
objectType,
objectID,
tk.GetRelation(),
tk.GetUser(),
conditionName,
conditionContext,
openfgav1.TupleOperation_TUPLE_OPERATION_WRITE,
id,
sq.Expr("NOW()"),
})
}
return deleteConditions, writeItems, changeLogItems, nil
}
type WriteData struct {
Deletes storage.Deletes
Writes storage.Writes
Opts storage.TupleWriteOptions
Now time.Time
}
// Write provides the common method for writing to database across sql storage.
func Write(
ctx context.Context,
dbInfo *DBInfo,
db *sql.DB,
store string,
writeData WriteData,
) error {
// 1. Begin Transaction ( Isolation Level = READ COMMITTED )
txn, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil {
return dbInfo.HandleSQLError(err)
}
defer func() { _ = txn.Rollback() }()
// 2. Compile a SELECT … FOR UPDATE statement to read the tuples for writes and lock tuples for Deletes
// Build a deduped, sorted list of keys to lock.
lockKeys := MakeTupleLockKeys(writeData.Deletes, writeData.Writes)
total := len(lockKeys)
if total == 0 {
// Nothing to do.
return nil
}
existing := make(map[string]*openfgav1.Tuple, total)
// 3. If list compiled in step 2 is not empty, execute SELECT … FOR UPDATE statement
for start := 0; start < total; start += storage.DefaultMaxTuplesPerWrite {
end := start + storage.DefaultMaxTuplesPerWrite
if end > total {
end = total
}
keys := lockKeys[start:end]
if err := selectExistingRowsForWrite(ctx, dbInfo, store, keys, txn, existing); err != nil {
return err
}
}
// 4. Construct the deleteConditions, write and changelog items to be written
deleteConditions, writeItems, changeLogItems, err := GetDeleteWriteChangelogItems(store, existing, writeData)
if err != nil {
return err
}
for start, totalDeletes := 0, len(deleteConditions); start < totalDeletes; start += storage.DefaultMaxTuplesPerWrite {
end := start + storage.DefaultMaxTuplesPerWrite
if end > totalDeletes {
end = totalDeletes
}
deleteConditionsBatch := deleteConditions[start:end]
res, err := dbInfo.stbl.Delete("tuple").Where(sq.Eq{"store": store}).
Where(deleteConditionsBatch).
RunWith(txn). // Part of a txn.
ExecContext(ctx)
if err != nil {
return dbInfo.HandleSQLError(err)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return dbInfo.HandleSQLError(err)
}
if rowsAffected != int64(len(deleteConditionsBatch)) {
// If we deleted fewer rows than planned (after read before write), means we hit a race condition - someone else deleted the same row(s).
return storage.ErrWriteConflictOnDelete
}
}
for start, totalWrites := 0, len(writeItems); start < totalWrites; start += storage.DefaultMaxTuplesPerWrite {
end := start + storage.DefaultMaxTuplesPerWrite
if end > totalWrites {
end = totalWrites
}
writesBatch := writeItems[start:end]
insertBuilder := dbInfo.stbl.
Insert("tuple").
Columns(
"store",
"object_type",
"object_id",
"relation",
"_user",
"user_type",
"condition_name",
"condition_context",
"ulid",
"inserted_at",
)
for _, item := range writesBatch {
insertBuilder = insertBuilder.Values(item...)
}
_, err = insertBuilder.
RunWith(txn). // Part of a txn.
ExecContext(ctx)
if err != nil {
dberr := dbInfo.HandleSQLError(err)
if errors.Is(dberr, storage.ErrCollision) {
// ErrCollision is returned on duplicate write (constraint violation), meaning we hit a race condition - someone else inserted the same row(s).
return storage.ErrWriteConflictOnInsert
}
return dberr
}
}
// 5. Execute INSERT changelog statements
for start, totalItems := 0, len(changeLogItems); start < totalItems; start += storage.DefaultMaxTuplesPerWrite {
end := start + storage.DefaultMaxTuplesPerWrite
if end > totalItems {
end = totalItems
}
changeLogBatch := changeLogItems[start:end]
changelogBuilder := dbInfo.stbl.
Insert("changelog").
Columns(
"store",
"object_type",
"object_id",
"relation",
"_user",
"condition_name",
"condition_context",
"operation",
"ulid",
"inserted_at",
)
for _, item := range changeLogBatch {
changelogBuilder = changelogBuilder.Values(item...)
}
_, err = changelogBuilder.RunWith(txn).ExecContext(ctx) // Part of a txn.
if err != nil {
return dbInfo.HandleSQLError(err)
}
}
// 6. Commit Transaction
if err := txn.Commit(); err != nil {
return dbInfo.HandleSQLError(err)
}
return nil
}
// WriteAuthorizationModel writes an authorization model for the given store in one row.
func WriteAuthorizationModel(
ctx context.Context,
dbInfo *DBInfo,
store string,
model *openfgav1.AuthorizationModel,
) error {
schemaVersion := model.GetSchemaVersion()
typeDefinitions := model.GetTypeDefinitions()
if len(typeDefinitions) < 1 {
return nil
}
pbdata, err := proto.Marshal(model)
if err != nil {
return err
}
_, err = dbInfo.stbl.
Insert("authorization_model").
Columns("store", "authorization_model_id", "schema_version", "type", "type_definition", "serialized_protobuf").
Values(store, model.GetId(), schemaVersion, "", nil, pbdata).
ExecContext(ctx)
if err != nil {
return dbInfo.HandleSQLError(err)
}
return nil
}
// ConstructAuthorizationModelFromSQLRows tries first to read and return a model that was written in one row (the new format).
// If it can't find one, it will then look for a model that was written across multiple rows (the old format).
func ConstructAuthorizationModelFromSQLRows(rows Rows) (*openfgav1.AuthorizationModel, error) {
var modelID string
var schemaVersion string
var typeDefs []*openfgav1.TypeDefinition
if rows.Next() {
var typeName string
var marshalledTypeDef []byte
var marshalledModel []byte
err := rows.Scan(&modelID, &schemaVersion, &typeName, &marshalledTypeDef, &marshalledModel)
if err != nil {
return nil, err
}
if len(marshalledModel) > 0 {
// Prefer building an authorization model from the first row that has it available.
var model openfgav1.AuthorizationModel
if err := proto.Unmarshal(marshalledModel, &model); err != nil {
return nil, err
}
return &model, nil
}
var typeDef openfgav1.TypeDefinition
if err := proto.Unmarshal(marshalledTypeDef, &typeDef); err != nil {
return nil, err
}
typeDefs = append(typeDefs, &typeDef)
}
for rows.Next() {
var scannedModelID string
var typeName string
var marshalledTypeDef []byte
var marshalledModel []byte
err := rows.Scan(&scannedModelID, &schemaVersion, &typeName, &marshalledTypeDef, &marshalledModel)
if err != nil {
return nil, err
}
if scannedModelID != modelID {
break
}
var typeDef openfgav1.TypeDefinition
if err := proto.Unmarshal(marshalledTypeDef, &typeDef); err != nil {
return nil, err
}
typeDefs = append(typeDefs, &typeDef)
}
if err := rows.Err(); err != nil {
return nil, err
}
if len(typeDefs) == 0 {
return nil, storage.ErrNotFound
}
return &openfgav1.AuthorizationModel{
SchemaVersion: schemaVersion,
Id: modelID,
TypeDefinitions: typeDefs,
// Conditions don't exist in the old data format
}, nil
}
// FindLatestAuthorizationModel reads the latest authorization model corresponding to the store.
func FindLatestAuthorizationModel(
ctx context.Context,
dbInfo *DBInfo,
store string,
) (*openfgav1.AuthorizationModel, error) {
rows, err := dbInfo.stbl.
Select("authorization_model_id", "schema_version", "type", "type_definition", "serialized_protobuf").
From("authorization_model").
Where(sq.Eq{"store": store}).
OrderBy("authorization_model_id desc").
QueryContext(ctx)
if err != nil {
return nil, dbInfo.HandleSQLError(err)
}
defer rows.Close()
ret, err := ConstructAuthorizationModelFromSQLRows(rows)
if err != nil {
return nil, dbInfo.HandleSQLError(err)
}
return ret, nil
}
// ReadAuthorizationModel reads the model corresponding to store and model ID.
func ReadAuthorizationModel(
ctx context.Context,
dbInfo *DBInfo,
store, modelID string,
) (*openfgav1.AuthorizationModel, error) {
rows, err := dbInfo.stbl.
Select("authorization_model_id", "schema_version", "type", "type_definition", "serialized_protobuf").
From("authorization_model").
Where(sq.Eq{
"store": store,
"authorization_model_id": modelID,
}).
QueryContext(ctx)
if err != nil {
return nil, dbInfo.HandleSQLError(err)
}
defer rows.Close()
ret, err := ConstructAuthorizationModelFromSQLRows(rows)
if err != nil {
return nil, dbInfo.HandleSQLError(err)
}
return ret, nil
}
// IsVersionReady checks if the database schema revision is at least the minimum supported revision.
// The passed in context should have a timeout.
func IsVersionReady(ctx context.Context, skipVersionCheck bool, db *sql.DB) (storage.ReadinessStatus, error) {
if skipVersionCheck {
return storage.ReadinessStatus{
IsReady: true,
}, nil
}
revision, err := goose.GetDBVersionContext(ctx, db)
if err != nil {
return storage.ReadinessStatus{}, err
}
if revision < build.MinimumSupportedDatastoreSchemaRevision {
return storage.ReadinessStatus{
Message: "datastore requires migrations: at revision '" +
strconv.FormatInt(revision, 10) +
"', but requires '" +
strconv.FormatInt(build.MinimumSupportedDatastoreSchemaRevision, 10) +
"'. Run 'openfga migrate'.",
IsReady: false,
}, nil
}
return storage.ReadinessStatus{
IsReady: true,
}, nil
}
// IsReady returns true if connection to datastore is successful AND
// (the datastore has the latest migration applied OR skipVersionCheck).
func IsReady(ctx context.Context, skipVersionCheck bool, db *sql.DB) (storage.ReadinessStatus, error) {
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
// do ping first to ensure we have better error message
// if error is due to connection issue.
if pingErr := db.PingContext(ctx); pingErr != nil {
return storage.ReadinessStatus{}, pingErr
}
return IsVersionReady(ctx, skipVersionCheck, db)
}
func AddFromUlid(sb sq.SelectBuilder, fromUlid string, sortDescending bool) sq.SelectBuilder {
if sortDescending {
return sb.Where(sq.Lt{"ulid": fromUlid})
}
return sb.Where(sq.Gt{"ulid": fromUlid})
}
package sqlite
import (
"context"
"database/sql"
"errors"
"fmt"
"net/url"
"sort"
"strings"
"time"
sq "github.com/Masterminds/squirrel"
"github.com/oklog/ulid/v2"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"modernc.org/sqlite"
sqlite3 "modernc.org/sqlite/lib"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/sqlcommon"
tupleUtils "github.com/openfga/openfga/pkg/tuple"
)
var tracer = otel.Tracer("openfga/pkg/storage/sqlite")
func startTrace(ctx context.Context, name string) (context.Context, trace.Span) {
return tracer.Start(ctx, "sqlite."+name)
}
var tupleColumns = []string{
"store", "object_type", "object_id", "relation",
"user_object_type", "user_object_id", "user_relation",
"condition_name", "condition_context", "ulid", "inserted_at",
}
// Datastore provides a SQLite based implementation of [storage.OpenFGADatastore].
type Datastore struct {
stbl sq.StatementBuilderType
db *sql.DB
dbInfo *sqlcommon.DBInfo
logger logger.Logger
dbStatsCollector prometheus.Collector
maxTuplesPerWriteField int
maxTypesPerModelField int
versionReady bool
}
// Ensures that SQLite implements the OpenFGADatastore interface.
var _ storage.OpenFGADatastore = (*Datastore)(nil)
// PrepareDSN Prepare a raw DSN from config for use with SQLite, specifying defaults for journal mode and busy timeout.
func PrepareDSN(uri string) (string, error) {
// Set journal mode and busy timeout pragmas if not specified.
query := url.Values{}
var err error
if i := strings.Index(uri, "?"); i != -1 {
query, err = url.ParseQuery(uri[i+1:])
if err != nil {
return uri, fmt.Errorf("error parsing dsn: %w", err)
}
uri = uri[:i]
}
foundJournalMode := false
foundBusyTimeout := false
for _, val := range query["_pragma"] {
if strings.HasPrefix(val, "journal_mode") {
foundJournalMode = true
} else if strings.HasPrefix(val, "busy_timeout") {
foundBusyTimeout = true
}
}
if !foundJournalMode {
query.Add("_pragma", "journal_mode(WAL)")
}
if !foundBusyTimeout {
query.Add("_pragma", "busy_timeout(100)")
}
// Set transaction mode to immediate if not specified
if !query.Has("_txlock") {
query.Set("_txlock", "immediate")
}
uri += "?" + query.Encode()
return uri, nil
}
// New creates a new [Datastore] storage.
func New(uri string, cfg *sqlcommon.Config) (*Datastore, error) {
uri, err := PrepareDSN(uri)
if err != nil {
return nil, err
}
db, err := sql.Open("sqlite", uri)
if err != nil {
return nil, fmt.Errorf("initialize sqlite connection: %w", err)
}
return NewWithDB(db, cfg)
}
// NewWithDB creates a new [Datastore] storage with the provided database connection.
func NewWithDB(db *sql.DB, cfg *sqlcommon.Config) (*Datastore, error) {
var collector prometheus.Collector
if cfg.ExportMetrics {
collector = collectors.NewDBStatsCollector(db, "openfga")
if err := prometheus.Register(collector); err != nil {
return nil, fmt.Errorf("initialize metrics: %w", err)
}
}
stbl := sq.StatementBuilder.RunWith(db)
dbInfo := sqlcommon.NewDBInfo(stbl, HandleSQLError, "sqlite")
return &Datastore{
stbl: stbl,
db: db,
dbInfo: dbInfo,
logger: cfg.Logger,
dbStatsCollector: collector,
maxTuplesPerWriteField: cfg.MaxTuplesPerWriteField,
maxTypesPerModelField: cfg.MaxTypesPerModelField,
versionReady: false,
}, nil
}
// Close see [storage.OpenFGADatastore].Close.
func (s *Datastore) Close() {
if s.dbStatsCollector != nil {
prometheus.Unregister(s.dbStatsCollector)
}
_ = s.db.Close()
}
// Read see [storage.RelationshipTupleReader].Read.
func (s *Datastore) Read(
ctx context.Context,
store string,
filter storage.ReadFilter,
_ storage.ReadOptions,
) (storage.TupleIterator, error) {
ctx, span := startTrace(ctx, "Read")
defer span.End()
return s.read(ctx, store, filter, nil)
}
// ReadPage see [storage.RelationshipTupleReader].ReadPage.
func (s *Datastore) ReadPage(ctx context.Context, store string, filter storage.ReadFilter, options storage.ReadPageOptions) ([]*openfgav1.Tuple, string, error) {
ctx, span := startTrace(ctx, "ReadPage")
defer span.End()
iter, err := s.read(ctx, store, filter, &options)
if err != nil {
return nil, "", err
}
defer iter.Stop()
return iter.ToArray(ctx, options.Pagination)
}
func (s *Datastore) read(ctx context.Context, store string, filter storage.ReadFilter, options *storage.ReadPageOptions) (*SQLTupleIterator, error) {
_, span := startTrace(ctx, "read")
defer span.End()
sb := s.stbl.
Select(
"store", "object_type", "object_id", "relation",
"user_object_type", "user_object_id", "user_relation",
"condition_name", "condition_context", "ulid", "inserted_at",
).
From("tuple").
Where(sq.Eq{"store": store})
if options != nil {
sb = sb.OrderBy("ulid")
}
objectType, objectID := tupleUtils.SplitObject(filter.Object)
if objectType != "" {
sb = sb.Where(sq.Eq{"object_type": objectType})
}
if objectID != "" {
sb = sb.Where(sq.Eq{"object_id": objectID})
}
if filter.Relation != "" {
sb = sb.Where(sq.Eq{"relation": filter.Relation})
}
if filter.User != "" {
userObjectType, userObjectID, userRelation := tupleUtils.ToUserParts(filter.User)
if userObjectType != "" {
sb = sb.Where(sq.Eq{
"user_object_type": userObjectType,
})
}
if userObjectID != "" {
sb = sb.Where(sq.Eq{
"user_object_id": userObjectID,
})
}
if userRelation != "" {
sb = sb.Where(sq.Eq{
"user_relation": userRelation,
})
}
}
if len(filter.Conditions) > 0 {
// Use COALESCE to treat NULL and '' as the same value (empty string).
// This allows filtering for "no condition" (e.g., filter.Conditions = [""])
// to correctly match rows where condition_name is either '' OR NULL.
sb = sb.Where(sq.Eq{"COALESCE(condition_name, '')": filter.Conditions})
}
if options != nil && options.Pagination.From != "" {
token := options.Pagination.From
sb = sb.Where(sq.GtOrEq{"ulid": token})
}
if options != nil && options.Pagination.PageSize != 0 {
sb = sb.Limit(uint64(options.Pagination.PageSize + 1)) // + 1 is used to determine whether to return a continuation token.
}
return NewSQLTupleIterator(sb, HandleSQLError), nil
}
// Write see [storage.RelationshipTupleWriter].Write.
func (s *Datastore) Write(
ctx context.Context,
store string,
deletes storage.Deletes,
writes storage.Writes,
opts ...storage.TupleWriteOption,
) error {
ctx, span := startTrace(ctx, "Write")
defer span.End()
return s.write(ctx, store, deletes, writes, storage.NewTupleWriteOptions(opts...), time.Now().UTC())
}
// tupleLockKey represents the composite key we lock on.
type tupleLockKey struct {
objectType string
objectID string
relation string
userObjectType string
userObjectID string
userRelation string
userType tupleUtils.UserType
}
// makeTupleLockKeys flattens deletes+writes into a deduped, sorted slice to ensure stable lock order.
func makeTupleLockKeys(deletes storage.Deletes, writes storage.Writes) []tupleLockKey {
keys := make([]tupleLockKey, 0, len(deletes)+len(writes))
seen := make(map[string]struct{}, cap(keys))
add := func(tk *openfgav1.TupleKey) {
objectType, objectID := tupleUtils.SplitObject(tk.GetObject())
userObjectType, userObjectID, userRelation := tupleUtils.ToUserParts(tk.GetUser())
k := tupleLockKey{
objectType: objectType,
objectID: objectID,
relation: tk.GetRelation(),
userObjectType: userObjectType,
userObjectID: userObjectID,
userRelation: userRelation,
userType: tupleUtils.GetUserTypeFromUser(tk.GetUser()),
}
s := strings.Join([]string{
k.objectType,
k.objectID,
k.relation,
k.userObjectType,
k.userObjectID,
k.userRelation,
string(k.userType),
}, "\x00")
if _, ok := seen[s]; ok {
return
}
seen[s] = struct{}{}
keys = append(keys, k)
}
for _, tk := range deletes {
add(tupleUtils.TupleKeyWithoutConditionToTupleKey(tk))
}
for _, tk := range writes {
add(tk)
}
// Sort deterministically by the composite key to keep lock order stable.
sort.Slice(keys, func(i, j int) bool {
a, b := keys[i], keys[j]
if a.objectType != b.objectType {
return a.objectType < b.objectType
}
if a.objectID != b.objectID {
return a.objectID < b.objectID
}
if a.relation != b.relation {
return a.relation < b.relation
}
if a.userObjectType != b.userObjectType {
return a.userObjectType < b.userObjectType
}
if a.userObjectID != b.userObjectID {
return a.userObjectID < b.userObjectID
}
if a.userRelation != b.userRelation {
return a.userRelation < b.userRelation
}
return a.userType < b.userType
})
return keys
}
// buildRowConstructorIN builds "((?,?,?,?,?,?,?),(?,?,?,?,?,?,?),...)" and arg list for row-constructor IN.
func buildRowConstructorIN(keys []tupleLockKey) (string, []interface{}) {
if len(keys) == 0 {
return "", nil
}
var sb strings.Builder
args := make([]interface{}, 0, len(keys)*7)
sb.WriteByte('(')
for i, k := range keys {
if i > 0 {
sb.WriteByte(',')
}
sb.WriteString("(?,?,?,?,?,?,?)")
args = append(args,
k.objectType,
k.objectID,
k.relation,
k.userObjectType,
k.userObjectID,
k.userRelation,
k.userType,
)
}
sb.WriteByte(')')
return sb.String(), args
}
// selectExistingRowsForWrite selects existing rows for the given keys and locks them FOR UPDATE.
// The existing rows are added to the existing map.
func (s *Datastore) selectExistingRowsForWrite(ctx context.Context, store string, keys []tupleLockKey, txn *sql.Tx, existing map[string]*openfgav1.Tuple) error {
inExpr, args := buildRowConstructorIN(keys)
selectBuilder := s.stbl.
Select(tupleColumns...).
Where(sq.Eq{"store": store}).
From("tuple").
// Row-constructor IN on full composite key for precise point locks.
Where(sq.Expr("(object_type, object_id, relation, user_object_type, user_object_id, user_relation, user_type) IN "+inExpr, args...)).
RunWith(txn) // make sure to run in the same transaction
iter := NewSQLTupleIterator(selectBuilder, HandleSQLError)
defer iter.Stop()
items, _, err := iter.ToArray(ctx, storage.PaginationOptions{PageSize: len(keys)})
if err != nil {
return err
}
for _, tuple := range items {
existing[tupleUtils.TupleKeyToString(tuple.GetKey())] = tuple
}
return nil
}
// Write provides the common method for writing to database across sql storage.
func (s *Datastore) write(
ctx context.Context,
store string,
deletes storage.Deletes,
writes storage.Writes,
opts storage.TupleWriteOptions,
now time.Time,
) error {
// 1. Begin Transaction ( Isolation Level = READ COMMITTED )
var txn *sql.Tx
err := busyRetry(func() error {
var err error
txn, err = s.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
return err
})
if err != nil {
return HandleSQLError(err)
}
defer func() {
_ = txn.Rollback()
}()
// 2. Compile a SELECT … FOR UPDATE statement to read the tuples for writes and lock tuples for deletes
// Build a deduped, sorted list of keys to lock.
lockKeys := makeTupleLockKeys(deletes, writes)
total := len(lockKeys)
if total == 0 {
// Nothing to do.
return nil
}
existing := make(map[string]*openfgav1.Tuple, total)
// 3. If list compiled in step 2 is not empty, execute SELECT … FOR UPDATE statement
for start := 0; start < total; start += storage.DefaultMaxTuplesPerWrite {
end := start + storage.DefaultMaxTuplesPerWrite
if end > total {
end = total
}
keys := lockKeys[start:end]
if err = s.selectExistingRowsForWrite(ctx, store, keys, txn, existing); err != nil {
return err
}
}
changeLogItems := make([][]interface{}, 0, len(deletes)+len(writes))
// ensures increasingly unique values within a single thread
entropy := ulid.DefaultEntropy()
deleteConditions := sq.Or{}
// 4. For deletes
// a. If on_missing: error ( default behavior ):
// - Execute DELETEs as a single statement.
// On conflict ( row count != delete count ) - rollback & return an error
// b. If on_missing: ignore use the result from Step 3.a.
// - Based on the results from step 3.a, which identified and locked existing rows,
// the system will generate DELETE tuple and INSERT changelog statements only for those specific tuples
// - For rows that don’t exist in DB - ignore, no-op
// - Execute DELETEs as a single statement.
// On conflict ( row count != delete count ) - rollback & return a HTTP 409 Conflict error
for _, tk := range deletes {
if _, ok := existing[tupleUtils.TupleKeyToString(tk)]; !ok {
// If the tuple does not exist, we can not delete it.
switch opts.OnMissingDelete {
case storage.OnMissingDeleteIgnore:
continue
case storage.OnMissingDeleteError:
fallthrough
default:
return storage.InvalidWriteInputError(
tk,
openfgav1.TupleOperation_TUPLE_OPERATION_DELETE,
)
}
}
id := ulid.MustNew(ulid.Timestamp(now), entropy).String()
objectType, objectID := tupleUtils.SplitObject(tk.GetObject())
userObjectType, userObjectID, userRelation := tupleUtils.ToUserParts(tk.GetUser())
deleteConditions = append(deleteConditions, sq.Eq{
"object_type": objectType,
"object_id": objectID,
"relation": tk.GetRelation(),
"user_object_type": userObjectType,
"user_object_id": userObjectID,
"user_relation": userRelation,
"user_type": tupleUtils.GetUserTypeFromUser(tk.GetUser()),
})
changeLogItems = append(changeLogItems, []interface{}{
store,
objectType,
objectID,
tk.GetRelation(),
userObjectType,
userObjectID,
userRelation,
"",
nil, // Redact condition info for deletes since we only need the base triplet (object, relation, user).
openfgav1.TupleOperation_TUPLE_OPERATION_DELETE,
id,
sq.Expr("datetime('subsec')"),
})
}
writeItems := make([][]interface{}, 0, len(writes))
// 5. For writes
// a. If on_duplicate: error ( default behavior )
// - Execute INSERTs as a single statement.
// On duplicate insert we’d get a CONSTRAINT VIOLATION error, return 400 Bad Request
// b. If on_duplicate: ignore
// - Based on the results from step 3.a, which identified and locked existing rows, the system will compare values to the ones we’re trying to insert
// - On conflict ( values not identical ) - return an error 409 Conflict
// - For rows that DO NOT exist in DB - create both INSERT tuple & INSERT changelog statements
// c. Execute INSERTs as a single statement
// On error, return 409 Conflict
for _, tk := range writes {
if existingTuple, ok := existing[tupleUtils.TupleKeyToString(tk)]; ok {
// If the tuple exists, we can not write it.
switch opts.OnDuplicateInsert {
case storage.OnDuplicateInsertIgnore:
// If the tuple exists and the condition is the same, we can ignore it.
// We need to use its serialized text instead of reflect.DeepEqual to avoid comparing internal values.
if proto.Equal(existingTuple.GetKey().GetCondition(), tk.GetCondition()) {
continue
}
// If tuple conditions are different, we throw an error.
return storage.TupleConditionConflictError(tk)
case storage.OnDuplicateInsertError:
fallthrough
default:
return storage.InvalidWriteInputError(
tk,
openfgav1.TupleOperation_TUPLE_OPERATION_WRITE,
)
}
}
id := ulid.MustNew(ulid.Timestamp(now), entropy).String()
objectType, objectID := tupleUtils.SplitObject(tk.GetObject())
userObjectType, userObjectID, userRelation := tupleUtils.ToUserParts(tk.GetUser())
conditionName, conditionContext, err := sqlcommon.MarshalRelationshipCondition(tk.GetCondition())
if err != nil {
return err
}
writeItems = append(writeItems, []interface{}{
store,
objectType,
objectID,
tk.GetRelation(),
userObjectType,
userObjectID,
userRelation,
tupleUtils.GetUserTypeFromUser(tk.GetUser()),
conditionName,
conditionContext,
id,
sq.Expr("datetime('subsec')"),
})
changeLogItems = append(changeLogItems, []interface{}{
store,
objectType,
objectID,
tk.GetRelation(),
userObjectType,
userObjectID,
userRelation,
conditionName,
conditionContext,
openfgav1.TupleOperation_TUPLE_OPERATION_WRITE,
id,
sq.Expr("datetime('subsec')"),
})
}
for start, totalDeletes := 0, len(deleteConditions); start < totalDeletes; start += storage.DefaultMaxTuplesPerWrite {
end := start + storage.DefaultMaxTuplesPerWrite
if end > totalDeletes {
end = totalDeletes
}
deleteConditionsBatch := deleteConditions[start:end]
res, err := s.stbl.Delete("tuple").Where(sq.Eq{"store": store}).
Where(deleteConditionsBatch).
RunWith(txn). // Part of a txn.
ExecContext(ctx)
if err != nil {
return HandleSQLError(err)
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return HandleSQLError(err)
}
if rowsAffected != int64(len(deleteConditionsBatch)) {
// If we deleted fewer rows than planned (after read before write), means we hit a race condition - someone else deleted the same row(s).
return storage.ErrWriteConflictOnDelete
}
}
for start, totalWrites := 0, len(writeItems); start < totalWrites; start += storage.DefaultMaxTuplesPerWrite {
end := start + storage.DefaultMaxTuplesPerWrite
if end > totalWrites {
end = totalWrites
}
writesBatch := writeItems[start:end]
insertBuilder := s.stbl.
Insert("tuple").
Columns(
"store",
"object_type",
"object_id",
"relation",
"user_object_type",
"user_object_id",
"user_relation",
"user_type",
"condition_name",
"condition_context",
"ulid",
"inserted_at",
)
for _, item := range writesBatch {
insertBuilder = insertBuilder.Values(item...)
}
_, err = insertBuilder.
RunWith(txn). // Part of a txn.
ExecContext(ctx)
if err != nil {
dberr := HandleSQLError(err)
if errors.Is(dberr, storage.ErrCollision) {
// ErrCollision is returned on duplicate write (constraint violation), meaning we hit a race condition - someone else inserted the same row(s).
return storage.ErrWriteConflictOnInsert
}
return dberr
}
}
// 6. Execute INSERT changelog statements
for start, totalItems := 0, len(changeLogItems); start < totalItems; start += storage.DefaultMaxTuplesPerWrite {
end := start + storage.DefaultMaxTuplesPerWrite
if end > totalItems {
end = totalItems
}
changeLogBatch := changeLogItems[start:end]
changelogBuilder := s.stbl.
Insert("changelog").
Columns(
"store",
"object_type",
"object_id",
"relation",
"user_object_type",
"user_object_id",
"user_relation",
"condition_name",
"condition_context",
"operation",
"ulid",
"inserted_at",
)
for _, item := range changeLogBatch {
changelogBuilder = changelogBuilder.Values(item...)
}
_, err = changelogBuilder.RunWith(txn).ExecContext(ctx) // Part of a txn.
if err != nil {
return HandleSQLError(err)
}
}
err = busyRetry(func() error {
return txn.Commit()
})
if err != nil {
return HandleSQLError(err)
}
return nil
}
// ReadUserTuple see [storage.RelationshipTupleReader].ReadUserTuple.
func (s *Datastore) ReadUserTuple(ctx context.Context, store string, tupleKey *openfgav1.TupleKey, _ storage.ReadUserTupleOptions) (*openfgav1.Tuple, error) {
ctx, span := startTrace(ctx, "ReadUserTuple")
defer span.End()
objectType, objectID := tupleUtils.SplitObject(tupleKey.GetObject())
userType := tupleUtils.GetUserTypeFromUser(tupleKey.GetUser())
userObjectType, userObjectID, userRelation := tupleUtils.ToUserParts(tupleKey.GetUser())
var conditionName sql.NullString
var conditionContext []byte
var record storage.TupleRecord
err := s.stbl.
Select(
"object_type", "object_id", "relation",
"user_object_type", "user_object_id", "user_relation",
"condition_name", "condition_context",
).
From("tuple").
Where(sq.Eq{
"store": store,
"object_type": objectType,
"object_id": objectID,
"relation": tupleKey.GetRelation(),
"user_object_type": userObjectType,
"user_object_id": userObjectID,
"user_relation": userRelation,
"user_type": userType,
}).
QueryRowContext(ctx).
Scan(
&record.ObjectType,
&record.ObjectID,
&record.Relation,
&record.UserObjectType,
&record.UserObjectID,
&record.UserRelation,
&conditionName,
&conditionContext,
)
if err != nil {
return nil, HandleSQLError(err)
}
if conditionName.String != "" {
record.ConditionName = conditionName.String
if conditionContext != nil {
var conditionContextStruct structpb.Struct
if err := proto.Unmarshal(conditionContext, &conditionContextStruct); err != nil {
return nil, err
}
record.ConditionContext = &conditionContextStruct
}
}
return record.AsTuple(), nil
}
// ReadUsersetTuples see [storage.RelationshipTupleReader].ReadUsersetTuples.
func (s *Datastore) ReadUsersetTuples(
ctx context.Context,
store string,
filter storage.ReadUsersetTuplesFilter,
_ storage.ReadUsersetTuplesOptions,
) (storage.TupleIterator, error) {
_, span := startTrace(ctx, "ReadUsersetTuples")
defer span.End()
sb := s.stbl.
Select(
"store", "object_type", "object_id", "relation",
"user_object_type", "user_object_id", "user_relation",
"condition_name", "condition_context", "ulid", "inserted_at",
).
From("tuple").
Where(sq.Eq{"store": store}).
Where(sq.Eq{"user_type": tupleUtils.UserSet})
objectType, objectID := tupleUtils.SplitObject(filter.Object)
if objectType != "" {
sb = sb.Where(sq.Eq{"object_type": objectType})
}
if objectID != "" {
sb = sb.Where(sq.Eq{"object_id": objectID})
}
if filter.Relation != "" {
sb = sb.Where(sq.Eq{"relation": filter.Relation})
}
if len(filter.AllowedUserTypeRestrictions) > 0 {
orConditions := sq.Or{}
for _, userset := range filter.AllowedUserTypeRestrictions {
if _, ok := userset.GetRelationOrWildcard().(*openfgav1.RelationReference_Relation); ok {
orConditions = append(orConditions, sq.Eq{
"user_object_type": userset.GetType(),
"user_relation": userset.GetRelation(),
})
}
if _, ok := userset.GetRelationOrWildcard().(*openfgav1.RelationReference_Wildcard); ok {
orConditions = append(orConditions, sq.Eq{
"user_object_type": userset.GetType(),
"user_object_id": "*",
})
}
}
sb = sb.Where(orConditions)
}
if len(filter.Conditions) > 0 {
sb = sb.Where(sq.Eq{"COALESCE(condition_name, '')": filter.Conditions})
}
return NewSQLTupleIterator(sb, HandleSQLError), nil
}
// ReadStartingWithUser see [storage.RelationshipTupleReader].ReadStartingWithUser.
func (s *Datastore) ReadStartingWithUser(
ctx context.Context,
store string,
filter storage.ReadStartingWithUserFilter,
_ storage.ReadStartingWithUserOptions,
) (storage.TupleIterator, error) {
_, span := startTrace(ctx, "ReadStartingWithUser")
defer span.End()
var targetUsersArg sq.Or
for _, u := range filter.UserFilter {
userObjectType, userObjectID, userRelation := tupleUtils.ToUserPartsFromObjectRelation(u)
targetUser := sq.Eq{
"user_object_type": userObjectType,
"user_object_id": userObjectID,
}
if userRelation != "" {
targetUser["user_relation"] = userRelation
}
targetUsersArg = append(targetUsersArg, targetUser)
}
builder := s.stbl.
Select(
"store", "object_type", "object_id", "relation",
"user_object_type", "user_object_id", "user_relation",
"condition_name", "condition_context", "ulid", "inserted_at",
).
From("tuple").
Where(sq.Eq{
"store": store,
"object_type": filter.ObjectType,
"relation": filter.Relation,
}).
Where(targetUsersArg).OrderBy("object_id")
if filter.ObjectIDs != nil && filter.ObjectIDs.Size() > 0 {
builder = builder.Where(sq.Eq{"object_id": filter.ObjectIDs.Values()})
}
if len(filter.Conditions) > 0 {
builder = builder.Where(sq.Eq{"COALESCE(condition_name, '')": filter.Conditions})
}
return NewSQLTupleIterator(builder, HandleSQLError), nil
}
// MaxTuplesPerWrite see [storage.RelationshipTupleWriter].MaxTuplesPerWrite.
func (s *Datastore) MaxTuplesPerWrite() int {
return s.maxTuplesPerWriteField
}
func constructAuthorizationModelFromSQLRows(rows *sql.Rows) (*openfgav1.AuthorizationModel, error) {
if rows.Next() {
var modelID string
var schemaVersion string
var marshalledModel []byte
err := rows.Scan(&modelID, &schemaVersion, &marshalledModel)
if err != nil {
return nil, HandleSQLError(err)
}
var model openfgav1.AuthorizationModel
if err := proto.Unmarshal(marshalledModel, &model); err != nil {
return nil, err
}
return &model, nil
}
if err := rows.Err(); err != nil {
return nil, HandleSQLError(err)
}
return nil, storage.ErrNotFound
}
// ReadAuthorizationModel see [storage.AuthorizationModelReadBackend].ReadAuthorizationModel.
func (s *Datastore) ReadAuthorizationModel(ctx context.Context, store string, modelID string) (*openfgav1.AuthorizationModel, error) {
ctx, span := startTrace(ctx, "ReadAuthorizationModel")
defer span.End()
rows, err := s.stbl.
Select("authorization_model_id", "schema_version", "serialized_protobuf").
From("authorization_model").
Where(sq.Eq{
"store": store,
"authorization_model_id": modelID,
}).
QueryContext(ctx)
if err != nil {
return nil, HandleSQLError(err)
}
defer rows.Close()
return constructAuthorizationModelFromSQLRows(rows)
}
// ReadAuthorizationModels see [storage.AuthorizationModelReadBackend].ReadAuthorizationModels.
func (s *Datastore) ReadAuthorizationModels(ctx context.Context, store string, options storage.ReadAuthorizationModelsOptions) ([]*openfgav1.AuthorizationModel, string, error) {
ctx, span := startTrace(ctx, "ReadAuthorizationModels")
defer span.End()
sb := s.stbl.
Select("authorization_model_id", "schema_version", "serialized_protobuf").
From("authorization_model").
Where(sq.Eq{"store": store}).
OrderBy("authorization_model_id desc")
if options.Pagination.From != "" {
sb = sb.Where(sq.LtOrEq{"authorization_model_id": options.Pagination.From})
}
if options.Pagination.PageSize > 0 {
sb = sb.Limit(uint64(options.Pagination.PageSize + 1)) // + 1 is used to determine whether to return a continuation token.
}
rows, err := sb.QueryContext(ctx)
if err != nil {
return nil, "", HandleSQLError(err)
}
defer rows.Close()
var modelID string
var schemaVersion string
var marshalledModel []byte
models := make([]*openfgav1.AuthorizationModel, 0, options.Pagination.PageSize)
var token string
for rows.Next() {
err = rows.Scan(&modelID, &schemaVersion, &marshalledModel)
if err != nil {
return nil, "", HandleSQLError(err)
}
if options.Pagination.PageSize > 0 && len(models) >= options.Pagination.PageSize {
return models, modelID, nil
}
var model openfgav1.AuthorizationModel
if err := proto.Unmarshal(marshalledModel, &model); err != nil {
return nil, "", err
}
models = append(models, &model)
}
if err := rows.Err(); err != nil {
return nil, "", HandleSQLError(err)
}
return models, token, nil
}
// FindLatestAuthorizationModel see [storage.AuthorizationModelReadBackend].FindLatestAuthorizationModel.
func (s *Datastore) FindLatestAuthorizationModel(ctx context.Context, store string) (*openfgav1.AuthorizationModel, error) {
ctx, span := startTrace(ctx, "FindLatestAuthorizationModel")
defer span.End()
rows, err := s.stbl.
Select("authorization_model_id", "schema_version", "serialized_protobuf").
From("authorization_model").
Where(sq.Eq{"store": store}).
OrderBy("authorization_model_id desc").
Limit(1).
QueryContext(ctx)
if err != nil {
return nil, HandleSQLError(err)
}
defer rows.Close()
return constructAuthorizationModelFromSQLRows(rows)
}
// MaxTypesPerAuthorizationModel see [storage.TypeDefinitionWriteBackend].MaxTypesPerAuthorizationModel.
func (s *Datastore) MaxTypesPerAuthorizationModel() int {
return s.maxTypesPerModelField
}
// WriteAuthorizationModel see [storage.TypeDefinitionWriteBackend].WriteAuthorizationModel.
func (s *Datastore) WriteAuthorizationModel(ctx context.Context, store string, model *openfgav1.AuthorizationModel) error {
ctx, span := startTrace(ctx, "WriteAuthorizationModel")
defer span.End()
schemaVersion := model.GetSchemaVersion()
typeDefinitions := model.GetTypeDefinitions()
if len(typeDefinitions) < 1 {
return nil
}
pbdata, err := proto.Marshal(model)
if err != nil {
return err
}
err = busyRetry(func() error {
_, err := s.stbl.
Insert("authorization_model").
Columns("store", "authorization_model_id", "schema_version", "serialized_protobuf").
Values(store, model.GetId(), schemaVersion, pbdata).
ExecContext(ctx)
return err
})
if err != nil {
return HandleSQLError(err)
}
return nil
}
// CreateStore adds a new store to storage.
func (s *Datastore) CreateStore(ctx context.Context, store *openfgav1.Store) (*openfgav1.Store, error) {
ctx, span := startTrace(ctx, "CreateStore")
defer span.End()
var id, name string
var createdAt, updatedAt time.Time
err := busyRetry(func() error {
return s.stbl.
Insert("store").
Columns("id", "name", "created_at", "updated_at").
Values(store.GetId(), store.GetName(), sq.Expr("datetime('subsec')"), sq.Expr("datetime('subsec')")).
Suffix("returning id, name, created_at, updated_at").
QueryRowContext(ctx).
Scan(&id, &name, &createdAt, &updatedAt)
})
if err != nil {
return nil, HandleSQLError(err)
}
return &openfgav1.Store{
Id: id,
Name: name,
CreatedAt: timestamppb.New(createdAt),
UpdatedAt: timestamppb.New(updatedAt),
}, nil
}
// GetStore retrieves the details of a specific store using its storeID.
func (s *Datastore) GetStore(ctx context.Context, id string) (*openfgav1.Store, error) {
ctx, span := startTrace(ctx, "GetStore")
defer span.End()
row := s.stbl.
Select("id", "name", "created_at", "updated_at").
From("store").
Where(sq.Eq{
"id": id,
"deleted_at": nil,
}).
QueryRowContext(ctx)
var storeID, name string
var createdAt, updatedAt time.Time
err := row.Scan(&storeID, &name, &createdAt, &updatedAt)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, storage.ErrNotFound
}
return nil, HandleSQLError(err)
}
return &openfgav1.Store{
Id: storeID,
Name: name,
CreatedAt: timestamppb.New(createdAt),
UpdatedAt: timestamppb.New(updatedAt),
}, nil
}
// ListStores provides a paginated list of all stores present in the storage.
func (s *Datastore) ListStores(ctx context.Context, options storage.ListStoresOptions) ([]*openfgav1.Store, string, error) {
ctx, span := startTrace(ctx, "ListStores")
defer span.End()
whereClause := sq.And{
sq.Eq{"deleted_at": nil},
}
if len(options.IDs) > 0 {
whereClause = append(whereClause, sq.Eq{"id": options.IDs})
}
if options.Name != "" {
whereClause = append(whereClause, sq.Eq{"name": options.Name})
}
if options.Pagination.From != "" {
whereClause = append(whereClause, sq.GtOrEq{"id": options.Pagination.From})
}
sb := s.stbl.
Select("id", "name", "created_at", "updated_at").
From("store").
Where(whereClause).
OrderBy("id")
if options.Pagination.PageSize > 0 {
sb = sb.Limit(uint64(options.Pagination.PageSize + 1)) // + 1 is used to determine whether to return a continuation token.
}
rows, err := sb.QueryContext(ctx)
if err != nil {
return nil, "", HandleSQLError(err)
}
defer rows.Close()
var stores []*openfgav1.Store
var id string
for rows.Next() {
var name string
var createdAt, updatedAt time.Time
err := rows.Scan(&id, &name, &createdAt, &updatedAt)
if err != nil {
return nil, "", HandleSQLError(err)
}
stores = append(stores, &openfgav1.Store{
Id: id,
Name: name,
CreatedAt: timestamppb.New(createdAt),
UpdatedAt: timestamppb.New(updatedAt),
})
}
if err := rows.Err(); err != nil {
return nil, "", HandleSQLError(err)
}
if len(stores) > options.Pagination.PageSize {
return stores[:options.Pagination.PageSize], id, nil
}
return stores, "", nil
}
// DeleteStore removes a store from storage.
func (s *Datastore) DeleteStore(ctx context.Context, id string) error {
ctx, span := startTrace(ctx, "DeleteStore")
defer span.End()
_, err := s.stbl.
Update("store").
Set("deleted_at", sq.Expr("datetime('subsec')")).
Where(sq.Eq{"id": id}).
ExecContext(ctx)
if err != nil {
return HandleSQLError(err)
}
return nil
}
// WriteAssertions see [storage.AssertionsBackend].WriteAssertions.
func (s *Datastore) WriteAssertions(ctx context.Context, store, modelID string, assertions []*openfgav1.Assertion) error {
ctx, span := startTrace(ctx, "WriteAssertions")
defer span.End()
marshalledAssertions, err := proto.Marshal(&openfgav1.Assertions{Assertions: assertions})
if err != nil {
return err
}
err = busyRetry(func() error {
_, err := s.stbl.
Insert("assertion").
Columns("store", "authorization_model_id", "assertions").
Values(store, modelID, marshalledAssertions).
Suffix("ON CONFLICT (store, authorization_model_id) DO UPDATE SET assertions = ?", marshalledAssertions).
ExecContext(ctx)
return err
})
if err != nil {
return HandleSQLError(err)
}
return nil
}
// ReadAssertions see [storage.AssertionsBackend].ReadAssertions.
func (s *Datastore) ReadAssertions(ctx context.Context, store, modelID string) ([]*openfgav1.Assertion, error) {
ctx, span := startTrace(ctx, "ReadAssertions")
defer span.End()
var marshalledAssertions []byte
err := s.stbl.
Select("assertions").
From("assertion").
Where(sq.Eq{
"store": store,
"authorization_model_id": modelID,
}).
QueryRowContext(ctx).
Scan(&marshalledAssertions)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return []*openfgav1.Assertion{}, nil
}
return nil, HandleSQLError(err)
}
var assertions openfgav1.Assertions
err = proto.Unmarshal(marshalledAssertions, &assertions)
if err != nil {
return nil, err
}
return assertions.GetAssertions(), nil
}
// ReadChanges see [storage.ChangelogBackend].ReadChanges.
func (s *Datastore) ReadChanges(ctx context.Context, store string, filter storage.ReadChangesFilter, options storage.ReadChangesOptions) ([]*openfgav1.TupleChange, string, error) {
ctx, span := startTrace(ctx, "ReadChanges")
defer span.End()
objectTypeFilter := filter.ObjectType
horizonOffset := filter.HorizonOffset
orderBy := "ulid asc"
if options.SortDesc {
orderBy = "ulid desc"
}
sb := s.stbl.
Select(
"ulid", "object_type", "object_id", "relation",
"user_object_type", "user_object_id", "user_relation",
"operation",
"condition_name", "condition_context", "inserted_at",
).
From("changelog").
Where(sq.Eq{"store": store}).
Where(fmt.Sprintf("inserted_at <= datetime('subsec','-%f seconds')", horizonOffset.Seconds())).
OrderBy(orderBy)
if objectTypeFilter != "" {
sb = sb.Where(sq.Eq{"object_type": objectTypeFilter})
}
if options.Pagination.From != "" {
sb = sqlcommon.AddFromUlid(sb, options.Pagination.From, options.SortDesc)
}
if options.Pagination.PageSize > 0 {
sb = sb.Limit(uint64(options.Pagination.PageSize)) // + 1 is NOT used here as we always return a continuation token.
}
rows, err := sb.QueryContext(ctx)
if err != nil {
return nil, "", HandleSQLError(err)
}
defer rows.Close()
var changes []*openfgav1.TupleChange
var ulid string
for rows.Next() {
var objectType, objectID, relation, userObjectType, userObjectID, userRelation string
var operation int
var insertedAt time.Time
var conditionName sql.NullString
var conditionContext []byte
err = rows.Scan(
&ulid,
&objectType,
&objectID,
&relation,
&userObjectType,
&userObjectID,
&userRelation,
&operation,
&conditionName,
&conditionContext,
&insertedAt,
)
if err != nil {
return nil, "", HandleSQLError(err)
}
var conditionContextStruct structpb.Struct
if conditionName.String != "" {
if conditionContext != nil {
if err := proto.Unmarshal(conditionContext, &conditionContextStruct); err != nil {
return nil, "", err
}
}
}
tk := tupleUtils.NewTupleKeyWithCondition(
tupleUtils.BuildObject(objectType, objectID),
relation,
tupleUtils.FromUserParts(userObjectType, userObjectID, userRelation),
conditionName.String,
&conditionContextStruct,
)
changes = append(changes, &openfgav1.TupleChange{
TupleKey: tk,
Operation: openfgav1.TupleOperation(operation),
Timestamp: timestamppb.New(insertedAt.UTC()),
})
}
if len(changes) == 0 {
return nil, "", storage.ErrNotFound
}
return changes, ulid, nil
}
// IsReady see [sqlcommon.IsReady].
func (s *Datastore) IsReady(ctx context.Context) (storage.ReadinessStatus, error) {
versionReady, err := sqlcommon.IsReady(ctx, s.versionReady, s.db)
if err != nil {
return versionReady, err
}
s.versionReady = versionReady.IsReady
return versionReady, nil
}
// HandleSQLError processes an SQL error and converts it into a more
// specific error type based on the nature of the SQL error.
func HandleSQLError(err error, args ...interface{}) error {
if errors.Is(err, sql.ErrNoRows) {
return storage.ErrNotFound
}
var sqliteErr *sqlite.Error
if errors.As(err, &sqliteErr) {
if sqliteErr.Code()&0xFF == sqlite3.SQLITE_CONSTRAINT {
if len(args) > 0 {
if tk, ok := args[0].(*openfgav1.TupleKey); ok {
return storage.InvalidWriteInputError(tk, openfgav1.TupleOperation_TUPLE_OPERATION_WRITE)
}
}
return storage.ErrCollision
}
}
return fmt.Errorf("sql error: %w", err)
}
// SQLite will return an SQLITE_BUSY error when the database is locked rather than waiting for the lock.
// This function retries the operation up to maxRetries times before returning the error.
func busyRetry(fn func() error) error {
const maxRetries = 10
for retries := 0; ; retries++ {
err := fn()
if err == nil {
return nil
}
if isBusyError(err) {
if retries < maxRetries {
continue
}
return fmt.Errorf("sqlite busy error after %d retries: %w", maxRetries, err)
}
return err
}
}
var busyErrors = map[int]struct{}{
sqlite3.SQLITE_BUSY_RECOVERY: {},
sqlite3.SQLITE_BUSY_SNAPSHOT: {},
sqlite3.SQLITE_BUSY_TIMEOUT: {},
sqlite3.SQLITE_BUSY: {},
sqlite3.SQLITE_LOCKED_SHAREDCACHE: {},
sqlite3.SQLITE_LOCKED: {},
}
func isBusyError(err error) bool {
var sqliteErr *sqlite.Error
if !errors.As(err, &sqliteErr) {
return false
}
_, ok := busyErrors[sqliteErr.Code()]
return ok
}
package sqlite
import (
"context"
"database/sql"
"errors"
"sync"
sq "github.com/Masterminds/squirrel"
"go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/storage"
)
type errorHandlerFn func(error, ...interface{}) error
// SQLTupleIterator is a struct that implements the storage.TupleIterator
// interface for iterating over tuples fetched from a SQL database.
type SQLTupleIterator struct {
rows *sql.Rows // GUARDED_BY(mu)
sb sq.SelectBuilder
handleSQLError errorHandlerFn
// firstRow is used as a temporary storage place if head is called.
// If firstRow is nil and Head is called, rows.Next() will return the first item and advance
// the iterator. Thus, we will need to store this first item so that future Head() and Next()
// will use this item instead. Otherwise, the first item will be lost.
firstRow *storage.TupleRecord // GUARDED_BY(mu)
mu sync.Mutex
}
// Ensures that SQLTupleIterator implements the TupleIterator interface.
var _ storage.TupleIterator = (*SQLTupleIterator)(nil)
// NewSQLTupleIterator returns a SQL tuple iterator.
func NewSQLTupleIterator(sb sq.SelectBuilder, errHandler errorHandlerFn) *SQLTupleIterator {
return &SQLTupleIterator{
sb: sb,
rows: nil,
handleSQLError: errHandler,
firstRow: nil,
mu: sync.Mutex{},
}
}
func (t *SQLTupleIterator) fetchBuffer(ctx context.Context) error {
ctx, span := tracer.Start(ctx, "sqlite.fetchBuffer", trace.WithAttributes())
defer span.End()
ctx = context.WithoutCancel(ctx)
rows, err := t.sb.QueryContext(ctx)
if err != nil {
return t.handleSQLError(err)
}
t.rows = rows
return nil
}
func (t *SQLTupleIterator) next(ctx context.Context) (*storage.TupleRecord, error) {
t.mu.Lock()
if t.rows == nil {
if err := t.fetchBuffer(ctx); err != nil {
t.mu.Unlock()
return nil, err
}
}
if t.firstRow != nil {
// If head was called previously, we don't need to scan / next
// again as the data is already there and the internal iterator would be advanced via `t.rows.Next()`.
// Calling t.rows.Next() in this case would lose the first row data.
//
// For example, let's say there are 3 items [1,2,3]
// If we called Head() and t.firstRow is empty, the rows will only be left with [2,3].
// Thus, we will need to save item [1] in firstRow. This allows future next() and head() to consume
// [1] first.
// If head() was not called, t.firstRow would be nil and we can follow the t.rows.Next() logic below.
firstRow := t.firstRow
t.firstRow = nil
t.mu.Unlock()
return firstRow, nil
}
if !t.rows.Next() {
err := t.rows.Err()
t.mu.Unlock()
if err != nil {
return nil, t.handleSQLError(err)
}
return nil, storage.ErrIteratorDone
}
var conditionName sql.NullString
var conditionContext []byte
var record storage.TupleRecord
err := t.rows.Scan(
&record.Store,
&record.ObjectType,
&record.ObjectID,
&record.Relation,
&record.UserObjectType,
&record.UserObjectID,
&record.UserRelation,
&conditionName,
&conditionContext,
&record.Ulid,
&record.InsertedAt,
)
t.mu.Unlock()
if err != nil {
return nil, t.handleSQLError(err)
}
record.ConditionName = conditionName.String
if conditionContext != nil {
var conditionContextStruct structpb.Struct
if err := proto.Unmarshal(conditionContext, &conditionContextStruct); err != nil {
return nil, err
}
record.ConditionContext = &conditionContextStruct
}
return &record, nil
}
func (t *SQLTupleIterator) head(ctx context.Context) (*storage.TupleRecord, error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.rows == nil {
if err := t.fetchBuffer(ctx); err != nil {
return nil, err
}
}
if t.firstRow != nil {
// If head was called previously, we don't need to scan / next
// again as the data is already there and the internal iterator would be advanced via `t.rows.Next()`.
// Calling t.rows.Next() in this case would lose the first row data.
//
// For example, let's say there are 3 items [1,2,3]
// If we called Head() and t.firstRow is empty, the rows will only be left with [2,3].
// Thus, we will need to save item [1] in firstRow. This allows future next() and head() to return
// [1] first. Note that for head(), we will not unset t.firstRow. Therefore, calling head() multiple times
// will yield the same result.
// If head() was not called, t.firstRow would be nil, and we can follow the t.rows.Next() logic below.
return t.firstRow, nil
}
if !t.rows.Next() {
if err := t.rows.Err(); err != nil {
return nil, t.handleSQLError(err)
}
return nil, storage.ErrIteratorDone
}
var conditionName sql.NullString
var conditionContext []byte
var record storage.TupleRecord
err := t.rows.Scan(
&record.Store,
&record.ObjectType,
&record.ObjectID,
&record.Relation,
&record.UserObjectType,
&record.UserObjectID,
&record.UserRelation,
&conditionName,
&conditionContext,
&record.Ulid,
&record.InsertedAt,
)
if err != nil {
return nil, t.handleSQLError(err)
}
record.ConditionName = conditionName.String
if conditionContext != nil {
var conditionContextStruct structpb.Struct
if err := proto.Unmarshal(conditionContext, &conditionContextStruct); err != nil {
return nil, err
}
record.ConditionContext = &conditionContextStruct
}
t.firstRow = &record
return &record, nil
}
// ToArray converts the tupleIterator to an []*openfgav1.Tuple and a possibly empty continuation token.
// If the continuation token exists it is the ulid of the last element of the returned array.
func (t *SQLTupleIterator) ToArray(
ctx context.Context,
opts storage.PaginationOptions,
) ([]*openfgav1.Tuple, string, error) {
var res []*openfgav1.Tuple
for i := 0; i < opts.PageSize; i++ {
tupleRecord, err := t.next(ctx)
if err != nil {
if errors.Is(err, storage.ErrIteratorDone) {
return res, "", nil
}
return nil, "", err
}
res = append(res, tupleRecord.AsTuple())
}
// Check if we are at the end of the iterator.
// If we are then we do not need to return a continuation token.
// This is why we have LIMIT+1 in the query.
tupleRecord, err := t.next(ctx)
if err != nil {
if errors.Is(err, storage.ErrIteratorDone) {
return res, "", nil
}
return nil, "", err
}
return res, tupleRecord.Ulid, nil
}
// Next will return the next available item.
func (t *SQLTupleIterator) Next(ctx context.Context) (*openfgav1.Tuple, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
record, err := t.next(ctx)
if err != nil {
return nil, err
}
return record.AsTuple(), nil
}
// Head will return the first available item.
func (t *SQLTupleIterator) Head(ctx context.Context) (*openfgav1.Tuple, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
record, err := t.head(ctx)
if err != nil {
return nil, err
}
return record.AsTuple(), nil
}
// Stop terminates iteration.
func (t *SQLTupleIterator) Stop() {
t.mu.Lock()
defer t.mu.Unlock()
if t.rows != nil {
_ = t.rows.Close()
}
}
//go:generate mockgen -source storage.go -destination ../../internal/mocks/mock_storage.go -package mocks OpenFGADatastore
package storage
import (
"context"
"time"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
)
type ctxKey string
const (
// DefaultMaxTuplesPerWrite specifies the default maximum number of tuples that can be written
// in a single write operation. This constant is used to limit the batch size in write operations
// to maintain performance and avoid overloading the system. The value is set to 100 tuples,
// which is a balance between efficiency and resource usage.
DefaultMaxTuplesPerWrite = 100
// DefaultMaxTypesPerAuthorizationModel defines the default upper limit on the number of distinct
// types that can be included in a single authorization model. This constraint helps in managing
// the complexity and ensuring the maintainability of the authorization models. The limit is
// set to 100 types, providing ample flexibility while keeping the model manageable.
DefaultMaxTypesPerAuthorizationModel = 100
// DefaultPageSize sets the default number of items to be returned in a single page when paginating
// through a set of results. This constant is used to standardize the pagination size across various
// parts of the system, ensuring a consistent and manageable volume of data per page. The default
// value is set to 50, balancing detail per page with the overall number of pages.
DefaultPageSize = 50
relationshipTupleReaderCtxKey ctxKey = "relationship-tuple-reader-context-key"
)
// ContextWithRelationshipTupleReader sets the provided [[RelationshipTupleReader]]
// in the context. The context returned is a new context derived from the parent
// context provided.
func ContextWithRelationshipTupleReader(
parent context.Context,
reader RelationshipTupleReader,
) context.Context {
return context.WithValue(parent, relationshipTupleReaderCtxKey, reader)
}
// RelationshipTupleReaderFromContext extracts a [[RelationshipTupleReader]] from the
// provided context (if any). If no such value is in the context a boolean false is returned,
// otherwise the RelationshipTupleReader is returned.
func RelationshipTupleReaderFromContext(ctx context.Context) (RelationshipTupleReader, bool) {
ctxValue := ctx.Value(relationshipTupleReaderCtxKey)
reader, ok := ctxValue.(RelationshipTupleReader)
return reader, ok
}
// PaginationOptions should not be instantiated directly. Use NewPaginationOptions.
type PaginationOptions struct {
PageSize int
// From is a continuation token that can be used to retrieve the next page of results. Its contents will depend on the API.
From string
}
// NewPaginationOptions creates a new [PaginationOptions] instance
// with a specified page size and continuation token. If the input page size is empty,
// it uses DefaultPageSize.
// The continuation token is used to retrieve the next page of results, OR the first page based on start time.
func NewPaginationOptions(ps int32, contToken string) PaginationOptions {
pageSize := DefaultPageSize
if ps > 0 {
pageSize = int(ps)
}
return PaginationOptions{
PageSize: pageSize,
From: contToken,
}
}
// ReadAuthorizationModelOptions represents the options that can
// be used with the ReadAuthorizationModels method.
type ReadAuthorizationModelsOptions struct {
Pagination PaginationOptions
}
// ListStoresOptions represents the options that can
// be used with the ListStores method.
type ListStoresOptions struct {
// IDs is a list of store IDs to filter the results.
IDs []string
// Name is used to filter the results. If left empty no filter is applied.
Name string
Pagination PaginationOptions
}
// ReadChangesOptions represents the options that can
// be used with the ReadChanges method.
type ReadChangesOptions struct {
Pagination PaginationOptions
SortDesc bool
}
// ReadPageOptions represents the options that can
// be used with the ReadPage method.
type ReadPageOptions struct {
Pagination PaginationOptions
Consistency ConsistencyOptions
}
// ConsistencyOptions represents the options that can
// be used for methods that accept a consistency preference.
type ConsistencyOptions struct {
Preference openfgav1.ConsistencyPreference
}
// ReadOptions represents the options that can
// be used with the Read method.
type ReadOptions struct {
Consistency ConsistencyOptions
}
// ReadUserTupleOptions represents the options that can
// be used with the ReadUserTuple method.
type ReadUserTupleOptions struct {
Consistency ConsistencyOptions
}
// ReadUsersetTuplesOptions represents the options that can
// be used with the ReadUsersetTuples method.
type ReadUsersetTuplesOptions struct {
Consistency ConsistencyOptions
}
// ReadStartingWithUserOptions represents the options that can
// be used with the ReadStartingWithUser method.
type ReadStartingWithUserOptions struct {
Consistency ConsistencyOptions
WithResultsSortedAscending bool
}
// Writes is a typesafe alias for Write arguments.
type Writes = []*openfgav1.TupleKey
// Deletes is a typesafe alias for Delete arguments.
type Deletes = []*openfgav1.TupleKeyWithoutCondition
// A TupleBackend provides a read/write interface for managing tuples.
type TupleBackend interface {
RelationshipTupleReader
RelationshipTupleWriter
}
// RelationshipTupleReader is an interface that defines the set of
// methods required to read relationship tuples from a data store.
type RelationshipTupleReader interface {
// Read the set of tuples associated with `store` and `tupleKey`, which may be nil or partially filled. If nil,
// Read will return an iterator over all the tuples in the given `store`. If the `tupleKey` is partially filled,
// it will return an iterator over those tuples which match the `tupleKey`. Note that at least one of `Object`
// or `User` (or both), must be specified in this case.
//
// The caller must be careful to close the [TupleIterator], either by consuming the entire iterator or by closing it.
// There is NO guarantee on the order of the tuples returned on the iterator.
Read(ctx context.Context, store string, filter ReadFilter, options ReadOptions) (TupleIterator, error)
// ReadPage functions similarly to Read but includes support for pagination. It takes
// mandatory ReadPageOptions options. PageSize will always be greater than zero.
// It returns a slice of tuples along with a continuation token. This token can be used for retrieving subsequent pages of data.
// There is NO guarantee on the order of the tuples in one page.
ReadPage(ctx context.Context, store string, filter ReadFilter, options ReadPageOptions) ([]*openfgav1.Tuple, string, error)
// ReadUserTuple tries to return one tuple that matches the provided key exactly.
// If none is found, it must return [ErrNotFound].
ReadUserTuple(
ctx context.Context,
store string,
tupleKey *openfgav1.TupleKey,
options ReadUserTupleOptions,
) (*openfgav1.Tuple, error)
// ReadUsersetTuples returns all userset tuples for a specified object and relation.
// For example, given the following relationship tuples:
// document:doc1, viewer, user:*
// document:doc1, viewer, group:eng#member
// and the filter
// object=document:1, relation=viewer, allowedTypesForUser=[group#member]
// this method would return the tuple (document:doc1, viewer, group:eng#member)
// If allowedTypesForUser is empty, both tuples would be returned.
// There is NO guarantee on the order returned on the iterator.
ReadUsersetTuples(
ctx context.Context,
store string,
filter ReadUsersetTuplesFilter,
options ReadUsersetTuplesOptions,
) (TupleIterator, error)
// ReadStartingWithUser performs a reverse read of relationship tuples starting at one or
// more user(s) or userset(s) and filtered by object type and relation and possibly a list of object IDs.
//
// For example, given the following relationship tuples:
// document:doc1, viewer, user:jon
// document:doc2, viewer, group:eng#member
// document:doc3, editor, user:jon
// document:doc4, viewer, group:eng#member
//
// ReadStartingWithUser for ['user:jon', 'group:eng#member'] filtered by 'document#viewer'
// and 'document:doc1, document:doc2' would
// return ['document:doc1#viewer@user:jon', 'document:doc2#viewer@group:eng#member'].
// If ReadStartingWithUserOptions.WithResultsSortedAscending bool is enabled, the tuples returned must be sorted by one or more fields in them.
ReadStartingWithUser(
ctx context.Context,
store string,
filter ReadStartingWithUserFilter,
options ReadStartingWithUserOptions,
) (TupleIterator, error)
}
// OnMissingDelete defines the behavior of delete operation when the tuple to be deleted does not exist.
type OnMissingDelete int32
// OnDuplicateInsert defines the behavior of insert operation when the tuple to be inserted already exists.
type OnDuplicateInsert int32
const (
// OnMissingDeleteError indicates that if a delete operation is attempted on a tuple that does
// not exist, an error should be returned.
OnMissingDeleteError OnMissingDelete = 0
// OnMissingDeleteIgnore indicates that if a delete operation is attempted on a tuple that does
// not exist, it should be ignored as no-op and no error should be returned.
OnMissingDeleteIgnore OnMissingDelete = 1
// OnDuplicateInsertError indicates that if an insert operation is attempted on a tuple that already exists,
// an error should be returned.
OnDuplicateInsertError OnDuplicateInsert = 0
// OnDuplicateInsertIgnore indicates that if an insert operation is attempted on a tuple that already exists,
// it should be ignored as a no-op and no error should be returned.
OnDuplicateInsertIgnore OnDuplicateInsert = 1
)
// TupleWriteOptions defines the options that can be used when writing tuples.
// It allows customization of the behavior when a delete operation is attempted on a tuple that does not
// exist, or when an insert operation is attempted on a tuple that already exists.
type TupleWriteOptions struct {
OnMissingDelete OnMissingDelete
OnDuplicateInsert OnDuplicateInsert
}
type TupleWriteOption func(*TupleWriteOptions)
func WithOnMissingDelete(onMissingDelete OnMissingDelete) TupleWriteOption {
return func(opts *TupleWriteOptions) {
opts.OnMissingDelete = onMissingDelete
}
}
func WithOnDuplicateInsert(onDuplicateInsert OnDuplicateInsert) TupleWriteOption {
return func(opts *TupleWriteOptions) {
opts.OnDuplicateInsert = onDuplicateInsert
}
}
func NewTupleWriteOptions(opts ...TupleWriteOption) TupleWriteOptions {
res := TupleWriteOptions{
OnMissingDelete: OnMissingDeleteError,
OnDuplicateInsert: OnDuplicateInsertError,
}
for _, opt := range opts {
opt(&res)
}
return res
}
// RelationshipTupleWriter is an interface that defines the set of methods
// required for writing relationship tuples in a data store.
type RelationshipTupleWriter interface {
// Write updates data in the tuple backend, performing all delete operations in
// `deletes` before adding new values in `writes`.
// It must also write to the changelog.
// If two concurrent requests attempt to write the same tuple at the same time, it must return ErrTransactionalWriteFailed. TODO write test
// If the tuple to be written already existed or the tuple to be deleted didn't exist, it must return InvalidWriteInputError. TODO write test
// opts are optional and can be used to customize the behavior of the write operation.
Write(ctx context.Context, store string, d Deletes, w Writes, opts ...TupleWriteOption) error
// MaxTuplesPerWrite returns the maximum number of items (writes and deletes combined)
// allowed in a single write transaction.
MaxTuplesPerWrite() int
}
// ReadStartingWithUserFilter specifies the filter options that will be used
// to constrain the [RelationshipTupleReader.ReadStartingWithUser] query.
type ReadStartingWithUserFilter struct {
// Mandatory.
ObjectType string
// Mandatory.
Relation string
// Mandatory.
UserFilter []*openfgav1.ObjectRelation
// Optional. It can be nil. If present, it will be sorted in ascending order.
// The datastore should return the intersection between this filter and what is in the database.
ObjectIDs SortedSet
// Optional. It can be nil. If present, it will be used to filter the results. Conditions can hold the empty value
Conditions []string
}
// ReadStartingWithUserFilter specifies the filter options that will be used
// to constrain the [RelationshipTupleReader.ReadStartingWithUser] query.
type ReadFilter struct {
// Mandatory.
Object string
// Mandatory.
Relation string
// Mandatory.
User string
// Optional. It can be nil. If present, it will be used to filter the results. Conditions can hold the empty value
Conditions []string
}
// ReadUsersetTuplesFilter specifies the filter options that
// will be used to constrain the ReadUsersetTuples query.
type ReadUsersetTuplesFilter struct {
Object string // Required.
Relation string // Required.
AllowedUserTypeRestrictions []*openfgav1.RelationReference // Optional.
Conditions []string // Optional. It can be nil. If present, it will be used to filter the results. Conditions can hold the empty value.
}
// AuthorizationModelReadBackend provides a read interface for managing type definitions.
type AuthorizationModelReadBackend interface {
// ReadAuthorizationModel reads the model corresponding to store and model ID.
// If it's not found, or if the model has zero types, it must return ErrNotFound.
ReadAuthorizationModel(ctx context.Context, store string, id string) (*openfgav1.AuthorizationModel, error)
// ReadAuthorizationModels reads all models for the supplied store and returns them in descending order of ULID (from newest to oldest).
// In addition to the models, it returns a continuation token that can be used to fetch the next page of results.
ReadAuthorizationModels(ctx context.Context, store string, options ReadAuthorizationModelsOptions) ([]*openfgav1.AuthorizationModel, string, error)
// FindLatestAuthorizationModel returns the last model for the store.
// If none were ever written, it must return ErrNotFound.
FindLatestAuthorizationModel(ctx context.Context, store string) (*openfgav1.AuthorizationModel, error)
}
// TypeDefinitionWriteBackend provides a write interface for managing typed definition.
type TypeDefinitionWriteBackend interface {
// MaxTypesPerAuthorizationModel returns the maximum number of type definition rows/items per model.
MaxTypesPerAuthorizationModel() int
// WriteAuthorizationModel writes an authorization model for the given store.
// If the model has zero types, the datastore may choose to do nothing and return no error.
WriteAuthorizationModel(ctx context.Context, store string, model *openfgav1.AuthorizationModel) error
}
// AuthorizationModelBackend provides an read/write interface for managing models and their type definitions.
type AuthorizationModelBackend interface {
AuthorizationModelReadBackend
TypeDefinitionWriteBackend
}
type StoresBackend interface {
// CreateStore must return an error if the store ID or the name aren't set. TODO write test.
// If the store ID already existed it must return ErrCollision.
CreateStore(ctx context.Context, store *openfgav1.Store) (*openfgav1.Store, error)
// DeleteStore must delete the store by either setting its DeletedAt field or removing the entry.
DeleteStore(ctx context.Context, id string) error
// GetStore must return ErrNotFound if the store is not found or its DeletedAt is set.
GetStore(ctx context.Context, id string) (*openfgav1.Store, error)
// ListStores returns a list of non-deleted stores that match the provided options.
// In addition to the stores, it returns a continuation token that can be used to fetch the next page of results.
// If no stores are found, it is expected to return an empty list and an empty continuation token.
ListStores(ctx context.Context, options ListStoresOptions) ([]*openfgav1.Store, string, error)
}
// AssertionsBackend is an interface that defines the set of methods for reading and writing assertions.
type AssertionsBackend interface {
// WriteAssertions overwrites the assertions for a store and modelID.
WriteAssertions(ctx context.Context, store, modelID string, assertions []*openfgav1.Assertion) error
// ReadAssertions returns the assertions for a store and modelID.
// If no assertions were ever written, it must return an empty list.
ReadAssertions(ctx context.Context, store, modelID string) ([]*openfgav1.Assertion, error)
}
type ReadChangesFilter struct {
ObjectType string
HorizonOffset time.Duration
}
// ChangelogBackend is an interface for interacting with and managing changelogs.
type ChangelogBackend interface {
// ReadChanges returns the writes and deletes that have occurred for tuples within a store,
// in the order that they occurred.
// You can optionally provide a filter to filter out changes for objects of a specific type.
// The horizonOffset should be specified using a unit no more granular than a millisecond.
// It should always return a ULID as a continuation token so readers can continue reading later, except the case where
// if no changes are found, it should return storage.ErrNotFound and an empty continuation token.
// It's important that the continuation token is a ULID, so it could be generated from timestamp.
ReadChanges(ctx context.Context, store string, filter ReadChangesFilter, options ReadChangesOptions) ([]*openfgav1.TupleChange, string, error)
}
// OpenFGADatastore is an interface that defines a set of methods for interacting
// with and managing data in an OpenFGA (Fine-Grained Authorization) system.
type OpenFGADatastore interface {
TupleBackend
AuthorizationModelBackend
StoresBackend
AssertionsBackend
ChangelogBackend
// IsReady reports whether the datastore is ready to accept traffic.
IsReady(ctx context.Context) (ReadinessStatus, error)
// Close closes the datastore and cleans up any residual resources.
Close()
}
// ReadinessStatus represents the readiness status of the datastore.
type ReadinessStatus struct {
// Message is a human-friendly status message for the current datastore status.
Message string
IsReady bool
}
package storagewrappers
import (
"context"
"sync/atomic"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/build"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/storagewrappers/storagewrappersutil"
)
const timeWaitingAttribute = "datastore_time_waiting"
const concurrentTimeWaitingThreshold = 1 * time.Millisecond
type StorageInstrumentation interface {
GetMetadata() Metadata
}
type Metadata struct {
DatastoreQueryCount uint32
DatastoreItemCount uint64
WasThrottled bool
}
type countingTupleIterator struct {
storage.TupleIterator
counter *atomic.Uint64
}
func (itr *countingTupleIterator) Next(ctx context.Context) (*openfgav1.Tuple, error) {
i, err := itr.TupleIterator.Next(ctx)
if err != nil {
return i, err
}
itr.counter.Add(1)
return i, nil
}
var (
_ storage.RelationshipTupleReader = (*BoundedTupleReader)(nil)
_ StorageInstrumentation = (*BoundedTupleReader)(nil)
_ storage.TupleIterator = (*countingTupleIterator)(nil)
concurrentReadDelayMsHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: "datastore_bounded_read_delay_ms",
Help: "Time spent waiting for any relevant Tuple read calls to the datastore",
Buckets: []float64{1, 3, 5, 10, 25, 50, 100, 1000, 5000}, // Milliseconds. Upper bound is config.UpstreamTimeout.
NativeHistogramBucketFactor: 1.1,
NativeHistogramMaxBucketNumber: 100,
NativeHistogramMinResetDuration: time.Hour,
}, []string{"operation", "method"})
throttledReadDelayMsHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: "datastore_throttled_read_delay_ms",
Help: "Time spent waiting for any relevant Tuple read calls to the datastore",
Buckets: []float64{1, 3, 5, 10, 25, 50, 100, 1000, 5000}, // Milliseconds. Upper bound is config.UpstreamTimeout.
NativeHistogramBucketFactor: 1.1,
NativeHistogramMaxBucketNumber: 100,
NativeHistogramMinResetDuration: time.Hour,
}, []string{"operation", "method"})
)
type BoundedTupleReader struct {
storage.RelationshipTupleReader
limiter chan struct{} // bound concurrency
countReads atomic.Uint32
countItems atomic.Uint64
method string
throttlingEnabled bool
threshold int
throttleTime time.Duration
throttled atomic.Bool
}
// NewBoundedTupleReader returns a wrapper over a datastore that makes sure that there are, at most,
// "concurrency" concurrent calls to Read, ReadUserTuple and ReadUsersetTuples.
// Consumers can then rest assured that one client will not hoard all the database connections available.
func NewBoundedTupleReader(wrapped storage.RelationshipTupleReader, op *Operation) *BoundedTupleReader {
return &BoundedTupleReader{
RelationshipTupleReader: wrapped,
limiter: make(chan struct{}, op.Concurrency),
countReads: atomic.Uint32{},
method: string(op.Method),
throttlingEnabled: op.ThrottlingEnabled,
threshold: op.ThrottleThreshold,
throttleTime: op.ThrottleDuration,
}
}
func (b *BoundedTupleReader) GetMetadata() Metadata {
return Metadata{
DatastoreQueryCount: b.countReads.Load(),
DatastoreItemCount: b.countItems.Load(),
WasThrottled: b.throttled.Load(),
}
}
// ReadUserTuple tries to return one tuple that matches the provided key exactly.
func (b *BoundedTupleReader) ReadUserTuple(
ctx context.Context,
store string,
tupleKey *openfgav1.TupleKey,
options storage.ReadUserTupleOptions,
) (*openfgav1.Tuple, error) {
err := b.bound(ctx, storagewrappersutil.OperationReadUserTuple)
if err != nil {
return nil, err
}
defer b.done()
t, err := b.RelationshipTupleReader.ReadUserTuple(ctx, store, tupleKey, options)
if t == nil || err != nil {
return t, err
}
b.countItems.Add(1)
return t, nil
}
// Read the set of tuples associated with `store` and `TupleKey`, which may be nil or partially filled.
func (b *BoundedTupleReader) Read(ctx context.Context, store string, filter storage.ReadFilter, options storage.ReadOptions) (storage.TupleIterator, error) {
err := b.bound(ctx, storagewrappersutil.OperationRead)
if err != nil {
return nil, err
}
defer b.done()
itr, err := b.RelationshipTupleReader.Read(ctx, store, filter, options)
if itr == nil || err != nil {
return itr, err
}
return &countingTupleIterator{itr, &b.countItems}, nil
}
// ReadUsersetTuples returns all userset tuples for a specified object and relation.
func (b *BoundedTupleReader) ReadUsersetTuples(
ctx context.Context,
store string,
filter storage.ReadUsersetTuplesFilter,
options storage.ReadUsersetTuplesOptions,
) (storage.TupleIterator, error) {
err := b.bound(ctx, storagewrappersutil.OperationReadUsersetTuples)
if err != nil {
return nil, err
}
defer b.done()
itr, err := b.RelationshipTupleReader.ReadUsersetTuples(ctx, store, filter, options)
if itr == nil || err != nil {
return itr, err
}
return &countingTupleIterator{itr, &b.countItems}, nil
}
// ReadStartingWithUser performs a reverse read of relationship tuples starting at one or
// more user(s) or userset(s) and filtered by object type and relation.
func (b *BoundedTupleReader) ReadStartingWithUser(
ctx context.Context,
store string,
filter storage.ReadStartingWithUserFilter,
options storage.ReadStartingWithUserOptions,
) (storage.TupleIterator, error) {
err := b.bound(ctx, storagewrappersutil.OperationReadStartingWithUser)
if err != nil {
return nil, err
}
defer b.done()
itr, err := b.RelationshipTupleReader.ReadStartingWithUser(ctx, store, filter, options)
if itr == nil || err != nil {
return itr, err
}
return &countingTupleIterator{itr, &b.countItems}, nil
}
func (b *BoundedTupleReader) instrument(ctx context.Context, op string, d time.Duration, vec *prometheus.HistogramVec) {
vec.WithLabelValues(op, b.method).Observe(float64(d))
span := trace.SpanFromContext(ctx)
span.SetAttributes(attribute.Int64(timeWaitingAttribute, d.Milliseconds()))
}
// bound will only allow the request to have a maximum number of concurrent access to the downstream datastore.
// After a threshold of accesses has been granted, an artificial amount of latency will be added to the access.
func (b *BoundedTupleReader) bound(ctx context.Context, op string) error {
startTime := time.Now()
if err := b.waitForLimiter(ctx); err != nil {
return err
}
if c := time.Since(startTime); c > concurrentTimeWaitingThreshold {
b.instrument(ctx, op, c, concurrentReadDelayMsHistogram)
}
reads := b.increaseReads()
if b.throttlingEnabled && b.threshold > 0 && reads > b.threshold {
b.throttled.Store(true)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(b.throttleTime):
break
}
b.instrument(ctx, op, time.Since(startTime), throttledReadDelayMsHistogram)
}
return nil
}
// waitForLimiter respects context errors and returns an error only if it couldn't send an item to the channel.
func (b *BoundedTupleReader) waitForLimiter(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case b.limiter <- struct{}{}:
break
}
return nil
}
func (b *BoundedTupleReader) done() {
select {
case <-b.limiter:
default:
}
}
func (b *BoundedTupleReader) increaseReads() int {
return int(b.countReads.Add(1))
}
package storagewrappers
import (
"context"
"errors"
"sync"
"sync/atomic"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"golang.org/x/sync/singleflight"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/build"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/storagewrappers/storagewrappersutil"
"github.com/openfga/openfga/pkg/tuple"
)
var (
tracer = otel.Tracer("openfga/pkg/storagewrappers/cached_datastore")
_ storage.RelationshipTupleReader = (*CachedDatastore)(nil)
tuplesCacheTotalCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: "tuples_cache_total_count",
Help: "The total number of created cached iterator instances.",
}, []string{"operation", "method"})
tuplesCacheHitCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: "tuples_cache_hit_count",
Help: "The total number of cache hits from cached iterator instances.",
}, []string{"operation", "method"})
tuplesCacheDiscardCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: "tuples_cache_discard_count",
Help: "The total number of discards from cached iterator instances.",
}, []string{"operation", "method"})
tuplesCacheSizeHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: "tuples_cache_size",
Help: "The number of tuples cached.",
Buckets: []float64{0, 1, 10, 100, 1000, 5000, 10000},
NativeHistogramBucketFactor: 1.1,
NativeHistogramMaxBucketNumber: 100,
NativeHistogramMinResetDuration: time.Hour,
}, []string{"operation", "method"})
currentIteratorCacheCount = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: build.ProjectName,
Name: "current_iterator_cache_count",
Help: "The current number of items of cache iterator.",
}, []string{"cached"})
)
// iterFunc is a function closure that returns an iterator
// from the underlying datastore.
type iterFunc func(ctx context.Context) (storage.TupleIterator, error)
type CachedDatastoreOpt func(*CachedDatastore)
// WithCachedDatastoreLogger sets the logger for the CachedDatastore.
func WithCachedDatastoreLogger(logger logger.Logger) CachedDatastoreOpt {
return func(b *CachedDatastore) {
b.logger = logger
}
}
// WithCachedDatastoreMethodName is used in metric differentiation to tell us if this was Check or ListObjects.
func WithCachedDatastoreMethodName(method string) CachedDatastoreOpt {
return func(b *CachedDatastore) {
b.method = method
}
}
// CachedDatastore is a wrapper over a datastore that caches iterators in memory.
type CachedDatastore struct {
storage.RelationshipTupleReader
ctx context.Context
cache storage.InMemoryCache[any]
maxResultSize int
ttl time.Duration
// sf is used to prevent draining the same iterator
// across multiple requests.
sf *singleflight.Group
// wg is used to synchronize inflight goroutines from underlying
// cached iterators.
wg *sync.WaitGroup
logger logger.Logger
method string // Whether this datastore is for Check or ListObjects
}
// NewCachedDatastore returns a wrapper over a datastore that caches iterators in memory.
func NewCachedDatastore(
ctx context.Context,
inner storage.RelationshipTupleReader,
cache storage.InMemoryCache[any],
maxSize int,
ttl time.Duration,
sf *singleflight.Group,
wg *sync.WaitGroup,
opts ...CachedDatastoreOpt,
) *CachedDatastore {
c := &CachedDatastore{
ctx: ctx,
RelationshipTupleReader: inner,
cache: cache,
maxResultSize: maxSize,
ttl: ttl,
sf: sf,
wg: wg,
logger: logger.NewNoopLogger(),
method: "",
}
for _, opt := range opts {
opt(c)
}
return c
}
func (c *CachedDatastore) ReadStartingWithUser(
ctx context.Context,
store string,
filter storage.ReadStartingWithUserFilter,
options storage.ReadStartingWithUserOptions,
) (storage.TupleIterator, error) {
ctx, span := tracer.Start(
ctx,
"cache.ReadStartingWithUser",
trace.WithAttributes(attribute.Bool("cached", false)),
)
defer span.End()
iter := func(ctx context.Context) (storage.TupleIterator, error) {
return c.RelationshipTupleReader.ReadStartingWithUser(ctx, store, filter, options)
}
if options.Consistency.Preference == openfgav1.ConsistencyPreference_HIGHER_CONSISTENCY {
return iter(ctx)
}
cacheKey, err := storagewrappersutil.ReadStartingWithUserKey(store, filter)
if err != nil {
return nil, err
}
// NOTE: There is no need to limit the length of this
// since at most it will have 2 entries (user and wildcard if possible)
subjects := make([]string, 0, len(filter.UserFilter))
for _, objectRel := range filter.UserFilter {
subject := objectRel.GetObject()
if objectRel.GetRelation() != "" {
subject = tuple.ToObjectRelationString(objectRel.GetObject(), objectRel.GetRelation())
}
subjects = append(subjects, subject)
}
return c.newCachedIteratorByUserObjectType(ctx, storagewrappersutil.OperationReadStartingWithUser, store, iter, cacheKey, subjects, filter.ObjectType)
}
// ReadUsersetTuples see [storage.RelationshipTupleReader].ReadUsersetTuples.
func (c *CachedDatastore) ReadUsersetTuples(
ctx context.Context,
store string,
filter storage.ReadUsersetTuplesFilter,
options storage.ReadUsersetTuplesOptions,
) (storage.TupleIterator, error) {
ctx, span := tracer.Start(
ctx,
"cache.ReadUsersetTuples",
trace.WithAttributes(attribute.Bool("cached", false)),
)
defer span.End()
iter := func(ctx context.Context) (storage.TupleIterator, error) {
return c.RelationshipTupleReader.ReadUsersetTuples(ctx, store, filter, options)
}
if options.Consistency.Preference == openfgav1.ConsistencyPreference_HIGHER_CONSISTENCY {
return iter(ctx)
}
return c.newCachedIteratorByObjectRelation(ctx,
storagewrappersutil.OperationReadUsersetTuples,
store,
iter,
storagewrappersutil.ReadUsersetTuplesKey(store, filter),
filter.Object,
filter.Relation)
}
// Read see [storage.RelationshipTupleReader].Read.
func (c *CachedDatastore) Read(
ctx context.Context,
store string,
filter storage.ReadFilter,
options storage.ReadOptions,
) (storage.TupleIterator, error) {
ctx, span := tracer.Start(
ctx,
"cache.Read",
trace.WithAttributes(attribute.Bool("cached", false)),
)
defer span.End()
iter := func(ctx context.Context) (storage.TupleIterator, error) {
return c.RelationshipTupleReader.Read(ctx, store, filter, options)
}
tupleKey := &openfgav1.TupleKey{
Object: filter.Object,
Relation: filter.Relation,
User: filter.User,
}
// this instance of Read is only called from TTU resolution path which always includes Object/Relation
if filter.Relation == "" || !tuple.IsValidObject(filter.Object) {
return iter(ctx)
}
if options.Consistency.Preference == openfgav1.ConsistencyPreference_HIGHER_CONSISTENCY {
return iter(ctx)
}
return c.newCachedIteratorByObjectRelation(ctx,
storagewrappersutil.OperationRead,
store,
iter,
storagewrappersutil.ReadKey(store, tupleKey),
tupleKey.GetObject(),
tupleKey.GetRelation())
}
func isInvalidAt(cache storage.InMemoryCache[any], ts time.Time, invalidStore string, invalidEntityKeys []string) bool {
if res := cache.Get(invalidStore); res != nil {
invalidEntry, ok := res.(*storage.InvalidEntityCacheEntry)
// if the invalid entity is not valid, do not discard
if ok && ts.Before(invalidEntry.LastModified) {
return true
}
}
for _, invalidEntityKey := range invalidEntityKeys {
if res := cache.Get(invalidEntityKey); res != nil {
invalidEntry, ok := res.(*storage.InvalidEntityCacheEntry)
// if the invalid entity is not valid, do not discard
if ok && ts.Before(invalidEntry.LastModified) {
return true
}
}
}
return false
}
// findInCache tries to find a key in the cache.
// It returns true if and only if:
// the key is present, and
// the cache key satisfies TS(key) >= TS(store), and
// all of the invalidEntityKeys satisfy TS(key) >= TS(invalid).
func findInCache(cache storage.InMemoryCache[any], key, storeKey string, invalidEntityKeys []string) (*storage.TupleIteratorCacheEntry, bool) {
var tupleEntry *storage.TupleIteratorCacheEntry
var ok bool
res := cache.Get(key)
if res == nil {
return nil, false
}
tupleEntry, ok = res.(*storage.TupleIteratorCacheEntry)
if !ok {
return nil, false
}
invalid := isInvalidAt(cache, tupleEntry.LastModified, storeKey, invalidEntityKeys)
if invalid {
cache.Delete(key)
return nil, false
}
return tupleEntry, true
}
func (c *CachedDatastore) newCachedIteratorByObjectRelation(
ctx context.Context,
operation string,
store string,
dsIterFunc iterFunc,
cacheKey string,
object string,
relation string,
) (storage.TupleIterator, error) {
objectType, objectID := tuple.SplitObject(object)
invalidEntityKey := storage.GetInvalidIteratorByObjectRelationCacheKey(store, object, relation)
return c.newCachedIterator(ctx, operation, store, dsIterFunc, cacheKey, []string{invalidEntityKey}, objectType, objectID, relation, "")
}
func (c *CachedDatastore) newCachedIteratorByUserObjectType(
ctx context.Context,
operation string,
store string,
dsIterFunc iterFunc,
cacheKey string,
users []string,
objectType string,
) (storage.TupleIterator, error) {
// if all users in filter are of the same type, we can store in cache without the value
var userType string
for _, user := range users {
userObjectType, _ := tuple.SplitObject(user)
if userType == "" {
userType = userObjectType
} else if userType != userObjectType {
userType = ""
break
}
}
invalidEntityKeys := storage.GetInvalidIteratorByUserObjectTypeCacheKeys(store, users, objectType)
return c.newCachedIterator(ctx, operation, store, dsIterFunc, cacheKey, invalidEntityKeys, objectType, "", "", userType)
}
// newCachedIterator either returns a cached static iterator for a cache hit, or
// returns a new iterator that attempts to cache the results.
func (c *CachedDatastore) newCachedIterator(
ctx context.Context,
operation string,
store string,
dsIterFunc iterFunc,
cacheKey string,
invalidEntityKeys []string,
objectType string,
objectID string,
relation string,
userType string,
) (storage.TupleIterator, error) {
span := trace.SpanFromContext(ctx)
span.SetAttributes(attribute.String("cache_key", cacheKey))
tuplesCacheTotalCounter.WithLabelValues(operation, c.method).Inc()
invalidStoreKey := storage.GetInvalidIteratorCacheKey(store)
if cacheEntry, ok := findInCache(c.cache, cacheKey, invalidStoreKey, invalidEntityKeys); ok {
tuplesCacheHitCounter.WithLabelValues(operation, c.method).Inc()
span.SetAttributes(attribute.Bool("cached", true))
staticIter := storage.NewStaticIterator[*storage.TupleRecord](cacheEntry.Tuples)
currentIteratorCacheCount.WithLabelValues("true").Inc()
return &cachedTupleIterator{
objectID: objectID,
objectType: objectType,
relation: relation,
userType: userType,
iter: staticIter,
}, nil
}
iter, err := dsIterFunc(ctx)
if err != nil {
return nil, err
}
currentIteratorCacheCount.WithLabelValues("false").Inc()
return &cachedIterator{
ctx: c.ctx,
iter: iter,
store: store,
operation: operation,
method: c.method,
// set an initial fraction capacity to balance constant reallocation and memory usage
tuples: make([]*openfgav1.Tuple, 0, c.maxResultSize/2),
cacheKey: cacheKey,
invalidStoreKey: invalidStoreKey,
invalidEntityKeys: invalidEntityKeys,
cache: c.cache,
maxResultSize: c.maxResultSize,
ttl: c.ttl,
initializedAt: time.Now(),
sf: c.sf,
objectType: objectType,
objectID: objectID,
relation: relation,
userType: userType,
wg: c.wg,
logger: c.logger,
}, nil
}
type cachedIterator struct {
ctx context.Context
iter storage.TupleIterator
store string
operation string
method string
cacheKey string
invalidStoreKey string
invalidEntityKeys []string
cache storage.InMemoryCache[any]
ttl time.Duration
initializedAt time.Time
objectID string
objectType string
relation string
userType string
// tuples is used to buffer tuples as they are read from `iter`.
tuples []*openfgav1.Tuple
// records is used to buffer tuples that might end up in cache.
records []*storage.TupleRecord
// maxResultSize is the maximum number of tuples to cache. If the number
// of tuples found exceeds this value, it will not be cached.
maxResultSize int
// sf is used to prevent draining the same iterator
// across multiple requests.
sf *singleflight.Group
// closeOnce is used to synchronize `.Close()` and ensure and stop
// producing tuples it's only done once.
closing atomic.Bool
// mu is used to synchronize access to the iterator.
mu sync.Mutex
// wg is used to synchronize inflight goroutines spawned
// when stopping the iterator.
wg *sync.WaitGroup
logger logger.Logger
stopped bool
}
// Next will return the next available tuple from the underlying iterator and
// will attempt to add to buffer if not yet full. To set buffered tuples in cache,
// you must call .Stop().
func (c *cachedIterator) Next(ctx context.Context) (*openfgav1.Tuple, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closing.Load() {
return nil, storage.ErrIteratorDone
}
t, err := c.iter.Next(ctx)
if err != nil {
if !storage.IterIsDoneOrCancelled(err) {
c.tuples = nil // don't store results that are incomplete
}
return nil, err
}
if c.tuples != nil {
c.tuples = append(c.tuples, t)
if len(c.tuples) >= c.maxResultSize {
tuplesCacheDiscardCounter.WithLabelValues(c.operation, c.method).Inc()
c.tuples = nil // don't store results that are incomplete
}
}
return t, nil
}
// Stop terminates iteration over the underlying iterator.
// - If there are incomplete results, they will not be cached.
// - If the iterator is already fully consumed, it will be cached in the foreground.
// - If the iterator is not fully consumed, it will be drained in the background,
// and attempt will be made to cache its results.
func (c *cachedIterator) Stop() {
c.mu.Lock()
defer c.mu.Unlock()
newStop := !c.stopped
c.stopped = true
if newStop {
currentIteratorCacheCount.WithLabelValues("false").Dec()
}
swapped := c.closing.CompareAndSwap(false, true)
if !swapped {
return
}
if c.tuples == nil || c.ctx.Err() != nil {
c.iter.Stop()
return
}
c.wg.Add(1)
go func() {
defer c.wg.Done()
defer c.iter.Stop()
// if cache is already set by another instance, we don't need to drain the iterator
_, ok := findInCache(c.cache, c.cacheKey, c.invalidStoreKey, c.invalidEntityKeys)
if ok {
c.iter.Stop()
c.tuples = nil
return
}
// if there was an invalidation _after_ the initialization, it shouldn't be stored
if isInvalidAt(c.cache, c.initializedAt, c.invalidStoreKey, c.invalidEntityKeys) {
c.iter.Stop()
c.tuples = nil
return
}
c.records = make([]*storage.TupleRecord, 0, len(c.tuples))
for _, t := range c.tuples {
c.addToBuffer(t)
}
// prevent goroutine if iterator was already consumed
if _, err := c.iter.Head(c.ctx); errors.Is(err, storage.ErrIteratorDone) {
c.flush()
return
}
// prevent draining on the same iterator across multiple requests
_, _, _ = c.sf.Do(c.cacheKey, func() (interface{}, error) {
for {
// attempt to drain the iterator to have it ready for subsequent calls
t, err := c.iter.Next(c.ctx)
if err != nil {
if errors.Is(err, storage.ErrIteratorDone) {
c.flush()
}
break
}
// if the size is exceeded we don't add anymore and exit
if !c.addToBuffer(t) {
break
}
}
return nil, nil
})
}()
}
// Head see [storage.Iterator].Head.
func (c *cachedIterator) Head(ctx context.Context) (*openfgav1.Tuple, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closing.Load() {
return nil, storage.ErrIteratorDone
}
return c.iter.Head(ctx)
}
// addToBuffer converts a proto tuple into a simpler storage.TupleRecord, removes
// any already known fields and adds it to the buffer if not yet full.
func (c *cachedIterator) addToBuffer(t *openfgav1.Tuple) bool {
if c.tuples == nil {
return false
}
tk := t.GetKey()
object := tk.GetObject()
objectType, objectID := tuple.SplitObject(object)
userObjectType, userObjectID, userRelation := tuple.ToUserParts(tk.GetUser())
record := &storage.TupleRecord{
ObjectID: objectID,
ObjectType: objectType,
Relation: tk.GetRelation(),
UserObjectType: userObjectType,
UserObjectID: userObjectID,
UserRelation: userRelation,
}
if timestamp := t.GetTimestamp(); timestamp != nil {
record.InsertedAt = timestamp.AsTime()
}
if condition := tk.GetCondition(); condition != nil {
record.ConditionName = condition.GetName()
record.ConditionContext = condition.GetContext()
}
// Remove any fields that are duplicated and known by iterator
if c.objectID != "" && c.objectID == record.ObjectID {
record.ObjectID = ""
}
if c.objectType != "" && c.objectType == record.ObjectType {
record.ObjectType = ""
}
if c.relation != "" && c.relation == record.Relation {
record.Relation = ""
}
if c.userType != "" && c.userType == record.UserObjectType {
record.UserObjectType = ""
}
c.records = append(c.records, record)
if len(c.records) >= c.maxResultSize {
tuplesCacheDiscardCounter.WithLabelValues(c.operation, c.method).Inc()
c.tuples = nil
c.records = nil
}
return true
}
// flush will store copy of buffered tuples into cache and delete invalidEntityKeys from the cache.
func (c *cachedIterator) flush() {
if c.tuples == nil || c.ctx.Err() != nil {
c.logger.Debug("cachedIterator flush noop due to empty tuples or c.ctx.Err",
zap.String("key", c.cacheKey),
zap.Bool("nil_tuples", c.tuples == nil),
zap.Error(c.ctx.Err()))
return
}
// Copy tuples buffer into new destination before storing into cache
// otherwise, the cache will be storing pointers. This should also help
// with garbage collection.
records := c.records
c.tuples = nil
c.records = nil
c.logger.Debug("cachedIterator flush and update cache for ", zap.String("cacheKey", c.cacheKey))
c.cache.Set(c.cacheKey, &storage.TupleIteratorCacheEntry{Tuples: records, LastModified: time.Now()}, c.ttl)
for _, k := range c.invalidEntityKeys {
c.cache.Delete(k)
}
tuplesCacheSizeHistogram.WithLabelValues(c.operation, c.method).Observe(float64(len(records)))
}
package storagewrappers
import (
"context"
"google.golang.org/protobuf/types/known/timestamppb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/tuple"
)
// cachedTupleIterator is a wrapper around an iterator
// for a given object/relation.
type cachedTupleIterator struct {
objectID string
objectType string
relation string
userType string
iter storage.Iterator[*storage.TupleRecord]
stopped bool
}
var _ storage.TupleIterator = (*cachedTupleIterator)(nil)
// Next will return the next available minimal cached tuple
// as a well-formed [openfgav1.Tuple].
func (c *cachedTupleIterator) Next(ctx context.Context) (*openfgav1.Tuple, error) {
t, err := c.iter.Next(ctx)
if err != nil {
return nil, err
}
return c.buildTuple(t), nil
}
// Stop see [storage.Iterator].Stop.
func (c *cachedTupleIterator) Stop() {
newStop := !c.stopped
c.stopped = true
c.iter.Stop()
if newStop {
currentIteratorCacheCount.WithLabelValues("true").Dec()
}
}
// Head will return the first minimal cached tuple of the iterator as
// a well-formed [openfgav1.Tuple].
func (c *cachedTupleIterator) Head(ctx context.Context) (*openfgav1.Tuple, error) {
t, err := c.iter.Head(ctx)
if err != nil {
return nil, err
}
return c.buildTuple(t), nil
}
func (c *cachedTupleIterator) buildTuple(t *storage.TupleRecord) *openfgav1.Tuple {
objectType := t.ObjectType
objectID := t.ObjectID
relation := t.Relation
userType := t.UserObjectType
if c.objectType != "" {
objectType = c.objectType
}
if c.objectID != "" {
objectID = c.objectID
}
if c.relation != "" {
relation = c.relation
}
if c.userType != "" {
userType = c.userType
}
return &openfgav1.Tuple{
Key: tuple.NewTupleKeyWithCondition(
tuple.BuildObject(objectType, objectID),
relation,
tuple.FromUserParts(userType, t.UserObjectID, t.UserRelation),
t.ConditionName,
t.ConditionContext,
),
Timestamp: timestamppb.New(t.InsertedAt),
}
}
package storagewrappers
import (
"context"
"slices"
"strings"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/tuple"
)
// NewCombinedTupleReader returns a [storage.RelationshipTupleReader] that reads from
// a persistent datastore and from the contextual tuples specified in the request.
func NewCombinedTupleReader(
ds storage.RelationshipTupleReader,
contextualTuples []*openfgav1.TupleKey,
) *CombinedTupleReader {
ctr := &CombinedTupleReader{
RelationshipTupleReader: ds,
}
cu := make([]*openfgav1.TupleKey, len(contextualTuples))
for i, t := range contextualTuples {
cu[i] = tuple.NewTupleKeyWithCondition(t.GetObject(), t.GetRelation(), t.GetUser(), t.GetCondition().GetName(), t.GetCondition().GetContext())
}
slices.SortFunc(cu, func(a *openfgav1.TupleKey, b *openfgav1.TupleKey) int {
return strings.Compare(a.GetObject(), b.GetObject())
})
ctr.contextualTuplesOrderedByObjectID = cu
return ctr
}
type CombinedTupleReader struct {
storage.RelationshipTupleReader
contextualTuplesOrderedByObjectID []*openfgav1.TupleKey
}
var _ storage.RelationshipTupleReader = (*CombinedTupleReader)(nil)
// filterTuples filters out the tuples in the provided slice by removing any tuples in the slice
// that don't match the object, relation or user provided in the filterKey.
//
//nolint:unparam
func filterTuples(tuples []*openfgav1.TupleKey, targetObject, targetRelation string, targetUsers []string) []*openfgav1.Tuple {
var filtered []*openfgav1.Tuple
for _, tk := range tuples {
if (targetObject == "" || tk.GetObject() == targetObject) &&
(targetRelation == "" || tk.GetRelation() == targetRelation) &&
(len(targetUsers) == 0 || slices.Contains(targetUsers, tk.GetUser())) {
filtered = append(filtered, &openfgav1.Tuple{
Key: tk,
})
}
}
return filtered
}
// Read see [storage.RelationshipTupleReader.Read].
func (c *CombinedTupleReader) Read(
ctx context.Context,
storeID string,
filter storage.ReadFilter,
options storage.ReadOptions,
) (storage.TupleIterator, error) {
filteredTuples := filterTuples(c.contextualTuplesOrderedByObjectID, filter.Object, filter.Relation, []string{})
iter1 := storage.NewStaticTupleIterator(filteredTuples)
iter2, err := c.RelationshipTupleReader.Read(ctx, storeID, filter, options)
if err != nil {
return nil, err
}
return storage.NewCombinedIterator(iter1, iter2), nil
}
// ReadPage see [storage.RelationshipTupleReader.ReadPage].
func (c *CombinedTupleReader) ReadPage(ctx context.Context, store string, filter storage.ReadFilter, options storage.ReadPageOptions) ([]*openfgav1.Tuple, string, error) {
// No reading from contextual tuples.
return c.RelationshipTupleReader.ReadPage(ctx, store, filter, options)
}
// ReadUserTuple see [storage.RelationshipTupleReader.ReadUserTuple].
func (c *CombinedTupleReader) ReadUserTuple(
ctx context.Context,
store string,
tk *openfgav1.TupleKey,
options storage.ReadUserTupleOptions,
) (*openfgav1.Tuple, error) {
targetUsers := []string{tk.GetUser()}
filteredContextualTuples := filterTuples(c.contextualTuplesOrderedByObjectID, tk.GetObject(), tk.GetRelation(), targetUsers)
for _, t := range filteredContextualTuples {
if t.GetKey().GetUser() == tk.GetUser() {
return t, nil
}
}
return c.RelationshipTupleReader.ReadUserTuple(ctx, store, tk, options)
}
func tupleMatchesAllowedUserTypeRestrictions(t *openfgav1.Tuple,
allowedUserTypeRestrictions []*openfgav1.RelationReference) bool {
tupleUser := t.GetKey().GetUser()
if tuple.GetUserTypeFromUser(tupleUser) != tuple.UserSet {
return false
}
// We expect there is always allowedUserTypeRestrictions. If none is specified,
// the request itself is unexpected and the safe thing is not return the
// contextual tuples.
for _, allowedUserType := range allowedUserTypeRestrictions {
if _, ok := allowedUserType.GetRelationOrWildcard().(*openfgav1.RelationReference_Wildcard); ok {
if tuple.IsTypedWildcard(tupleUser) && tuple.GetType(tupleUser) == allowedUserType.GetType() {
return true
}
}
if _, ok := allowedUserType.GetRelationOrWildcard().(*openfgav1.RelationReference_Relation); ok {
if tuple.IsObjectRelation(tupleUser) &&
tuple.GetType(tupleUser) == allowedUserType.GetType() &&
tuple.GetRelation(tupleUser) == allowedUserType.GetRelation() {
return true
}
}
}
return false
}
// ReadUsersetTuples see [storage.RelationshipTupleReader.ReadUsersetTuples].
func (c *CombinedTupleReader) ReadUsersetTuples(
ctx context.Context,
store string,
filter storage.ReadUsersetTuplesFilter,
options storage.ReadUsersetTuplesOptions,
) (storage.TupleIterator, error) {
var usersetTuples []*openfgav1.Tuple
for _, t := range filterTuples(c.contextualTuplesOrderedByObjectID, filter.Object, filter.Relation, []string{}) {
if tupleMatchesAllowedUserTypeRestrictions(t, filter.AllowedUserTypeRestrictions) {
usersetTuples = append(usersetTuples, t)
}
}
iter1 := storage.NewStaticTupleIterator(usersetTuples)
iter2, err := c.RelationshipTupleReader.ReadUsersetTuples(ctx, store, filter, options)
if err != nil {
return nil, err
}
return storage.NewCombinedIterator(iter1, iter2), nil
}
// ReadStartingWithUser see [storage.RelationshipTupleReader.ReadStartingWithUser].
func (c *CombinedTupleReader) ReadStartingWithUser(
ctx context.Context,
store string,
filter storage.ReadStartingWithUserFilter,
options storage.ReadStartingWithUserOptions,
) (storage.TupleIterator, error) {
var userFilters []string
for _, u := range filter.UserFilter {
uf := u.GetObject()
if u.GetRelation() != "" {
uf = tuple.ToObjectRelationString(uf, u.GetRelation())
}
userFilters = append(userFilters, uf)
}
filteredTuples := make([]*openfgav1.Tuple, 0, len(c.contextualTuplesOrderedByObjectID))
for _, t := range filterTuples(c.contextualTuplesOrderedByObjectID, "", filter.Relation, userFilters) {
if tuple.GetType(t.GetKey().GetObject()) != filter.ObjectType {
continue
}
filteredTuples = append(filteredTuples, t)
}
iter1 := storage.NewStaticTupleIterator(filteredTuples)
iter2, err := c.RelationshipTupleReader.ReadStartingWithUser(ctx, store, filter, options)
if err != nil {
return nil, err
}
if options.WithResultsSortedAscending {
// Note that both iter1 and iter2 return sorted by object ID
return storage.NewOrderedCombinedIterator(storage.ObjectMapper(), iter1, iter2), nil
}
return storage.NewCombinedIterator(iter1, iter2), nil
}
package storagewrappers
import (
"context"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/storage"
)
// ContextTracerWrapper is a wrapper for a datastore that introduces a new context to the underlying datastore methods.
// Its purpose is to prevent the closure of the underlying database connection in case the original context is cancelled,
// such as when a client cancels the context. This ensures that ongoing queries are allowed to complete even if the
// original context is cancelled, helping to avoid unnecessary database connection churn.
type ContextTracerWrapper struct {
storage.OpenFGADatastore
}
var _ storage.OpenFGADatastore = (*ContextTracerWrapper)(nil)
// NewContextWrapper creates a new instance of [ContextTracerWrapper], wrapping the specified datastore. It is crucial
// for [ContextTracerWrapper] to be the first wrapper around the datastore for traces to function correctly.
func NewContextWrapper(inner storage.OpenFGADatastore) *ContextTracerWrapper {
return &ContextTracerWrapper{inner}
}
// queryContext generates a new context that is independent of the provided
// context timeout.
func queryContext(ctx context.Context) context.Context {
return context.WithoutCancel(ctx)
}
// Close ensures proper cleanup and closure of resources associated with the OpenFGADatastore.
func (c *ContextTracerWrapper) Close() {
c.OpenFGADatastore.Close()
}
// Read see [storage.RelationshipTupleReader.ReadUserTuple].
func (c *ContextTracerWrapper) Read(ctx context.Context, store string, filter storage.ReadFilter, options storage.ReadOptions) (storage.TupleIterator, error) {
queryCtx := queryContext(ctx)
return c.OpenFGADatastore.Read(queryCtx, store, filter, options)
}
// ReadPage see [storage.RelationshipTupleReader.ReadPage].
func (c *ContextTracerWrapper) ReadPage(ctx context.Context, store string, filter storage.ReadFilter, options storage.ReadPageOptions) ([]*openfgav1.Tuple, string, error) {
queryCtx := queryContext(ctx)
return c.OpenFGADatastore.ReadPage(queryCtx, store, filter, options)
}
// ReadUserTuple see [storage.RelationshipTupleReader].ReadUserTuple.
func (c *ContextTracerWrapper) ReadUserTuple(ctx context.Context, store string, tupleKey *openfgav1.TupleKey, options storage.ReadUserTupleOptions) (*openfgav1.Tuple, error) {
queryCtx := queryContext(ctx)
return c.OpenFGADatastore.ReadUserTuple(queryCtx, store, tupleKey, options)
}
// ReadUsersetTuples see [storage.RelationshipTupleReader].ReadUsersetTuples.
func (c *ContextTracerWrapper) ReadUsersetTuples(ctx context.Context, store string, filter storage.ReadUsersetTuplesFilter, options storage.ReadUsersetTuplesOptions) (storage.TupleIterator, error) {
queryCtx := queryContext(ctx)
return c.OpenFGADatastore.ReadUsersetTuples(queryCtx, store, filter, options)
}
// ReadStartingWithUser see [storage.RelationshipTupleReader].ReadStartingWithUser.
func (c *ContextTracerWrapper) ReadStartingWithUser(ctx context.Context, store string, opts storage.ReadStartingWithUserFilter, options storage.ReadStartingWithUserOptions) (storage.TupleIterator, error) {
queryCtx := queryContext(ctx)
return c.OpenFGADatastore.ReadStartingWithUser(queryCtx, store, opts, options)
}
package storagewrappers
import (
"context"
"fmt"
"time"
"golang.org/x/sync/singleflight"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/storage"
)
const ttl = time.Hour * 168
var (
_ storage.OpenFGADatastore = (*cachedOpenFGADatastore)(nil)
_ storage.CacheItem = (*cachedAuthorizationModel)(nil)
)
type cachedAuthorizationModel struct {
*openfgav1.AuthorizationModel
}
func (c *cachedAuthorizationModel) CacheEntityType() string {
return "authz_model"
}
type cachedOpenFGADatastore struct {
storage.OpenFGADatastore
lookupGroup singleflight.Group
cache storage.InMemoryCache[*cachedAuthorizationModel]
}
// NewCachedOpenFGADatastore returns a wrapper over a datastore that caches up to maxSize
// [*openfgav1.AuthorizationModel] on every call to storage.ReadAuthorizationModel.
// It caches with unlimited TTL because models are immutable. It uses LRU for eviction.
func NewCachedOpenFGADatastore(inner storage.OpenFGADatastore, maxSize int) (*cachedOpenFGADatastore, error) {
cache, err := storage.NewInMemoryLRUCache[*cachedAuthorizationModel](storage.WithMaxCacheSize[*cachedAuthorizationModel](int64(maxSize)))
if err != nil {
return nil, err
}
return &cachedOpenFGADatastore{
OpenFGADatastore: inner,
cache: *cache,
}, nil
}
// ReadAuthorizationModel reads the model corresponding to store and model ID.
func (c *cachedOpenFGADatastore) ReadAuthorizationModel(ctx context.Context, storeID, modelID string) (*openfgav1.AuthorizationModel, error) {
cacheKey := fmt.Sprintf("%s:%s", storeID, modelID)
cachedEntry := c.cache.Get(cacheKey)
if cachedEntry != nil {
return cachedEntry.AuthorizationModel, nil
}
model, err := c.OpenFGADatastore.ReadAuthorizationModel(ctx, storeID, modelID)
if err != nil {
return nil, err
}
c.cache.Set(cacheKey, &cachedAuthorizationModel{model}, ttl) // These are immutable, once created, there cannot be edits, therefore they can be cached without ttl.
return model, nil
}
// FindLatestAuthorizationModel see [storage.AuthorizationModelReadBackend].FindLatestAuthorizationModel.
func (c *cachedOpenFGADatastore) FindLatestAuthorizationModel(ctx context.Context, storeID string) (*openfgav1.AuthorizationModel, error) {
v, err, _ := c.lookupGroup.Do("FindLatestAuthorizationModel:"+storeID, func() (interface{}, error) {
return c.OpenFGADatastore.FindLatestAuthorizationModel(ctx, storeID)
})
if err != nil {
return nil, err
}
return v.(*openfgav1.AuthorizationModel), nil
}
// Close closes the datastore and cleans up any residual resources.
func (c *cachedOpenFGADatastore) Close() {
c.cache.Stop()
c.OpenFGADatastore.Close()
}
package storagewrappers
import (
"time"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/shared"
"github.com/openfga/openfga/internal/utils/apimethod"
"github.com/openfga/openfga/pkg/server/config"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/storagewrappers/sharediterator"
)
type OperationType int
type Operation struct {
Method apimethod.APIMethod
Concurrency uint32
ThrottlingEnabled bool
ThrottleThreshold int
ThrottleDuration time.Duration
}
// RequestStorageWrapper uses the decorator pattern to wrap a RelationshipTupleReader with various functionalities,
// which includes exposing metrics.
type RequestStorageWrapper struct {
storage.RelationshipTupleReader
StorageInstrumentation
}
type DataResourceConfiguration struct {
Resources *shared.SharedDatastoreResources
CacheSettings config.CacheSettings
UseShadowCache bool
}
var _ StorageInstrumentation = (*RequestStorageWrapper)(nil)
// NewRequestStorageWrapperWithCache wraps the existing datastore to enable caching of iterators.
func NewRequestStorageWrapperWithCache(
ds storage.RelationshipTupleReader,
requestContextualTuples []*openfgav1.TupleKey,
op *Operation,
dataResourceConfiguration DataResourceConfiguration,
) *RequestStorageWrapper {
instrumented := NewBoundedTupleReader(ds, op) // to rate-limit reads
var tupleReader storage.RelationshipTupleReader
tupleReader = instrumented
if op.Method == apimethod.Check && dataResourceConfiguration.CacheSettings.ShouldCacheCheckIterators() {
// Reads tuples from cache where possible
tupleReader = NewCachedDatastore(
dataResourceConfiguration.Resources.ServerCtx,
tupleReader,
dataResourceConfiguration.Resources.CheckCache,
int(dataResourceConfiguration.CacheSettings.CheckIteratorCacheMaxResults),
dataResourceConfiguration.CacheSettings.CheckIteratorCacheTTL,
dataResourceConfiguration.Resources.SingleflightGroup,
dataResourceConfiguration.Resources.WaitGroup,
WithCachedDatastoreLogger(dataResourceConfiguration.Resources.Logger),
WithCachedDatastoreMethodName(string(op.Method)),
)
} else if op.Method == apimethod.ListObjects && dataResourceConfiguration.CacheSettings.ShouldCacheListObjectsIterators() {
checkCache := dataResourceConfiguration.Resources.CheckCache
if dataResourceConfiguration.UseShadowCache {
checkCache = dataResourceConfiguration.Resources.ShadowCheckCache
}
tupleReader = NewCachedDatastore(
dataResourceConfiguration.Resources.ServerCtx,
tupleReader,
checkCache,
int(dataResourceConfiguration.CacheSettings.ListObjectsIteratorCacheMaxResults),
dataResourceConfiguration.CacheSettings.ListObjectsIteratorCacheTTL,
dataResourceConfiguration.Resources.SingleflightGroup,
dataResourceConfiguration.Resources.WaitGroup,
WithCachedDatastoreLogger(dataResourceConfiguration.Resources.Logger),
WithCachedDatastoreMethodName(string(op.Method)),
)
}
if dataResourceConfiguration.CacheSettings.SharedIteratorEnabled {
tupleReader = sharediterator.NewSharedIteratorDatastore(tupleReader, dataResourceConfiguration.Resources.SharedIteratorStorage,
sharediterator.WithSharedIteratorDatastoreLogger(dataResourceConfiguration.Resources.Logger),
sharediterator.WithMethod(string(op.Method)))
}
combinedTupleReader := NewCombinedTupleReader(tupleReader, requestContextualTuples) // to read the contextual tuples
return &RequestStorageWrapper{
RelationshipTupleReader: combinedTupleReader,
StorageInstrumentation: instrumented,
}
}
// NewRequestStorageWrapper is used for ListUsers.
func NewRequestStorageWrapper(ds storage.RelationshipTupleReader, requestContextualTuples []*openfgav1.TupleKey, op *Operation) *RequestStorageWrapper {
instrumented := NewBoundedTupleReader(ds, op)
return &RequestStorageWrapper{
RelationshipTupleReader: NewCombinedTupleReader(instrumented, requestContextualTuples),
StorageInstrumentation: instrumented,
}
}
func (s *RequestStorageWrapper) GetMetadata() Metadata {
return s.StorageInstrumentation.GetMetadata()
}
package sharediterator
import (
"context"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/internal/build"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/pkg/server/config"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/storagewrappers/storagewrappersutil"
)
var (
_ storage.RelationshipTupleReader = (*IteratorDatastore)(nil)
_ storage.TupleIterator = (*sharedIterator)(nil)
sharedIteratorQueryHistogram = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: build.ProjectName,
Name: "shared_iterator_query_ms",
Help: "The duration (in ms) of a shared iterator query labeled by operation and shared.",
Buckets: []float64{1, 5, 10, 25, 50, 100, 200, 300, 1000},
NativeHistogramBucketFactor: 1.1,
NativeHistogramMaxBucketNumber: 100,
NativeHistogramMinResetDuration: time.Hour,
}, []string{"operation", "method", "shared"})
sharedIteratorBypassed = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: build.ProjectName,
Name: "shared_iterator_bypassed",
Help: "Total number of iterators bypassed by the shared iterator layer because the internal map size exceed specified limit OR max admission time has passed.",
}, []string{"operation"})
sharedIteratorCount = promauto.NewGauge(prometheus.GaugeOpts{
Namespace: build.ProjectName,
Name: "shared_iterator_count",
Help: "The current number of items of shared iterator.",
})
sharedIteratorCloneCount = promauto.NewGauge(prometheus.GaugeOpts{
Namespace: build.ProjectName,
Name: "shared_iterator_clone_count",
Help: "The current number of clones of shared iterators.",
})
)
const (
defaultSharedIteratorLimit = 1000000
defaultIteratorTargetSize = 1000
)
// Storage is a simple in-memory storage for shared iterators.
// It uses a sync.Map to store iterators and an atomic counter to keep track of the number of items.
// The limit is set to defaultSharedIteratorLimit, which can be overridden by the user.
// The storage is used to share iterators across multiple requests, allowing for efficient reuse of iterators.
type Storage struct {
// read stores shared iterators for the Read operation.
read sync.Map
// rswu stores shared iterators for the ReadWithUser operation.
rswu sync.Map
// rut stores shared iterators for the ReadUserTuples operation.
rut sync.Map
// limit is the maximum number of items that can be stored in the storage.
// If the number of items exceeds this limit, new iterators will not be created and the request will be bypassed.
// This is to prevent memory exhaustion and ensure that the storage does not grow indefinitely.
limit int64
// ctr is an atomic counter that keeps track of the number of items in the storage.
// It is incremented when a new iterator is created and decremented when an iterator is removed.
// This counter is used to determine if the storage is full and if new iterators should be bypassed.
// It is also used to monitor the number of active iterators in the system.
ctr atomic.Int64
}
// storageItem is a wrapper around a shared iterator that provides thread-safe access to the iterator.
// The unwrap method is used to get a clone of the iterator, and if the iterator is not yet created,
// it will call the producer function to create a new iterator.
type storageItem struct {
// iter is the shared iterator that is being wrapped.
// It is set to nil when the iterator is not yet created, and will be created by the producer function.
iter *sharedIterator
// err is an error that is set when the iterator creation fails.
err error
admissionDuration time.Duration
idleDuration time.Duration
idleTimer *time.Timer
// producer is a function that creates a new shared iterator.
// It is called when the iterator is not yet created, and it should return a new shared iterator or an error.
// This allows the storageItem to lazily create the iterator when it is first accessed.
// This is useful for cases where the iterator creation is expensive or requires context.
producer func() (*sharedIterator, error)
cleanup func()
// once is a sync.Once to ensure that the producer function is only called once.
once sync.Once
}
// unwrap returns a clone of the shared iterator.
// If a new iterator is created, it will return true to indicate that the iterator was created.
// If the iterator is not yet created, it will call the producer function to create a new iterator.
// If the iterator is already created, it will return a clone of the existing iterator.
// If there is an error while creating the iterator, it will return nil and the error.
func (s *storageItem) unwrap() (*sharedIterator, bool, error) {
var created bool
s.once.Do(func() {
s.iter, s.err = s.producer()
if s.err != nil {
return
}
admissionTimer := time.AfterFunc(s.admissionDuration, func() {
s.iter.Stop()
s.cleanup()
})
s.idleTimer = time.AfterFunc(s.idleDuration, func() {
if admissionTimer.Stop() {
s.iter.Stop()
s.cleanup()
}
})
created = true
})
if s.err != nil {
return nil, false, s.err
}
clone := s.iter.clone()
s.idleTimer.Reset(s.idleDuration)
return clone, created, s.err
}
type DatastoreStorageOpt func(*Storage)
// WithSharedIteratorDatastoreStorageLimit sets the limit on the number of items in SF iterator iters.
func WithSharedIteratorDatastoreStorageLimit(limit int) DatastoreStorageOpt {
return func(b *Storage) {
b.limit = int64(limit)
}
}
func NewSharedIteratorDatastoreStorage(opts ...DatastoreStorageOpt) *Storage {
newStorage := &Storage{
limit: defaultSharedIteratorLimit,
}
for _, opt := range opts {
opt(newStorage)
}
return newStorage
}
type IteratorDatastoreOpt func(*IteratorDatastore)
// WithSharedIteratorDatastoreLogger sets the logger for the IteratorDatastore.
func WithSharedIteratorDatastoreLogger(logger logger.Logger) IteratorDatastoreOpt {
return func(b *IteratorDatastore) {
b.logger = logger
}
}
// WithMaxAdmissionTime sets the maximum duration for which a shared iterator may reused.
// After this period, the shared iterator will be removed from the cache. This is done
// to prevent stale data if there are very long-running requests.
func WithMaxAdmissionTime(maxAdmissionTime time.Duration) IteratorDatastoreOpt {
return func(b *IteratorDatastore) {
b.maxAdmissionTime = maxAdmissionTime
}
}
// WithMaxIdleTime sets the maximum duration for which a shared iterator can remain idle before it is no longer reusable.
// After this period, the shared iterator will be removed from the cache. This is done
// to prevent unused shared iterators from being cached for longer than necessary.
func WithMaxIdleTime(maxIdleTime time.Duration) IteratorDatastoreOpt {
return func(b *IteratorDatastore) {
b.maxIdleTime = maxIdleTime
}
}
// WithMethod specifies whether the shared iterator is for check or list objects for metrics.
func WithMethod(method string) IteratorDatastoreOpt {
return func(b *IteratorDatastore) {
b.method = method
}
}
// IteratorDatastore is a wrapper around a storage.RelationshipTupleReader that provides shared iterators.
// It uses an internal storage to manage shared iterators and provides methods to read tuples starting with a user,
// read userset tuples, and read tuples by key.
// The shared iterators are created lazily and are shared across multiple requests to improve performance.
type IteratorDatastore struct {
// RelationshipTupleReader is the inner datastore that provides the actual implementation of reading tuples.
storage.RelationshipTupleReader
// logger is used for logging messages related to the shared iterator operations.
logger logger.Logger
// method is used to specify the type of operation being performed (e.g., "check" or "list objects").
method string
// internalStorage is used to store shared iterators and manage their lifecycle.
internalStorage *Storage
// maxAdmissionTime is the maximum duration for which shared iterator allows clone.
maxAdmissionTime time.Duration
maxIdleTime time.Duration
}
// NewSharedIteratorDatastore creates a new IteratorDatastore with the given inner RelationshipTupleReader and internal storage.
// It also accepts optional configuration options to customize the behavior of the datastore.
// The datastore will use shared iterators to improve performance by reusing iterators across multiple requests.
// If the number of active iterators exceeds the specified limit, new requests will be bypassed and handled by the inner reader.
func NewSharedIteratorDatastore(inner storage.RelationshipTupleReader, internalStorage *Storage, opts ...IteratorDatastoreOpt) *IteratorDatastore {
sf := &IteratorDatastore{
RelationshipTupleReader: inner,
logger: logger.NewNoopLogger(),
internalStorage: internalStorage,
method: "",
maxAdmissionTime: config.DefaultSharedIteratorMaxAdmissionTime,
maxIdleTime: config.DefaultSharedIteratorMaxIdleTime,
}
for _, opt := range opts {
opt(sf)
}
return sf
}
// ReadStartingWithUser reads tuples starting with a user using shared iterators.
// If the request is for higher consistency, it will fall back to the inner RelationshipTupleReader.
func (sf *IteratorDatastore) ReadStartingWithUser(
ctx context.Context,
store string,
filter storage.ReadStartingWithUserFilter,
options storage.ReadStartingWithUserOptions,
) (storage.TupleIterator, error) {
if options.Consistency.Preference == openfgav1.ConsistencyPreference_HIGHER_CONSISTENCY {
// for now, we will skip shared iterator since there is a possibility that the request
// may be slightly stale. In the future, consider whether we should have shared iterator
// for higher consistency request. This may mean having separate cache.
return sf.RelationshipTupleReader.ReadStartingWithUser(ctx, store, filter, options)
}
start := time.Now()
cacheKey, err := storagewrappersutil.ReadStartingWithUserKey(store, filter)
if err != nil {
return nil, err
}
// If the limit is zero, we will not use the shared iterator.
full := sf.internalStorage.limit == 0 || sf.internalStorage.ctr.Load() >= sf.internalStorage.limit
if full {
sharedIteratorBypassed.WithLabelValues(storagewrappersutil.OperationReadStartingWithUser).Inc()
return sf.RelationshipTupleReader.ReadStartingWithUser(ctx, store, filter, options)
}
// Create a new storage item to hold the shared iterator.
// This item will be stored in the internal storage map and will be used to share the iterator across requests.
newStorageItem := new(storageItem)
newStorageItem.admissionDuration = sf.maxAdmissionTime
newStorageItem.idleDuration = sf.maxIdleTime
// The producer function is called to create a new shared iterator when it is first accessed.
newStorageItem.producer = func() (*sharedIterator, error) {
it, err := sf.RelationshipTupleReader.ReadStartingWithUser(ctx, store, filter, options)
if err != nil {
return nil, err
}
si := newSharedIterator(it)
sharedIteratorCount.Inc()
return si, nil
}
newStorageItem.cleanup = func() {
if sf.internalStorage.rswu.CompareAndDelete(cacheKey, newStorageItem) {
sf.internalStorage.ctr.Add(-1)
sharedIteratorCount.Dec()
}
}
// Load or store the new storage item in the internal storage map.
// If the item is not already present, it will be added to the map and the counter will be incremented.
value, loaded := sf.internalStorage.rswu.LoadOrStore(cacheKey, newStorageItem)
if !loaded {
sf.internalStorage.ctr.Add(1)
}
item, _ := value.(*storageItem)
// Unwrap the storage item to get the shared iterator.
// If there is an error while unwrapping, we will remove the item from the internal storage and return the error.
// If this is the first time the iterator is accessed, it will call the producer function to create a new iterator.
it, created, err := item.unwrap()
if err != nil {
sf.internalStorage.rswu.CompareAndDelete(cacheKey, newStorageItem)
return nil, err
}
// If the iterator is nil, we will fall back to the inner RelationshipTupleReader.
// This can happen if the cloned shared iterator is already stopped and all references have been cleaned up.
if it == nil {
sharedIteratorBypassed.WithLabelValues(storagewrappersutil.OperationReadStartingWithUser).Inc()
return sf.RelationshipTupleReader.ReadStartingWithUser(ctx, store, filter, options)
}
sharedIteratorQueryHistogram.WithLabelValues(
storagewrappersutil.OperationReadStartingWithUser, sf.method, strconv.FormatBool(!created),
).Observe(float64(time.Since(start).Milliseconds()))
return it, nil
}
// ReadUsersetTuples reads userset tuples using shared iterators.
// If the request is for higher consistency, it will fall back to the inner RelationshipTupleReader.
func (sf *IteratorDatastore) ReadUsersetTuples(
ctx context.Context,
store string,
filter storage.ReadUsersetTuplesFilter,
options storage.ReadUsersetTuplesOptions,
) (storage.TupleIterator, error) {
if options.Consistency.Preference == openfgav1.ConsistencyPreference_HIGHER_CONSISTENCY {
return sf.RelationshipTupleReader.ReadUsersetTuples(ctx, store, filter, options)
}
start := time.Now()
cacheKey := storagewrappersutil.ReadUsersetTuplesKey(store, filter)
// If the limit is zero, we will not use the shared iterator.
full := sf.internalStorage.limit == 0 || sf.internalStorage.ctr.Load() >= sf.internalStorage.limit
if full {
sharedIteratorBypassed.WithLabelValues(storagewrappersutil.OperationReadUsersetTuples).Inc()
return sf.RelationshipTupleReader.ReadUsersetTuples(ctx, store, filter, options)
}
// Create a new storage item to hold the shared iterator.
// This item will be stored in the internal storage map and will be used to share the iterator across requests.
newStorageItem := new(storageItem)
newStorageItem.admissionDuration = sf.maxAdmissionTime
newStorageItem.idleDuration = sf.maxIdleTime
// The producer function is called to create a new shared iterator when it is first accessed.
newStorageItem.producer = func() (*sharedIterator, error) {
it, err := sf.RelationshipTupleReader.ReadUsersetTuples(ctx, store, filter, options)
if err != nil {
return nil, err
}
si := newSharedIterator(it)
sharedIteratorCount.Inc()
return si, nil
}
newStorageItem.cleanup = func() {
if sf.internalStorage.rut.CompareAndDelete(cacheKey, newStorageItem) {
sf.internalStorage.ctr.Add(-1)
sharedIteratorCount.Dec()
}
}
// Load or store the new storage item in the internal storage map.
// If the item is not already present, it will be added to the map and the counter will be incremented.
value, loaded := sf.internalStorage.rut.LoadOrStore(cacheKey, newStorageItem)
if !loaded {
sf.internalStorage.ctr.Add(1)
}
item, _ := value.(*storageItem)
// Unwrap the storage item to get the shared iterator.
// If there is an error while unwrapping, we will remove the item from the internal storage and return the error.
// If this is the first time the iterator is accessed, it will call the producer function to create a new iterator.
it, created, err := item.unwrap()
if err != nil {
sf.internalStorage.rut.CompareAndDelete(cacheKey, newStorageItem)
return nil, err
}
// If the iterator is nil, we will fall back to the inner RelationshipTupleReader.
// This can happen if the cloned shared iterator is already stopped and all references have been cleaned up.
if it == nil {
sharedIteratorBypassed.WithLabelValues(storagewrappersutil.OperationReadUsersetTuples).Inc()
return sf.RelationshipTupleReader.ReadUsersetTuples(ctx, store, filter, options)
}
sharedIteratorQueryHistogram.WithLabelValues(
storagewrappersutil.OperationReadUsersetTuples, sf.method, strconv.FormatBool(!created),
).Observe(float64(time.Since(start).Milliseconds()))
return it, nil
}
// Read reads tuples by key using shared iterators.
// If the request is for higher consistency, it will fall back to the inner RelationshipTupleReader.
func (sf *IteratorDatastore) Read(
ctx context.Context,
store string,
filter storage.ReadFilter,
options storage.ReadOptions) (storage.TupleIterator, error) {
if options.Consistency.Preference == openfgav1.ConsistencyPreference_HIGHER_CONSISTENCY {
return sf.RelationshipTupleReader.Read(ctx, store, filter, options)
}
start := time.Now()
tupleKey := &openfgav1.TupleKey{
Object: filter.Object,
Relation: filter.Relation,
User: filter.User,
}
cacheKey := storagewrappersutil.ReadKey(store, tupleKey)
// If the limit is zero, we will not use the shared iterator.
full := sf.internalStorage.limit == 0 || sf.internalStorage.ctr.Load() >= sf.internalStorage.limit
if full {
sharedIteratorBypassed.WithLabelValues(storagewrappersutil.OperationRead).Inc()
return sf.RelationshipTupleReader.Read(ctx, store, filter, options)
}
// Create a new storage item to hold the shared iterator.
// This item will be stored in the internal storage map and will be used to share the iterator across requests.
newStorageItem := new(storageItem)
newStorageItem.admissionDuration = sf.maxAdmissionTime
newStorageItem.idleDuration = sf.maxIdleTime
// The producer function is called to create a new shared iterator when it is first accessed.
newStorageItem.producer = func() (*sharedIterator, error) {
it, err := sf.RelationshipTupleReader.Read(ctx, store, filter, options)
if err != nil {
return nil, err
}
si := newSharedIterator(it)
sharedIteratorCount.Inc()
return si, nil
}
newStorageItem.cleanup = func() {
if sf.internalStorage.read.CompareAndDelete(cacheKey, newStorageItem) {
sf.internalStorage.ctr.Add(-1)
sharedIteratorCount.Dec()
}
}
// Load or store the new storage item in the internal storage map.
// If the item is not already present, it will be added to the map and the counter will be incremented.
value, loaded := sf.internalStorage.read.LoadOrStore(cacheKey, newStorageItem)
if !loaded {
sf.internalStorage.ctr.Add(1)
}
item, _ := value.(*storageItem)
// Unwrap the storage item to get the shared iterator.
// If there is an error while unwrapping, we will remove the item from the internal storage and return the error.
// If this is the first time the iterator is accessed, it will call the producer function to create a new iterator.
it, created, err := item.unwrap()
if err != nil {
sf.internalStorage.read.CompareAndDelete(cacheKey, newStorageItem)
return nil, err
}
// If the iterator is nil, we will fall back to the inner RelationshipTupleReader.
// This can happen if the cloned shared iterator is already stopped and all references have been cleaned up.
if it == nil {
sharedIteratorBypassed.WithLabelValues(storagewrappersutil.OperationRead).Inc()
return sf.RelationshipTupleReader.Read(ctx, store, filter, options)
}
sharedIteratorQueryHistogram.WithLabelValues(
storagewrappersutil.OperationRead, sf.method, strconv.FormatBool(!created),
).Observe(float64(time.Since(start).Milliseconds()))
return it, nil
}
// bufferSize is the number of items to fetch at a time when reading from the shared iterator.
const bufferSize = 100
// await is an object that executes an action exactly once at a time.
//
// The singleflight.Group type was used as a comparison to the await type, but was found to be ~59% slower than await
// in concurrent stress test benchmarks.
type await struct {
active bool
wg *sync.WaitGroup
mu sync.Mutex
}
// Do executes the provided function fn if it is not already being executed.
// The first goroutine to call Do will execute the function, while subsequent calls will block until the function has completed.
// This ensures that only one goroutine can execute the function at a time, preventing concurrent execution of the function.
func (a *await) Do(fn func()) {
a.mu.Lock()
if a.active {
wg := a.wg
a.mu.Unlock()
wg.Wait()
return
}
a.active = true
a.wg = new(sync.WaitGroup)
a.wg.Add(1)
a.mu.Unlock()
fn()
a.wg.Done()
a.mu.Lock()
a.active = false
a.mu.Unlock()
}
// iteratorReader is a wrapper around a storage.Iterator that implements the reader interface.
type iteratorReader[T any] struct {
storage.Iterator[T]
}
// Read reads items from the iterator into the provided buffer.
// The method will read up to the length of the buffer, and if there are fewer items available,
// it will return the number of items read and an error if any occurred.
func (ir *iteratorReader[T]) Read(ctx context.Context, buf []T) (int, error) {
for i := range buf {
t, err := ir.Next(ctx)
if err != nil {
return i, err
}
buf[i] = t
}
return len(buf), nil
}
// iteratorState holds the state of the shared iterator.
// It contains a slice of items that have been fetched and any error encountered during the iteration.
// This state is shared between all clones of the iterator and is updated atomically to ensure thread-safety.
type iteratorState struct {
items []*openfgav1.Tuple
err error
}
// sharedIterator is a thread-safe iterator that allows multiple goroutines to share the same iterator.
// It uses a mutex to ensure that only one goroutine can access an individual iterator instance at a time.
// Atomic variables are used to manage data shared between clones.
type sharedIterator struct {
// mu is a mutex to ensure that only one goroutine can access the current iterator instance at a time.
// mu is a value type because it is not shared between iterator instances.
mu sync.RWMutex
// head is the index of the next item to be returned by the current iterator instance.
// It is incremented each time an item is returned by the iterator in a call to Next.
// head is a value type because it is not shared between iterator instances.
head int
// stopped is a boolean that indicates whether the current iterator instance has been stopped.
// stopped is a value type because it is not shared between iterator instances.
stopped bool
// await ensures that only one goroutine can fetch items at a time.
// It is used to block until new items are available or the context is done.
await *await
// ir is the underlying iterator reader that provides the actual implementation of reading tuples.
ir *iteratorReader[*openfgav1.Tuple]
// state is a shared atomic pointer to the iterator state, which contains the items and any error encountered during iteration.
state *atomic.Pointer[iteratorState]
// refs is a shared atomic counter that keeps track of the number of shared instances of the iterator.
refs *atomic.Int64
}
// initSharedIterator creates a new shared iterator from the given storage.TupleIterator.
// It initializes the shared context, cancellation function, and other necessary fields.
func newSharedIterator(it storage.TupleIterator) *sharedIterator {
var aw await
// Initialize the reference counter to 1, indicating that there is one active instance of the iterator.
var refs atomic.Int64
refs.Store(1)
ir := &iteratorReader[*openfgav1.Tuple]{
Iterator: it,
}
// Initialize the iterator state with an empty slice of items and no error.
var state iteratorState
var pstate atomic.Pointer[iteratorState]
pstate.Store(&state)
return &sharedIterator{
await: &aw,
ir: ir,
state: &pstate,
refs: &refs,
}
}
// Clone increments an internal reference count and returns a new instance of the shared iterator.
// If the original iterator has been stopped, it returns nil.
// This allows multiple goroutines to share the same iterator instance without interfering with each other.
// The clone method is thread-safe and ensures that the reference count is incremented atomically.
func (s *sharedIterator) clone() *sharedIterator {
s.mu.RLock()
if s.stopped {
s.mu.RUnlock()
return nil
}
s.refs.Add(1)
s.mu.RUnlock()
sharedIteratorCloneCount.Inc()
return &sharedIterator{
await: s.await,
ir: s.ir,
state: s.state,
refs: s.refs,
}
}
// fetchMore is a method that fetches more items from the underlying storage.TupleIterator.
// It reads a fixed number of items (bufferSize) from the iterator and appends them
// to the shared items slice in the iterator state.
// If an error occurs during the read operation, it updates the error in the iterator state.
func (s *sharedIterator) fetchMore() {
var buf [bufferSize]*openfgav1.Tuple
read, e := s.ir.Read(context.Background(), buf[:])
// Load the current items from the shared items pointer and append the newly fetched items to it.
state := s.state.Load()
newState := &iteratorState{
items: make([]*openfgav1.Tuple, len(state.items)+read),
err: state.err,
}
copy(newState.items, state.items)
copy(newState.items[len(state.items):], buf[:read])
if e != nil {
newState.err = e
}
s.state.Store(newState)
}
// fetchAndWait is a method that fetches items from the underlying storage.TupleIterator and waits for new items to be available.
// It blocks until new items are fetched or an error occurs.
// The items and err pointers are updated with the fetched items and any error encountered.
func (s *sharedIterator) fetchAndWait(items *[]*openfgav1.Tuple, err *error) {
for {
state := s.state.Load()
if s.head < len(state.items) || state.err != nil {
*items = state.items
*err = state.err
return
}
s.await.Do(s.fetchMore)
}
}
// Current returns the current item in the shared iterator without advancing the iterator.
// It is used to peek at the next item without consuming it.
// If the iterator is stopped or there are no items available, it returns an error.
// It also handles fetching new items if the current head is beyond the available items.
func (s *sharedIterator) currentLocked(ctx context.Context) (*openfgav1.Tuple, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
if s.stopped {
return nil, storage.ErrIteratorDone
}
var items []*openfgav1.Tuple
var err error
s.fetchAndWait(&items, &err)
if ctx.Err() != nil {
return nil, ctx.Err()
}
if s.head >= len(items) {
if err != nil {
return nil, err
}
// This is a guard clause to ensure we do not access out of bounds.
// If we reach here, it means there is a bug in the underlying iterator.
return nil, storage.ErrIteratorDone
}
return items[s.head], nil
}
// Head returns the first item in the shared iterator without advancing the iterator.
// It is used to peek at the next item without consuming it.
// If the iterator is stopped or there are no items available, it returns an error.
// It also handles fetching new items if the current head is beyond the available items.
func (s *sharedIterator) Head(ctx context.Context) (*openfgav1.Tuple, error) {
s.mu.Lock()
defer s.mu.Unlock()
return s.currentLocked(ctx)
}
// Next returns the next item in the shared iterator and advances the iterator.
// It is used to consume the next item in the iterator.
// If the iterator is stopped or there are no items available, it returns an error.
// It also handles fetching new items if the current head is beyond the available items.
func (s *sharedIterator) Next(ctx context.Context) (*openfgav1.Tuple, error) {
s.mu.Lock()
defer s.mu.Unlock()
result, err := s.currentLocked(ctx)
if err != nil {
return nil, err
}
s.head++
return result, nil
}
// Stop stops the shared iterator and cleans up resources.
// It decrements the reference count and checks if it should clean up the iterator.
// If the reference count reaches zero, it calls the cleanup function to remove the iterator from the internal storage.
func (s *sharedIterator) Stop() {
s.mu.Lock()
defer s.mu.Unlock()
if !s.stopped {
s.stopped = true
if s.refs.Add(-1) == 0 {
s.ir.Stop()
}
sharedIteratorCloneCount.Dec()
}
}
package storagewrappersutil
import (
"strconv"
"strings"
"github.com/cespare/xxhash/v2"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/tuple"
)
const (
OperationRead = "Read"
OperationReadStartingWithUser = "ReadStartingWithUser"
OperationReadUsersetTuples = "ReadUsersetTuples"
OperationReadUserTuple = "ReadUserTuple"
)
func ReadStartingWithUserKey(
store string,
filter storage.ReadStartingWithUserFilter,
) (string, error) {
var b strings.Builder
b.WriteString(
storage.GetReadStartingWithUserCacheKeyPrefix(store, filter.ObjectType, filter.Relation),
)
// NOTE: There is no need to limit the length of this
// since at most it will have 2 entries (user and wildcard if possible)
for _, objectRel := range filter.UserFilter {
subject := objectRel.GetObject()
if objectRel.GetRelation() != "" {
subject = tuple.ToObjectRelationString(objectRel.GetObject(), objectRel.GetRelation())
}
b.WriteString("/" + subject)
}
if filter.ObjectIDs != nil {
hasher := xxhash.New()
for _, oid := range filter.ObjectIDs.Values() {
if _, err := hasher.WriteString(oid); err != nil {
return "", err
}
}
b.WriteString("/" + strconv.FormatUint(hasher.Sum64(), 10))
}
return b.String(), nil
}
func ReadUsersetTuplesKey(store string, filter storage.ReadUsersetTuplesFilter) string {
var b strings.Builder
b.WriteString(
storage.GetReadUsersetTuplesCacheKeyPrefix(store, filter.Object, filter.Relation),
)
var rb strings.Builder
var wb strings.Builder
for _, userset := range filter.AllowedUserTypeRestrictions {
if _, ok := userset.GetRelationOrWildcard().(*openfgav1.RelationReference_Relation); ok {
rb.WriteString("/" + userset.GetType() + "#" + userset.GetRelation())
}
if _, ok := userset.GetRelationOrWildcard().(*openfgav1.RelationReference_Wildcard); ok {
wb.WriteString("/" + userset.GetType() + ":*")
}
}
// wildcard should have precedence
if wb.Len() > 0 {
b.WriteString(wb.String())
}
if rb.Len() > 0 {
b.WriteString(rb.String())
}
return b.String()
}
func ReadKey(store string, tupleKey *openfgav1.TupleKey) string {
var b strings.Builder
b.WriteString(
storage.GetReadCacheKey(store, tuple.TupleKeyToString(tupleKey)),
)
return b.String()
}
package storage
import (
"context"
"errors"
"fmt"
"slices"
"sync"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
)
var ErrIteratorDone = errors.New("iterator done")
// Iterator is a generic interface defining methods for
// iterating over a collection of items of type T.
type Iterator[T any] interface {
// Next will return the next available
// item or ErrIteratorDone if no more
// items are available.
Next(ctx context.Context) (T, error)
// Stop terminates iteration. Any subsequent calls to Next must return ErrIteratorDone.
Stop()
// Head will return the first item or ErrIteratorDone if the iterator is finished or empty.
// It's possible for this method to advance the iterator internally, but a subsequent call to Next will not miss any results.
// Calling Head() continuously without calling Next() will yield the same result (the first one) over and over.
Head(ctx context.Context) (T, error)
}
type TupleIterator = Iterator[*openfgav1.Tuple]
type TupleKeyIterator = Iterator[*openfgav1.TupleKey]
// combinedIterator is a thread-safe iterator that merges multiple iterators.
// Duplicates can be returned.
type combinedIterator[T any] struct {
mu *sync.Mutex
once *sync.Once
pending []Iterator[T] // GUARDED_BY(mu)
}
var _ Iterator[any] = (*combinedIterator[any])(nil)
// Next see [Iterator.Next].
func (c *combinedIterator[T]) Next(ctx context.Context) (T, error) {
c.mu.Lock() // no defer of Unlock because of the recursive call
if len(c.pending) == 0 {
// All iterators ended.
var val T
c.mu.Unlock()
return val, ErrIteratorDone
}
iter := c.pending[0]
val, err := iter.Next(ctx)
if err != nil {
if errors.Is(err, ErrIteratorDone) {
c.pending = c.pending[1:]
iter.Stop() // clean up before dropping the reference
c.mu.Unlock()
return c.Next(ctx)
}
c.mu.Unlock()
return val, err
}
c.mu.Unlock()
return val, nil
}
// Stop see [Iterator.Stop].
func (c *combinedIterator[T]) Stop() {
c.once.Do(func() {
c.mu.Lock()
defer c.mu.Unlock()
for _, iter := range c.pending {
iter.Stop()
}
})
}
// Head see [Iterator.Head].
func (c *combinedIterator[T]) Head(ctx context.Context) (T, error) {
c.mu.Lock() // no defer of Unlock because of the recursive call
if len(c.pending) == 0 {
// All iterators ended.
var val T
c.mu.Unlock()
return val, ErrIteratorDone
}
iter := c.pending[0]
val, err := iter.Head(ctx)
if err != nil {
if errors.Is(err, ErrIteratorDone) {
c.pending = c.pending[1:]
iter.Stop()
c.mu.Unlock()
return c.Head(ctx)
}
c.mu.Unlock()
return val, err
}
c.mu.Unlock()
return val, nil
}
// NewCombinedIterator is a thread-safe iterator that takes generic iterators of a given type T
// and combines them into a single iterator that yields all the
// values from all iterators. Duplicates can be returned.
func NewCombinedIterator[T any](iters ...Iterator[T]) Iterator[T] {
pending := make([]Iterator[T], 0, len(iters))
for _, iter := range iters {
if iter != nil {
pending = append(pending, iter)
}
}
return &combinedIterator[T]{pending: pending, once: &sync.Once{}, mu: &sync.Mutex{}}
}
// NewStaticTupleIterator returns a [TupleIterator] that iterates over the provided slice.
func NewStaticTupleIterator(tuples []*openfgav1.Tuple) TupleIterator {
iter := &StaticIterator[*openfgav1.Tuple]{
items: tuples,
mu: &sync.Mutex{},
}
return iter
}
// NewStaticTupleKeyIterator returns a [TupleKeyIterator] that iterates over the provided slice.
func NewStaticTupleKeyIterator(tupleKeys []*openfgav1.TupleKey) TupleKeyIterator {
iter := &StaticIterator[*openfgav1.TupleKey]{
items: tupleKeys,
mu: &sync.Mutex{},
}
return iter
}
type tupleKeyIterator struct {
iter TupleIterator
once *sync.Once
}
var _ TupleKeyIterator = (*tupleKeyIterator)(nil)
// Next see [Iterator.Next].
func (t *tupleKeyIterator) Next(ctx context.Context) (*openfgav1.TupleKey, error) {
tuple, err := t.iter.Next(ctx)
if err != nil {
return nil, err
}
return tuple.GetKey(), nil
}
// Stop see [Iterator.Stop].
func (t *tupleKeyIterator) Stop() {
t.once.Do(func() {
t.iter.Stop()
})
}
// Head see [Iterator.Head].
func (t *tupleKeyIterator) Head(ctx context.Context) (*openfgav1.TupleKey, error) {
tuple, err := t.iter.Head(ctx)
if err != nil {
return nil, err
}
return tuple.GetKey(), nil
}
// NewTupleKeyIteratorFromTupleIterator takes a [TupleIterator] and yields
// all the [*openfgav1.TupleKey](s) from it as a [TupleKeyIterator].
func NewTupleKeyIteratorFromTupleIterator(iter TupleIterator) TupleKeyIterator {
return &tupleKeyIterator{iter, &sync.Once{}}
}
type StaticIterator[T any] struct {
items []T // GUARDED_BY(mu)
mu *sync.Mutex
}
var _ Iterator[any] = (*StaticIterator[any])(nil)
// Next see [Iterator.Next].
func (s *StaticIterator[T]) Next(ctx context.Context) (T, error) {
var val T
if ctx.Err() != nil {
return val, ctx.Err()
}
s.mu.Lock()
defer s.mu.Unlock()
if len(s.items) == 0 {
return val, ErrIteratorDone
}
next, rest := s.items[0], s.items[1:]
s.items = rest
return next, nil
}
// Stop see [Iterator.Stop].
func (s *StaticIterator[T]) Stop() {
s.mu.Lock()
defer s.mu.Unlock()
s.items = nil
}
// Head see [Iterator.Head].
func (s *StaticIterator[T]) Head(ctx context.Context) (T, error) {
var val T
if ctx.Err() != nil {
return val, ctx.Err()
}
s.mu.Lock()
defer s.mu.Unlock()
if len(s.items) == 0 {
return val, ErrIteratorDone
}
return s.items[0], nil
}
func NewStaticIterator[T any](items []T) Iterator[T] {
return &StaticIterator[T]{items: items, mu: &sync.Mutex{}}
}
// TupleKeyFilterFunc is a filter function that is used to filter out
// tuples from a [TupleKeyIterator] that don't meet certain criteria.
// Implementations should return true if the tuple should be returned
// and false if it should be filtered out.
type TupleKeyFilterFunc func(tupleKey *openfgav1.TupleKey) bool
type filteredTupleKeyIterator struct {
iter TupleKeyIterator
filter TupleKeyFilterFunc
once *sync.Once
}
var _ TupleKeyIterator = (*filteredTupleKeyIterator)(nil)
// Next returns the next most tuple in the underlying iterator that meets
// the filter function this iterator was constructed with.
func (f *filteredTupleKeyIterator) Next(ctx context.Context) (*openfgav1.TupleKey, error) {
for {
tuple, err := f.iter.Next(ctx)
if err != nil {
return nil, err
}
if f.filter(tuple) {
return tuple, nil
}
}
}
// Stop see [Iterator.Stop].
func (f *filteredTupleKeyIterator) Stop() {
f.once.Do(func() {
f.iter.Stop()
})
}
// Head returns the next most tuple in the underlying iterator that meets
// the filter function this iterator was constructed with.
// Note: the underlying iterator will advance until the filter is satisfied.
func (f *filteredTupleKeyIterator) Head(ctx context.Context) (*openfgav1.TupleKey, error) {
for {
tuple, err := f.iter.Head(ctx)
if err != nil {
return nil, err
}
if f.filter(tuple) {
return tuple, nil
}
_, err = f.iter.Next(ctx)
if err != nil {
return nil, err
}
}
}
// NewFilteredTupleKeyIterator returns a [TupleKeyIterator] that filters out all
// [*openfgav1.Tuple](s) that don't meet the conditions of the provided [TupleKeyFilterFunc].
func NewFilteredTupleKeyIterator(iter TupleKeyIterator, filter TupleKeyFilterFunc) TupleKeyIterator {
return &filteredTupleKeyIterator{
iter,
filter,
&sync.Once{},
}
}
// TupleKeyConditionFilterFunc is a filter function that is used to filter out
// tuples from a [TupleKeyIterator] that don't meet the tuple the conditions provided by the request.
// Implementations should return true if the tuple should be returned
// and false if it should be filtered out.
// Errors will be treated as false. If none of the tuples are valid AND there are errors, Next() will return
// the last error.
type TupleKeyConditionFilterFunc func(tupleKey *openfgav1.TupleKey) (bool, error)
type ConditionsFilteredTupleKeyIterator struct {
iter TupleKeyIterator
filter TupleKeyConditionFilterFunc
lastError error
onceValid bool
once *sync.Once
}
var _ TupleKeyIterator = (*ConditionsFilteredTupleKeyIterator)(nil)
// Next returns the next most tuple in the underlying iterator that meets
// the filter function this iterator was constructed with.
// This function is not thread-safe.
func (f *ConditionsFilteredTupleKeyIterator) Next(ctx context.Context) (*openfgav1.TupleKey, error) {
for {
tuple, err := f.iter.Next(ctx)
if err != nil {
if errors.Is(err, ErrIteratorDone) {
if f.onceValid || f.lastError == nil {
return nil, ErrIteratorDone
}
lastError := f.lastError
f.lastError = nil
return nil, lastError
}
return nil, err
}
valid, err := f.filter(tuple)
if err != nil {
f.lastError = err
continue
}
if !valid {
continue
}
f.onceValid = true
return tuple, nil
}
}
// Stop see [Iterator.Stop].
func (f *ConditionsFilteredTupleKeyIterator) Stop() {
f.once.Do(func() {
f.iter.Stop()
})
}
// Head returns the next most tuple in the underlying iterator that meets
// the filter function this iterator was constructed with.
// The underlying iterator may advance but calling consecutive Head will yield consistent result.
// Further, calling Head following by Next will also yield consistent result.
// This function is not thread-safe.
func (f *ConditionsFilteredTupleKeyIterator) Head(ctx context.Context) (*openfgav1.TupleKey, error) {
for {
tuple, err := f.iter.Head(ctx)
if err != nil {
if errors.Is(err, ErrIteratorDone) {
if f.onceValid || f.lastError == nil {
return nil, ErrIteratorDone
}
return nil, f.lastError
}
return nil, err
}
valid, err := f.filter(tuple)
if err != nil || !valid {
if err != nil {
f.lastError = err
}
// Note that we don't care about the item returned by Next() as this is already via Head(). We call Next() solely
// for the purpose of getting rid of the first item.
_, err = f.iter.Next(ctx)
if err != nil {
// This should never happen except if the underlying ds has error. This is because f.iter.Head() had already
// checked whether we are at the end of list. For example, in a list of [1] (all invalid),
// Head() will return 1. If it is invalid, Next() will return 1 and move the pointer to end of list.
// Thus, Head() will return ErrIteratorDone next time being called.
return nil, err
}
continue
}
f.onceValid = true
return tuple, nil
}
}
// NewConditionsFilteredTupleKeyIterator returns a [TupleKeyIterator] that filters out all
// [*openfgav1.Tuple](s) that don't meet the conditions of the provided [TupleKeyFilterFunc].
func NewConditionsFilteredTupleKeyIterator(iter TupleKeyIterator, filter TupleKeyConditionFilterFunc) TupleKeyIterator {
return &ConditionsFilteredTupleKeyIterator{
iter: iter,
filter: filter,
once: &sync.Once{},
}
}
type OrderedCombinedIterator struct {
mu *sync.Mutex
once *sync.Once
mapper TupleMapperFunc
pending []TupleIterator // GUARDED_BY(mu)
lastHead *openfgav1.Tuple // GUARDED_BY(mu)
lastYielded *openfgav1.Tuple // GUARDED_BY(mu)
}
var _ TupleIterator = (*OrderedCombinedIterator)(nil)
// NewOrderedCombinedIterator is a thread-safe iterator that combines a list of iterators into a single ordered iterator.
// All the input iterators must be individually ordered already according to mapper.
// Iterators can yield the same value (as defined by mapper) multiple times, but it will only be returned once.
func NewOrderedCombinedIterator(mapper TupleMapperFunc, sortedIters ...TupleIterator) *OrderedCombinedIterator {
pending := make([]TupleIterator, 0, len(sortedIters))
for _, sortedIter := range sortedIters {
if sortedIter != nil {
pending = append(pending, sortedIter)
}
}
return &OrderedCombinedIterator{pending: pending, once: &sync.Once{}, mu: &sync.Mutex{}, mapper: mapper}
}
func (c *OrderedCombinedIterator) Next(ctx context.Context) (*openfgav1.Tuple, error) {
c.mu.Lock()
defer c.mu.Unlock()
idx, err := c.head(ctx)
if err != nil {
return nil, err
}
c.lastHead = nil // invalidate it
t, err := c.pending[idx].Next(ctx)
if err != nil {
return nil, err
}
c.lastYielded = t
return t, nil
}
func (c *OrderedCombinedIterator) Head(ctx context.Context) (*openfgav1.Tuple, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.lastHead != nil {
return c.lastHead, nil
}
idx, err := c.head(ctx)
if err != nil {
return nil, err
}
lastHead, err := c.pending[idx].Head(ctx)
c.lastHead = lastHead
return c.lastHead, err
}
// head returns the index (within the pending array) that has the next smallest element, or -1 if there was an error.
// The pending array is mutated, and there may be nil elements in it after this function runs.
// Callers must use the returned index immediately, without mutating the pending array.
// NOTE: callers must hold mu.
func (c *OrderedCombinedIterator) head(ctx context.Context) (int, error) {
c.clearPendingThatAreNil()
var headMin *openfgav1.Tuple
minIdx := -1
IterateOverPending:
for pendingIdx, iter := range c.pending {
head, err := iter.Head(ctx)
if err != nil {
if errors.Is(err, ErrIteratorDone) {
iter.Stop()
c.pending[pendingIdx] = nil
continue
}
return -1, err
}
if c.lastYielded != nil {
if c.mapper(head) < c.mapper(c.lastYielded) {
return -1, fmt.Errorf("iterator %d is not in ascending order", pendingIdx)
}
// Discard duplicate values.
// We do this based on the previous Head() returned for performance reasons.
// If on every call to Head() we discarded, we would need to iterate twice over pending:
// one time to find the minIdx, and one time to move the corresponding iterators.
for c.mapper(head) == c.mapper(c.lastYielded) {
head, err = iter.Next(ctx)
if err != nil {
if errors.Is(err, ErrIteratorDone) {
iter.Stop()
c.pending[pendingIdx] = nil
continue IterateOverPending
}
return -1, err
}
}
}
// initialize or found a new lower value at head
if headMin == nil || c.mapper(headMin) > c.mapper(head) {
headMin = head
minIdx = pendingIdx
}
}
if minIdx == -1 {
return minIdx, ErrIteratorDone
}
return minIdx, nil
}
func (c *OrderedCombinedIterator) clearPendingThatAreNil() {
c.pending = slices.DeleteFunc(c.pending, func(t TupleIterator) bool {
return t == nil
})
}
func (c *OrderedCombinedIterator) Stop() {
c.once.Do(func() {
c.mu.Lock()
defer c.mu.Unlock()
c.clearPendingThatAreNil()
for _, iter := range c.pending {
iter.Stop()
}
})
}
// IterIsDoneOrCancelled is true if the error is due to done or cancelled or deadline exceeded.
func IterIsDoneOrCancelled(err error) bool {
return errors.Is(err, ErrIteratorDone) || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
}
package storage
import (
"context"
"fmt"
"sync"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/tuple"
)
type TupleMapperFunc func(t *openfgav1.Tuple) string
func UserMapper() TupleMapperFunc {
return func(t *openfgav1.Tuple) string {
return t.GetKey().GetUser()
}
}
func ObjectMapper() TupleMapperFunc {
return func(t *openfgav1.Tuple) string {
return t.GetKey().GetObject()
}
}
type TupleMapperKind int64
const (
// UsersetKind is a mapper that returns the userset ID from the tuple's user field.
UsersetKind TupleMapperKind = iota
// TTUKind is a mapper that returns the user field of the tuple.
TTUKind
// ObjectIDKind is mapper that returns the object field of the tuple.
ObjectIDKind
)
// TupleMapper is an iterator that, on calls to Next and Head, returns a mapping of the tuple.
type TupleMapper interface {
Iterator[string]
}
type UsersetMapper struct {
iter TupleKeyIterator
once *sync.Once
}
var _ TupleMapper = (*UsersetMapper)(nil)
func (n UsersetMapper) Next(ctx context.Context) (string, error) {
tupleRes, err := n.iter.Next(ctx)
if err != nil {
return "", err
}
return n.doMap(tupleRes)
}
func (n UsersetMapper) Stop() {
n.once.Do(func() {
n.iter.Stop()
})
}
func (n UsersetMapper) Head(ctx context.Context) (string, error) {
tupleRes, err := n.iter.Head(ctx)
if err != nil {
return "", err
}
return n.doMap(tupleRes)
}
func (n UsersetMapper) doMap(t *openfgav1.TupleKey) (string, error) {
usersetName, relation := tuple.SplitObjectRelation(t.GetUser())
if relation == "" && !tuple.IsWildcard(usersetName) {
// This should never happen because ReadUsersetTuples only returns usersets as users.
return "", fmt.Errorf("unexpected userset %s with no relation", t.GetUser())
}
return usersetName, nil
}
type TTUMapper struct {
iter TupleKeyIterator
once *sync.Once
}
var _ TupleMapper = (*TTUMapper)(nil)
func (n TTUMapper) Next(ctx context.Context) (string, error) {
tupleRes, err := n.iter.Next(ctx)
if err != nil {
return "", err
}
return n.doMap(tupleRes)
}
func (n TTUMapper) Stop() {
n.once.Do(func() {
n.iter.Stop()
})
}
func (n TTUMapper) Head(ctx context.Context) (string, error) {
tupleRes, err := n.iter.Head(ctx)
if err != nil {
return "", err
}
return n.doMap(tupleRes)
}
func (n TTUMapper) doMap(t *openfgav1.TupleKey) (string, error) {
return t.GetUser(), nil
}
type ObjectIDMapper struct {
iter TupleKeyIterator
once *sync.Once
}
var _ TupleMapper = (*ObjectIDMapper)(nil)
func (n ObjectIDMapper) Next(ctx context.Context) (string, error) {
tupleRes, err := n.iter.Next(ctx)
if err != nil {
return "", err
}
return n.doMap(tupleRes)
}
func (n ObjectIDMapper) Stop() {
n.once.Do(func() {
n.iter.Stop()
})
}
func (n ObjectIDMapper) Head(ctx context.Context) (string, error) {
tupleRes, err := n.iter.Head(ctx)
if err != nil {
return "", err
}
return n.doMap(tupleRes)
}
func (n ObjectIDMapper) doMap(t *openfgav1.TupleKey) (string, error) {
return t.GetObject(), nil
}
func WrapIterator(kind TupleMapperKind, iter TupleKeyIterator) TupleMapper {
switch kind {
case UsersetKind:
return &UsersetMapper{iter: iter, once: &sync.Once{}}
case TTUKind:
return &TTUMapper{iter: iter, once: &sync.Once{}}
case ObjectIDKind:
return &ObjectIDMapper{iter: iter, once: &sync.Once{}}
}
return nil
}
package telemetry
import (
"context"
)
type rpcContextName string
const (
rpcInfoContextName rpcContextName = "rpcInfo"
)
type RPCInfo struct {
Method string
Service string
}
// ContextWithRPCInfo will save the rpc method and service information in context.
func ContextWithRPCInfo(ctx context.Context, rpcInfo RPCInfo) context.Context {
return context.WithValue(ctx, rpcInfoContextName, rpcInfo)
}
// RPCInfoFromContext returns method and service stored in context.
func RPCInfoFromContext(ctx context.Context) RPCInfo {
rpcInfo, ok := ctx.Value(rpcInfoContextName).(RPCInfo)
if ok {
return rpcInfo
}
return RPCInfo{
Method: "unknown",
Service: "unknown",
}
}
// Package telemetry contains code that emits telemetry (logging, metrics, tracing).
package telemetry
import (
"context"
"errors"
"fmt"
"time"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
)
type TracerOption func(d *customTracer)
func WithOTLPEndpoint(endpoint string) TracerOption {
return func(d *customTracer) {
d.endpoint = endpoint
}
}
func WithOTLPInsecure() TracerOption {
return func(d *customTracer) {
d.insecure = true
}
}
func WithSamplingRatio(samplingRatio float64) TracerOption {
return func(d *customTracer) {
d.samplingRatio = samplingRatio
}
}
func WithAttributes(attrs ...attribute.KeyValue) TracerOption {
return func(d *customTracer) {
d.attributes = attrs
}
}
type customTracer struct {
endpoint string
insecure bool
attributes []attribute.KeyValue
samplingRatio float64
}
func MustNewTracerProvider(opts ...TracerOption) *sdktrace.TracerProvider {
tracer := &customTracer{
endpoint: "",
attributes: []attribute.KeyValue{},
samplingRatio: 0,
}
for _, opt := range opts {
opt(tracer)
}
res, err := resource.Merge(
resource.Default(),
resource.NewSchemaless(tracer.attributes...))
if err != nil {
panic(err)
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
options := []otlptracegrpc.Option{
otlptracegrpc.WithEndpoint(tracer.endpoint),
otlptracegrpc.WithDialOption(
// nolint:staticcheck // ignoring gRPC deprecations
grpc.WithBlock(),
),
}
if tracer.insecure {
options = append(options, otlptracegrpc.WithInsecure())
}
var exp sdktrace.SpanExporter
exp, err = otlptracegrpc.New(ctx, options...)
if err != nil {
panic(fmt.Sprintf("failed to establish a connection with the otlp exporter: %v", err))
}
tp := sdktrace.NewTracerProvider(
sdktrace.WithSampler(sdktrace.TraceIDRatioBased(tracer.samplingRatio)),
sdktrace.WithResource(res),
sdktrace.WithSpanProcessor(sdktrace.NewBatchSpanProcessor(exp)),
)
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}))
otel.SetTracerProvider(tp)
return tp
}
// TraceError marks the span as having an error, except if the error is context.Canceled,
// in which case it does nothing.
func TraceError(span trace.Span, err error) {
if errors.Is(err, context.Canceled) {
return
}
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
}
package storage
import (
"context"
"io"
"log"
"strings"
"testing"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/containerd/errdefs"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/image"
"github.com/docker/docker/client"
"github.com/docker/go-connections/nat"
"github.com/go-sql-driver/mysql"
"github.com/oklog/ulid/v2"
"github.com/pressly/goose/v3"
"github.com/stretchr/testify/require"
"github.com/openfga/openfga/assets"
)
const (
mySQLImage = "mysql:8"
)
type mySQLTestContainer struct {
addr string
version int64
username string
password string
}
// NewMySQLTestContainer returns an implementation of the DatastoreTestContainer interface
// for MySQL.
func NewMySQLTestContainer() *mySQLTestContainer {
return &mySQLTestContainer{}
}
func (m *mySQLTestContainer) GetDatabaseSchemaVersion() int64 {
return m.version
}
// RunMySQLTestContainer runs a MySQL container, connects to it, and returns a
// bootstrapped implementation of the DatastoreTestContainer interface wired up for the
// MySQL datastore engine.
func (m *mySQLTestContainer) RunMySQLTestContainer(t testing.TB) DatastoreTestContainer {
dockerClient, err := client.NewClientWithOpts(
client.FromEnv,
client.WithAPIVersionNegotiation(),
)
require.NoError(t, err)
t.Cleanup(func() {
dockerClient.Close()
})
allImages, err := dockerClient.ImageList(context.Background(), image.ListOptions{
All: true,
})
require.NoError(t, err)
foundMysqlImage := false
AllImages:
for _, image := range allImages {
for _, tag := range image.RepoTags {
if strings.Contains(tag, mySQLImage) {
foundMysqlImage = true
break AllImages
}
}
}
if !foundMysqlImage {
t.Logf("Pulling image %s", mySQLImage)
reader, err := dockerClient.ImagePull(context.Background(), mySQLImage, image.PullOptions{})
require.NoError(t, err)
_, err = io.Copy(io.Discard, reader) // consume the image pull output to make sure it's done
require.NoError(t, err)
}
containerCfg := container.Config{
Env: []string{
"MYSQL_DATABASE=defaultdb",
"MYSQL_ROOT_PASSWORD=secret",
},
ExposedPorts: nat.PortSet{
nat.Port("3306/tcp"): {},
},
Image: mySQLImage,
}
hostCfg := container.HostConfig{
AutoRemove: true,
PublishAllPorts: true,
Tmpfs: map[string]string{"/var/lib/mysql": ""},
}
name := "mysql-" + ulid.Make().String()
cont, err := dockerClient.ContainerCreate(context.Background(), &containerCfg, &hostCfg, nil, nil, name)
require.NoError(t, err, "failed to create mysql docker container")
t.Cleanup(func() {
t.Logf("stopping container %s", name)
timeoutSec := 5
err := dockerClient.ContainerStop(context.Background(), cont.ID, container.StopOptions{Timeout: &timeoutSec})
if err != nil && !errdefs.IsNotFound(err) {
t.Logf("failed to stop mysql container: %v", err)
}
t.Logf("stopped container %s", name)
})
err = dockerClient.ContainerStart(context.Background(), cont.ID, container.StartOptions{})
require.NoError(t, err, "failed to start mysql container")
containerJSON, err := dockerClient.ContainerInspect(context.Background(), cont.ID)
require.NoError(t, err)
p, ok := containerJSON.NetworkSettings.Ports["3306/tcp"]
if !ok || len(p) == 0 {
require.Fail(t, "failed to get host port mapping from mysql container")
}
mySQLTestContainer := &mySQLTestContainer{
addr: "localhost:" + p[0].HostPort,
username: "root",
password: "secret",
}
uri := mySQLTestContainer.username + ":" + mySQLTestContainer.password + "@tcp(" + mySQLTestContainer.addr + ")/defaultdb?parseTime=true"
err = mysql.SetLogger(log.New(io.Discard, "", 0))
require.NoError(t, err)
goose.SetLogger(goose.NopLogger())
db, err := goose.OpenDBWithDriver("mysql", uri)
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
backoffPolicy := backoff.NewExponentialBackOff()
backoffPolicy.MaxElapsedTime = 2 * time.Minute
err = backoff.Retry(
func() error {
return db.Ping()
},
backoffPolicy,
)
require.NoError(t, err, "failed to connect to mysql container")
goose.SetBaseFS(assets.EmbedMigrations)
err = goose.Up(db, assets.MySQLMigrationDir)
require.NoError(t, err)
version, err := goose.GetDBVersion(db)
require.NoError(t, err)
mySQLTestContainer.version = version
return mySQLTestContainer
}
// GetConnectionURI returns the mysql connection uri for the running mysql test container.
func (m *mySQLTestContainer) GetConnectionURI(includeCredentials bool) string {
creds := ""
if includeCredentials {
creds = m.username + ":" + m.password + "@"
}
return creds + "tcp(" + m.addr + ")/defaultdb?parseTime=true"
}
func (m *mySQLTestContainer) GetUsername() string {
return m.username
}
func (m *mySQLTestContainer) GetPassword() string {
return m.password
}
func (m *mySQLTestContainer) CreateSecondary(t testing.TB) error {
return nil
}
func (m *mySQLTestContainer) GetSecondaryConnectionURI(includeCredentials bool) string {
return ""
}
package storage
import (
"context"
"fmt"
"io"
"strings"
"testing"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/containerd/errdefs"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/image"
"github.com/docker/docker/client"
"github.com/docker/go-connections/nat"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver.
"github.com/oklog/ulid/v2"
"github.com/pressly/goose/v3"
"github.com/stretchr/testify/require"
"github.com/openfga/openfga/assets"
)
const (
postgresImage = "postgres:17"
)
type postgresTestContainer struct {
addr string
version int64
username string
password string
replica *postgresReplicaContainer
}
type postgresReplicaContainer struct {
addr string
username string
password string
}
// NewPostgresTestContainer returns an implementation of the DatastoreTestContainer interface
// for Postgres.
func NewPostgresTestContainer() *postgresTestContainer {
return &postgresTestContainer{}
}
func (p *postgresTestContainer) GetDatabaseSchemaVersion() int64 {
return p.version
}
// RunPostgresTestContainer runs a Postgres container, connects to it, and returns a
// bootstrapped implementation of the DatastoreTestContainer interface wired up for the
// Postgres datastore engine.
func (p *postgresTestContainer) RunPostgresTestContainer(t testing.TB) DatastoreTestContainer {
dockerClient, err := client.NewClientWithOpts(
client.FromEnv,
client.WithAPIVersionNegotiation(),
)
require.NoError(t, err)
t.Cleanup(func() {
dockerClient.Close()
})
allImages, err := dockerClient.ImageList(context.Background(), image.ListOptions{
All: true,
})
require.NoError(t, err)
foundPostgresImage := false
AllImages:
for _, image := range allImages {
for _, tag := range image.RepoTags {
if strings.Contains(tag, postgresImage) {
foundPostgresImage = true
break AllImages
}
}
}
if !foundPostgresImage {
t.Logf("Pulling image %s", postgresImage)
reader, err := dockerClient.ImagePull(context.Background(), postgresImage, image.PullOptions{})
require.NoError(t, err)
_, err = io.Copy(io.Discard, reader) // consume the image pull output to make sure it's done
require.NoError(t, err)
}
containerCfg := container.Config{
Env: []string{
"POSTGRES_DB=defaultdb",
"POSTGRES_PASSWORD=secret",
},
ExposedPorts: nat.PortSet{
nat.Port("5432/tcp"): {},
},
Image: postgresImage,
Cmd: []string{
"postgres",
"-c", "wal_level=replica",
"-c", "max_wal_senders=3",
"-c", "max_replication_slots=3",
"-c", "wal_keep_size=64MB",
"-c", "hot_standby=on",
},
}
hostCfg := container.HostConfig{
AutoRemove: true,
PublishAllPorts: true,
ExtraHosts: []string{"host.docker.internal:host-gateway"},
}
name := "postgres-" + ulid.Make().String()
cont, err := dockerClient.ContainerCreate(context.Background(), &containerCfg, &hostCfg, nil, nil, name)
require.NoError(t, err, "failed to create postgres docker container")
t.Cleanup(func() {
t.Logf("stopping container %s", name)
timeoutSec := 5
err := dockerClient.ContainerStop(context.Background(), cont.ID, container.StopOptions{Timeout: &timeoutSec})
if err != nil && !errdefs.IsNotFound(err) {
t.Logf("failed to stop postgres container: %v", err)
}
t.Logf("stopped container %s", name)
})
err = dockerClient.ContainerStart(context.Background(), cont.ID, container.StartOptions{})
require.NoError(t, err, "failed to start postgres container")
containerJSON, err := dockerClient.ContainerInspect(context.Background(), cont.ID)
require.NoError(t, err)
m, ok := containerJSON.NetworkSettings.Ports["5432/tcp"]
if !ok || len(m) == 0 {
require.Fail(t, "failed to get host port mapping from postgres container")
}
pgTestContainer := &postgresTestContainer{
addr: "localhost:" + m[0].HostPort,
username: "postgres",
password: "secret",
}
uri := fmt.Sprintf("postgres://%s:%s@%s/defaultdb?sslmode=disable", pgTestContainer.username, pgTestContainer.password, pgTestContainer.addr)
goose.SetLogger(goose.NopLogger())
db, err := goose.OpenDBWithDriver("pgx", uri)
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
backoffPolicy := backoff.NewExponentialBackOff()
backoffPolicy.MaxElapsedTime = 30 * time.Second
err = backoff.Retry(
func() error {
return db.Ping()
},
backoffPolicy,
)
require.NoError(t, err, "failed to connect to postgres container")
goose.SetBaseFS(assets.EmbedMigrations)
err = goose.Up(db, assets.PostgresMigrationDir)
require.NoError(t, err)
version, err := goose.GetDBVersion(db)
require.NoError(t, err)
pgTestContainer.version = version
return pgTestContainer
}
// GetConnectionURI returns the postgres connection uri for the running postgres test container.
func (p *postgresTestContainer) GetConnectionURI(includeCredentials bool) string {
creds := ""
if includeCredentials {
creds = fmt.Sprintf("%s:%s@", p.username, p.password)
}
return fmt.Sprintf(
"postgres://%s%s/%s?sslmode=disable",
creds,
p.addr,
"defaultdb",
)
}
func (p *postgresTestContainer) GetUsername() string {
return p.username
}
func (p *postgresTestContainer) GetPassword() string {
return p.password
}
// CreateSecondary creates a secondary PostgreSQL container.
func (p *postgresTestContainer) CreateSecondary(t testing.TB) error {
dockerClient, err := client.NewClientWithOpts(
client.FromEnv,
client.WithAPIVersionNegotiation(),
)
require.NoError(t, err)
t.Cleanup(func() {
dockerClient.Close()
})
// Configure the master for replication.
masterContainerID, err := p.getMasterContainerID(dockerClient)
require.NoError(t, err)
err = p.configureMasterForReplication(t, dockerClient, masterContainerID)
require.NoError(t, err)
// Wait for the master to be configured.
time.Sleep(3 * time.Second)
// Extract host and port from master for basebackup.
masterHost := "host.docker.internal"
masterPort := strings.Split(p.addr, ":")[1]
// Use standard PostgreSQL approach with docker-entrypoint-initdb.d.
containerCfg := container.Config{
Env: []string{
"POSTGRES_DB=defaultdb",
"POSTGRES_PASSWORD=secret",
"PGPASSWORD=secret",
"POSTGRES_INITDB_ARGS=--auth-host=trust",
"POSTGRES_MASTER_HOST=" + masterHost,
"POSTGRES_MASTER_PORT=" + masterPort,
},
ExposedPorts: nat.PortSet{
nat.Port("5432/tcp"): {},
},
Image: postgresImage,
Entrypoint: []string{"/bin/bash", "-c"},
Cmd: []string{fmt.Sprintf(`
set -e
export PGPASSWORD=secret
echo "Initializing PostgreSQL replica..."
# Wait for master to be ready
until pg_isready -h %s -p %s -U postgres; do
echo "Waiting for master..."
sleep 2
done
echo "Master ready, creating base backup..."
# Remove default PGDATA content
rm -rf $PGDATA/*
# Create base backup
pg_basebackup -h %s -p %s -U postgres -D $PGDATA -Fp -Xs -P -R
# Configure as replica
echo "hot_standby = on" >> $PGDATA/postgresql.conf
touch $PGDATA/standby.signal
echo "Starting PostgreSQL replica..."
exec docker-entrypoint.sh postgres -c hot_standby=on -c wal_level=replica
`, masterHost, masterPort, masterHost, masterPort)},
}
hostCfg := container.HostConfig{
AutoRemove: true,
PublishAllPorts: true,
ExtraHosts: []string{"host.docker.internal:host-gateway"},
}
name := "postgres-replica-" + ulid.Make().String()
cont, err := dockerClient.ContainerCreate(context.Background(), &containerCfg, &hostCfg, nil, nil, name)
require.NoError(t, err, "failed to create postgres replica docker container")
t.Cleanup(func() {
t.Logf("stopping replica container %s", name)
timeoutSec := 5
err := dockerClient.ContainerStop(context.Background(), cont.ID, container.StopOptions{Timeout: &timeoutSec})
if err != nil && !errdefs.IsNotFound(err) {
t.Logf("failed to stop postgres replica container: %v", err)
}
t.Logf("stopped replica container %s", name)
})
err = dockerClient.ContainerStart(context.Background(), cont.ID, container.StartOptions{})
require.NoError(t, err, "failed to start postgres replica container")
containerJSON, err := dockerClient.ContainerInspect(context.Background(), cont.ID)
require.NoError(t, err)
m, ok := containerJSON.NetworkSettings.Ports["5432/tcp"]
if !ok || len(m) == 0 {
require.Fail(t, "failed to get host port mapping from postgres replica container")
}
p.replica = &postgresReplicaContainer{
addr: "localhost:" + m[0].HostPort,
username: "postgres",
password: "secret",
}
// Wait for replica to be ready and synchronized.
err = p.waitForReplicaSync(t)
require.NoError(t, err, "failed to sync replica")
return nil
}
// getMasterContainerID finds the master container ID.
func (p *postgresTestContainer) getMasterContainerID(dockerClient *client.Client) (string, error) {
containers, err := dockerClient.ContainerList(context.Background(), container.ListOptions{})
if err != nil {
return "", err
}
for _, cont := range containers {
for _, name := range cont.Names {
if strings.Contains(name, "postgres-") && !strings.Contains(name, "replica") && !strings.Contains(name, "basebackup") {
return cont.ID, nil
}
}
}
return "", fmt.Errorf("master container not found")
}
// configureMasterForReplication configures the master to accept replication connections.
func (p *postgresTestContainer) configureMasterForReplication(t testing.TB, dockerClient *client.Client, masterContainerID string) error {
// Configuration for streaming replication - only pg_hba.conf and reload.
commands := [][]string{
{"sh", "-c", "echo 'host replication postgres all trust' >> /var/lib/postgresql/data/pg_hba.conf"},
{"psql", "-U", "postgres", "-d", "defaultdb", "-c", "SELECT pg_reload_conf()"},
}
for _, cmd := range commands {
execConfig := container.ExecOptions{
Cmd: cmd,
}
exec, err := dockerClient.ContainerExecCreate(context.Background(), masterContainerID, execConfig)
if err != nil {
return fmt.Errorf("failed to create exec for command %v: %w", cmd, err)
}
err = dockerClient.ContainerExecStart(context.Background(), exec.ID, container.ExecStartOptions{})
if err != nil {
return fmt.Errorf("failed to execute command %v: %w", cmd, err)
}
// Wait for command to complete.
inspect, err := dockerClient.ContainerExecInspect(context.Background(), exec.ID)
if err != nil {
return fmt.Errorf("failed to inspect exec %v: %w", cmd, err)
}
if inspect.ExitCode != 0 {
t.Logf("Command %v completed with exit code %d", cmd, inspect.ExitCode)
}
}
return nil
}
// waitForReplicaSync waits for the replica to be synchronized with the master.
func (p *postgresTestContainer) waitForReplicaSync(t testing.TB) error {
uri := fmt.Sprintf("postgres://%s:%s@%s/defaultdb?sslmode=disable", p.replica.username, p.replica.password, p.replica.addr)
backoffPolicy := backoff.NewExponentialBackOff()
backoffPolicy.MaxElapsedTime = 120 * time.Second // Increase to 2 minutes for initialization.
backoffPolicy.InitialInterval = 2 * time.Second // Start with 2 seconds.
backoffPolicy.MaxInterval = 10 * time.Second // Cap at 10 seconds.
return backoff.Retry(
func() error {
db, err := goose.OpenDBWithDriver("pgx", uri)
if err != nil {
t.Logf("Connection to replica failed (expected during initialization): %v", err)
return fmt.Errorf("failed to connect to replica: %w", err)
}
defer db.Close()
// Check that replica is in recovery mode (standby)
var inRecovery bool
err = db.QueryRow("SELECT pg_is_in_recovery()").Scan(&inRecovery)
if err != nil {
t.Logf("Failed to check recovery status (replica may still be initializing): %v", err)
return fmt.Errorf("failed to check recovery status: %w", err)
}
if !inRecovery {
return fmt.Errorf("replica is not in recovery mode")
}
// Check that replica is receiving WAL
var replicaLSN *string
err = db.QueryRow("SELECT pg_last_wal_receive_lsn()").Scan(&replicaLSN)
if err != nil {
t.Logf("Failed to get replica LSN: %v", err)
return fmt.Errorf("failed to get replica LSN: %w", err)
}
if replicaLSN == nil || *replicaLSN == "" {
return fmt.Errorf("replica has not received any WAL yet")
}
t.Logf("Replica is synchronized and receiving WAL at LSN: %s", *replicaLSN)
return nil
},
backoffPolicy,
)
}
// GetSecondaryConnectionURI returns the connection URI for the read replica.
func (p *postgresTestContainer) GetSecondaryConnectionURI(includeCredentials bool) string {
if p.replica == nil {
return ""
}
creds := ""
if includeCredentials {
creds = fmt.Sprintf("%s:%s@", p.replica.username, p.replica.password)
}
return fmt.Sprintf(
"postgres://%s%s/%s?sslmode=disable",
creds,
p.replica.addr,
"defaultdb",
)
}
package storage
import (
"fmt"
"os"
"path/filepath"
"testing"
"github.com/pressly/goose/v3"
"github.com/stretchr/testify/require"
"github.com/openfga/openfga/assets"
)
type sqliteTestContainer struct {
path string
version int64
}
// NewSqliteTestContainer returns an implementation of the DatastoreTestContainer interface
// for SQLite.
func NewSqliteTestContainer() *sqliteTestContainer {
return &sqliteTestContainer{}
}
func (m *sqliteTestContainer) GetDatabaseSchemaVersion() int64 {
return m.version
}
// RunSqliteTestContainer creates a sqlite database file, and returns a
// bootstrapped implementation of the DatastoreTestContainer interface wired up for the
// Sqlite datastore engine.
func (m *sqliteTestContainer) RunSqliteTestDatabase(t testing.TB) DatastoreTestContainer {
dbDir, err := os.MkdirTemp("", "openfga-test-sqlite-*")
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, os.RemoveAll(dbDir)) })
m.path = filepath.Join(dbDir, "database.db")
uri := m.GetConnectionURI(true)
goose.SetLogger(goose.NopLogger())
db, err := goose.OpenDBWithDriver("sqlite", uri)
require.NoError(t, err)
defer db.Close()
goose.SetBaseFS(assets.EmbedMigrations)
err = goose.Up(db, assets.SqliteMigrationDir)
require.NoError(t, err)
version, err := goose.GetDBVersion(db)
require.NoError(t, err)
m.version = version
err = db.Close()
require.NoError(t, err)
return m
}
// GetConnectionURI returns the sqlite connection uri for the running sqlite test container.
func (m *sqliteTestContainer) GetConnectionURI(includeCredentials bool) string {
return fmt.Sprintf("file:%s?_pragma=journal_mode(WAL)&_pragma=busy_timeout(100)", m.path)
}
func (m *sqliteTestContainer) GetUsername() string {
return ""
}
func (m *sqliteTestContainer) GetPassword() string {
return ""
}
func (m *sqliteTestContainer) CreateSecondary(t testing.TB) error {
return nil
}
func (m *sqliteTestContainer) GetSecondaryConnectionURI(includeCredentials bool) string {
return ""
}
// Package storage contains containers that can be used to test all available data stores.
package storage
import (
"testing"
)
// DatastoreTestContainer represents a runnable container for testing specific datastore engines.
type DatastoreTestContainer interface {
// GetConnectionURI returns a connection string to the datastore instance running inside
// the container.
GetConnectionURI(includeCredentials bool) string
// GetDatabaseSchemaVersion returns the last migration applied (e.g. 3) when the container was created
GetDatabaseSchemaVersion() int64
GetUsername() string
GetPassword() string
// CreateSecondary creates a secondary datastore if supported.
// Returns an error if the operation fails or if the datastore doesn't support secondary datastores.
CreateSecondary(t testing.TB) error
// GetSecondaryConnectionURI returns the connection URI for the secondary datastore if one exists.
// Returns an empty string if no secondary datastore exists.
GetSecondaryConnectionURI(includeCredentials bool) string
}
type memoryTestContainer struct{}
func (m memoryTestContainer) GetConnectionURI(includeCredentials bool) string {
return ""
}
func (m memoryTestContainer) GetUsername() string {
return ""
}
func (m memoryTestContainer) GetPassword() string {
return ""
}
func (m memoryTestContainer) GetDatabaseSchemaVersion() int64 {
return 1
}
func (m memoryTestContainer) CreateSecondary(t testing.TB) error {
return nil
}
func (m memoryTestContainer) GetSecondaryConnectionURI(includeCredentials bool) string {
return ""
}
// RunDatastoreTestContainer constructs and runs a specific DatastoreTestContainer for the provided
// datastore engine. If applicable, it also runs all existing database migrations.
// The resources used by the test engine will be cleaned up after the test has finished.
func RunDatastoreTestContainer(t testing.TB, engine string) DatastoreTestContainer {
switch engine {
case "mysql":
return NewMySQLTestContainer().RunMySQLTestContainer(t)
case "postgres":
return NewPostgresTestContainer().RunPostgresTestContainer(t)
case "memory":
return memoryTestContainer{}
case "sqlite":
return NewSqliteTestContainer().RunSqliteTestDatabase(t)
default:
t.Fatalf("unsupported datastore engine: %q", engine)
return nil
}
}
// Package testutils contains code that is useful in tests.
package testutils
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"net/http"
"sort"
"strconv"
"strings"
"testing"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/go-cmp/cmp"
"github.com/hashicorp/go-retryablehttp"
"github.com/oklog/ulid/v2"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
grpcbackoff "google.golang.org/grpc/backoff"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
healthv1pb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/protobuf/types/known/structpb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
parser "github.com/openfga/language/pkg/go/transformer"
serverconfig "github.com/openfga/openfga/pkg/server/config"
"github.com/openfga/openfga/pkg/tuple"
)
const (
AllChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
)
var (
TupleCmpTransformer = cmp.Transformer("Sort", func(in []*openfgav1.Tuple) []*openfgav1.Tuple {
out := append([]*openfgav1.Tuple(nil), in...) // Copy input to avoid mutating it
sort.SliceStable(out, func(i, j int) bool {
if out[i].GetKey().GetObject() != out[j].GetKey().GetObject() {
return out[i].GetKey().GetObject() < out[j].GetKey().GetObject()
}
if out[i].GetKey().GetRelation() != out[j].GetKey().GetRelation() {
return out[i].GetKey().GetRelation() < out[j].GetKey().GetRelation()
}
if out[i].GetKey().GetUser() != out[j].GetKey().GetUser() {
return out[i].GetKey().GetUser() < out[j].GetKey().GetUser()
}
return true
})
return out
})
TupleKeyCmpTransformer = cmp.Transformer("Sort", func(in []*openfgav1.TupleKey) []*openfgav1.TupleKey {
out := append([]*openfgav1.TupleKey(nil), in...) // Copy input to avoid mutating it
sort.SliceStable(out, func(i, j int) bool {
if out[i].GetObject() != out[j].GetObject() {
return out[i].GetObject() < out[j].GetObject()
}
if out[i].GetRelation() != out[j].GetRelation() {
return out[i].GetRelation() < out[j].GetRelation()
}
if out[i].GetUser() != out[j].GetUser() {
return out[i].GetUser() < out[j].GetUser()
}
return true
})
return out
})
)
func ConvertTuplesToTupleKeys(input []*openfgav1.Tuple) []*openfgav1.TupleKey {
converted := make([]*openfgav1.TupleKey, len(input))
for i := range input {
converted[i] = input[i].GetKey()
}
return converted
}
func ConvertTuplesKeysToTuples(input []*openfgav1.TupleKey) []*openfgav1.Tuple {
converted := make([]*openfgav1.Tuple, len(input))
for i := range input {
converted[i] = &openfgav1.Tuple{Key: tuple.NewTupleKey(input[i].GetObject(), input[i].GetRelation(), input[i].GetUser())}
}
return converted
}
// Shuffle returns the input but with order of elements randomized.
func Shuffle(arr []*openfgav1.TupleKey) []*openfgav1.TupleKey {
// copy array to avoid data races :(
copied := make([]*openfgav1.TupleKey, len(arr))
for i := range arr {
copied[i] = tuple.NewTupleKeyWithCondition(arr[i].GetObject(),
arr[i].GetRelation(),
arr[i].GetUser(),
arr[i].GetCondition().GetName(),
arr[i].GetCondition().GetContext(),
)
}
rand.Shuffle(len(copied), func(i, j int) {
copied[i], copied[j] = copied[j], copied[i]
})
return copied
}
func CreateRandomString(n int) string {
b := make([]byte, n)
for i := range b {
b[i] = AllChars[rand.Intn(len(AllChars))]
}
return string(b)
}
func MustNewStruct(t require.TestingT, v map[string]interface{}) *structpb.Struct {
conditionContext, err := structpb.NewStruct(v)
require.NoError(t, err)
return conditionContext
}
// MakeSliceWithGenerator generates a slice of length 'n' and populates the contents
// with values based on the generator provided.
func MakeSliceWithGenerator[T any](n uint64, generator func(n uint64) any) []T {
s := make([]T, 0, n)
for i := uint64(0); i < n; i++ {
s = append(s, generator(i).(T))
}
return s
}
// NumericalStringGenerator generates a string representation of the provided
// uint value.
func NumericalStringGenerator(n uint64) any {
return strconv.FormatUint(n, 10)
}
func MakeStringWithRuneset(n uint64, runeSet []rune) string {
var sb strings.Builder
for i := uint64(0); i < n; i++ {
sb.WriteRune(runeSet[rand.Intn(len(runeSet))])
}
return sb.String()
}
// MustTransformDSLToProtoWithID interprets the provided string s as an FGA model and
// attempts to parse it using the official OpenFGA language parser. The model returned
// includes an auto-generated model id which assists with producing models for testing
// purposes.
func MustTransformDSLToProtoWithID(s string) *openfgav1.AuthorizationModel {
model := parser.MustTransformDSLToProto(s)
model.Id = ulid.Make().String()
return model
}
// CreateGrpcConnection creates a grpc connection to an address and closes it when the test ends.
func CreateGrpcConnection(t *testing.T, grpcAddress string, opts ...grpc.DialOption) *grpc.ClientConn {
t.Helper()
defaultOptions := []grpc.DialOption{
grpc.WithConnectParams(grpc.ConnectParams{Backoff: grpcbackoff.DefaultConfig}),
grpc.WithTransportCredentials(insecure.NewCredentials()),
}
defaultOptions = append(defaultOptions, opts...)
// nolint:staticcheck // ignoring gRPC deprecations
conn, err := grpc.Dial(
grpcAddress, defaultOptions...,
)
require.NoError(t, err)
t.Cleanup(func() {
conn.Close()
})
return conn
}
// EnsureServiceHealthy is a test helper that ensures that a service's grpc and http health endpoints are responding OK.
// If the http address is empty, it doesn't check the http health endpoint.
// If the service doesn't respond healthy in 30 seconds it fails the test.
func EnsureServiceHealthy(t testing.TB, grpcAddr, httpAddr string, transportCredentials credentials.TransportCredentials) {
t.Helper()
creds := insecure.NewCredentials()
if transportCredentials != nil {
creds = transportCredentials
}
dialOpts := []grpc.DialOption{
grpc.WithTransportCredentials(creds),
grpc.WithConnectParams(grpc.ConnectParams{Backoff: grpcbackoff.DefaultConfig}),
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
t.Log("creating connection to address", grpcAddr)
// nolint:staticcheck // ignoring gRPC deprecations
conn, err := grpc.DialContext(
ctx,
grpcAddr,
dialOpts...,
)
require.NoError(t, err, "error creating grpc connection to server")
t.Cleanup(func() {
conn.Close()
})
client := healthv1pb.NewHealthClient(conn)
policy := backoff.NewExponentialBackOff()
policy.MaxElapsedTime = 30 * time.Second
err = backoff.Retry(func() error {
resp, err := client.Check(context.Background(), &healthv1pb.HealthCheckRequest{
Service: openfgav1.OpenFGAService_ServiceDesc.ServiceName,
})
if err != nil {
t.Log(time.Now(), "not serving yet at address", grpcAddr, err)
return err
}
if resp.GetStatus() != healthv1pb.HealthCheckResponse_SERVING {
t.Log(time.Now(), resp.GetStatus())
return errors.New("not serving")
}
return nil
}, policy)
require.NoError(t, err, "server did not reach healthy status")
if httpAddr != "" {
resp, err := retryablehttp.Get(fmt.Sprintf("http://%s/healthz", httpAddr))
require.NoError(t, err, "http endpoint not healthy")
t.Cleanup(func() {
err := resp.Body.Close()
require.NoError(t, err)
})
require.Equal(t, http.StatusOK, resp.StatusCode, "unexpected status code received from server")
}
}
// MustDefaultConfigWithRandomPorts returns default server config but with random ports for the grpc and http addresses
// and with the playground, tracing and metrics turned off.
// This function may panic if somehow a random port cannot be chosen.
func MustDefaultConfigWithRandomPorts() *serverconfig.Config {
config := serverconfig.MustDefaultConfig()
config.Experimentals = append(config.Experimentals, "enable-check-optimizations", "enable-list-objects-optimizations")
httpPort, httpPortReleaser := TCPRandomPort()
defer httpPortReleaser()
grpcPort, grpcPortReleaser := TCPRandomPort()
defer grpcPortReleaser()
config.GRPC.Addr = fmt.Sprintf("localhost:%d", grpcPort)
config.HTTP.Addr = fmt.Sprintf("localhost:%d", httpPort)
return config
}
// TCPRandomPort tries to find a random TCP Port. If it can't find one, it panics. Else, it returns the port and a function that releases the port.
// It is the responsibility of the caller to call the release function right before trying to listen on the given port.
func TCPRandomPort() (int, func()) {
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
panic(err)
}
return l.Addr().(*net.TCPAddr).Port, func() {
l.Close()
}
}
// Package tuple contains code to manipulate tuples and errors related to tuples.
package tuple
import (
"fmt"
"regexp"
"strings"
"google.golang.org/protobuf/types/known/structpb"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
)
type Tuple openfgav1.TupleKey
func (t *Tuple) GetObject() string {
return (*openfgav1.TupleKey)(t).GetObject()
}
func (t *Tuple) GetRelation() string {
return (*openfgav1.TupleKey)(t).GetRelation()
}
func (t *Tuple) GetUser() string {
return (*openfgav1.TupleKey)(t).GetUser()
}
func (t *Tuple) String() string {
tk := (*openfgav1.TupleKey)(t)
return tk.GetObject() +
"#" +
tk.GetRelation() +
"@" +
tk.GetUser()
}
func From(tk *openfgav1.TupleKey) *Tuple {
return (*Tuple)(tk)
}
type TupleKeys []*openfgav1.TupleKey
// Len is a method that is required to implement the
// sort.Interface interface. Len returns the number
// of elements in the slice.
func (tk TupleKeys) Len() int {
return len(tk)
}
// Less is a method that is required to implement the
// sort.Interface interface. Less returns true when the
// value at index i is less than the value at index j.
// Tuples are compared first by their object, then their
// relation, then their user, and finally their condition.
// If Less(i, j) returns false and Less(j, i) returns false,
// then the tuples are equal.
func (tk TupleKeys) Less(i, j int) bool {
if tk[i].GetObject() != tk[j].GetObject() {
return tk[i].GetObject() < tk[j].GetObject()
}
if tk[i].GetRelation() != tk[j].GetRelation() {
return tk[i].GetRelation() < tk[j].GetRelation()
}
if tk[i].GetUser() != tk[j].GetUser() {
return tk[i].GetUser() < tk[j].GetUser()
}
cond1 := tk[i].GetCondition()
cond2 := tk[j].GetCondition()
if (cond1 != nil || cond2 != nil) && cond1.GetName() != cond2.GetName() {
return cond1.GetName() < cond2.GetName()
}
// Note: conditions also optionally have context structs, but we aren't sorting by context
return true
}
// Swap is a method that is required to implement the
// sort.Interface interface. Swap exchanges the values
// at slice indexes i and j.
func (tk TupleKeys) Swap(i, j int) {
tk[i], tk[j] = tk[j], tk[i]
}
type TupleWithCondition interface {
TupleWithoutCondition
GetCondition() *openfgav1.RelationshipCondition
}
type TupleWithoutCondition interface {
GetUser() string
GetObject() string
GetRelation() string
String() string
}
type UserType string
const (
User UserType = "user"
UserSet UserType = "userset"
Wildcard = "*"
)
var (
userIDRegex = regexp.MustCompile(`^[^:#\s]+$`)
objectRegex = regexp.MustCompile(`^[^:#\s]+:[^#:\s]+$`)
userSetRegex = regexp.MustCompile(`^[^:#\s]+:[^#:*\s]+#[^:#*\s]+$`)
relationRegex = regexp.MustCompile(`^[^:#@\s]+$`)
)
func ConvertCheckRequestTupleKeyToTupleKey(tk *openfgav1.CheckRequestTupleKey) *openfgav1.TupleKey {
return &openfgav1.TupleKey{
Object: tk.GetObject(),
Relation: tk.GetRelation(),
User: tk.GetUser(),
}
}
func ConvertAssertionTupleKeyToTupleKey(tk *openfgav1.AssertionTupleKey) *openfgav1.TupleKey {
return &openfgav1.TupleKey{
Object: tk.GetObject(),
Relation: tk.GetRelation(),
User: tk.GetUser(),
}
}
func ConvertReadRequestTupleKeyToTupleKey(tk *openfgav1.ReadRequestTupleKey) *openfgav1.TupleKey {
return &openfgav1.TupleKey{
Object: tk.GetObject(),
Relation: tk.GetRelation(),
User: tk.GetUser(),
}
}
func TupleKeyToTupleKeyWithoutCondition(tk *openfgav1.TupleKey) *openfgav1.TupleKeyWithoutCondition {
return &openfgav1.TupleKeyWithoutCondition{
Object: tk.GetObject(),
Relation: tk.GetRelation(),
User: tk.GetUser(),
}
}
func TupleKeyWithoutConditionToTupleKey(tk *openfgav1.TupleKeyWithoutCondition) *openfgav1.TupleKey {
return &openfgav1.TupleKey{
Object: tk.GetObject(),
Relation: tk.GetRelation(),
User: tk.GetUser(),
}
}
func TupleKeysWithoutConditionToTupleKeys(tks ...*openfgav1.TupleKeyWithoutCondition) []*openfgav1.TupleKey {
converted := make([]*openfgav1.TupleKey, 0, len(tks))
for _, tk := range tks {
converted = append(converted, TupleKeyWithoutConditionToTupleKey(tk))
}
return converted
}
func NewTupleKey(object, relation, user string) *openfgav1.TupleKey {
return &openfgav1.TupleKey{
Object: object,
Relation: relation,
User: user,
}
}
func NewTupleKeyWithCondition(
object, relation, user, conditionName string,
context *structpb.Struct,
) *openfgav1.TupleKey {
return &openfgav1.TupleKey{
Object: object,
Relation: relation,
User: user,
Condition: NewRelationshipCondition(conditionName, context),
}
}
func NewRelationshipCondition(name string, context *structpb.Struct) *openfgav1.RelationshipCondition {
if name == "" {
return nil
}
if context == nil {
return &openfgav1.RelationshipCondition{
Name: name,
Context: &structpb.Struct{},
}
}
return &openfgav1.RelationshipCondition{
Name: name,
Context: context,
}
}
func NewAssertionTupleKey(object, relation, user string) *openfgav1.AssertionTupleKey {
return &openfgav1.AssertionTupleKey{
Object: object,
Relation: relation,
User: user,
}
}
func NewCheckRequestTupleKey(object, relation, user string) *openfgav1.CheckRequestTupleKey {
return &openfgav1.CheckRequestTupleKey{
Object: object,
Relation: relation,
User: user,
}
}
func NewExpandRequestTupleKey(object, relation string) *openfgav1.ExpandRequestTupleKey {
return &openfgav1.ExpandRequestTupleKey{
Object: object,
Relation: relation,
}
}
// ObjectKey returns the canonical key for the provided Object. The ObjectKey of an object
// is the string 'objectType:objectId'.
func ObjectKey(obj *openfgav1.Object) string {
return BuildObject(obj.GetType(), obj.GetId())
}
type UserString = string
// UserProtoToString returns a string from a User proto. Ex: 'user:maria' or 'group:fga#member'. It is
// the opposite of StringToUserProto function.
func UserProtoToString(obj *openfgav1.User) UserString {
switch obj.GetUser().(type) {
case *openfgav1.User_Wildcard:
return obj.GetWildcard().GetType() + ":*"
case *openfgav1.User_Userset:
us := obj.GetUser().(*openfgav1.User_Userset)
return us.Userset.GetType() + ":" + us.Userset.GetId() + "#" + us.Userset.GetRelation()
case *openfgav1.User_Object:
us := obj.GetUser().(*openfgav1.User_Object)
return us.Object.GetType() + ":" + us.Object.GetId()
default:
panic("unsupported type")
}
}
// StringToUserProto returns a User proto from a string. Ex: 'user:maria#member'.
// It is the opposite of UserProtoToString function.
func StringToUserProto(userKey UserString) *openfgav1.User {
userObj, userRel := SplitObjectRelation(userKey)
userObjType, userObjID := SplitObject(userObj)
if userRel == "" && userObjID == "*" {
return &openfgav1.User{User: &openfgav1.User_Wildcard{
Wildcard: &openfgav1.TypedWildcard{
Type: userObjType,
},
}}
}
if userRel == "" {
return &openfgav1.User{User: &openfgav1.User_Object{Object: &openfgav1.Object{
Type: userObjType,
Id: userObjID,
}}}
}
return &openfgav1.User{User: &openfgav1.User_Userset{Userset: &openfgav1.UsersetUser{
Type: userObjType,
Id: userObjID,
Relation: userRel,
}}}
}
// SplitObject splits an object into an objectType, an optional objectRelation, and an objectID.
// E.g.
// 1. "group:fga" returns "group" and "fga".
// 2. "group#member:fga" returns "group#member" and "fga".
// 3. "anne" returns "" and "anne".
func SplitObject(object string) (string, string) {
switch i := strings.IndexByte(object, ':'); i {
case -1:
return "", object
case len(object) - 1:
return object[0:i], ""
default:
return object[0:i], object[i+1:]
}
}
func BuildObject(objectType, objectID string) string {
return objectType + ":" + objectID
}
// GetObjectRelationAsString returns a string like "object#relation". If there is no relation it returns "object".
func GetObjectRelationAsString(objectRelation *openfgav1.ObjectRelation) string {
if objectRelation.GetRelation() != "" {
return objectRelation.GetObject() + "#" + objectRelation.GetRelation()
}
return objectRelation.GetObject()
}
// SplitObjectRelation splits an object relation string into an object ID and relation name. If no relation is present,
// it returns the original string and an empty relation.
func SplitObjectRelation(objectRelation string) (string, string) {
switch i := strings.LastIndexByte(objectRelation, '#'); i {
case -1:
return objectRelation, ""
case len(objectRelation) - 1:
return objectRelation[0:i], ""
default:
return objectRelation[0:i], objectRelation[i+1:]
}
}
// GetType returns the type from a supplied Object identifier or an empty string if the object id does not contain a
// type.
func GetType(objectID string) string {
t, _ := SplitObject(objectID)
return t
}
// GetRelation returns the 'relation' portion of an object relation string (e.g. `object#relation`), which may be empty if the input is malformed
// (or does not contain a relation).
func GetRelation(objectRelation string) string {
_, relation := SplitObjectRelation(objectRelation)
return relation
}
// IsObjectRelation returns true if the given string specifies a valid object and relation.
func IsObjectRelation(userset string) bool {
return GetType(userset) != "" && GetRelation(userset) != ""
}
// ToObjectRelationString formats an object/relation pair as an object#relation string. This is the inverse of
// SplitObjectRelation.
func ToObjectRelationString(object, relation string) string {
return object + "#" + relation
}
// GetUserTypeFromUser returns the type of user (userset or user).
func GetUserTypeFromUser(user string) UserType {
if IsObjectRelation(user) || IsWildcard(user) {
return UserSet
}
return User
}
// TupleKeyToString converts a tuple key into its string representation. It assumes the tupleKey is valid
// (i.e. no forbidden characters).
func TupleKeyToString(tk TupleWithoutCondition) string {
return tk.GetObject() +
"#" +
tk.GetRelation() +
"@" +
tk.GetUser()
}
// TupleKeyWithConditionToString converts a tuple key with condition into its string representation. It assumes the tupleKey is valid
// (i.e. no forbidden characters).
func TupleKeyWithConditionToString(tk TupleWithCondition) string {
var sb strings.Builder
sb.WriteString(TupleKeyToString(tk))
if tk.GetCondition() != nil {
sb.WriteString(" (condition " + tk.GetCondition().GetName() + ")")
}
return sb.String()
}
// IsValidObject determines if a string s is a valid object. A valid object contains exactly one `:` and no `#` or spaces.
func IsValidObject(s string) bool {
return objectRegex.MatchString(s)
}
// IsValidRelation determines if a string s is a valid relation. This means it does not contain any `:`, `#`, or spaces.
func IsValidRelation(s string) bool {
return relationRegex.MatchString(s)
}
// IsValidUser determines if a string is a valid user. A valid user contains at most one `:`, at most one `#` and no spaces.
func IsValidUser(user string) bool {
if user == Wildcard || userIDRegex.MatchString(user) || objectRegex.MatchString(user) || userSetRegex.MatchString(user) {
return true
}
return false
}
// IsWildcard returns true if the string 's' could be interpreted as a typed or untyped wildcard (e.g. '*' or 'type:*').
func IsWildcard(s string) bool {
return s == Wildcard || IsTypedWildcard(s)
}
// IsTypedWildcard returns true if the string 's' is a typed wildcard. A typed wildcard
// has the form 'type:*'.
func IsTypedWildcard(s string) bool {
t, id := SplitObject(s)
return t != "" && id == Wildcard
}
// TypedPublicWildcard returns the string tuple representation for a given object type (ex: "user:*").
func TypedPublicWildcard(objectType string) string {
return BuildObject(objectType, Wildcard)
}
// MustParseTupleString attempts to parse a relationship tuple specified
// in string notation and return the protobuf TupleKey for it. If parsing
// of the string fails this function will panic. It is meant for testing
// purposes.
//
// Given string 'document:1#viewer@user:jon', return the protobuf TupleKey
// for it.
func MustParseTupleString(s string) *openfgav1.TupleKey {
t, err := ParseTupleString(s)
if err != nil {
panic(err)
}
return t
}
func MustParseTupleStrings(tupleStrs ...string) []*openfgav1.TupleKey {
tuples := make([]*openfgav1.TupleKey, 0, len(tupleStrs))
for _, tupleStr := range tupleStrs {
tuples = append(tuples, MustParseTupleString(tupleStr))
}
return tuples
}
// ParseTupleString attempts to parse a relationship tuple specified
// in string notation and return the protobuf TupleKey for it. If parsing
// of the string fails this function returns an err.
//
// Given string 'document:1#viewer@user:jon', return the protobuf TupleKey
// for it or an error.
func ParseTupleString(s string) (*openfgav1.TupleKey, error) {
object, rhs, found := strings.Cut(s, "#")
if !found {
return nil, fmt.Errorf("expected at least one '#' separating the object and relation")
}
if !IsValidObject(object) {
return nil, fmt.Errorf("invalid tuple 'object' field format")
}
relation, user, found := strings.Cut(rhs, "@")
if !found {
return nil, fmt.Errorf("expected at least one '@' separating the relation and user")
}
if !IsValidRelation(relation) {
return nil, fmt.Errorf("invalid tuple 'relation' field format")
}
if !IsValidUser(user) {
return nil, fmt.Errorf("invalid tuple 'user' field format")
}
return &openfgav1.TupleKey{
Object: object,
Relation: relation,
User: user,
}, nil
}
func ToUserPartsFromObjectRelation(u *openfgav1.ObjectRelation) (string, string, string) {
userObjectType, userObjectID := SplitObject(u.GetObject())
return userObjectType, userObjectID, u.GetRelation()
}
func ToUserParts(user string) (string, string, string) {
userObject, userRelation := SplitObjectRelation(user) // e.g. (person:bob, "") or (group:abc, member) or (person:*, "")
userObjectType, userObjectID := SplitObject(userObject)
return userObjectType, userObjectID, userRelation
}
func FromUserParts(userObjectType, userObjectID, userRelation string) string {
user := userObjectID
if userObjectType != "" {
user = userObjectType + ":" + userObjectID
}
if userRelation != "" {
user = user + "#" + userRelation
}
return user
}
// IsSelfDefining returns true if the tuple is reflexive/self-defining. E.g. Document:1#viewer@document:1#viewer.
// See https://github.com/openfga/rfcs/blob/main/20240328-queries-with-usersets.md
func IsSelfDefining(tuple *openfgav1.TupleKey) bool {
userObject, userRelation := SplitObjectRelation(tuple.GetUser())
return tuple.GetRelation() == userRelation && tuple.GetObject() == userObject
}
// UsersetMatchTypeAndRelation returns true if the type and relation of a userset match the inputs.
func UsersetMatchTypeAndRelation(userset, relation, typee string) bool {
userObjectType, _, userRelation := ToUserParts(userset)
return relation == userRelation && typee == userObjectType
}
package tuple
import (
"fmt"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
)
// InvalidConditionalTupleError is returned if the tuple's condition is invalid.
type InvalidConditionalTupleError struct {
Cause error
TupleKey TupleWithCondition
}
func (i *InvalidConditionalTupleError) Error() string {
return fmt.Sprintf("Invalid tuple '%s'. Reason: %s", TupleKeyWithConditionToString(i.TupleKey), i.Cause)
}
func (i *InvalidConditionalTupleError) Is(target error) bool {
_, ok := target.(*InvalidConditionalTupleError)
return ok
}
// InvalidTupleError is returned if the tuple is invalid.
type InvalidTupleError struct {
Cause error
TupleKey TupleWithoutCondition
}
func (i *InvalidTupleError) Error() string {
return fmt.Sprintf("Invalid tuple '%s'. Reason: %s", TupleKeyToString(i.TupleKey), i.Cause)
}
func (i *InvalidTupleError) Is(target error) bool {
_, ok := target.(*InvalidTupleError)
return ok
}
type TypeNotFoundError struct {
TypeName string
}
func (i *TypeNotFoundError) Error() string {
return fmt.Sprintf("type '%s' not found", i.TypeName)
}
func (i *TypeNotFoundError) Is(target error) bool {
_, ok := target.(*TypeNotFoundError)
return ok
}
type RelationNotFoundError struct {
TupleKey *openfgav1.TupleKey
Relation string
TypeName string
}
func (i *RelationNotFoundError) Error() string {
msg := fmt.Sprintf("relation '%s#%s' not found", i.TypeName, i.Relation)
if i.TupleKey != nil {
msg += fmt.Sprintf(" for tuple '%s'", TupleKeyToString(i.TupleKey))
}
return msg
}
func (i *RelationNotFoundError) Is(target error) bool {
_, ok := target.(*RelationNotFoundError)
return ok
}
package typesystem
import (
"errors"
"fmt"
"github.com/openfga/openfga/pkg/tuple"
)
var (
// ErrModelNotFound is returned when an authorization model is not found.
ErrModelNotFound = errors.New("authorization model not found")
// ErrDuplicateTypes is returned when an authorization model contains duplicate types.
ErrDuplicateTypes = errors.New("an authorization model cannot contain duplicate types")
// ErrInvalidSchemaVersion is returned for an invalid schema version in the authorization model.
ErrInvalidSchemaVersion = errors.New("invalid schema version")
// ErrInvalidModel is returned when encountering an invalid authorization model.
ErrInvalidModel = errors.New("invalid authorization model encountered")
// ErrRelationUndefined is returned when encountering an undefined relation in the authorization model.
ErrRelationUndefined = errors.New("undefined relation")
// ErrObjectTypeUndefined is returned when encountering an undefined object type in the authorization model.
ErrObjectTypeUndefined = errors.New("undefined object type")
// ErrInvalidUsersetRewrite is returned for an invalid userset rewrite definition.
ErrInvalidUsersetRewrite = errors.New("invalid userset rewrite definition")
// ErrReservedKeywords is returned when using reserved keywords "self" and "this".
ErrReservedKeywords = errors.New("self and this are reserved keywords")
// ErrCycle is returned when a cycle is detected in an authorization model.
// This occurs if an objectType and relation in the model define a rewrite
// rule that is self-referencing through computed relationships.
ErrCycle = errors.New("an authorization model cannot contain a cycle")
// ErrNoEntrypoints is returned when a particular objectType and relation in an authorization
// model are not accessible via a direct edge, for example from another objectType.
ErrNoEntrypoints = errors.New("no entrypoints defined")
// ErrNoEntryPointsLoop is returned when an authorization model contains a cycle
// because at least one objectType and relation returned ErrNoEntrypoints.
ErrNoEntryPointsLoop = errors.New("potential loop")
// ErrNoConditionForRelation is returned when no condition is defined for a relation in the authorization model.
ErrNoConditionForRelation = errors.New("no condition defined for relation")
)
// InvalidTypeError represents an error indicating an invalid object type.
type InvalidTypeError struct {
ObjectType string
Cause error
}
// Error implements the error interface for InvalidTypeError.
func (e *InvalidTypeError) Error() string {
return fmt.Sprintf("the definition of type '%s' is invalid", e.ObjectType)
}
// Unwrap returns the underlying cause of the error.
func (e *InvalidTypeError) Unwrap() error {
return e.Cause
}
// InvalidRelationError represents an error indicating an invalid relation definition.
type InvalidRelationError struct {
ObjectType string
Relation string
Cause error
}
// Error implements the error interface for InvalidRelationError.
func (e *InvalidRelationError) Error() string {
return fmt.Sprintf("the definition of relation '%s' in object type '%s' is invalid: %s", e.Relation, e.ObjectType, e.Cause)
}
// Unwrap returns the underlying cause of the error.
func (e *InvalidRelationError) Unwrap() error {
return e.Cause
}
// ObjectTypeUndefinedError represents an error indicating an undefined object type.
type ObjectTypeUndefinedError struct {
ObjectType string
Err error
}
// Error implements the error interface for ObjectTypeUndefinedError.
func (e *ObjectTypeUndefinedError) Error() string {
return fmt.Sprintf("'%s' is an undefined object type", e.ObjectType)
}
// Unwrap returns the underlying cause of the error.
func (e *ObjectTypeUndefinedError) Unwrap() error {
return e.Err
}
// RelationUndefinedError represents an error indicating an undefined relation.
type RelationUndefinedError struct {
ObjectType string
Relation string
Err error
}
// Error implements the error interface for RelationUndefinedError.
func (e *RelationUndefinedError) Error() string {
if e.ObjectType != "" {
return fmt.Sprintf("'%s#%s' relation is undefined", e.ObjectType, e.Relation)
}
return fmt.Sprintf("'%s' relation is undefined", e.Relation)
}
// Unwrap returns the underlying cause of the error.
func (e *RelationUndefinedError) Unwrap() error {
return e.Err
}
// RelationConditionError represents an error indicating an undefined condition for a relation.
type RelationConditionError struct {
Condition string
Relation string
Err error
}
// Error implements the error interface for RelationConditionError.
func (e *RelationConditionError) Error() string {
return fmt.Sprintf("condition %s is undefined for relation %s", e.Condition, e.Relation)
}
// Unwrap returns the underlying cause of the error.
func (e *RelationConditionError) Unwrap() error {
return e.Err
}
// AssignableRelationError returns an error for an assignable relation with no relation types defined.
func AssignableRelationError(objectType, relation string) error {
return fmt.Errorf("the assignable relation '%s' in object type '%s' must contain at least one relation type", relation, objectType)
}
// NonAssignableRelationError returns an error for a non-assignable relation with a relation type defined.
func NonAssignableRelationError(objectType, relation string) error {
return fmt.Errorf("the non-assignable relation '%s' in object type '%s' should not contain a relation type", relation, objectType)
}
// InvalidRelationTypeError returns an error for an invalid relation type in a relation definition.
func InvalidRelationTypeError(objectType, relation, relatedObjectType, relatedRelation string) error {
relationType := relatedObjectType
if relatedRelation != "" {
relationType = tuple.ToObjectRelationString(relatedObjectType, relatedRelation)
}
return fmt.Errorf("the relation type '%s' on '%s' in object type '%s' is not valid", relationType, relation, objectType)
}
package typesystem
import (
"context"
"errors"
"fmt"
"time"
"github.com/oklog/ulid/v2"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/singleflight"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/pkg/storage"
)
// TODO there is a duplicate cache of models elsewhere: https://github.com/openfga/openfga/issues/1045
const (
typesystemCacheTTL = 168 * time.Hour // 7 days.
)
type TypesystemResolverFunc func(ctx context.Context, storeID, modelID string) (*TypeSystem, error)
// MemoizedTypesystemResolverFunc does several things.
//
// If given a model ID: validates the model ID, and tries to fetch it from the cache.
// If not found in the cache, fetches from the datastore, validates it, stores in cache, and returns it.
//
// If not given a model ID: fetches the latest model ID from the datastore, then sees if the model ID is in the cache.
// If it is, returns it. Else, validates it and returns it.
func MemoizedTypesystemResolverFunc(datastore storage.AuthorizationModelReadBackend) (TypesystemResolverFunc, func(), error) {
lookupGroup := singleflight.Group{}
// cache holds models that have already been validated.
cache, err := storage.NewInMemoryLRUCache[*TypeSystem]()
if err != nil {
return nil, nil, err
}
return func(ctx context.Context, storeID, modelID string) (*TypeSystem, error) {
ctx, span := tracer.Start(ctx, "resolveTypesystem", trace.WithAttributes(
attribute.String("store_id", storeID),
))
defer func() {
span.SetAttributes(attribute.String("authorization_model_id", modelID))
span.End()
}()
var err error
if modelID != "" {
if _, err := ulid.Parse(modelID); err != nil {
return nil, ErrModelNotFound
}
}
var model *openfgav1.AuthorizationModel
var key string
if modelID == "" {
v, err, _ := lookupGroup.Do("FindLatestAuthorizationModel:"+storeID, func() (interface{}, error) {
return datastore.FindLatestAuthorizationModel(ctx, storeID)
})
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
return nil, ErrModelNotFound
}
return nil, fmt.Errorf("failed to FindLatestAuthorizationModel: %w", err)
}
model = v.(*openfgav1.AuthorizationModel)
modelID = model.GetId()
}
key = fmt.Sprintf("%s/%s", storeID, modelID)
item := cache.Get(key)
if item != nil {
return item, nil
}
if model == nil {
v, err, _ := lookupGroup.Do(fmt.Sprintf("ReadAuthorizationModel:%s/%s", storeID, modelID), func() (interface{}, error) {
return datastore.ReadAuthorizationModel(ctx, storeID, modelID)
})
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
return nil, ErrModelNotFound
}
return nil, fmt.Errorf("failed to ReadAuthorizationModel: %w", err)
}
model = v.(*openfgav1.AuthorizationModel)
}
typesys, err := NewAndValidate(ctx, model)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidModel, err)
}
cache.Set(key, typesys, typesystemCacheTTL)
return typesys, nil
}, cache.Stop, nil
}
package typesystem
import (
"context"
"errors"
"fmt"
"maps"
"reflect"
"slices"
"sort"
"strings"
"sync"
"github.com/emirpasic/gods/sets/hashset"
"go.opentelemetry.io/otel"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/language/pkg/go/graph"
"github.com/openfga/openfga/internal/condition"
"github.com/openfga/openfga/internal/utils"
"github.com/openfga/openfga/pkg/server/config"
serverErrors "github.com/openfga/openfga/pkg/server/errors"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/tuple"
)
var tracer = otel.Tracer("openfga/pkg/typesystem")
type ctxKey string
const (
// SchemaVersion1_0 for the authorization models.
SchemaVersion1_0 string = "1.0"
// SchemaVersion1_1 for the authorization models.
SchemaVersion1_1 string = "1.1"
// SchemaVersion1_2 for the authorization models.
SchemaVersion1_2 string = "1.2"
typesystemCtxKey ctxKey = "typesystem-context-key"
)
// IsSchemaVersionSupported checks if the provided schema version is supported.
func IsSchemaVersionSupported(version string) bool {
switch version {
case SchemaVersion1_1,
SchemaVersion1_2:
return true
default:
return false
}
}
// ContextWithTypesystem creates a copy of the parent context with the provided TypeSystem.
func ContextWithTypesystem(parent context.Context, typesys *TypeSystem) context.Context {
return context.WithValue(parent, typesystemCtxKey, typesys)
}
// TypesystemFromContext returns the TypeSystem from the provided context (if any).
func TypesystemFromContext(ctx context.Context) (*TypeSystem, bool) {
typesys, ok := ctx.Value(typesystemCtxKey).(*TypeSystem)
return typesys, ok
}
// DirectRelationReference creates a direct RelationReference for the given object type and relation.
func DirectRelationReference(objectType, relation string) *openfgav1.RelationReference {
relationReference := &openfgav1.RelationReference{
Type: objectType,
}
if relation != "" {
relationReference.RelationOrWildcard = &openfgav1.RelationReference_Relation{
Relation: relation,
}
}
return relationReference
}
// WildcardRelationReference creates a RelationReference for a wildcard relation of the given object type.
func WildcardRelationReference(objectType string) *openfgav1.RelationReference {
return &openfgav1.RelationReference{
Type: objectType,
RelationOrWildcard: &openfgav1.RelationReference_Wildcard{
Wildcard: &openfgav1.Wildcard{},
},
}
}
// This creates an Userset representing the special "this" userset.
func This() *openfgav1.Userset {
return &openfgav1.Userset{
Userset: &openfgav1.Userset_This{},
}
}
// ComputedUserset creates an Userset representing a computed userset based on the specified relation.
func ComputedUserset(relation string) *openfgav1.Userset {
return &openfgav1.Userset{
Userset: &openfgav1.Userset_ComputedUserset{
ComputedUserset: &openfgav1.ObjectRelation{
Relation: relation,
},
},
}
}
// TupleToUserset creates an Userset based on the provided tupleset and computed userset.
func TupleToUserset(tupleset, computedUserset string) *openfgav1.Userset {
return &openfgav1.Userset{
Userset: &openfgav1.Userset_TupleToUserset{
TupleToUserset: &openfgav1.TupleToUserset{
Tupleset: &openfgav1.ObjectRelation{
Relation: tupleset,
},
ComputedUserset: &openfgav1.ObjectRelation{
Relation: computedUserset,
},
},
},
}
}
// Union creates an Userset representing the union of the provided children Usersets.
func Union(children ...*openfgav1.Userset) *openfgav1.Userset {
return &openfgav1.Userset{
Userset: &openfgav1.Userset_Union{
Union: &openfgav1.Usersets{
Child: children,
},
},
}
}
// Intersection creates a new Userset representing the intersection of the provided Usersets.
func Intersection(children ...*openfgav1.Userset) *openfgav1.Userset {
return &openfgav1.Userset{
Userset: &openfgav1.Userset_Intersection{
Intersection: &openfgav1.Usersets{
Child: children,
},
},
}
}
// Difference creates new Userset representing the difference between two Usersets 'base' and 'sub'.
func Difference(base *openfgav1.Userset, sub *openfgav1.Userset) *openfgav1.Userset {
return &openfgav1.Userset{
Userset: &openfgav1.Userset_Difference{
Difference: &openfgav1.Difference{
Base: base,
Subtract: sub,
},
},
}
}
// ConditionedRelationReference assigns a condition to a given
// RelationReference and returns the modified RelationReference.
func ConditionedRelationReference(rel *openfgav1.RelationReference, condition string) *openfgav1.RelationReference {
rel.Condition = condition
return rel
}
var _ storage.CacheItem = (*TypeSystem)(nil)
// TypeSystem is a wrapper over an [openfgav1.AuthorizationModel].
type TypeSystem struct {
// [objectType] => typeDefinition.
typeDefinitions map[string]*openfgav1.TypeDefinition
// [objectType] => [relationName] => relation.
relations map[string]map[string]*openfgav1.Relation
// [conditionName] => condition.
conditions map[string]*condition.EvaluableCondition
// [objectType] => [relationName] => TTU relation.
ttuRelations map[string]map[string][]*openfgav1.TupleToUserset
computedRelations sync.Map
modelID string
schemaVersion string
authorizationModelGraph *graph.AuthorizationModelGraph
authzWeightedGraph *graph.WeightedAuthorizationModelGraph
}
func (t *TypeSystem) GetWeightedGraph() *graph.WeightedAuthorizationModelGraph {
return t.authzWeightedGraph
}
// New creates a *TypeSystem from an *openfgav1.AuthorizationModel.
// It assumes that the input model is valid. If you need to run validations, use NewAndValidate.
func New(model *openfgav1.AuthorizationModel) (*TypeSystem, error) {
tds := make(map[string]*openfgav1.TypeDefinition, len(model.GetTypeDefinitions()))
relations := make(map[string]map[string]*openfgav1.Relation, len(model.GetTypeDefinitions()))
ttuRelations := make(map[string]map[string][]*openfgav1.TupleToUserset, len(model.GetTypeDefinitions()))
for _, td := range model.GetTypeDefinitions() {
typeName := td.GetType()
tds[typeName] = td
tdRelations := make(map[string]*openfgav1.Relation, len(td.GetRelations()))
ttuRelations[typeName] = make(map[string][]*openfgav1.TupleToUserset, len(td.GetRelations()))
for relation, rewrite := range td.GetRelations() {
r := &openfgav1.Relation{
Name: relation,
Rewrite: rewrite,
TypeInfo: &openfgav1.RelationTypeInfo{},
}
if metadata, ok := td.GetMetadata().GetRelations()[relation]; ok {
r.TypeInfo.DirectlyRelatedUserTypes = metadata.GetDirectlyRelatedUserTypes()
}
tdRelations[relation] = r
ttuRelations[typeName][relation] = flattenUserset(rewrite)
}
relations[typeName] = tdRelations
}
uncompiledConditions := make(map[string]*condition.EvaluableCondition, len(model.GetConditions()))
for name, cond := range model.GetConditions() {
uncompiledConditions[name] = condition.NewUncompiled(cond).
WithTrackEvaluationCost().
WithMaxEvaluationCost(config.MaxConditionEvaluationCost()).
WithInterruptCheckFrequency(config.DefaultInterruptCheckFrequency)
}
authorizationModelGraph, err := graph.NewAuthorizationModelGraph(model)
if err != nil {
return nil, err
}
if authorizationModelGraph.GetDrawingDirection() != graph.DrawingDirectionListObjects {
// by default, this should not happen. However, this is here in case the default order is changed.
authorizationModelGraph, err = authorizationModelGraph.Reversed()
if err != nil {
return nil, err
}
}
wgb := graph.NewWeightedAuthorizationModelGraphBuilder()
// TODO: this will require a deprecation not ignore the error and remove nil checks
weightedGraph, _ := wgb.Build(model)
return &TypeSystem{
modelID: model.GetId(),
schemaVersion: model.GetSchemaVersion(),
typeDefinitions: tds,
relations: relations,
conditions: uncompiledConditions,
ttuRelations: ttuRelations,
authorizationModelGraph: authorizationModelGraph,
authzWeightedGraph: weightedGraph,
}, nil
}
func (t *TypeSystem) CacheEntityType() string {
return "typesystem"
}
// GetAuthorizationModelID returns the ID for the authorization
// model this TypeSystem was constructed for.
func (t *TypeSystem) GetAuthorizationModelID() string {
return t.modelID
}
// GetSchemaVersion returns the schema version associated with the TypeSystem instance.
func (t *TypeSystem) GetSchemaVersion() string {
return t.schemaVersion
}
// GetAllRelations returns a map [objectType] => [relationName] => relation.
func (t *TypeSystem) GetAllRelations() map[string]map[string]*openfgav1.Relation {
return t.relations
}
// GetConditions retrieves a map of condition names to their corresponding
// EvaluableCondition instances within the TypeSystem.
func (t *TypeSystem) GetConditions() map[string]*condition.EvaluableCondition {
return t.conditions
}
// GetTypeDefinition searches for a TypeDefinition in the TypeSystem based on the given objectType string.
func (t *TypeSystem) GetTypeDefinition(objectType string) (*openfgav1.TypeDefinition, bool) {
if typeDefinition, ok := t.typeDefinitions[objectType]; ok {
return typeDefinition, true
}
return nil, false
}
// ResolveComputedRelation traverses the typesystem until finding the final resolution of a computed relationship.
// Subsequent calls to this method are resolved from a cache.
func (t *TypeSystem) ResolveComputedRelation(objectType, relation string) (string, error) {
memoizeKey := fmt.Sprintf("%s-%s", objectType, relation)
if val, ok := t.computedRelations.Load(memoizeKey); ok {
return val.(string), nil
}
rel, err := t.GetRelation(objectType, relation)
if err != nil {
return "", err
}
rewrite := rel.GetRewrite()
switch rewrite.GetUserset().(type) {
case *openfgav1.Userset_ComputedUserset:
return t.ResolveComputedRelation(objectType, rewrite.GetComputedUserset().GetRelation())
case *openfgav1.Userset_This:
t.computedRelations.Store(memoizeKey, relation)
return relation, nil
default:
return "", fmt.Errorf("unsupported rewrite %s", rewrite.String())
}
}
// GetRelations returns all relations in the TypeSystem for a given type.
func (t *TypeSystem) GetRelations(objectType string) (map[string]*openfgav1.Relation, error) {
_, ok := t.GetTypeDefinition(objectType)
if !ok {
return nil, &ObjectTypeUndefinedError{
ObjectType: objectType,
Err: ErrObjectTypeUndefined,
}
}
return t.relations[objectType], nil
}
// GetRelation retrieves a specific Relation from the TypeSystem
// based on the provided objectType and relation strings.
// It can return ErrObjectTypeUndefined and ErrRelationUndefined.
func (t *TypeSystem) GetRelation(objectType, relation string) (*openfgav1.Relation, error) {
relations, err := t.GetRelations(objectType)
if err != nil {
return nil, err
}
r, ok := relations[relation]
if !ok {
return nil, &RelationUndefinedError{
ObjectType: objectType,
Relation: relation,
Err: ErrRelationUndefined,
}
}
return r, nil
}
// GetCondition searches for an EvaluableCondition in the TypeSystem by its name.
func (t *TypeSystem) GetCondition(name string) (*condition.EvaluableCondition, bool) {
if _, ok := t.conditions[name]; !ok {
return nil, false
}
return t.conditions[name], true
}
// GetRelationReferenceAsString returns team#member, or team:*, or an empty string if the input is nil.
func GetRelationReferenceAsString(rr *openfgav1.RelationReference) string {
if rr == nil {
return ""
}
if _, ok := rr.GetRelationOrWildcard().(*openfgav1.RelationReference_Relation); ok {
return fmt.Sprintf("%s#%s", rr.GetType(), rr.GetRelation())
}
if _, ok := rr.GetRelationOrWildcard().(*openfgav1.RelationReference_Wildcard); ok {
return tuple.TypedPublicWildcard(rr.GetType())
}
panic("unexpected relation reference")
}
// GetDirectlyRelatedUserTypes fetches user types directly related to a specified objectType-relation pair.
func (t *TypeSystem) GetDirectlyRelatedUserTypes(objectType, relation string) ([]*openfgav1.RelationReference, error) {
r, err := t.GetRelation(objectType, relation)
if err != nil {
return nil, err
}
return r.GetTypeInfo().GetDirectlyRelatedUserTypes(), nil
}
// DirectlyRelatedUsersets returns a list of the directly user related types that are usersets.
func (t *TypeSystem) DirectlyRelatedUsersets(objectType, relation string) ([]*openfgav1.RelationReference, error) {
refs, err := t.GetDirectlyRelatedUserTypes(objectType, relation)
var usersetRelationReferences []*openfgav1.RelationReference
if err != nil {
return usersetRelationReferences, err
}
for _, ref := range refs {
if ref.GetRelation() != "" {
usersetRelationReferences = append(usersetRelationReferences, ref)
}
}
return usersetRelationReferences, nil
}
func RelationEquals(a *openfgav1.RelationReference, b *openfgav1.RelationReference) bool {
if a.GetType() != b.GetType() {
return false
}
// Type with no relation or wildcard (e.g. 'user').
if a.GetRelationOrWildcard() == nil && b.GetRelationOrWildcard() == nil {
return true
}
// Typed wildcard (e.g. 'user:*').
if a.GetWildcard() != nil && b.GetWildcard() != nil {
return true
}
return a.GetRelation() != "" && b.GetRelation() != "" && a.GetRelation() == b.GetRelation()
}
// IsDirectlyRelated determines whether the type of the target DirectRelationReference contains the source DirectRelationReference.
func (t *TypeSystem) IsDirectlyRelated(target *openfgav1.RelationReference, source *openfgav1.RelationReference) (bool, error) {
relation, err := t.GetRelation(target.GetType(), target.GetRelation())
if err != nil {
return false, err
}
for _, typeRestriction := range relation.GetTypeInfo().GetDirectlyRelatedUserTypes() {
if RelationEquals(source, typeRestriction) {
return true, nil
}
}
return false, nil
}
func (t *TypeSystem) UsersetUseWeight2Resolver(objectType, relation, userType string, userset *openfgav1.RelationReference) bool {
if t.authzWeightedGraph == nil {
return false
}
node, ok := t.authzWeightedGraph.GetNodeByID(tuple.ToObjectRelationString(objectType, relation))
if !ok {
return false
}
if node.IsPartOfTupleCycle() || len(node.GetRecursiveRelation()) > 0 {
// if there is a tuple cycle, we have to go through default resolver (or recursive one)
return false
}
usersetNodeID := tuple.ToObjectRelationString(userset.GetType(), userset.GetRelation())
usersetNode, ok := t.authzWeightedGraph.GetNodeByID(usersetNodeID)
if !ok {
return false
}
// the node itself has to be weight 1 (not 2, because its the userset node that we are verifying at this point). the edge pointing to it would be weight 2.
weight, ok := usersetNode.GetWeight(userType)
if !ok {
return false
}
return weight == 1
}
// UsersetUseWeight2Resolvers
// TODO: Deprecate once userset refactor is complete.
func (t *TypeSystem) UsersetUseWeight2Resolvers(objectType, relation, userType string, usersets []*openfgav1.RelationReference) bool {
allowedType := hashset.New()
for _, u := range usersets {
if allowedType.Contains(u.GetType()) {
// If there are more than 1 directly related userset types of the same type, we cannot do userset optimization because
// we cannot rely on the fact that the object ID matches. Instead, we need to take into consideration
// on the relation as well.
return false
}
if !t.UsersetUseWeight2Resolver(objectType, relation, userType, u) {
return false
}
allowedType.Add(u.GetType())
}
return true
}
func (t *TypeSystem) TTUUseWeight2Resolver(objectType, relation, userType string, ttu *openfgav1.TupleToUserset) bool {
if t.authzWeightedGraph == nil {
return false
}
objRel := tuple.ToObjectRelationString(objectType, relation)
tuplesetRelationKey := tuple.ToObjectRelationString(objectType, ttu.GetTupleset().GetRelation())
computedRelation := ttu.GetComputedUserset().GetRelation()
node, ok := t.authzWeightedGraph.GetNodeByID(objRel)
if !ok {
return false
}
// verifying weight here is not enough given the relation from parent might be weight 2, but we do not explicitly know
// the ttu given we aren't in the weighted graph as we traverse and that ttu could possibly not have a weight for the terminal type,
// thus having to fully inspect to match the context of what is being resolved.
_, ok = node.GetWeight(userType)
if !ok {
return false
}
edges, ok := t.authzWeightedGraph.GetEdgesFromNode(node)
if !ok {
return false
}
ttuEdges := make([]*graph.WeightedAuthorizationModelEdge, 0)
// find all TTU edges with valid weight
// but exit immediately if there is any above weight 2
for len(ttuEdges) == 0 {
innerEdges := make([]*graph.WeightedAuthorizationModelEdge, 0)
for _, edge := range edges {
// edge is a set operator thus we have to inspect each node of the operator
if edge.GetEdgeType() == graph.RewriteEdge {
operationalEdges, ok := t.authzWeightedGraph.GetEdgesFromNode(edge.GetTo())
if !ok {
return false
}
innerEdges = append(innerEdges, operationalEdges...)
}
// a TuplesetRelation may have multiple parents and these need to be visited to ensure their weight does not
// exceed weight 2
if edge.GetEdgeType() == graph.TTUEdge &&
edge.GetTuplesetRelation() == tuplesetRelationKey &&
strings.HasSuffix(edge.GetTo().GetUniqueLabel(), "#"+computedRelation) {
w, ok := edge.GetWeight(userType)
if ok {
if w > 2 {
return false
}
ttuEdges = append(ttuEdges, edge)
}
}
}
if len(innerEdges) == 0 {
break
}
edges = innerEdges
}
return len(ttuEdges) != 0
}
// TTUUseRecursiveResolver returns true fast path can be applied to user/relation.
// For it to return true, all of these conditions:
// 1. Node[objectType#relation].weights[userType] = infinite
// 2. Node[objectType#relation].RecursiveRelation = objectType#relation
// 3. Node[objectType#relation].IsPartOfTupleCycle == false
// 4. Node[objectType#relation] has only 1 edge, and it's to an OR node
// 5. The OR node has one or more TTU edge with weight infinite for the terminal type and the computed relation for the TTU is the same
// 6. Any other edge coming out of the OR node that has a weight for terminal type, it should be weight 1
// must be all true.
func (t *TypeSystem) TTUUseRecursiveResolver(objectType, relation, userType string, ttu *openfgav1.TupleToUserset) bool {
if t.authzWeightedGraph == nil {
return false
}
objRel := tuple.ToObjectRelationString(objectType, relation)
objRelNode, ok := t.authzWeightedGraph.GetNodeByID(objRel)
if !ok {
return false
}
w, ok := objRelNode.GetWeight(userType)
if !ok || w != graph.Infinite {
return false
}
// if we are not in the presence of a recursive relation or it is part of a tuple cycle, return false
if objRelNode.GetRecursiveRelation() != objRelNode.GetUniqueLabel() || objRelNode.IsPartOfTupleCycle() {
return false
}
edges, ok := t.authzWeightedGraph.GetEdgesFromNode(objRelNode)
if !ok {
return false
}
recursiveTTUFound := false
for len(edges) != 0 {
innerEdges := make([]*graph.WeightedAuthorizationModelEdge, 0)
for _, edge := range edges {
w, ok := edge.GetWeight(userType)
if !ok {
// if the edge does not have a weight for the terminal type, we can skip it
continue
}
// if the edge is part of the recursive path
if edge.GetRecursiveRelation() == objRel {
// if the edge is a TTUEdge and points to the original node, and we haven't found any other recursive edge
if edge.GetEdgeType() == graph.TTUEdge && edge.GetTo() == objRelNode && !recursiveTTUFound {
recursiveTTUFound = true
continue
}
// Because we are not in the presence of a tuple cycle, the only rewrite edges that could exist
// in a recursive path by definition in the weighted graph is the operational edges or logical TTU edges
if edge.GetEdgeType() == graph.RewriteEdge || edge.GetEdgeType() == graph.TTULogicalEdge {
newEdges, okEdge := t.authzWeightedGraph.GetEdgesFromNode(edge.GetTo())
if !okEdge {
return false
}
// these edges will need to be evaluated in subsequent iterations
innerEdges = append(innerEdges, newEdges...)
continue
}
}
if w > 1 {
// for any other edge that is not part of the recursive path, it must be weight 1 or if there is any other edge for the same ttu that is the recursive ttu
return false
}
}
if len(innerEdges) == 0 {
break
}
edges = innerEdges
}
return recursiveTTUFound
}
// UsersetUseRecursiveResolver returns true if all these conditions apply:
// 1. Node[objectType#relation].weights[userType] = infinite
// 2. Any other direct type, userset or computed relation used in the relation needs to be weight = 1 for the usertype
// Example:
// type doc
// rel1 = [doc#rel1, user, user with cond, employee, doc#rel8] or ( (rel2 but not rel7) or rel8)
// rel2 = rel4 but not rel5
// rel4 = [user]
// rel5 = [user]
// rel7 = [user]
// rel8 = [employee]
// calling UsersetUseRecursiveResolver(doc, rel1, user) should return TRUE
// calling UsersetUseRecursiveResolver(doc, rel1, employee) should return FALSE because there is a doc#rel8 that has weight = 2 for employee.
func (t *TypeSystem) UsersetUseRecursiveResolver(objectType, relation, userType string) bool {
if t.authzWeightedGraph == nil {
return false
}
objRel := tuple.ToObjectRelationString(objectType, relation)
objRelNode, ok := t.authzWeightedGraph.GetNodeByID(objRel)
if !ok {
return false
}
w, ok := objRelNode.GetWeight(userType)
if !ok || w != graph.Infinite {
return false
}
// if we are not in the presence of a recursive relation or it is part of a tuple cycle, return false
if objRelNode.GetRecursiveRelation() != objRelNode.GetUniqueLabel() || objRelNode.IsPartOfTupleCycle() {
return false
}
edges, ok := t.authzWeightedGraph.GetEdgesFromNode(objRelNode)
if !ok {
return false
}
recursiveUsersetFound := false
for len(edges) != 0 {
innerEdges := make([]*graph.WeightedAuthorizationModelEdge, 0)
for _, edge := range edges {
w, ok := edge.GetWeight(userType)
if !ok {
// if the edge does not have a weight for the terminal type, we can skip it
continue
}
if edge.GetRecursiveRelation() == objRel {
if edge.GetEdgeType() == graph.DirectEdge && edge.GetTo() == objRelNode && !recursiveUsersetFound {
recursiveUsersetFound = true
continue
}
if edge.GetEdgeType() == graph.RewriteEdge || edge.GetEdgeType() == graph.DirectLogicalEdge {
newEdges, okEdge := t.authzWeightedGraph.GetEdgesFromNode(edge.GetTo())
if !okEdge {
return false
}
// these edges will need to be evaluated in subsequent iterations
innerEdges = append(innerEdges, newEdges...)
continue
}
}
// catch all, everything has to be weight 1 regardless of direct, computed, rewrite.
// thus if an infinite didn't get handled it will exit through here
if w > 1 {
return false
}
}
if len(innerEdges) == 0 {
break
}
edges = innerEdges
}
return recursiveUsersetFound // return if the recursive userset was found
}
// PathExists returns true if:
// - the `user` type is a subject e.g. `user`, and there is a path from `user` to `objectType#relation`, or there is a path from `user:*` to `objectType#relation`
// or
// - the `user` type is a userset e.g. `group#member`, and there is a path from `group#member` to `objectType#relation`.
func (t *TypeSystem) PathExists(user, relation, objectType string) (bool, error) {
userType, _, userRelation := tuple.ToUserParts(user)
isUserset := userRelation != ""
userTypeRelation := userType
if isUserset {
userTypeRelation = tuple.ToObjectRelationString(userType, userRelation)
}
// first check
fromLabel := userTypeRelation
toLabel := tuple.ToObjectRelationString(objectType, relation)
normalPathExists, err := t.authorizationModelGraph.PathExists(fromLabel, toLabel)
if err != nil {
return false, err
}
if normalPathExists {
return true, nil
}
// skip second check in case it's a userset, since a userset cannot have public wildcard
if isUserset {
return false, nil
}
// second check
fromLabel = tuple.TypedPublicWildcard(userType)
wildcardPathExists, err := t.authorizationModelGraph.PathExists(fromLabel, toLabel)
if err != nil {
// The only possible error is graph.ErrQueryingGraph, which means the wildcard node cannot
// be found. Given this, we are safe to conclude there is no path.
return false, nil
}
return wildcardPathExists, nil
}
// IsPubliclyAssignable checks if the provided objectType is part
// of a typed wildcard type restriction on the target relation.
//
// Example:
//
// type user
//
// type document
// relations
// define viewer: [user:*]
//
// In the example above, the 'user' objectType is publicly assignable to the 'document#viewer' relation.
// If the input target is not a defined relation, it returns false and RelationUndefinedError.
func (t *TypeSystem) IsPubliclyAssignable(target *openfgav1.RelationReference, objectType string) (bool, error) {
ref, err := t.PubliclyAssignableReferences(target, objectType)
if err != nil {
return false, err
}
return ref != nil, nil
}
// PubliclyAssignableReferences returns the publicly assignable references with the specified objectType.
func (t *TypeSystem) PubliclyAssignableReferences(target *openfgav1.RelationReference, objectType string) (*openfgav1.RelationReference, error) {
relation, err := t.GetRelation(target.GetType(), target.GetRelation())
if err != nil {
return nil, err
}
for _, typeRestriction := range relation.GetTypeInfo().GetDirectlyRelatedUserTypes() {
if typeRestriction.GetType() == objectType {
if typeRestriction.GetWildcard() != nil {
return typeRestriction, nil
}
}
}
return nil, nil
}
// HasTypeInfo determines if a given objectType-relation pair has associated type information.
// It checks against the specific schema version and the existence of type information in the relation.
// Returns true if type information is present and an error if the relation is not found.
func (t *TypeSystem) HasTypeInfo(objectType, relation string) (bool, error) {
r, err := t.GetRelation(objectType, relation)
if err != nil {
return false, err
}
if IsSchemaVersionSupported(t.GetSchemaVersion()) && r.GetTypeInfo() != nil {
return true, nil
}
return false, nil
}
// RelationInvolvesIntersection returns true if the provided relation's userset rewrite
// is defined by one or more direct or indirect intersections or any of the types related to
// the provided relation are defined by one or more direct or indirect intersections.
func (t *TypeSystem) RelationInvolvesIntersection(objectType, relation string) (bool, error) {
visited := map[string]struct{}{}
return t.relationInvolves(objectType, relation, visited, intersectionSetOperator)
}
// RelationInvolvesExclusion returns true if the provided relation's userset rewrite
// is defined by one or more direct or indirect exclusions or any of the types related to
// the provided relation are defined by one or more direct or indirect exclusions.
func (t *TypeSystem) RelationInvolvesExclusion(objectType, relation string) (bool, error) {
visited := map[string]struct{}{}
return t.relationInvolves(objectType, relation, visited, exclusionSetOperator)
}
const (
intersectionSetOperator uint = iota
exclusionSetOperator
)
func (t *TypeSystem) relationInvolves(objectType, relation string, visited map[string]struct{}, target uint) (bool, error) {
key := tuple.ToObjectRelationString(objectType, relation)
if _, ok := visited[key]; ok {
return false, nil
}
visited[key] = struct{}{}
rel, err := t.GetRelation(objectType, relation)
if err != nil {
return false, err
}
rewrite := rel.GetRewrite()
result, err := WalkUsersetRewrite(rewrite, func(r *openfgav1.Userset) interface{} {
switch rw := r.GetUserset().(type) {
case *openfgav1.Userset_ComputedUserset:
rewrittenRelation := rw.ComputedUserset.GetRelation()
rewritten, err := t.GetRelation(objectType, rewrittenRelation)
if err != nil {
return err
}
containsTarget, err := t.relationInvolves(objectType, rewritten.GetName(), visited, target)
if err != nil {
return err
}
if containsTarget {
return true
}
case *openfgav1.Userset_TupleToUserset:
tupleset := rw.TupleToUserset.GetTupleset().GetRelation()
rewrittenRelation := rw.TupleToUserset.GetComputedUserset().GetRelation()
tuplesetRel, err := t.GetRelation(objectType, tupleset)
if err != nil {
return err
}
directlyRelatedTypes := tuplesetRel.GetTypeInfo().GetDirectlyRelatedUserTypes()
for _, relatedType := range directlyRelatedTypes {
// Must be of the form 'objectType' by this point since we disallow `tupleset` relations of the form `objectType:id#relation`.
r := relatedType.GetRelation()
if r != "" {
return fmt.Errorf(
"invalid type restriction '%s#%s' specified on tupleset relation '%s#%s': %w",
relatedType.GetType(),
relatedType.GetRelation(),
objectType,
tupleset,
ErrInvalidModel,
)
}
rel, err := t.GetRelation(relatedType.GetType(), rewrittenRelation)
if err != nil {
if errors.Is(err, ErrObjectTypeUndefined) || errors.Is(err, ErrRelationUndefined) {
continue
}
return err
}
containsTarget, err := t.relationInvolves(relatedType.GetType(), rel.GetName(), visited, target)
if err != nil {
return err
}
if containsTarget {
return true
}
}
return nil
case *openfgav1.Userset_Intersection:
return target == intersectionSetOperator
case *openfgav1.Userset_Difference:
return target == exclusionSetOperator
}
return nil
})
if err != nil {
return false, err
}
if result != nil && result.(bool) {
return true, nil
}
for _, typeRestriction := range rel.GetTypeInfo().GetDirectlyRelatedUserTypes() {
if typeRestriction.GetRelation() != "" {
key := tuple.ToObjectRelationString(typeRestriction.GetType(), typeRestriction.GetRelation())
if _, ok := visited[key]; ok {
continue
}
containsTarget, err := t.relationInvolves(typeRestriction.GetType(), typeRestriction.GetRelation(), visited, target)
if err != nil {
return false, err
}
if containsTarget {
return true, nil
}
}
}
return false, nil
}
// hasEntrypoints recursively walks the rewrite definition for the given relation to determine if there is at least
// one path in the rewrite rule that could relate to at least one concrete object type. If there is no such path that
// could lead to at least one relationship with some object type, then false is returned along with an error indicating
// no entrypoints were found. If at least one relationship with a specific object type is found while walking the rewrite,
// then true is returned along with a nil error.
// This function assumes that all other model validations have run.
func hasEntrypoints(
typedefs map[string]map[string]*openfgav1.Relation,
typeName, relationName string,
rewrite *openfgav1.Userset,
visitedRelations map[string]map[string]bool,
) (bool, bool, error) {
v := maps.Clone(visitedRelations)
// Presence of a key represents that we've visited that object and relation. We keep track of this to avoid stack overflows.
// The value of the key represents hasEntrypoints for that relation. We set this to true only when the relation is directly assignable.
if val, ok := v[typeName]; ok {
val[relationName] = false
} else {
v[typeName] = map[string]bool{
relationName: false,
}
}
relation, ok := typedefs[typeName][relationName]
if !ok {
return false, false, fmt.Errorf("undefined type definition for '%s#%s'", typeName, relationName)
}
switch rw := rewrite.GetUserset().(type) {
case *openfgav1.Userset_This:
// At least one type must have an entrypoint.
for _, assignableType := range relation.GetTypeInfo().GetDirectlyRelatedUserTypes() {
if assignableType.GetRelationOrWildcard() == nil || assignableType.GetWildcard() != nil {
v[typeName][relationName] = true
return true, false, nil
}
assignableTypeName := assignableType.GetType()
assignableRelationName := assignableType.GetRelation()
assignableRelation, ok := typedefs[assignableTypeName][assignableRelationName]
if !ok {
return false, false, fmt.Errorf("undefined type definition for '%s#%s'", assignableTypeName, assignableRelationName)
}
if _, ok := v[assignableTypeName][assignableRelationName]; ok {
continue
}
hasEntrypoint, _, err := hasEntrypoints(typedefs, assignableTypeName, assignableRelationName, assignableRelation.GetRewrite(), v)
if err != nil {
return false, false, err
}
if hasEntrypoint {
return true, false, nil
}
}
return false, false, nil
case *openfgav1.Userset_ComputedUserset:
computedRelationName := rw.ComputedUserset.GetRelation()
computedRelation, ok := typedefs[typeName][computedRelationName]
if !ok {
return false, false, fmt.Errorf("undefined type definition for '%s#%s'", typeName, computedRelationName)
}
if hasEntrypoint, ok := v[typeName][computedRelationName]; ok {
return hasEntrypoint, true, nil
}
hasEntrypoint, loop, err := hasEntrypoints(typedefs, typeName, computedRelationName, computedRelation.GetRewrite(), v)
if err != nil {
return false, false, err
}
return hasEntrypoint, loop, nil
case *openfgav1.Userset_TupleToUserset:
tuplesetRelationName := rw.TupleToUserset.GetTupleset().GetRelation()
computedRelationName := rw.TupleToUserset.GetComputedUserset().GetRelation()
tuplesetRelation, ok := typedefs[typeName][tuplesetRelationName]
if !ok {
return false, false, fmt.Errorf("undefined type definition for '%s#%s'", typeName, tuplesetRelationName)
}
// At least one type must have an entrypoint.
for _, assignableType := range tuplesetRelation.GetTypeInfo().GetDirectlyRelatedUserTypes() {
assignableTypeName := assignableType.GetType()
if assignableRelation, ok := typedefs[assignableTypeName][computedRelationName]; ok {
if hasEntrypoint, ok := v[assignableTypeName][computedRelationName]; ok {
if hasEntrypoint {
return true, false, nil
}
continue
}
hasEntrypoint, _, err := hasEntrypoints(typedefs, assignableTypeName, computedRelationName, assignableRelation.GetRewrite(), v)
if err != nil {
return false, false, err
}
if hasEntrypoint {
return true, false, nil
}
}
}
return false, false, nil
case *openfgav1.Userset_Union:
// At least one type must have an entrypoint.
loop := false
for _, child := range rw.Union.GetChild() {
hasEntrypoints, childLoop, err := hasEntrypoints(typedefs, typeName, relationName, child, visitedRelations)
if err != nil {
return false, false, err
}
if hasEntrypoints {
return true, false, nil
}
loop = loop || childLoop
}
return false, loop, nil
case *openfgav1.Userset_Intersection:
for _, child := range rw.Intersection.GetChild() {
// All the children must have an entrypoint.
hasEntrypoints, childLoop, err := hasEntrypoints(typedefs, typeName, relationName, child, visitedRelations)
if err != nil {
return false, false, err
}
if !hasEntrypoints {
return false, childLoop, nil
}
}
return true, false, nil
case *openfgav1.Userset_Difference:
// All the children must have an entrypoint.
hasEntrypoint, loop, err := hasEntrypoints(typedefs, typeName, relationName, rw.Difference.GetBase(), visitedRelations)
if err != nil {
return false, false, err
}
if !hasEntrypoint {
return false, loop, nil
}
hasEntrypoint, loop, err = hasEntrypoints(typedefs, typeName, relationName, rw.Difference.GetSubtract(), visitedRelations)
if err != nil {
return false, false, err
}
if !hasEntrypoint {
return false, loop, nil
}
return true, false, nil
}
// This should never happen because rewrite.GetUserset().(type) returns an unknown type (or it itself is nil).
rwString := "rewrite_nil"
if rewrite != nil {
rwString = rewrite.String()
}
return false, false, serverErrors.HandleError("error validating model", fmt.Errorf("hasEntrypoints unknown rewrite %s for '%s#%s'", rwString, typeName, relationName))
}
// NewAndValidate is like New but also validates the model according to the following rules:
// 1. Checks that the *TypeSystem have a valid schema version.
// 2. For every rewrite the relations in the rewrite must:
// a) Be valid relations on the same type in the *TypeSystem (in cases of computedUserset)
// b) Be valid relations on another existing type (in cases of tupleToUserset)
// 3. Do not allow duplicate types or duplicate relations (only need to check types as relations are
// in a map so cannot contain duplicates)
//
// If the *TypeSystem has a v1.1 schema version (with types on relations), then additionally
// validate the *TypeSystem according to the following rules:
// 3. Every type restriction on a relation must be a valid type:
// a) For a type (e.g. user) this means checking that this type is in the *TypeSystem
// b) For a type#relation this means checking that this type with this relation is in the *TypeSystem
// 4. Check that a relation is assignable if and only if it has a non-zero list of types
func NewAndValidate(ctx context.Context, model *openfgav1.AuthorizationModel) (*TypeSystem, error) {
_, span := tracer.Start(ctx, "typesystem.NewAndValidate")
defer span.End()
t, err := New(model)
if err != nil {
return nil, err
}
schemaVersion := t.GetSchemaVersion()
if !IsSchemaVersionSupported(schemaVersion) {
return nil, ErrInvalidSchemaVersion
}
if containsDuplicateType(model) {
return nil, ErrDuplicateTypes
}
if err := t.validateNames(); err != nil {
return nil, err
}
typedefsMap := t.typeDefinitions
typeNames := make([]string, 0, len(typedefsMap))
for typeName := range typedefsMap {
typeNames = append(typeNames, typeName)
}
// Range over the type definitions in sorted order to produce a deterministic outcome.
sort.Strings(typeNames)
for _, typeName := range typeNames {
typedef := typedefsMap[typeName]
relationMap := typedef.GetRelations()
relationNames := make([]string, 0, len(relationMap))
for relationName := range relationMap {
relationNames = append(relationNames, relationName)
}
// Range over the relations in sorted order to produce a deterministic outcome.
sort.Strings(relationNames)
for _, relationName := range relationNames {
err := t.validateRelation(typeName, relationName, relationMap)
if err != nil {
return nil, err
}
}
}
if err := t.validateConditions(); err != nil {
return nil, err
}
return t, nil
}
// validateRelation applies all the validation rules to a relation definition in a model. A relation
// must meet all the rewrite validation, type restriction validation, and entrypoint validation criteria
// for it to be valid. Otherwise, an error is returned.
func (t *TypeSystem) validateRelation(typeName, relationName string, relationMap map[string]*openfgav1.Userset) error {
rewrite := relationMap[relationName]
err := t.isUsersetRewriteValid(typeName, relationName, rewrite)
if err != nil {
return err
}
err = t.validateTypeRestrictions(typeName, relationName)
if err != nil {
return err
}
visitedRelations := map[string]map[string]bool{}
hasEntrypoints, loop, err := hasEntrypoints(t.relations, typeName, relationName, rewrite, visitedRelations)
if err != nil {
return err
}
if !hasEntrypoints {
cause := ErrNoEntrypoints
if loop {
cause = ErrNoEntryPointsLoop
}
return &InvalidRelationError{
ObjectType: typeName,
Relation: relationName,
Cause: cause,
}
}
hasCycle, err := t.HasCycle(typeName, relationName)
if err != nil {
return err
}
if hasCycle {
return &InvalidRelationError{
ObjectType: typeName,
Relation: relationName,
Cause: ErrCycle,
}
}
return nil
}
func containsDuplicateType(model *openfgav1.AuthorizationModel) bool {
seen := make(map[string]struct{}, len(model.GetTypeDefinitions()))
for _, td := range model.GetTypeDefinitions() {
objectType := td.GetType()
if _, ok := seen[objectType]; ok {
return true
}
seen[objectType] = struct{}{}
}
return false
}
// validateNames ensures that a model doesn't have object
// types or relations called "self" or "this".
func (t *TypeSystem) validateNames() error {
for _, td := range t.typeDefinitions {
objectType := td.GetType()
if objectType == "" {
return fmt.Errorf("the type name of a type definition cannot be an empty string")
}
if objectType == "self" || objectType == "this" {
return &InvalidTypeError{ObjectType: objectType, Cause: ErrReservedKeywords}
}
for relation := range td.GetRelations() {
if relation == "" {
return fmt.Errorf("type '%s' defines a relation with an empty string for a name", objectType)
}
if relation == "self" || relation == "this" {
return &InvalidRelationError{ObjectType: objectType, Relation: relation, Cause: ErrReservedKeywords}
}
}
}
return nil
}
// isUsersetRewriteValid checks if the rewrite on objectType#relation is valid.
func (t *TypeSystem) isUsersetRewriteValid(objectType, relation string, rewrite *openfgav1.Userset) error {
if rewrite.GetUserset() == nil {
return &InvalidRelationError{ObjectType: objectType, Relation: relation, Cause: ErrInvalidUsersetRewrite}
}
switch r := rewrite.GetUserset().(type) {
case *openfgav1.Userset_ComputedUserset:
computedUserset := r.ComputedUserset.GetRelation()
if computedUserset == relation {
return &InvalidRelationError{ObjectType: objectType, Relation: relation, Cause: ErrInvalidUsersetRewrite}
}
if _, err := t.GetRelation(objectType, computedUserset); err != nil {
return &RelationUndefinedError{ObjectType: objectType, Relation: computedUserset, Err: ErrRelationUndefined}
}
case *openfgav1.Userset_TupleToUserset:
tupleset := r.TupleToUserset.GetTupleset().GetRelation()
tuplesetRelation, err := t.GetRelation(objectType, tupleset)
if err != nil {
return &RelationUndefinedError{ObjectType: objectType, Relation: tupleset, Err: ErrRelationUndefined}
}
// Tupleset relations must only be direct relationships, no rewrites are allowed on them.
tuplesetRewrite := tuplesetRelation.GetRewrite()
if reflect.TypeOf(tuplesetRewrite.GetUserset()) != reflect.TypeOf(&openfgav1.Userset_This{}) {
return fmt.Errorf("the '%s#%s' relation is referenced in at least one tupleset and thus must be a direct relation", objectType, tupleset)
}
computedUserset := r.TupleToUserset.GetComputedUserset().GetRelation()
if IsSchemaVersionSupported(t.GetSchemaVersion()) {
// For 1.1 models, relation `computedUserset` has to be defined in one of the types declared by the tupleset's list of allowed types.
userTypes := tuplesetRelation.GetTypeInfo().GetDirectlyRelatedUserTypes()
for _, rr := range userTypes {
if _, err := t.GetRelation(rr.GetType(), computedUserset); err == nil {
return nil
}
}
return fmt.Errorf("%w: %s does not appear as a relation in any of the directly related user types %s", ErrRelationUndefined, computedUserset, userTypes)
}
// For 1.0 models, relation `computedUserset` has to be defined _somewhere_ in the model.
for typeName := range t.relations {
if _, err := t.GetRelation(typeName, computedUserset); err == nil {
return nil
}
}
return &RelationUndefinedError{ObjectType: "", Relation: computedUserset, Err: ErrRelationUndefined}
case *openfgav1.Userset_Union:
for _, child := range r.Union.GetChild() {
err := t.isUsersetRewriteValid(objectType, relation, child)
if err != nil {
return err
}
}
case *openfgav1.Userset_Intersection:
for _, child := range r.Intersection.GetChild() {
err := t.isUsersetRewriteValid(objectType, relation, child)
if err != nil {
return err
}
}
case *openfgav1.Userset_Difference:
err := t.isUsersetRewriteValid(objectType, relation, r.Difference.GetBase())
if err != nil {
return err
}
err = t.isUsersetRewriteValid(objectType, relation, r.Difference.GetSubtract())
if err != nil {
return err
}
}
return nil
}
// validateTypeRestrictions validates the type restrictions of a given relation using the following rules:
// 1. An assignable relation must have one or more type restrictions.
// 2. A non-assignable relation must not have any type restrictions.
// 3. For each type restriction referenced for an assignable relation, each of the referenced types and relations
// must be defined in the model.
// 4. If the provided relation is a tupleset relation, then the type restriction must be on a direct object.
func (t *TypeSystem) validateTypeRestrictions(objectType string, relationName string) error {
relation, err := t.GetRelation(objectType, relationName)
if err != nil {
return err
}
relatedTypes := relation.GetTypeInfo().GetDirectlyRelatedUserTypes()
assignable := t.IsDirectlyAssignable(relation)
if assignable && len(relatedTypes) == 0 {
return AssignableRelationError(objectType, relationName)
}
if !assignable && len(relatedTypes) != 0 {
return NonAssignableRelationError(objectType, relationName)
}
for _, related := range relatedTypes {
relatedObjectType := related.GetType()
relatedRelation := related.GetRelation()
if _, err := t.GetRelations(relatedObjectType); err != nil {
return InvalidRelationTypeError(objectType, relationName, relatedObjectType, relatedRelation)
}
if related.GetRelationOrWildcard() != nil {
// The type of the relation cannot contain a userset or wildcard if the relation is a tupleset relation.
if ok, _ := t.IsTuplesetRelation(objectType, relationName); ok {
return InvalidRelationTypeError(objectType, relationName, relatedObjectType, relatedRelation)
}
if relatedRelation != "" {
if _, err := t.GetRelation(relatedObjectType, relatedRelation); err != nil {
return InvalidRelationTypeError(objectType, relationName, relatedObjectType, relatedRelation)
}
}
}
if related.GetCondition() != "" {
// Validate the conditions referenced by the relations are included in the model.
if _, ok := t.conditions[related.GetCondition()]; !ok {
return &RelationConditionError{
Relation: relationName,
Condition: related.GetCondition(),
Err: ErrNoConditionForRelation,
}
}
}
}
return nil
}
// validateConditions validates the conditions provided in the model.
func (t *TypeSystem) validateConditions() error {
for key, c := range t.conditions {
if key != c.Name {
return fmt.Errorf("condition key '%s' does not match condition name '%s'", key, c.Name)
}
if err := c.Compile(); err != nil {
return err
}
}
return nil
}
func (t *TypeSystem) IsDirectlyAssignable(relation *openfgav1.Relation) bool {
return RewriteContainsSelf(relation.GetRewrite())
}
// RewriteContainsSelf returns true if the provided userset rewrite
// is defined by one or more self referencing definitions.
func RewriteContainsSelf(rewrite *openfgav1.Userset) bool {
result, err := WalkUsersetRewrite(rewrite, func(r *openfgav1.Userset) interface{} {
if _, ok := r.GetUserset().(*openfgav1.Userset_This); ok {
return true
}
return nil
})
if err != nil {
panic("unexpected error during rewrite evaluation")
}
return result != nil && result.(bool) // Type-cast matches the return from the WalkRelationshipRewriteHandler above.
}
func (t *TypeSystem) hasCycle(
objectType, relationName string,
rewrite *openfgav1.Userset,
visited map[string]struct{},
) (bool, error) {
visited[fmt.Sprintf("%s#%s", objectType, relationName)] = struct{}{}
visitedCopy := maps.Clone(visited)
var children []*openfgav1.Userset
switch rw := rewrite.GetUserset().(type) {
case *openfgav1.Userset_This, *openfgav1.Userset_TupleToUserset:
return false, nil
case *openfgav1.Userset_ComputedUserset:
rewrittenRelation := rw.ComputedUserset.GetRelation()
if _, ok := visited[fmt.Sprintf("%s#%s", objectType, rewrittenRelation)]; ok {
return true, nil
}
rewrittenRewrite, err := t.GetRelation(objectType, rewrittenRelation)
if err != nil {
return false, err
}
return t.hasCycle(objectType, rewrittenRelation, rewrittenRewrite.GetRewrite(), visitedCopy)
case *openfgav1.Userset_Union:
children = append(children, rw.Union.GetChild()...)
case *openfgav1.Userset_Intersection:
children = append(children, rw.Intersection.GetChild()...)
case *openfgav1.Userset_Difference:
children = append(children, rw.Difference.GetBase(), rw.Difference.GetSubtract())
}
for _, child := range children {
hasCycle, err := t.hasCycle(objectType, relationName, child, visitedCopy)
if err != nil {
return false, err
}
if hasCycle {
return true, nil
}
}
return false, nil
}
// HasCycle runs a cycle detection test on the provided `objectType#relation` to see if the relation
// defines a rewrite rule that is self-referencing in any way (through computed relationships).
func (t *TypeSystem) HasCycle(objectType, relationName string) (bool, error) {
visited := map[string]struct{}{}
relation, err := t.GetRelation(objectType, relationName)
if err != nil {
return false, err
}
return t.hasCycle(objectType, relationName, relation.GetRewrite(), visited)
}
// IsTuplesetRelation returns a boolean indicating if the provided relation is defined under a
// TupleToUserset rewrite as a tupleset relation (i.e. the right hand side of a `X from Y`).
func (t *TypeSystem) IsTuplesetRelation(objectType, relation string) (bool, error) {
_, err := t.GetRelation(objectType, relation)
if err != nil {
return false, err
}
for _, ttuDefinitions := range t.ttuRelations[objectType] {
for _, ttuDef := range ttuDefinitions {
if ttuDef.GetTupleset().GetRelation() == relation {
return true, nil
}
}
}
return false, nil
}
// GetEdgesFromNode first checks if the node can reach the source type,
// then returns all the from edges for the node.
func (t *TypeSystem) GetEdgesFromNode(
node *graph.WeightedAuthorizationModelNode,
sourceType string,
) ([]*graph.WeightedAuthorizationModelEdge, error) {
if t.authzWeightedGraph == nil {
return nil, fmt.Errorf("weighted graph is nil")
}
wg := t.authzWeightedGraph
// This means we cannot reach the source type requested, so there are no relevant edges.
if !hasPathTo(node, sourceType) {
return nil, nil
}
edges, ok := wg.GetEdgesFromNode(node)
if !ok {
// Note: this should not happen, but adding the guard nonetheless
return nil, fmt.Errorf("no outgoing edges from node: %s", node.GetUniqueLabel())
}
return edges, nil
}
// GetInternalEdges returns a slice with all the edges linked to a grouping logical node, otherwise the slice contains the original edge.
func (t *TypeSystem) GetInternalEdges(edge *graph.WeightedAuthorizationModelEdge, sourceType string) ([]*graph.WeightedAuthorizationModelEdge, error) {
var edges []*graph.WeightedAuthorizationModelEdge
if edge.GetEdgeType() == graph.DirectLogicalEdge || edge.GetEdgeType() == graph.TTULogicalEdge {
logicalEdges, err := t.GetConnectedEdges(edge.GetTo().GetUniqueLabel(), sourceType)
if err != nil {
return nil, err
}
edges = append(edges, logicalEdges...)
} else {
edges = append(edges, edge)
}
return edges, nil
}
// GetConnectedEdges returns all edges which have a path to the source type.
func (t *TypeSystem) GetConnectedEdges(targetTypeRelation string, sourceType string) ([]*graph.WeightedAuthorizationModelEdge, error) {
currentNode, ok := t.GetNode(targetTypeRelation)
if !ok {
return nil, fmt.Errorf("could not find node with label: %s", targetTypeRelation)
}
edges, err := t.GetEdgesFromNode(currentNode, sourceType)
if err != nil {
return nil, err
}
// Filter to only return edges which have a path to the sourceType
relevantEdges := slices.Collect(utils.Filter(edges, func(edge *graph.WeightedAuthorizationModelEdge) bool {
return hasPathTo(edge, sourceType)
}))
return relevantEdges, nil
}
func (t *TypeSystem) GetNode(uniqueID string) (*graph.WeightedAuthorizationModelNode, bool) {
if t.authzWeightedGraph == nil {
return nil, false
}
return t.authzWeightedGraph.GetNodeByID(uniqueID)
}
func flattenUserset(relationDef *openfgav1.Userset) []*openfgav1.TupleToUserset {
output := make([]*openfgav1.TupleToUserset, 0)
userset := relationDef.GetUserset()
switch x := userset.(type) {
case *openfgav1.Userset_TupleToUserset:
if x.TupleToUserset != nil {
output = append(output, x.TupleToUserset)
}
case *openfgav1.Userset_Union:
if x.Union != nil {
for _, child := range x.Union.GetChild() {
output = append(output, flattenUserset(child)...)
}
}
case *openfgav1.Userset_Intersection:
if x.Intersection != nil {
for _, child := range x.Intersection.GetChild() {
output = append(output, flattenUserset(child)...)
}
}
case *openfgav1.Userset_Difference:
if x.Difference != nil {
output = append(output, flattenUserset(x.Difference.GetBase())...)
output = append(output, flattenUserset(x.Difference.GetSubtract())...)
}
}
return output
}
// WalkUsersetRewriteHandler is a userset rewrite handler that is applied to a node in a userset rewrite
// tree. Implementations of the WalkUsersetRewriteHandler should return a non-nil value when the traversal
// over the rewrite tree should terminate and nil if traversal should proceed to other nodes in the tree.
type WalkUsersetRewriteHandler func(rewrite *openfgav1.Userset) interface{}
// WalkUsersetRewrite recursively walks the provided userset rewrite and invokes the provided WalkUsersetRewriteHandler
// to each node in the userset rewrite tree until the first non-nil response is encountered.
func WalkUsersetRewrite(rewrite *openfgav1.Userset, handler WalkUsersetRewriteHandler) (interface{}, error) {
var children []*openfgav1.Userset
if result := handler(rewrite); result != nil {
return result, nil
}
switch t := rewrite.GetUserset().(type) {
case *openfgav1.Userset_This:
return handler(rewrite), nil
case *openfgav1.Userset_ComputedUserset:
return handler(rewrite), nil
case *openfgav1.Userset_TupleToUserset:
return handler(rewrite), nil
case *openfgav1.Userset_Union:
children = t.Union.GetChild()
case *openfgav1.Userset_Intersection:
children = t.Intersection.GetChild()
case *openfgav1.Userset_Difference:
children = append(children, t.Difference.GetBase(), t.Difference.GetSubtract())
default:
return nil, fmt.Errorf("unexpected userset rewrite type encountered")
}
for _, child := range children {
result, err := WalkUsersetRewrite(child, handler)
if err != nil {
return nil, err
}
if result != nil {
return result, nil
}
}
return nil, nil
}
package typesystem
import (
"fmt"
"math"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/language/pkg/go/graph"
"github.com/openfga/openfga/pkg/tuple"
)
type weightedGraphItem interface {
GetWeight(destinationType string) (int, bool)
}
// hasPathTo returns a boolean indicating if a path exists from a node or edge to a terminal type. E.g
// can we reach "user" from "document".
func hasPathTo(nodeOrEdge weightedGraphItem, destinationType string) bool {
_, ok := nodeOrEdge.GetWeight(destinationType)
return ok
}
type IntersectionEdges struct {
LowestEdge *graph.WeightedAuthorizationModelEdge // lowest edge to apply list objects
SiblingEdges []*graph.WeightedAuthorizationModelEdge // the rest of the edges to apply intersection
}
type ExclusionEdges struct {
BaseEdge *graph.WeightedAuthorizationModelEdge // base edge to apply list objects
ExcludedEdge *graph.WeightedAuthorizationModelEdge // excluded edge to apply exclusion
}
// GetEdgesForIntersection returns the lowest weighted edge and
// its siblings edges for intersection based via the weighted graph.
// If the direct edges have equal weight as its sibling edges, it will choose
// the direct edges as preference.
// If any of the children are not connected, it will return empty IntersectionEdges.
func GetEdgesForIntersection(edges []*graph.WeightedAuthorizationModelEdge, sourceType string) (IntersectionEdges, error) {
if len(edges) < 2 {
// Intersection by definition must have at least 2 children
return IntersectionEdges{}, fmt.Errorf("invalid edges for source type %s", sourceType)
}
// Find the group with the lowest maximum weight
lowestWeight := math.MaxInt32
var lowestEdge *graph.WeightedAuthorizationModelEdge
siblingEdges := make([]*graph.WeightedAuthorizationModelEdge, 0, len(edges))
for _, edge := range edges {
weight, _ := edge.GetWeight(sourceType)
// get the max weight of the grouping
if weight < lowestWeight {
if lowestEdge != nil {
siblingEdges = append(siblingEdges, lowestEdge)
}
lowestWeight = weight
lowestEdge = edge
} else {
siblingEdges = append(siblingEdges, edge)
}
}
return IntersectionEdges{
LowestEdge: lowestEdge,
SiblingEdges: siblingEdges,
}, nil
}
// GetEdgesForExclusion returns the base edges (i.e., edge A in "A but not B") and
// excluded edge (edge B in "A but not B") based on weighted graph for exclusion.
func GetEdgesForExclusion(
edges []*graph.WeightedAuthorizationModelEdge,
sourceType string,
) (ExclusionEdges, error) {
if len(edges) != 2 {
return ExclusionEdges{}, fmt.Errorf("invalid number of edges in an exclusion operation: expected 2, got %d", len(edges))
}
if !hasPathTo(edges[0], sourceType) {
return ExclusionEdges{}, fmt.Errorf("the base edge does not have a path to the source type %s", sourceType)
}
baseEdge := edges[0]
excludedEdge := edges[1]
if !hasPathTo(excludedEdge, sourceType) {
excludedEdge = nil
}
return ExclusionEdges{
BaseEdge: baseEdge,
ExcludedEdge: excludedEdge,
}, nil
}
// ConstructUserset returns the openfgav1.Userset to run CheckRewrite against list objects candidate when
// model has intersection / exclusion.
func (t *TypeSystem) ConstructUserset(currentEdge *graph.WeightedAuthorizationModelEdge, sourceUserType string) (*openfgav1.Userset, error) {
currentNode := currentEdge.GetTo()
edgeType := currentEdge.GetEdgeType()
uniqueLabel := currentNode.GetUniqueLabel()
switch currentNode.GetNodeType() {
case graph.SpecificType, graph.SpecificTypeWildcard, graph.LogicalDirectGrouping:
return This(), nil
case graph.LogicalTTUGrouping:
edges, err := t.GetEdgesFromNode(currentNode, sourceUserType)
if err != nil {
return nil, fmt.Errorf("failed to get edges from node %s: %w", currentNode.GetUniqueLabel(), err)
}
if len(edges) == 0 {
return nil, fmt.Errorf("no edges found for node %s", currentNode.GetUniqueLabel())
}
// we just need to get the ttu information out of the first edge of the ttu logical group as the all belong to the same ttu definition
return t.ConstructUserset(edges[0], sourceUserType)
case graph.SpecificTypeAndRelation:
switch edgeType {
case graph.DirectEdge:
// userset use case
return This(), nil
case graph.RewriteEdge, graph.ComputedEdge:
_, relation := tuple.SplitObjectRelation(uniqueLabel)
return &openfgav1.Userset{
Userset: &openfgav1.Userset_ComputedUserset{
ComputedUserset: &openfgav1.ObjectRelation{
Relation: relation,
},
},
}, nil
case graph.TTUEdge:
_, parent := tuple.SplitObjectRelation(currentEdge.GetTuplesetRelation())
_, relation := tuple.SplitObjectRelation(uniqueLabel)
return &openfgav1.Userset{
Userset: &openfgav1.Userset_TupleToUserset{
TupleToUserset: &openfgav1.TupleToUserset{
Tupleset: &openfgav1.ObjectRelation{
Relation: parent, // parent
},
ComputedUserset: &openfgav1.ObjectRelation{
Relation: relation,
},
},
},
}, nil
default:
// This should never happen.
return nil, fmt.Errorf("unknown edge type: %v for node: %s", edgeType, currentNode.GetUniqueLabel())
}
case graph.OperatorNode:
switch currentNode.GetLabel() {
case graph.ExclusionOperator:
return t.ConstructExclusionUserset(currentNode, sourceUserType)
case graph.IntersectionOperator:
return t.ConstructIntersectionUserset(currentNode, sourceUserType)
case graph.UnionOperator:
return t.ConstructUnionUserset(currentNode, sourceUserType)
default:
// This should never happen.
return nil, fmt.Errorf("unknown operator node label %s for node %s", currentNode.GetLabel(), currentNode.GetUniqueLabel())
}
default:
// This should never happen.
return nil, fmt.Errorf("unknown node type %v for node %s", currentNode.GetNodeType(), currentNode.GetUniqueLabel())
}
}
func (t *TypeSystem) ConstructExclusionUserset(node *graph.WeightedAuthorizationModelNode, sourceUserType string) (*openfgav1.Userset, error) {
edges, ok := t.authzWeightedGraph.GetEdgesFromNode(node)
if !ok || node.GetLabel() != graph.ExclusionOperator {
// This should never happen.
return nil, fmt.Errorf("incorrect exclusion node: %s", node.GetUniqueLabel())
}
exclusionEdges, err := GetEdgesForExclusion(edges, sourceUserType)
if err != nil {
return nil, fmt.Errorf("error getting the edges for operation: exclusion: %s", err.Error())
}
baseUserset, err := t.ConstructUserset(exclusionEdges.BaseEdge, sourceUserType)
if err != nil {
return nil, fmt.Errorf("failed to construct userset for edge %s: %w", exclusionEdges.BaseEdge.GetTo().GetUniqueLabel(), err)
}
if exclusionEdges.ExcludedEdge == nil {
return baseUserset, nil
}
excludedUserset, err := t.ConstructUserset(exclusionEdges.ExcludedEdge, sourceUserType)
if err != nil {
return nil, fmt.Errorf("failed to construct userset for edge %s: %w", exclusionEdges.ExcludedEdge.GetTo().GetUniqueLabel(), err)
}
return &openfgav1.Userset{
Userset: &openfgav1.Userset_Difference{
Difference: &openfgav1.Difference{
Base: baseUserset,
Subtract: excludedUserset,
}}}, nil
}
func (t *TypeSystem) ConstructIntersectionUserset(node *graph.WeightedAuthorizationModelNode, sourceUserType string) (*openfgav1.Userset, error) {
edges, ok := t.authzWeightedGraph.GetEdgesFromNode(node)
if !ok || node.GetLabel() != graph.IntersectionOperator {
// This should never happen.
return nil, fmt.Errorf("incorrect intersection node: %s", node.GetUniqueLabel())
}
var usersets []*openfgav1.Userset
if len(edges) < 2 {
return nil, fmt.Errorf("no valid edges found for intersection")
}
for _, edge := range edges {
userset, err := t.ConstructUserset(edge, sourceUserType)
if err != nil {
return nil, fmt.Errorf("failed to construct userset for edge %s: %w", edge.GetTo().GetUniqueLabel(), err)
}
usersets = append(usersets, userset)
}
return &openfgav1.Userset{
Userset: &openfgav1.Userset_Intersection{
Intersection: &openfgav1.Usersets{
Child: usersets,
}}}, nil
}
func (t *TypeSystem) ConstructUnionUserset(node *graph.WeightedAuthorizationModelNode, sourceUserType string) (*openfgav1.Userset, error) {
edges, ok := t.authzWeightedGraph.GetEdgesFromNode(node)
if !ok || node.GetLabel() != graph.UnionOperator {
// This should never happen.
return nil, fmt.Errorf("incorrect union node: %s", node.GetUniqueLabel())
}
var usersets []*openfgav1.Userset
if len(edges) < 2 {
return nil, fmt.Errorf("no valid edges found for union")
}
for _, edge := range edges {
userset, err := t.ConstructUserset(edge, sourceUserType)
if err != nil {
return nil, fmt.Errorf("failed to construct userset for edge %s: %w", edge.GetTo().GetUniqueLabel(), err)
}
usersets = append(usersets, userset)
}
return &openfgav1.Userset{
Userset: &openfgav1.Userset_Union{
Union: &openfgav1.Usersets{
Child: usersets,
}}}, nil
}
package tests
import (
"context"
"fmt"
"testing"
"time"
"google.golang.org/grpc"
openfgav1 "github.com/openfga/api/proto/openfga/v1"
"github.com/openfga/openfga/cmd/run"
"github.com/openfga/openfga/pkg/logger"
serverconfig "github.com/openfga/openfga/pkg/server/config"
"github.com/openfga/openfga/pkg/testfixtures/storage"
"github.com/openfga/openfga/pkg/testutils"
)
// TestClientBootstrapper defines a client interface definition that can be used by tests
// to bootstrap OpenFGA resources (stores, models, relationship tuples, etc.), needed to
// execute tests.
type TestClientBootstrapper interface {
CreateStore(ctx context.Context, in *openfgav1.CreateStoreRequest, opts ...grpc.CallOption) (*openfgav1.CreateStoreResponse, error)
WriteAuthorizationModel(ctx context.Context, in *openfgav1.WriteAuthorizationModelRequest, opts ...grpc.CallOption) (*openfgav1.WriteAuthorizationModelResponse, error)
Write(ctx context.Context, in *openfgav1.WriteRequest, opts ...grpc.CallOption) (*openfgav1.WriteResponse, error)
}
// ClientInterface defines client interface for running tests.
type ClientInterface interface {
TestClientBootstrapper
Check(ctx context.Context, in *openfgav1.CheckRequest, opts ...grpc.CallOption) (*openfgav1.CheckResponse, error)
ListUsers(ctx context.Context, in *openfgav1.ListUsersRequest, opts ...grpc.CallOption) (*openfgav1.ListUsersResponse, error)
ListObjects(ctx context.Context, in *openfgav1.ListObjectsRequest, opts ...grpc.CallOption) (*openfgav1.ListObjectsResponse, error)
StreamedListObjects(ctx context.Context, in *openfgav1.StreamedListObjectsRequest, opts ...grpc.CallOption) (openfgav1.OpenFGAService_StreamedListObjectsClient, error)
}
// StartServer calls StartServerWithContext. See the docs for that.
func StartServer(t testing.TB, cfg *serverconfig.Config) {
logger := logger.MustNewLogger(cfg.Log.Format, cfg.Log.Level, cfg.Log.TimestampFormat)
serverCtx := &run.ServerContext{Logger: logger}
StartServerWithContext(t, cfg, serverCtx)
}
// StartServerWithContext starts a server in random ports and with a specific ServerContext and waits until it is healthy.
// When the test ends, all resources are cleaned.
func StartServerWithContext(t testing.TB, cfg *serverconfig.Config, serverCtx *run.ServerContext) {
container := storage.RunDatastoreTestContainer(t, cfg.Datastore.Engine)
cfg.Datastore.URI = container.GetConnectionURI(true)
ctx, cancel := context.WithCancel(context.Background())
httpPort, httpPortReleaser := testutils.TCPRandomPort()
cfg.HTTP.Addr = fmt.Sprintf("localhost:%d", httpPort)
grpcPort, grpcPortReleaser := testutils.TCPRandomPort()
cfg.GRPC.Addr = fmt.Sprintf("localhost:%d", grpcPort)
// these two functions release the ports so that the server can start listening on them
httpPortReleaser()
grpcPortReleaser()
serverDone := make(chan error)
go func() {
serverDone <- serverCtx.Run(ctx, cfg)
}()
t.Cleanup(func() {
t.Log("waiting for server to stop")
cancel()
serverErr := <-serverDone
t.Log("server stopped with error: ", serverErr)
})
testutils.EnsureServiceHealthy(t, cfg.GRPC.Addr, cfg.HTTP.Addr, nil)
}
// BuildClientInterface sets up test client interface to be used for matrix test.
func BuildClientInterface(t *testing.T, engine string, experimentals []string) ClientInterface {
cfg := serverconfig.MustDefaultConfig()
if len(experimentals) > 0 {
cfg.Experimentals = append(cfg.Experimentals, experimentals...)
}
cfg.Log.Level = "error"
cfg.Datastore.Engine = engine
cfg.ListUsersDeadline = 0 // no deadline
cfg.ListObjectsDeadline = 0 // no deadline
// extend the timeout for the tests, coverage makes them slower
cfg.RequestTimeout = 10 * time.Second
cfg.SharedIterator.Enabled = true
cfg.CheckIteratorCache.Enabled = true
cfg.ListObjectsIteratorCache.Enabled = true
cfg.ContextPropagationToDatastore = true
StartServer(t, cfg)
conn := testutils.CreateGrpcConnection(t, cfg.GRPC.Addr)
return openfgav1.NewOpenFGAServiceClient(conn)
}