diff --git a/cluster/cluster/forking/cluster_invoker.go b/cluster/cluster/forking/cluster_invoker.go index 326a89f682..f110dc991b 100644 --- a/cluster/cluster/forking/cluster_invoker.go +++ b/cluster/cluster/forking/cluster_invoker.go @@ -71,10 +71,15 @@ func (invoker *forkingClusterInvoker) Invoke(ctx context.Context, invocation pro } } + // forkCtx is canceled when Invoke returns, signaling losing parallel + // goroutines to stop rather than running to completion. + forkCtx, cancel := context.WithCancel(ctx) + defer cancel() + resultQ := queue.New(1) for _, ivk := range selected { go func(k protocolbase.Invoker) { - result := k.Invoke(ctx, invocation) + result := k.Invoke(forkCtx, invocation) if err := resultQ.Put(result); err != nil { logger.Errorf("[Cluster][Forking] resultQ put failed with exception err=%v", err) } diff --git a/cluster/cluster/forking/cluster_test.go b/cluster/cluster/forking/cluster_test.go index 3eec6dce0b..aac90caa08 100644 --- a/cluster/cluster/forking/cluster_test.go +++ b/cluster/cluster/forking/cluster_test.go @@ -163,3 +163,38 @@ func TestForkingInvokeHalfTimeout(t *testing.T) { assert.Equal(t, mockResult, result) wg.Wait() } + +func TestForkingInvokeCancelsLosingBranches(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + forkingUrl.AddParam(constant.ForksKey, strconv.Itoa(2)) + + loserCtxCanceled := make(chan struct{}) + + winner := mock.NewMockInvoker(ctrl) + winner.EXPECT().IsAvailable().Return(true).AnyTimes() + winner.EXPECT().Invoke(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, _ base.Invocation) result.Result { + return &result.RPCResult{} + }) + + loser := mock.NewMockInvoker(ctrl) + loser.EXPECT().IsAvailable().Return(true).AnyTimes() + loser.EXPECT().Invoke(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, _ base.Invocation) result.Result { + <-ctx.Done() + close(loserCtxCanceled) + return &result.RPCResult{} + }) + + clusterInvoker := registerForking(winner, loser) + res := clusterInvoker.Invoke(context.Background(), &invocation.RPCInvocation{}) + require.NoError(t, res.Error()) + + select { + case <-loserCtxCanceled: + case <-time.After(2 * time.Second): + t.Fatal("loser branch context was not canceled after winner succeeded") + } +} diff --git a/cluster/loadbalance/roundrobin/loadbalance.go b/cluster/loadbalance/roundrobin/loadbalance.go index fcd151d185..4859ddb2cc 100644 --- a/cluster/loadbalance/roundrobin/loadbalance.go +++ b/cluster/loadbalance/roundrobin/loadbalance.go @@ -69,6 +69,12 @@ func (lb *rrLoadBalance) Select(invokers []base.Invoker, invocation base.Invocat cache, _ := methodWeightMap.LoadOrStore(key, &cachedInvokers{}) cachedInvokers := cache.(*cachedInvokers) + // Serialize the full select+update sequence per service+method key so that + // concurrent callers cannot observe each other's intermediate currentWeight + // state and skew the weighted distribution. + cachedInvokers.mu.Lock() + defer cachedInvokers.mu.Unlock() + var ( clean = false totalWeight = int64(0) @@ -166,5 +172,6 @@ func (robin *weightedRoundRobin) Current(delta int64) { } type cachedInvokers struct { + mu sync.Mutex sync.Map /*[string]weightedRoundRobin*/ } diff --git a/cluster/loadbalance/roundrobin/loadbalance_test.go b/cluster/loadbalance/roundrobin/loadbalance_test.go index 056e670858..5499f70234 100644 --- a/cluster/loadbalance/roundrobin/loadbalance_test.go +++ b/cluster/loadbalance/roundrobin/loadbalance_test.go @@ -20,6 +20,7 @@ package roundrobin import ( "fmt" "strconv" + "sync" "testing" ) @@ -75,3 +76,27 @@ func TestRoundRobinByWeight(t *testing.T) { assert.Equal(t, w, selected[i]) } } + +func TestRoundRobinByWeightConcurrent(t *testing.T) { + loadBalance := NewRRLoadBalance() + + var invokers []base.Invoker + for i := 1; i <= 5; i++ { + url, _ := common.NewURL(fmt.Sprintf("dubbo://192.168.1.%v:20000/org.apache.demo.HelloService?weight=%v", i, i)) + invokers = append(invokers, base.NewBaseInvoker(url)) + } + + const goroutines = 50 + var wg sync.WaitGroup + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < 30; j++ { + invoker := loadBalance.Select(invokers, &invocation.RPCInvocation{}) + assert.NotNil(t, invoker) + } + }() + } + wg.Wait() +} diff --git a/cluster/router/script/instance/js_instance.go b/cluster/router/script/instance/js_instance.go index 58e51693b3..364d3ae631 100644 --- a/cluster/router/script/instance/js_instance.go +++ b/cluster/router/script/instance/js_instance.go @@ -46,6 +46,10 @@ type jsInstances struct { type jsInstance struct { rt *goja.Runtime + // baseGlobals is the set of global names present in a freshly created + // runtime (built-ins). Any global outside this set is treated as + // request/script state and removed on reset. + baseGlobals map[string]struct{} } type program struct { @@ -82,6 +86,10 @@ func (i *jsInstances) Run(rawScript string, invokers []base.Invoker, invocation return invokers, nil } matcher := i.insPool.Get().(*jsInstance) + defer func() { + matcher.reset() + i.insPool.Put(matcher) + }() packInvokers := make([]base.Invoker, 0, len(invokers)) for _, invoker := range invokers { @@ -179,9 +187,30 @@ func (j jsInstance) initReplyVar() { } } +// reset removes every global added since the runtime was created — the request +// bindings (invokers/invocation/context/result) as well as any globals defined +// by the executed user script — before returning the instance to the pool. This +// prevents cross-request state leakage via pooled goja.Runtime globals. Built-in +// globals captured at construction time are preserved. +func (j jsInstance) reset() { + global := j.rt.GlobalObject() + for _, key := range global.Keys() { + if _, isBase := j.baseGlobals[key]; isBase { + continue + } + _ = global.Delete(key) + } +} + func newJsInstance() *jsInstance { + rt := goja.New() + base := make(map[string]struct{}) + for _, key := range rt.GlobalObject().Keys() { + base[key] = struct{}{} + } return &jsInstance{ - rt: goja.New(), + rt: rt, + baseGlobals: base, } } diff --git a/cluster/router/script/instance/js_instance_test.go b/cluster/router/script/instance/js_instance_test.go index c91d17433c..888c5cc59e 100644 --- a/cluster/router/script/instance/js_instance_test.go +++ b/cluster/router/script/instance/js_instance_test.go @@ -716,3 +716,37 @@ func TestRunScriptInPanic(t *testing.T) { ` wontPanic(scriptCallWrongArgs3) } + +func TestJsInstanceResetClearsBindings(t *testing.T) { + inst := newJsInstance() + + testify_require.NoError(t, inst.rt.Set("invokers", []string{"test"})) + testify_require.NoError(t, inst.rt.Set("invocation", "test-inv")) + testify_require.NoError(t, inst.rt.Set("context", "test-ctx")) + testify_require.NoError(t, inst.rt.Set(jsScriptResultName, "test-result")) + // Global defined by a user script must also be cleared on reset. + _, err := inst.rt.RunString("userDefinedGlobal = 42;") + testify_require.NoError(t, err) + assert.NotNil(t, inst.rt.Get("userDefinedGlobal")) + + inst.reset() + + // Get returns nil for undefined globals after reset. + assert.Nil(t, inst.rt.Get("invokers")) + assert.Nil(t, inst.rt.Get("invocation")) + assert.Nil(t, inst.rt.Get("context")) + assert.Nil(t, inst.rt.Get(jsScriptResultName)) + assert.Nil(t, inst.rt.Get("userDefinedGlobal")) +} + +func TestJsInstanceRunReturnsRuntimeToPool(t *testing.T) { + instances := newJsInstances() + testify_require.NoError(t, instances.Compile(Func_Script)) + + invokers, inv, _ := getRouteArgs() + for i := 0; i < 5; i++ { + result, err := instances.Run(Func_Script, invokers, inv) + testify_require.NoError(t, err) + assert.NotEmpty(t, result) + } +} diff --git a/filter/tps/limiter/method_service.go b/filter/tps/limiter/method_service.go index 4d0896310b..6909d27713 100644 --- a/filter/tps/limiter/method_service.go +++ b/filter/tps/limiter/method_service.go @@ -20,12 +20,12 @@ package limiter import ( "strconv" "sync" + "sync/atomic" + "time" ) import ( "github.com/dubbogo/gost/log/logger" - - "github.com/modern-go/concurrent" ) import ( @@ -38,8 +38,30 @@ import ( const ( name = "method-service" + + tpsLimiterStateTTL = 10 * time.Minute + tpsLimiterCleanupInterval = 5 * time.Minute ) +// tpsLimitEntry wraps a TpsLimitStrategy with a last-access timestamp so that +// stale entries (services/methods no longer invoked) can be evicted periodically. +type tpsLimitEntry struct { + strategy filter.TpsLimitStrategy + lastAccess int64 // unix nanoseconds, updated atomically +} + +func newTpsLimitEntry(s filter.TpsLimitStrategy) *tpsLimitEntry { + return &tpsLimitEntry{ + strategy: s, + lastAccess: time.Now().UnixNano(), + } +} + +func (e *tpsLimitEntry) IsAllowable() bool { + atomic.StoreInt64(&e.lastAccess, time.Now().UnixNano()) + return e.strategy.IsAllowable() +} + func init() { extension.SetTpsLimiter(constant.DefaultKey, GetMethodServiceTpsLimiter) extension.SetTpsLimiter(name, GetMethodServiceTpsLimiter) @@ -112,7 +134,7 @@ func init() { * In this case, only UpdateUser will be limited by its configuration (70 times in 40000ms) */ type MethodServiceTpsLimiter struct { - tpsState *concurrent.Map + tpsState sync.Map // map[string]*tpsLimitEntry } // IsAllowable based on method-level and service-level. @@ -121,7 +143,7 @@ type MethodServiceTpsLimiter struct { // The key point is how to keep thread-safe // This implementation use concurrent map + loadOrStore to make implementation thread-safe // You can image that even multiple threads create limiter, but only one could store the limiter into tpsState -func (limiter MethodServiceTpsLimiter) IsAllowable(url *common.URL, invocation base.Invocation) bool { +func (limiter *MethodServiceTpsLimiter) IsAllowable(url *common.URL, invocation base.Invocation) bool { methodConfigPrefix := "methods." + invocation.MethodName() + "." methodLimitRateConfig := url.GetParam(methodConfigPrefix+constant.TPSLimitRateKey, "") @@ -140,7 +162,7 @@ func (limiter MethodServiceTpsLimiter) IsAllowable(url *common.URL, invocation b limitState, found := limiter.tpsState.Load(limitTarget) if found { // the limiter has been cached, we return its result - return limitState.(filter.TpsLimitStrategy).IsAllowable() + return limitState.(*tpsLimitEntry).IsAllowable() } // we could not find the limiter, and try to create one. @@ -172,10 +194,27 @@ func (limiter MethodServiceTpsLimiter) IsAllowable(url *common.URL, invocation b return true } - // we using loadOrStore to ensure thread-safe - limitState, _ = limiter.tpsState.LoadOrStore(limitTarget, limitStateCreator.Create(int(limitRate), int(limitInterval))) + // we using LoadOrStore to ensure thread-safe; wrap in tpsLimitEntry for TTL tracking + entry := newTpsLimitEntry(limitStateCreator.Create(int(limitRate), int(limitInterval))) + actual, _ := limiter.tpsState.LoadOrStore(limitTarget, entry) + return actual.(*tpsLimitEntry).IsAllowable() +} - return limitState.(filter.TpsLimitStrategy).IsAllowable() +// runCleanup periodically evicts limiter entries that have not been accessed +// within tpsLimiterStateTTL. This prevents unbounded accumulation when +// services or methods are removed dynamically. +func (limiter *MethodServiceTpsLimiter) runCleanup() { + ticker := time.NewTicker(tpsLimiterCleanupInterval) + defer ticker.Stop() + for range ticker.C { + cutoff := time.Now().Add(-tpsLimiterStateTTL).UnixNano() + limiter.tpsState.Range(func(key, val any) bool { + if atomic.LoadInt64(&val.(*tpsLimitEntry).lastAccess) < cutoff { + limiter.tpsState.Delete(key) + } + return true + }) + } } // getLimitConfig will try to fetch the configuration from url. @@ -215,9 +254,9 @@ var ( // GetMethodServiceTpsLimiter will return an MethodServiceTpsLimiter instance. func GetMethodServiceTpsLimiter() filter.TpsLimiter { methodServiceTpsLimiterOnce.Do(func() { - methodServiceTpsLimiterInstance = &MethodServiceTpsLimiter{ - tpsState: concurrent.NewMap(), - } + inst := &MethodServiceTpsLimiter{} + go inst.runCleanup() + methodServiceTpsLimiterInstance = inst }) return methodServiceTpsLimiterInstance } diff --git a/filter/tps/limiter/method_service_test.go b/filter/tps/limiter/method_service_test.go index f376e3e538..900b7e4ef5 100644 --- a/filter/tps/limiter/method_service_test.go +++ b/filter/tps/limiter/method_service_test.go @@ -19,7 +19,9 @@ package limiter import ( "net/url" + "sync/atomic" "testing" + "time" ) import ( @@ -156,3 +158,40 @@ func (creator *mockStrategyCreator) Create(rate int, interval int) filter.TpsLim assert.Equal(creator.t, creator.interval, interval) return creator.strategy } + +type stubAllowStrategy struct{} + +func (s *stubAllowStrategy) IsAllowable() bool { return true } + +func TestTpsLimitEntryUpdatesLastAccess(t *testing.T) { + entry := newTpsLimitEntry(&stubAllowStrategy{}) + before := atomic.LoadInt64(&entry.lastAccess) + time.Sleep(time.Millisecond) + entry.IsAllowable() + after := atomic.LoadInt64(&entry.lastAccess) + assert.Greater(t, after, before) +} + +func TestTpsLimiterStaleEntryEviction(t *testing.T) { + limiter := &MethodServiceTpsLimiter{} + + staleEntry := newTpsLimitEntry(&stubAllowStrategy{}) + atomic.StoreInt64(&staleEntry.lastAccess, time.Now().Add(-tpsLimiterStateTTL-time.Second).UnixNano()) + limiter.tpsState.Store("stale-key", staleEntry) + + activeEntry := newTpsLimitEntry(&stubAllowStrategy{}) + limiter.tpsState.Store("active-key", activeEntry) + + cutoff := time.Now().Add(-tpsLimiterStateTTL).UnixNano() + limiter.tpsState.Range(func(key, val any) bool { + if atomic.LoadInt64(&val.(*tpsLimitEntry).lastAccess) < cutoff { + limiter.tpsState.Delete(key) + } + return true + }) + + _, staleExists := limiter.tpsState.Load("stale-key") + _, activeExists := limiter.tpsState.Load("active-key") + assert.False(t, staleExists, "stale entry should be evicted") + assert.True(t, activeExists, "active entry should remain") +} diff --git a/graceful_shutdown/shutdown.go b/graceful_shutdown/shutdown.go index 7b07ed369e..900855bcb2 100644 --- a/graceful_shutdown/shutdown.go +++ b/graceful_shutdown/shutdown.go @@ -64,10 +64,11 @@ var ( shutdownConfigMu sync.RWMutex shutdownConfig *global.ShutdownConfig - shutdownOnce sync.Once - shutdownStarted atomic.Bool - shutdownDone = make(chan struct{}) - shutdownResult error + shutdownOnce sync.Once + shutdownStarted atomic.Bool + shutdownDone = make(chan struct{}) + shutdownResultMu sync.Mutex + shutdownResult error signalNotify = signal.Notify ) @@ -152,7 +153,10 @@ func Shutdown(ctx context.Context) error { select { case <-shutdownDone: - return shutdownResult + shutdownResultMu.Lock() + err := shutdownResult + shutdownResultMu.Unlock() + return err case <-ctx.Done(): return ctx.Err() } @@ -214,14 +218,18 @@ func startShutdownOnce() { defer func() { if recovered := recover(); recovered != nil { logger.Warnf("[GracefulShutdown] shutdown panicked, err=%v", recovered) + shutdownResultMu.Lock() shutdownResult = fmt.Errorf("graceful shutdown panic: %v", recovered) + shutdownResultMu.Unlock() } close(shutdownDone) }() cfg := loadShutdownConfig() beforeShutdown(cfg) + shutdownResultMu.Lock() shutdownResult = nil + shutdownResultMu.Unlock() }() }) } diff --git a/graceful_shutdown/shutdown_test.go b/graceful_shutdown/shutdown_test.go index e69280f48c..776d1b6bf1 100644 --- a/graceful_shutdown/shutdown_test.go +++ b/graceful_shutdown/shutdown_test.go @@ -113,6 +113,7 @@ func resetShutdownTestState() { shutdownOnce = sync.Once{} shutdownStarted = atomic.Bool{} shutdownDone = make(chan struct{}) + shutdownResultMu = sync.Mutex{} shutdownResult = nil signalNotify = signal.Notify } @@ -630,3 +631,58 @@ func TestInvokeCustomShutdownCallbackDoesNotBlockForever(t *testing.T) { assert.Less(t, elapsed, time.Second) } + +func TestShutdownResultConsistentUnderConcurrentReaders(t *testing.T) { + resetShutdownTestState() + + cfg := global.DefaultShutdownConfig() + internalSignal := false + cfg.InternalSignal = &internalSignal + cfg.ConsumerUpdateWaitTime = "0s" + cfg.StepTimeout = "0s" + cfg.NotifyTimeout = "10ms" + cfg.OfflineRequestWindowTimeout = "0s" + Init(SetShutdownConfig(cfg)) + + const readers = 50 + var wg sync.WaitGroup + wg.Add(readers) + errs := make([]error, readers) + for i := 0; i < readers; i++ { + idx := i + go func() { + defer wg.Done() + errs[idx] = Shutdown(context.Background()) + }() + } + wg.Wait() + for _, err := range errs { + assert.NoError(t, err) + } +} + +func TestShutdownPanicSetsResultError(t *testing.T) { + resetShutdownTestState() + + testProtocolName := "shutdown-panic-protocol" + extension.SetProtocol(testProtocolName, func() base.Protocol { + return &testProtocol{destroy: func() { panic("simulated shutdown panic") }} + }) + t.Cleanup(func() { extension.UnregisterProtocol(testProtocolName) }) + + cfg := global.DefaultShutdownConfig() + internalSignal := false + cfg.InternalSignal = &internalSignal + cfg.ConsumerUpdateWaitTime = "0s" + cfg.StepTimeout = "0s" + cfg.NotifyTimeout = "10ms" + cfg.OfflineRequestWindowTimeout = "0s" + Init(SetShutdownConfig(cfg)) + + // RegisterProtocol must come after Init — Init resets protocols on first call. + RegisterProtocol(testProtocolName) + + err := Shutdown(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "panic") +} diff --git a/remoting/getty/getty_server.go b/remoting/getty/getty_server.go index 17c039dec4..016cb499f1 100644 --- a/remoting/getty/getty_server.go +++ b/remoting/getty/getty_server.go @@ -97,21 +97,35 @@ func initServer(url *common.URL) { return } - gettyServerConfigBytes, err := yaml.Marshal(gettyServerConfig) + gettyServerConfigBytes, err := safeYAMLMarshal(gettyServerConfig) if err != nil { - panic(err) + logger.Errorf("[Remoting][Getty] failed to marshal getty server config, err=%v", err) + return } err = yaml.Unmarshal(gettyServerConfigBytes, srvConf) if err != nil { - panic(err) + logger.Errorf("[Remoting][Getty] failed to unmarshal getty server config, err=%v", err) + return } } if err := srvConf.CheckValidity(); err != nil { - panic(err) + logger.Errorf("[Remoting][Getty] server config is invalid, err=%v", err) + return } } +// safeYAMLMarshal wraps yaml.Marshal with a recover so that types unsupported by +// the yaml encoder (e.g. channels, funcs) return an error instead of panicking. +func safeYAMLMarshal(v any) (out []byte, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("yaml marshal panic: %v", r) + } + }() + return yaml.Marshal(v) +} + // SetServerConfig set dubbo server config. func SetServerConfig(s ServerConfig) { srvConf = &s @@ -177,7 +191,7 @@ func (s *Server) newSession(session getty.Session) error { return nil } if _, ok = session.Conn().(*net.TCPConn); !ok { - panic(fmt.Sprintf("%s, session.conn{%#v} is not tcp connection\n", session.Stat(), session.Conn())) + return perrors.New(fmt.Sprintf("%s, session.conn{%#v} is not tcp connection", session.Stat(), session.Conn())) } if _, ok = session.Conn().(*tls.Conn); !ok { diff --git a/remoting/getty/getty_server_test.go b/remoting/getty/getty_server_test.go index d0ef024241..03909867e7 100644 --- a/remoting/getty/getty_server_test.go +++ b/remoting/getty/getty_server_test.go @@ -93,3 +93,20 @@ func TestInitServerTLS(t *testing.T) { assert.False(t, srvConf.SSLEnabled) }) } + +func TestInitServerDoesNotPanicOnUnmarshalableParams(t *testing.T) { + url, err := common.NewURL("dubbo://127.0.0.1:20003/test") + require.NoError(t, err) + url.SetAttribute(constant.ProtocolConfigKey, map[string]*global.ProtocolConfig{ + "dubbo": { + Name: "dubbo", + Ip: "127.0.0.1", + Port: "20003", + Params: map[string]any{"conn-pool-size": make(chan int)}, // channel is not YAML-serializable + }, + }) + url.SetAttribute(constant.ApplicationKey, global.ApplicationConfig{}) + assert.NotPanics(t, func() { + initServer(url) + }) +}