Skip to content

Instantly share code, notes, and snippets.

@jmbarbone
Last active October 18, 2025 23:31
Show Gist options
  • Select an option

  • Save jmbarbone/c7a8059edaed74cddb608b88a6577088 to your computer and use it in GitHub Desktop.

Select an option

Save jmbarbone/c7a8059edaed74cddb608b88a6577088 to your computer and use it in GitHub Desktop.
typed functions in R
arg <- function(sym, type, default) {
  sym <- substitute(sym)
  spec <- structure(
    list(
      sym = as.character(sym),
      type = type
    ),
    class = "arg"
  )
  
  if (missing(default)) {
    spec$missing <- TRUE
  } else {
    spec$missing <- FALSE
    spec$default <- substitute(default)
  }
  spec
}

arg(a, S7::class_logical)
#> $sym
#> [1] "a"
#> 
#> $type
#> <S7_base_class>: <logical>
#> 
#> $missing
#> [1] TRUE
#> 
#> attr(,"class")
#> [1] "arg"
arg(a, S7::class_integer, b)
#> $sym
#> [1] "a"
#> 
#> $type
#> <S7_base_class>: <integer>
#> 
#> $missing
#> [1] FALSE
#> 
#> $default
#> b
#> 
#> attr(,"class")
#> [1] "arg"

typed <- function(...) {
  function_env <- new.env(parent = parent.frame())
  typed_env <- new.env(parent = function_env)
  typed_env$function_env <- function_env
  args <- as.list(substitute({...}))[-1L]
  
  forms <- list()
  for (i in seq_along(args)) {
    if (is.call(args[[i]]) && args[[i]][[1]] == "arg") {
      a <- ...elt(i)
      
      forms <- c(
        forms, 
        structure(
          list(if (a$missing) substitute() else a$default), 
          names = a$sym
        )
      )
      
      makeActiveBinding(
        sym = as.character(a$sym),
        env = function_env,
        fun = local({
          .spec <- a
          # TODO custom default value handling
          delayedAssign(".value", get("default", .spec))
          function(value) {
            if (missing(value)) {
              return(.value)
            } else {
              if (!inherits(value, class(.spec$type))) {
                stop(
                  "Argument '", 
                  as.character(.spec$sym), 
                  "' must be of type ", 
                  deparse1(.spec$type)
                )
              }
              .value <<- value
            }
          }
        })
      )
    } else if (is.call(args[[i]]) && args[[i]][[1]] == "expression") {
      typed_env$..body <- ...elt(i)
    } else {
      typed_env$..return <- ...elt(i)
    }
  }
  
  typed_function <- local(
    envir = typed_env,
    {
      .value <- NULL
      .evaluated <- FALSE
      
      makeActiveBinding(
        sym = "..result",
        env = environment(),
        fun = function(value) {
          if (missing(value)) {
            return(.value)
          }
          
          # currently some issues wiht S7 inherits(x, S7)
          if (!inherits(value, class(..return))) {
            print(value)
            stop("Return value must be of type ", deparse1(..return))
          }
          
          .evaluated <<- TRUE
          .value <<- value
        }
      )
      
      function(...) {
        e <- new.env(parent = parent.env(function_env))
        lapply(
          ls(function_env, all.names = TRUE),
          function(nm) {
            makeActiveBinding(
              nm,
              activeBindingFunction(nm, function_env),
              e
            )
          }
        )
        list2env(lapply(as.list(match.call())[-1L], eval), envir = e)
        ..result <<- eval(..body, e)
        ..result
      }
    }
  )
  
  formals(typed_function) <- as.pairlist(forms)
  class(typed_function) <- "typed_function"
  typed_function
}

@export

print.typed_function <- function(x, ...) {
  .x <- function() NULL
  formals(.x) <- formals(x)
  body(.x) <- environment(x)$..body
  print(.x)
  # cat("Return type: ", format(environment(x)$..return), "\n", sep = "")
  invisible(x)
}

# use arg syntax?
expression({
  typed(
    .A(a, class_logical, 1),
    .A(b, class_logical, 1),
    .A(c, class_numeric, 1),
    ..(if (a) 1 else 0 + if (b) 2 else 0 + c),
    class_numeric
  )
})
#> expression({
#>     typed(.A(a, class_logical, 1), .A(b, class_logical, 1), .A(c, 
#>         class_numeric, 1), ..(if (a) 
#>         1
#>     else 0 + if (b) 
#>         2
#>     else 0 + c), class_numeric)
#> })

# use fun syntax?
expression({
  typed(
    a %?% class_logical %=% 1,
    b %?% class_logical %=% 1,
    c %?% class_numeric %=% 1,
    ...(.E(if (a) 1 else 0 + if (b) 2 else 0 + c)),
    class_numeric
  )
})
#> expression({
#>     typed(a %?% class_logical %=% 1, b %?% class_logical %=% 
#>         1, c %?% class_numeric %=% 1, ...(.E(if (a) 
#>         1
#>     else 0 + if (b) 
#>         2
#>     else 0 + c)), class_numeric)
#> })

# debugonce(typed)
tf <- typed(
  arg(a, integer()),
  arg(b, integer()),
  arg(c, logical(), TRUE),
  expression(a + b + c),
  integer()
)

try(tf()) 
#> Error in get("default", .spec) : object 'default' not found
tf(0L, 0L, FALSE)
#> [1] 0
try(tf()) # FIXME this should fail
#> [1] 0
tf(0L, 1L, FALSE)
#> [1] 1
tf(0L, 1L, TRUE)
#> [1] 2
tf(0L, 1L, NA)
#> [1] NA
tf(1L, -2L, TRUE)
#> [1] 0
# debugonce(tf)
try(tf(2.0))
#> Error in (function (value)  : Argument 'a' must be of type integer(0)

