diff --git a/types/null.go b/types/null.go new file mode 100644 index 000000000..d339e0e17 --- /dev/null +++ b/types/null.go @@ -0,0 +1,53 @@ +package types + +import ( + "encoding/json" +) + +// Null represents a value of any type T that may be null. +type Null[T any] struct { + value T + valid bool +} + +// NullWithValue creates a new, valid Null[T]. +func NullWithValue[T any](value T) Null[T] { + return Null[T]{value: value, valid: true} +} + +// Value returns the value of the Type[T] if set, otherwise it returns the provided default value or the zero value of T. +func (n Null[T]) Value(defaultValue ...T) T { + if n.valid { + return n.value + } + if len(defaultValue) > 0 { + return defaultValue[0] + } + var zero T + return zero +} + +// SetValue sets the value of the Type[T]. +func (n *Null[T]) SetValue(t T) { + n.value = t + n.valid = true +} + +// MarshalJSON implements [json.Marshaler]. +func (n Null[T]) MarshalJSON() ([]byte, error) { + if n.valid { + return json.Marshal(n.value) + } + return []byte("null"), nil +} + +// UnmarshalJSON implements [json.Unmarshaler]. +func (n *Null[T]) UnmarshalJSON(data []byte) error { + if string(data) != "null" { + if err := json.Unmarshal(data, &n.value); err != nil { + return err + } + n.valid = true + } + return nil +} diff --git a/types/null_test.go b/types/null_test.go new file mode 100644 index 000000000..f99e722ba --- /dev/null +++ b/types/null_test.go @@ -0,0 +1,53 @@ +package types_test + +import ( + "encoding/json" + "testing" + + "github.com/ollama/ollama/types" +) + +func TestNull(t *testing.T) { + var s types.Null[string] + if val := s.Value(); val != "" { + t.Errorf("expected Value to return zero value '', got '%s'", val) + } + + if val := s.Value("default"); val != "default" { + t.Errorf("expected Value to return default value 'default', got '%s'", val) + } + + if bts, err := json.Marshal(s); err != nil { + t.Errorf("unexpected error during MarshalJSON: %v", err) + } else if want := "null"; string(bts) != want { + t.Errorf("expected marshaled JSON to be %s, got %s", want, string(bts)) + } + + s.SetValue("foo") + if val := s.Value(); val != "foo" { + t.Errorf("expected Value to return 'foo', got '%s'", val) + } + + s = types.NullValue("bar") + if val := s.Value(); val != "bar" { + t.Errorf("expected Value to return 'bar', got '%s'", val) + } + + if bts, err := json.Marshal(s); err != nil { + t.Errorf("unexpected error during MarshalJSON: %v", err) + } else if want := `"bar"`; string(bts) != want { + t.Errorf("expected marshaled JSON to be %s, got %s", want, string(bts)) + } + + if err := json.Unmarshal([]byte(`null`), &s); err != nil { + t.Errorf("unexpected error during UnmarshalJSON: %v", err) + } + + if err := json.Unmarshal([]byte(`"baz"`), &s); err != nil { + t.Errorf("unexpected error during UnmarshalJSON: %v", err) + } + + if err := json.Unmarshal([]byte(`1.2345`), &s); err == nil { + t.Error("expected error during UnmarshalJSON with invalid JSON, got nil") + } +}