Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (a ActivityWithImpl) validate(q *Queue, v *validationState) error {
}
_, ok := v.activitiesValidated[a.activityName]
if ok {
return fmt.Errorf("duplicate activtity name %s for queue %s", a.activityName, q.name)
return fmt.Errorf("duplicate activity name %s for queue %s", a.activityName, q.name)
}
v.activitiesValidated[a.activityName] = struct{}{}
return nil
Expand All @@ -43,9 +43,11 @@ type Activity[Param, Return any] struct {

// NewActivity declares the existence of an activity on a given queue with a given name.
func NewActivity[Param, Return any](q *Queue, name string) Activity[Param, Return] {
validateName(name)
panicIfNotStruct[Param]("NewActivity")
q.registerActivity(name, func(ctx context.Context, param Param) (Return, error) {
panic(fmt.Sprintf("Activity %s execution not mocked", name))
q.registerActivity(name, func(_ context.Context, _ Param) (Return, error) {
var zero Return
return zero, fmt.Errorf("Activity %s execution not mocked", name)
})
return Activity[Param, Return]{Name: name, queue: q}
}
Expand All @@ -54,6 +56,7 @@ func NewActivity[Param, Return any](q *Queue, name string) Activity[Param, Retur
// Instead of passing the Param struct directly to the activity, it passes each field of the struct
// as a separate positional argument in the order they are defined.
func NewActivityPositional[Param, Return any](q *Queue, name string) Activity[Param, Return] {
validateName(name)
panicIfNotStruct[Param]("NewActivityPositional")

// Get the type information for the Param struct
Expand All @@ -74,9 +77,11 @@ func NewActivityPositional[Param, Return any](q *Queue, name string) Activity[Pa
errorType := reflect.TypeOf((*error)(nil)).Elem()
fnType := reflect.FuncOf(paramTypes, []reflect.Type{returnType, errorType}, false)

// Create a function that panics with the message "Function execution not mocked"
// Create a function that returns an error instead of panicking
mockFn := reflect.MakeFunc(fnType, func(args []reflect.Value) []reflect.Value {
panic(fmt.Sprintf("Activity %s execution not mocked", name))
zero := reflect.New(returnType).Elem()
errVal := reflect.ValueOf(fmt.Errorf("Activity %s execution not mocked", name))
return []reflect.Value{zero, errVal}
})

// Register the mock function
Expand All @@ -85,6 +90,12 @@ func NewActivityPositional[Param, Return any](q *Queue, name string) Activity[Pa
return Activity[Param, Return]{Name: name, queue: q, positional: true}
}

// panicIfNotStruct enforces that the Param type parameter is a struct (or *struct).
//
// Why panic: Go's generics cannot express a "must be struct" constraint, so this
// is a runtime enforcement of a compile-time invariant. It fires during package
// init (these constructors are called in var declarations), never at request time.
// If Go adds structural type constraints, this function becomes unnecessary.
func panicIfNotStruct[Param any](funcName string) {
paramType := reflect.TypeOf((*Param)(nil)).Elem()
if paramType.Kind() == reflect.Ptr {
Expand All @@ -95,6 +106,12 @@ func panicIfNotStruct[Param any](funcName string) {
}
}

// extractFieldTypes returns the types of all exported fields in a struct type.
//
// Why panic: This is a defensive assertion in an internal helper. It is only
// reachable if the caller bypasses panicIfNotStruct, which is a programming
// error internal to this package. Returning an error here would propagate
// complexity to every call site for a condition that cannot occur in correct code.
func extractFieldTypes(structType reflect.Type) []reflect.Type {
if structType.Kind() == reflect.Ptr {
structType = structType.Elem()
Expand Down
215 changes: 215 additions & 0 deletions activity_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
package tempts_test

import (
"context"
"strings"
"testing"
"time"

"github.com/vikstrous/tempts"
"go.temporal.io/sdk/testsuite"
"go.temporal.io/sdk/workflow"
)

func TestNewActivity(t *testing.T) {
q := tempts.NewQueue("act-new-q")
type P struct{ V string }
type R struct{ V string }

act := tempts.NewActivity[P, R](q, "test-act")
if act.Name != "test-act" {
t.Fatalf("expected name 'test-act', got %q", act.Name)
}
}

func TestNewActivity_EmptyName(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic for empty activity name")
}
}()
q := tempts.NewQueue("act-empty-q")
type P struct{ V string }
type R struct{ V string }
tempts.NewActivity[P, R](q, "")
}

func TestNewActivity_NonStructParam(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic for non-struct param")
}
}()
q := tempts.NewQueue("act-nonstruct-q")
tempts.NewActivity[string, string](q, "non-struct-act")
}

func TestNewActivityPositional(t *testing.T) {
q := tempts.NewQueue("act-pos-q")
type P struct {
First string
Last string
}
type R struct{ Full string }

act := tempts.NewActivityPositional[P, R](q, "pos-act")
if act.Name != "pos-act" {
t.Fatalf("expected name 'pos-act', got %q", act.Name)
}
}

func TestActivity_WithImplementation(t *testing.T) {
q := tempts.NewQueue("act-impl-q")
type P struct{ V string }
type R struct{ V string }

act := tempts.NewActivity[P, R](q, "impl-act")
impl := act.WithImplementation(func(_ context.Context, p P) (R, error) {
return R{V: p.V}, nil
})
if impl == nil {
t.Fatal("expected non-nil ActivityWithImpl")
}
}

func TestActivity_WithImplementationPositional(t *testing.T) {
q := tempts.NewQueue("act-impl-pos-q")
type P struct {
First string
Last string
}
type R struct{ Full string }

act := tempts.NewActivityPositional[P, R](q, "impl-pos-act")
impl := act.WithImplementation(func(_ context.Context, p P) (R, error) {
return R{Full: p.First + " " + p.Last}, nil
})
if impl == nil {
t.Fatal("expected non-nil ActivityWithImpl")
}
}

func TestActivity_RunInWorkflow(t *testing.T) {
q := tempts.NewQueue("act-run-q")
type AP struct{ Name string }
type AR struct{ Name string }
type WP struct{ Name string }
type WR struct{ Result string }

act := tempts.NewActivity[AP, AR](q, "run-act")
wf := tempts.NewWorkflow[WP, WR](q, "run-wf")

wfImpl := wf.WithImplementation(func(ctx workflow.Context, p WP) (WR, error) {
ctx = workflow.WithActivityOptions(ctx, workflow.ActivityOptions{
StartToCloseTimeout: 5 * time.Second,
})
r, err := act.Run(ctx, AP{Name: p.Name})
return WR{Result: r.Name}, err
})

wrk, err := tempts.NewWorker(q, []tempts.Registerable{
act.WithImplementation(func(_ context.Context, p AP) (AR, error) {
return AR{Name: strings.ToUpper(p.Name)}, nil
}),
wfImpl,
})
if err != nil {
t.Fatal(err)
}

ts := testsuite.WorkflowTestSuite{}
ts.SetDisableRegistrationAliasing(true)
env := ts.NewTestWorkflowEnvironment()
t.Cleanup(func() { env.AssertExpectations(t) })
wrk.Register(env)

result, err := wfImpl.ExecuteInTest(env, WP{Name: "test"})
if err != nil {
t.Fatal(err)
}
if result.Result != "TEST" {
t.Fatalf("expected 'TEST', got %q", result.Result)
}
}

func TestActivity_RunPositionalInWorkflow(t *testing.T) {
q := tempts.NewQueue("act-run-pos-q")
type AP struct {
First string
Last string
}
type AR struct{ Full string }
type WP struct{ Name string }
type WR struct{ Result string }

act := tempts.NewActivityPositional[AP, AR](q, "run-pos-act")
wf := tempts.NewWorkflow[WP, WR](q, "run-pos-wf")

wfImpl := wf.WithImplementation(func(ctx workflow.Context, p WP) (WR, error) {
ctx = workflow.WithActivityOptions(ctx, workflow.ActivityOptions{
StartToCloseTimeout: 5 * time.Second,
})
r, err := act.Run(ctx, AP{First: "John", Last: "Doe"})
return WR{Result: r.Full}, err
})

wrk, err := tempts.NewWorker(q, []tempts.Registerable{
act.WithImplementation(func(_ context.Context, p AP) (AR, error) {
return AR{Full: p.First + " " + p.Last}, nil
}),
wfImpl,
})
if err != nil {
t.Fatal(err)
}

ts := testsuite.WorkflowTestSuite{}
ts.SetDisableRegistrationAliasing(true)
env := ts.NewTestWorkflowEnvironment()
t.Cleanup(func() { env.AssertExpectations(t) })
wrk.Register(env)

result, err := wfImpl.ExecuteInTest(env, WP{Name: "test"})
if err != nil {
t.Fatal(err)
}
if result.Result != "John Doe" {
t.Fatalf("expected 'John Doe', got %q", result.Result)
}
}

func TestActivity_MockStubReturnsError(t *testing.T) {
q := tempts.NewQueue("act-mock-stub-q")
type AP struct{ V string }
type AR struct{ V string }
type WP struct{ V string }
type WR struct{ V string }

act := tempts.NewActivity[AP, AR](q, "mock-stub-act")
wf := tempts.NewWorkflow[WP, WR](q, "mock-stub-wf")

// Workflow calls the unmocked activity
wfImpl := wf.WithImplementation(func(ctx workflow.Context, p WP) (WR, error) {
ctx = workflow.WithActivityOptions(ctx, workflow.ActivityOptions{
StartToCloseTimeout: 5 * time.Second,
})
r, err := act.Run(ctx, AP{V: p.V})
return WR{V: r.V}, err
})

ts := testsuite.WorkflowTestSuite{}
ts.SetDisableRegistrationAliasing(true)
env := ts.NewTestWorkflowEnvironment()

// Register mock fallbacks (not real implementations)
q.RegisterMockFallbacks(env)

// Execute the workflow — the activity stub should return an error, not panic
_, err := wfImpl.ExecuteInTest(env, WP{V: "test"})
if err == nil {
t.Fatal("expected error from unmocked activity")
}
if !strings.Contains(err.Error(), "not mocked") {
t.Fatalf("expected 'not mocked' in error, got: %v", err)
}
}
13 changes: 10 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package tempts

import "go.temporal.io/sdk/client"
import (
"fmt"

// Client is a wrapper for the temporal SDK client that keeps track of which namepace the client is connected to to return more useful errors if the wrong namespace is used.
"go.temporal.io/sdk/client"
)

// Client is a wrapper for the temporal SDK client that keeps track of which namespace the client is connected to to return more useful errors if the wrong namespace is used.
type Client struct {
namespace string
Client client.Client
Expand Down Expand Up @@ -39,8 +43,11 @@ func Dial(opts client.Options) (*Client, error) {
return &Client{Client: c, namespace: namespace}, nil
}

// NewFromSDK allows the caller to pass in an existing temporal SDK client and manually specify which name that client was connected to.
// NewFromSDK allows the caller to pass in an existing temporal SDK client and manually specify which namespace that client was connected to.
func NewFromSDK(c client.Client, namespace string) (*Client, error) {
if c == nil {
return nil, fmt.Errorf("client must not be nil")
}
if namespace == "" {
namespace = client.DefaultNamespace
}
Expand Down
40 changes: 40 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package tempts_test

import (
"testing"

"github.com/vikstrous/tempts"
"go.temporal.io/sdk/client"
)

// mockSDKClient satisfies client.Client without connecting to a real server.
type mockSDKClient struct {
client.Client
}

func TestNewFromSDK_NilClient(t *testing.T) {
_, err := tempts.NewFromSDK(nil, "default")
if err == nil {
t.Fatal("expected error for nil client")
}
}

func TestNewFromSDK_ValidClient(t *testing.T) {
c, err := tempts.NewFromSDK(mockSDKClient{}, "test-ns")
if err != nil {
t.Fatal(err)
}
if c == nil {
t.Fatal("expected non-nil client")
}
}

func TestNewFromSDK_EmptyNamespace(t *testing.T) {
c, err := tempts.NewFromSDK(mockSDKClient{}, "")
if err != nil {
t.Fatal(err)
}
if c == nil {
t.Fatal("expected non-nil client")
}
}
7 changes: 7 additions & 0 deletions example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ func workflowFormatAndGreet(ctx workflow.Context, params FormatAndGreetParams) (
newName := "unknown"
suffix := ""

ctx = workflow.WithActivityOptions(ctx, workflow.ActivityOptions{
StartToCloseTimeout: time.Second * 10,
})

workflowTypeFormatAndGreetGetName.SetHandler(ctx, func(struct{}) (FormatAndGreetGetNameResult, error) {
return FormatAndGreetGetNameResult{Name: newName + suffix}, nil
})
Expand Down Expand Up @@ -186,6 +190,9 @@ func workflowFormatAndGreet(ctx workflow.Context, params FormatAndGreetParams) (
}

func workflowJustGreet(ctx workflow.Context, params JustGreetParams) (JustGreetResult, error) {
ctx = workflow.WithActivityOptions(ctx, workflow.ActivityOptions{
StartToCloseTimeout: time.Second * 10,
})
name, err := activityTypeGreet.Run(ctx, GreetParams{Name: params.Name})
return JustGreetResult{Name: name.Name}, err
}
Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
module github.com/vikstrous/tempts

// TODO: Upgrade to go 1.25 for testing/synctest, sync.WaitGroup.Go(), t.Context(), maps.Keys(), reflect.TypeAssert
go 1.22.0

require (
github.com/gogo/protobuf v1.3.2
github.com/stretchr/testify v1.8.4
go.temporal.io/api v1.24.0
go.temporal.io/sdk v1.25.1
)
Expand All @@ -22,7 +24,6 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/robfig/cron v1.2.0 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/stretchr/testify v1.8.4 // indirect
go.uber.org/atomic v1.9.0 // indirect
golang.org/x/net v0.14.0 // indirect
golang.org/x/sys v0.11.0 // indirect
Expand Down
Loading