Created on 2025-10-18 with reprex v2.1.1

library(S7)

typed <- function(f, ret) {
  force(f)
  f <- match.fun(f)
  expr <- body(f)
  forms <- formals(f)
  n <- length(forms)
  
  if (missing(ret)) {
    stopifnot(identical(forms[[n]], substitute()))
    ret <- getExportedValue("S7", names(forms)[n])
    forms <- forms[-n]
    n <- n - 1L
  }
  
  types <- rep(list(substitute()), n) -> params
  
  for (i in seq_len(n)) {
    value <- as.list(forms[[i]])
    if (length(value) == 1L) {
      params[[i]] <- substitute()
      types[[i]] <- getExportedValue("S7", as.character(value[[1L]]))
    } else {
      params[[i]] <- value[[-1L]]
      types[[i]] <- getExportedValue("S7", as.character(value[[1L]]))
    }
  }
  
  f <- function() NULL
  formals(f) <- `names<-`(params, names(forms))
  body(f) <- expr

  e <- parent.frame()
  returner(f, types, ret, e)
}

returner <- function(f, types, ret, .envir) {
  err <- function(...) {
    cond <- list(message = .makeMessage(...), call = sys.call(1L))
    class(cond) <- c("typedError", "error", "condition")
    stop(cond)
  }

  .types <- function(i) {
    types[[i]]
  }

  force(.envir)
  out <- function(...) NULL
  body(out) <- substitute(
    {
      ls <- list(...)
      nms <- names(ls)
      if (is.null(nms)) {
        names(ls) <- ..args..[seq_along(ls)]
      } else if (any(z <- !nzchar(z))) {
        names(ls)[z] <- ..args..[which(z)]
      }

      bad <- setdiff(names(ls), ..args..)
      if (length(bad)) {
        stop(
          "invalid params: ",
          toString(bad),
          call. = FALSE
        )
      }

      for (i in ...length()) {
        if (!is_class_ok(ls[[i]], .types(i))) {
          stop(
            err(
              sprintf(
                "'%s' must be of class '%s', not '%s'",
                names(ls)[i],
                .types(i)$class,
                toString(class(ls[[i]]))
              )
            )
          )
        }
      }
      
      res <- do.call(..f.., ls, envir = .envir)
      if (!is_class_ok(res, ..ret..)) {
        stop(
          err(
            sprintf(
              "return type must of a class '%s' not '%s'",
              ..ret..$class,
              toString(class(res))
            )
          )
        )
      }
      res
    },
    list(
      ..f.. = match.fun(f),
      ..args.. = names(formals(f)),
      ..ret.. = ret
    )
  )
  out
}

is_class_ok <- function(x, class) {
  if (identical(class, class_any)) {
    TRUE
  } else {
    inherits(x, class)
  }
}

foo <- function(
  a = class_integer,
  b = class_character("b"),
  c = class_logical(TRUE),
  class_character
) {
  sprintf("%i: %s (%s)", a, b, format(c))
}

(f <- typed(foo))
#> function (...) 
#> {
#>     ls <- list(...)
#>     nms <- names(ls)
#>     if (is.null(nms)) {
#>         names(ls) <- c("a", "b", "c")[seq_along(ls)]
#>     }
#>     else if (any(z <- !nzchar(z))) {
#>         names(ls)[z] <- c("a", "b", "c")[which(z)]
#>     }
#>     bad <- setdiff(names(ls), c("a", "b", "c"))
#>     if (length(bad)) {
#>         stop("invalid params: ", toString(bad), call. = FALSE)
#>     }
#>     for (i in ...length()) {
#>         if (!is_class_ok(ls[[i]], .types(i))) {
#>             stop(err(sprintf("'%s' must be of class '%s', not '%s'", 
#>                 names(ls)[i], .types(i)$class, toString(class(ls[[i]])))))
#>         }
#>     }
#>     res <- do.call(function (a, b = "b", c = TRUE) 
#>     {
#>         sprintf("%i: %s (%s)", a, b, format(c))
#>     }, ls, envir = .envir)
#>     if (!is_class_ok(res, list(class = "character", constructor_name = "character", 
#>         constructor = function (.data = character(0)) 
#>         .data, validator = function (object) 
#>         {
#>             if (base_class(object) != name) {
#>                 sprintf("Underlying data must be <%s> not <%s>", 
#>                   name, base_class(object))
#>             }
#>         }))) {
#>         stop(err(sprintf("return type must of a class '%s' not '%s'", 
#>             list(class = "character", constructor_name = "character", 
#>                 constructor = function (.data = character(0)) 
#>                 .data, validator = function (object) 
#>                 {
#>                   if (base_class(object) != name) {
#>                     sprintf("Underlying data must be <%s> not <%s>", 
#>                       name, base_class(object))
#>                   }
#>                 })$class, toString(class(res)))))
#>     }
#>     res
#> }
#> <environment: 0x594b57f1d488>
try(f(TRUE))
#> Error in base::tryCatch(base::withCallingHandlers({ : 
#>   'a' must be of class 'integer', not 'logical'
f(1L)
#> [1] "1: b (TRUE)"


# debugonce(typed)
try(typed(\(a = class_any, class_integer) a)(1))
#> Error in base::tryCatch(base::withCallingHandlers({ : 
#>   return type must of a class 'integer' not 'numeric'

Created on 2025-10-03 with reprex v2.1.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment