diff --git a/connection.go b/connection.go index c297d5bd..0d4af39e 100644 --- a/connection.go +++ b/connection.go @@ -323,6 +323,11 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver req.Parameters = parameters } + // Add enforce embedded schema correctness if enabled + if c.cfg.EnforceEmbeddedSchemaCorrectness { + req.EnforceEmbeddedSchemaCorrectness = &c.cfg.EnforceEmbeddedSchemaCorrectness + } + resp, err := c.client.ExecuteStatement(ctx, &req) var log *logger.DBSQLLogger log, ctx = client.LoggerAndContext(ctx, resp) diff --git a/connector.go b/connector.go index 1f77ac3f..7a4fe993 100644 --- a/connector.go +++ b/connector.go @@ -291,6 +291,15 @@ func WithEnableMetricViewMetadata(enable bool) ConnOption { } } +// WithEnforceEmbeddedSchemaCorrectness enables enforcement of embedded schema correctness +// in query execution. When set to true, the server will enforce embedded schema correctness. +// Default is false. +func WithEnforceEmbeddedSchemaCorrectness(enforce bool) ConnOption { + return func(c *config.Config) { + c.EnforceEmbeddedSchemaCorrectness = enforce + } +} + // Setup of Oauth M2m authentication func WithClientCredentials(clientID, clientSecret string) ConnOption { return func(c *config.Config) { diff --git a/internal/cli_service/cli_service.go b/internal/cli_service/cli_service.go index 71952c69..d923bea6 100644 --- a/internal/cli_service/cli_service.go +++ b/internal/cli_service/cli_service.go @@ -11776,6 +11776,7 @@ func (p *TSparkArrowTypes) Validate() error { // - Parameters // - MaxBytesPerBatch // - StatementConf +// - EnforceEmbeddedSchemaCorrectness type TExecuteStatementReq struct { SessionHandle *TSessionHandle `thrift:"sessionHandle,1,required" db:"sessionHandle" json:"sessionHandle"` Statement string `thrift:"statement,2,required" db:"statement" json:"statement"` @@ -11794,6 +11795,8 @@ type TExecuteStatementReq struct { MaxBytesPerBatch *int64 `thrift:"maxBytesPerBatch,1289" db:"maxBytesPerBatch" json:"maxBytesPerBatch,omitempty"` // unused fields # 1290 to 1295 StatementConf *TStatementConf `thrift:"statementConf,1296" db:"statementConf" json:"statementConf,omitempty"` + // unused fields # 1297 to 3352 + EnforceEmbeddedSchemaCorrectness *bool `thrift:"enforceEmbeddedSchemaCorrectness,3353" db:"enforceEmbeddedSchemaCorrectness" json:"enforceEmbeddedSchemaCorrectness,omitempty"` } func NewTExecuteStatementReq() *TExecuteStatementReq { @@ -11894,6 +11897,13 @@ func (p *TExecuteStatementReq) GetStatementConf() *TStatementConf { } return p.StatementConf } +var TExecuteStatementReq_EnforceEmbeddedSchemaCorrectness_DEFAULT bool +func (p *TExecuteStatementReq) GetEnforceEmbeddedSchemaCorrectness() bool { + if !p.IsSetEnforceEmbeddedSchemaCorrectness() { + return TExecuteStatementReq_EnforceEmbeddedSchemaCorrectness_DEFAULT + } +return *p.EnforceEmbeddedSchemaCorrectness +} func (p *TExecuteStatementReq) IsSetSessionHandle() bool { return p.SessionHandle != nil } @@ -11950,6 +11960,10 @@ func (p *TExecuteStatementReq) IsSetStatementConf() bool { return p.StatementConf != nil } +func (p *TExecuteStatementReq) IsSetEnforceEmbeddedSchemaCorrectness() bool { + return p.EnforceEmbeddedSchemaCorrectness != nil +} + func (p *TExecuteStatementReq) Read(ctx context.Context, iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) @@ -12117,6 +12131,16 @@ func (p *TExecuteStatementReq) Read(ctx context.Context, iprot thrift.TProtocol) return err } } + case 3353: + if fieldTypeId == thrift.BOOL { + if err := p.ReadField3353(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } default: if err := iprot.Skip(ctx, fieldTypeId); err != nil { return err @@ -12299,6 +12323,15 @@ func (p *TExecuteStatementReq) ReadField1296(ctx context.Context, iprot thrift. return nil } +func (p *TExecuteStatementReq) ReadField3353(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadBool(ctx); err != nil { + return thrift.PrependError("error reading field 3353: ", err) +} else { + p.EnforceEmbeddedSchemaCorrectness = &v +} + return nil +} + func (p *TExecuteStatementReq) Write(ctx context.Context, oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin(ctx, "TExecuteStatementReq"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } @@ -12318,6 +12351,7 @@ func (p *TExecuteStatementReq) Write(ctx context.Context, oprot thrift.TProtocol if err := p.writeField1288(ctx, oprot); err != nil { return err } if err := p.writeField1289(ctx, oprot); err != nil { return err } if err := p.writeField1296(ctx, oprot); err != nil { return err } + if err := p.writeField3353(ctx, oprot); err != nil { return err } } if err := oprot.WriteFieldStop(ctx); err != nil { return thrift.PrependError("write field stop error: ", err) } @@ -12525,6 +12559,18 @@ func (p *TExecuteStatementReq) writeField1296(ctx context.Context, oprot thrift. return err } +func (p *TExecuteStatementReq) writeField3353(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetEnforceEmbeddedSchemaCorrectness() { + if err := oprot.WriteFieldBegin(ctx, "enforceEmbeddedSchemaCorrectness", thrift.BOOL, 3353); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 3353:enforceEmbeddedSchemaCorrectness: ", p), err) } + if err := oprot.WriteBool(ctx, bool(*p.EnforceEmbeddedSchemaCorrectness)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.enforceEmbeddedSchemaCorrectness (3353) field write error: ", p), err) } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 3353:enforceEmbeddedSchemaCorrectness: ", p), err) } + } + return err +} + func (p *TExecuteStatementReq) Equals(other *TExecuteStatementReq) bool { if p == other { return true @@ -12584,6 +12630,12 @@ func (p *TExecuteStatementReq) Equals(other *TExecuteStatementReq) bool { if (*p.MaxBytesPerBatch) != (*other.MaxBytesPerBatch) { return false } } if !p.StatementConf.Equals(other.StatementConf) { return false } + if p.EnforceEmbeddedSchemaCorrectness != other.EnforceEmbeddedSchemaCorrectness { + if p.EnforceEmbeddedSchemaCorrectness == nil || other.EnforceEmbeddedSchemaCorrectness == nil { + return false + } + if (*p.EnforceEmbeddedSchemaCorrectness) != (*other.EnforceEmbeddedSchemaCorrectness) { return false } + } return true } diff --git a/internal/config/config.go b/internal/config/config.go index e13cb98f..1f008403 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -99,8 +99,9 @@ type UserConfig struct { RetryWaitMax time.Duration RetryMax int Transport http.RoundTripper - UseLz4Compression bool - EnableMetricViewMetadata bool + UseLz4Compression bool + EnableMetricViewMetadata bool + EnforceEmbeddedSchemaCorrectness bool CloudFetchConfig } @@ -282,6 +283,13 @@ func ParseDSN(dsn string) (UserConfig, error) { ucfg.EnableMetricViewMetadata = enableMetricViewMetadata } + if enforceEmbeddedSchemaCorrectness, ok, err := params.extractAsBool("enforceEmbeddedSchemaCorrectness"); ok { + if err != nil { + return UserConfig{}, err + } + ucfg.EnforceEmbeddedSchemaCorrectness = enforceEmbeddedSchemaCorrectness + } + // for timezone we do a case insensitive key match. // We use getNoCase because we want to leave timezone in the params so that it will also // be used as a session param.