forked from elastic/go-elasticsearch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnlprobertatokenizationconfig.go
142 lines (123 loc) · 3.86 KB
/
nlprobertatokenizationconfig.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// Code generated from the elasticsearch-specification DO NOT EDIT.
// https://github.com/elastic/elasticsearch-specification/tree/5fb8f1ce9c4605abcaa44aa0f17dbfc60497a757
package types
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"strconv"
"github.com/elastic/go-elasticsearch/v8/typedapi/types/enums/tokenizationtruncate"
)
// NlpRobertaTokenizationConfig type.
//
// https://github.com/elastic/elasticsearch-specification/blob/5fb8f1ce9c4605abcaa44aa0f17dbfc60497a757/specification/ml/_types/inference.ts#L160-L187
type NlpRobertaTokenizationConfig struct {
// AddPrefixSpace Should the tokenizer prefix input with a space character
AddPrefixSpace *bool `json:"add_prefix_space,omitempty"`
// MaxSequenceLength Maximum input sequence length for the model
MaxSequenceLength *int `json:"max_sequence_length,omitempty"`
// Span Tokenization spanning options. Special value of -1 indicates no spanning
// takes place
Span *int `json:"span,omitempty"`
// Truncate Should tokenization input be automatically truncated before sending to the
// model for inference
Truncate *tokenizationtruncate.TokenizationTruncate `json:"truncate,omitempty"`
// WithSpecialTokens Is tokenization completed with special tokens
WithSpecialTokens *bool `json:"with_special_tokens,omitempty"`
}
func (s *NlpRobertaTokenizationConfig) UnmarshalJSON(data []byte) error {
dec := json.NewDecoder(bytes.NewReader(data))
for {
t, err := dec.Token()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return err
}
switch t {
case "add_prefix_space":
var tmp interface{}
dec.Decode(&tmp)
switch v := tmp.(type) {
case string:
value, err := strconv.ParseBool(v)
if err != nil {
return fmt.Errorf("%s | %w", "AddPrefixSpace", err)
}
s.AddPrefixSpace = &value
case bool:
s.AddPrefixSpace = &v
}
case "max_sequence_length":
var tmp interface{}
dec.Decode(&tmp)
switch v := tmp.(type) {
case string:
value, err := strconv.Atoi(v)
if err != nil {
return fmt.Errorf("%s | %w", "MaxSequenceLength", err)
}
s.MaxSequenceLength = &value
case float64:
f := int(v)
s.MaxSequenceLength = &f
}
case "span":
var tmp interface{}
dec.Decode(&tmp)
switch v := tmp.(type) {
case string:
value, err := strconv.Atoi(v)
if err != nil {
return fmt.Errorf("%s | %w", "Span", err)
}
s.Span = &value
case float64:
f := int(v)
s.Span = &f
}
case "truncate":
if err := dec.Decode(&s.Truncate); err != nil {
return fmt.Errorf("%s | %w", "Truncate", err)
}
case "with_special_tokens":
var tmp interface{}
dec.Decode(&tmp)
switch v := tmp.(type) {
case string:
value, err := strconv.ParseBool(v)
if err != nil {
return fmt.Errorf("%s | %w", "WithSpecialTokens", err)
}
s.WithSpecialTokens = &value
case bool:
s.WithSpecialTokens = &v
}
}
}
return nil
}
// NewNlpRobertaTokenizationConfig returns a NlpRobertaTokenizationConfig.
func NewNlpRobertaTokenizationConfig() *NlpRobertaTokenizationConfig {
r := &NlpRobertaTokenizationConfig{}
return r
}