Skip to content

Instantly share code, notes, and snippets.

@julien-c
Last active October 15, 2020 20:46
Show Gist options
  • Save julien-c/857ba86a6c6a895ecd90e7f7cab48046 to your computer and use it in GitHub Desktop.
Save julien-c/857ba86a6c6a895ecd90e7f7cab48046 to your computer and use it in GitHub Desktop.
Model tag <=> pipeline type logic, for public reference
export class ModelInfo {
/**
* Key to config.json file.
*/
key: string;
etag: string;
lastModified: Date;
size: number;
modelId: ModelId;
author?: AuthorId;
siblings: IS3ObjectWRelativeFilename[];
config: Obj;
configTxt?: string; /// if flag is set when fetching.
downloads?: number; /// if flag is set when fetching.
naturalIdx: number;
cardSource?: Source;
cardData?: Obj;
constructor(o: Partial<ModelInfo>) {
return Object.assign(this, o);
}
get jsonUrl(): string {
return Bucket.R.models.urlForKey(this.key);
}
get cdnJsonUrl(): string {
return Bucket.R.models.cdnUrlForKey(this.key);
}
async validate(): Promise<Ajv.ErrorObject[] | undefined> {
const jsonSchema = JSON.parse(
await fs.promises.readFile(CONFIG_JSON_SCHEMA, 'utf8')
);
const ajv = new Ajv();
ajv.validate(jsonSchema, this.config);
return ajv.errors ?? undefined;
}
/**
* Readme key, w. and w/o S3 prefix.
*/
get readmeKey(): string {
return this.key.replace("config.json", "README.md");
}
get readmeTrimmedKey(): string {
return Utils.trimPrefix(this.readmeKey, S3_MODELS_PREFIX);
}
/**
* ["pytorch", "tf", ...]
*/
get mlFrameworks(): string[] {
return Object.keys(FileType).filter(k => {
const filename = FileType[k];
const isExtension = filename.startsWith(".");
return isExtension
? this.siblings.some(sibling => sibling.rfilename.endsWith(filename))
: this.siblings.some(sibling => sibling.rfilename === filename);
});
}
/**
* What to display in the code sample.
*/
get autoArchitecture(): string {
const useTF = this.mlFrameworks.includes("tf") && ! this.mlFrameworks.includes("pytorch");
const arch = this.autoArchType[0];
return useTF ? `TF${arch}` : arch;
}
get autoArchType(): [string, string | undefined] {
const architectures = this.config.architectures;
if (!architectures || architectures.length === 0) {
return ["AutoModel", undefined];
}
const architecture = architectures[0].toString() as string;
if (architecture.endsWith("ForQuestionAnswering")) {
return ["AutoModelForQuestionAnswering", "question-answering"];
}
else if (architecture.endsWith("ForTokenClassification")) {
return ["AutoModelForTokenClassification", "token-classification"];
}
else if (architecture.endsWith("ForSequenceClassification")) {
return ["AutoModelForSequenceClassification", "text-classification"];
}
else if (architecture.endsWith("ForMultipleChoice")) {
return ["AutoModelForMultipleChoice", "multiple-choice"];
}
else if (architecture.endsWith("ForPreTraining")) {
return ["AutoModelForPreTraining", "pretraining"];
}
else if (architecture.endsWith("ForMaskedLM")) {
return ["AutoModelForMaskedLM", "masked-lm"];
}
else if (architecture.endsWith("ForCausalLM")) {
return ["AutoModelForCausalLM", "causal-lm"];
}
else if (
architecture.endsWith("ForConditionalGeneration")
|| architecture.endsWith("MTModel")
|| architecture == "EncoderDecoderModel"
) {
return ["AutoModelForSeq2SeqLM", "seq2seq"];
}
else if (architecture.includes("LMHead")) {
return ["AutoModelWithLMHead", "lm-head"];
}
else if (architecture.endsWith("Model")) {
return ["AutoModel", undefined];
}
else {
return [architecture, undefined];
}
}
/**
* All tags
*/
get tags(): string[] {
const x = [
...this.mlFrameworks,
];
if (this.config.model_type) {
x.push(this.config.model_type);
}
const arch = this.autoArchType[1];
if (arch) {
x.push(arch);
}
if (arch === "lm-head" && this.config.model_type) {
if ([
"t5",
"bart",
"marian",
].includes(this.config.model_type)) {
x.push("seq2seq");
}
else if ([
"gpt2",
"ctrl",
"openai-gpt",
"xlnet",
"transfo-xl",
"reformer",
].includes(this.config.model_type)) {
x.push("causal-lm");
}
else {
x.push("masked-lm");
}
}
x.push(
...this.languages() ?? []
);
x.push(
...this.datasets().map(k => `dataset:${k}`)
);
for (let [k, v] of Object.entries(this.cardData ?? {})) {
if (!['tags', 'license'].includes(k)) {
/// ^^ whitelist of other accepted keys
continue;
}
if (typeof v === 'string') {
v = [ v ];
} else if (Utils.isStrArray(v)) {
/// ok
} else {
c.error(`Invalid ${k} tag type`, v);
c.debug(this.modelId);
continue;
}
if (k === 'license') {
x.push(...v.map(x => `license:${x.toLowerCase()}`));
} else {
x.push(...v);
}
}
if (this.config.task_specific_params) {
const keys = Object.keys(this.config.task_specific_params);
for (const key of keys) {
x.push(`pipeline:${key}`);
}
}
const explicit_ptag = this.cardData?.pipeline_tag;
if (explicit_ptag) {
if (typeof explicit_ptag === 'string') {
x.push(`pipeline_tag:${explicit_ptag}`);
} else {
x.push(`pipeline_tag:invalid`);
}
}
return [...new Set(x)];
}
get pipeline_tag(): (keyof typeof PipelineType) | undefined {
if (isBlacklisted(this.modelId) || this.cardData?.inference === false) {
return undefined;
}
const explicit_ptag = this.cardData?.pipeline_tag;
if (explicit_ptag) {
if (typeof explicit_ptag == 'string') {
return explicit_ptag as keyof typeof PipelineType;
} else {
c.error(`Invalid explicit pipeline_tag`, explicit_ptag);
return undefined;
}
}
const tags = this.tags;
/// Special case for translation
/// Get the first of the explicit tags that matches.
const EXPLICIT_PREFIX = "pipeline:";
const explicit_tag = tags.find(x => x.startsWith(EXPLICIT_PREFIX + `translation`));
if (!!explicit_tag) {
return "translation";
}
/// Otherwise, get the first (most specific) match **from the mapping**.
for (const ptag of ALL_PIPELINE_TYPES) {
if (tags.includes(ptag)) {
return ptag;
}
}
/// Extra mapping
const mapping = new Map<string, keyof typeof PipelineType>([
["seq2seq", "text-generation"],
["causal-lm", "text-generation"],
["masked-lm", "fill-mask"],
]);
for (const [tag, ptag] of mapping) {
if (tags.includes(tag)) {
return ptag;
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment